多头注意力机制类的使用方法
MultiHeadAttention(num_heads=num_heads, key_dim=key_dim),
类代码如下:
class MultiHeadAttention(Layer):
def __init__(self, num_heads, key_dim, **kwargs):
super(MultiHeadAttention, self).__init__(**kwargs)
self.num_heads = num_heads
self.key_dim = key_dim
self.depth = self.key_dim // self.num_heads
def build(self, input_shape):
self.Wq = self.add_weight(
shape=(input_shape[-1], self.key_dim),
initializer=’glorot_uniform’,
trainable=True,
name=’Wq’
)
self.Wk = self.add_weight(
shape=(input_shape[-1], self.key_dim),
initializer=’glorot_uniform’,
trainable=True,
name=’Wk’
)
self.Wv = self.add_weight(
shape=(input_shape[-1], self.key_dim),
initializer=’glorot_uniform’,
trainable=True,
name=’Wv’
)
self.dense = Dense(input_shape[-1], use_bias=False)
super(MultiHeadAttention, self).build(input_shape)
def split_heads(self, x, batch_size):
x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
return tf.transpose(x, perm=[0, 2, 1, 3])
def scaled_dot_product_attention(self, q, k, v):
matmul_qk = tf.matmul(q, k, transpose_b=True)
dk = tf.cast(tf.shape(k)[-1], tf.float32)
scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)
output = tf.matmul(attention_weights, v)
return output, attention_weights
def call(self, inputs, **kwargs):
batch_size = tf.shape(inputs)[0]
q = tf.matmul(inputs, self.Wq)
k = tf.matmul(inputs, self.Wk)
v = tf.matmul(inputs, self.Wv)
q = self.split_heads(q, batch_size)
k = self.split_heads(k, batch_size)
v = self.split_heads(v, batch_size)
scaled_attention, attention_weights = self.scaled_dot_product_attention(q, k, v)
scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])
# `concat_attention` now has shape (batch_size, sequence_length, key_dim)
#concat_attention = tf.reshape(scaled_attention, (batch_size, -1, self.key_dim))
# `concat_attention` now has shape (batch_size, sequence_length, num_heads * depth)
concat_attention = tf.reshape(scaled_attention, (batch_size, -1, self.num_heads * self.depth))
# `output` now has shape (batch_size, sequence_length, input_shape[-1])
output = self.dense(concat_attention)
return output
# 在类中添加以下方法
def get_config(self):
config = super(MultiHeadAttention, self).get_config()
config.update({
‘num_heads’: self.num_heads,
‘key_dim’: self.key_dim,
})
return config