多头注意力机制Python实现


多头注意力机制类的使用方法

MultiHeadAttention(num_heads=num_heads, key_dim=key_dim),

类代码如下:

class MultiHeadAttention(Layer):

    def __init__(self, num_heads, key_dim, **kwargs):

        super(MultiHeadAttention, self).__init__(**kwargs)

        self.num_heads = num_heads

        self.key_dim = key_dim

        self.depth = self.key_dim // self.num_heads

    def build(self, input_shape):

        self.Wq = self.add_weight(

            shape=(input_shape[-1], self.key_dim),

            initializer=’glorot_uniform’,

            trainable=True,

            name=’Wq’

        )

        self.Wk = self.add_weight(

            shape=(input_shape[-1], self.key_dim),

            initializer=’glorot_uniform’,

            trainable=True,

            name=’Wk’

        )

        self.Wv = self.add_weight(

            shape=(input_shape[-1], self.key_dim),

            initializer=’glorot_uniform’,

            trainable=True,

            name=’Wv’

        )

        self.dense = Dense(input_shape[-1], use_bias=False)

        super(MultiHeadAttention, self).build(input_shape)

    def split_heads(self, x, batch_size):

        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))

        return tf.transpose(x, perm=[0, 2, 1, 3])

    def scaled_dot_product_attention(self, q, k, v):

        matmul_qk = tf.matmul(q, k, transpose_b=True)

        dk = tf.cast(tf.shape(k)[-1], tf.float32)

        scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)

        attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)

        output = tf.matmul(attention_weights, v)

        return output, attention_weights

    def call(self, inputs, **kwargs):

        batch_size = tf.shape(inputs)[0]

        q = tf.matmul(inputs, self.Wq)

        k = tf.matmul(inputs, self.Wk)

        v = tf.matmul(inputs, self.Wv)

        q = self.split_heads(q, batch_size)

        k = self.split_heads(k, batch_size)

        v = self.split_heads(v, batch_size)

        scaled_attention, attention_weights = self.scaled_dot_product_attention(q, k, v)

        scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])

        # `concat_attention` now has shape (batch_size, sequence_length, key_dim)

        #concat_attention = tf.reshape(scaled_attention, (batch_size, -1, self.key_dim))

        # `concat_attention` now has shape (batch_size, sequence_length, num_heads * depth)

        concat_attention = tf.reshape(scaled_attention, (batch_size, -1, self.num_heads * self.depth))

        # `output` now has shape (batch_size, sequence_length, input_shape[-1])

        output = self.dense(concat_attention)

        return output

    # 在类中添加以下方法

    def get_config(self):

        config = super(MultiHeadAttention, self).get_config()

        config.update({

            ‘num_heads’: self.num_heads,

            ‘key_dim’: self.key_dim,

        })

        return config