class MultiHeadAttention(nn.Module):
""" Implements Multi-Head Self-Attention proposed by Vaswani et al., 2017.
refer https://arxiv.org/abs/1706.03762
"""
def __init__(self, d_model, n_heads, dropout_rate=0.1):
super(MultiHeadAttention, self).__init__()
assert d_model % n_heads == 0, "`d_model` should be a multiple of `n_heads`"
self.d_model = d_model
self.n_heads = n_heads
self.d_k = self.d_v = d_model // n_heads # head_dim
self.dropout_rate = dropout_rate
self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_model, d_model, bias=False)
self.W_v = nn.Linear(d_model, d_model, bias=False)
self.W_o = nn.Linear(d_model, d_model)
self.attention = ScaledDotProductAttention(np.sqrt(self.d_k), dropout_rate)
def split_heads(self, x):
""" x: (batch_size, seq_len, d_model)
"""
batch_size = x.size(0)
x = x.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2) # (batch_size, n_heads, seq_len, d_k)
# x: (batch_size, n_heads, seq_len, head_dim)
return x
def group_heads(self, x):
""" x: (batch_size, n_heads, seq_len, head_dim)
"""
batch_size = x.size(0)
x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.d_k)
# x: (batch_size, seq_len, d_model)
return x
def forward(self, query, key, value, mask=None):
""" query: (batch_size, query_len, d_model)
key: (batch_size, key_len, d_model)
value: (batch_size, value_len, d_model)
mask: (batch_size, 1, source_seq_len) for source mask
(batch_size, target_seq_len, target_seq_len) for target mask
"""
# apply linear projections to query, key and value
Q = self.split_heads(self.W_q(query)) # (batch_size, n_heads, query_len, head_dim)
K = self.split_heads(self.W_k(key)) # (batch_size, n_heads, key_len, head_dim)
V = self.split_heads(self.W_v(value)) # (batch_size, n_heads, value_len, head_dim)
if mask is not None:
# apply same mask for all the heads
mask = mask.unsqueeze(1)
# mask: (batch_size, 1, 1, source_seq_len) for source mask
# (batch_size, 1, target_seq_len, target_seq_len) for target mask
# calculate attention weights and context vector for each of the heads
x, attn = self.attention(Q, K, V, mask)
# x: (batch_size, n_heads, query_len, head_dim)
# attn: (batch_size, n_heads, query_len, value_len)
# concatenate context vector of all the heads
x = self.group_heads(x) # (batch_size, query_len, d_model)
# apply linear projection to concatenated context vector
x = self.W_o(x) # (batch_size, query_len, d_model)
# x: (batch_size, query_len, d_model)
# attn: (batch_size, n_heads, query_len, value_len)
return x, attn