MultiHeadAttention with attention_mask
Understanding Attention_mask
In multi-head attention, the attention mechanism allows the model to focus on specific parts of the input sequence. The attention_mask is a crucial component that helps the model understand which parts of the sequence are valid and which are to be ignored.
- Padding: When input sequences have varying lengths, they are often padded with zeros to ensure a uniform shape. The attention mask prevents the model from attending to these padded elements.
- Causality: In autoregressive tasks like language modeling, the model should not attend to future tokens. The causal mask ensures this by masking out future elements.
Implementing Attention Mask in Keras
Let’s illustrate how to use the attention_mask in Keras with a simple example:
Code | Output |
---|---|
import tensorflow as tf from tensorflow import keras from tensorflow.keras import layers # Define input sequence length sequence_length = 10 # Create an input layer inputs = keras.Input(shape=(sequence_length,)) # Generate a random input sequence input_data = tf.random.uniform((1, sequence_length), maxval=10, dtype=tf.int32) # Create the attention mask attention_mask = tf.ones_like(input_data) # Create a MultiHeadAttention layer attention_layer = layers.MultiHeadAttention(num_heads=2, key_dim=8) # Apply attention attention_output, attention_weights = attention_layer( input_data, input_data, attention_mask=attention_mask ) print(f"Input data:\n{input_data.numpy()}") print(f"Attention weights:\n{attention_weights.numpy()}") |
Input data: [[4 3 6 7 1 6 1 8 7 7]] Attention weights: [[[[0.36596107 0.27032757 0.14118775 0.12767427 0.0379417 0.0305599 0.01137569 0.00932305 0.00465975 0.00099059] [0.24415219 0.33017747 0.18627228 0.11818279 0.04263779 0.03834654 0.01577126 0.01116366 0.00572356 0.00797244]] [[0.24415219 0.33017747 0.18627228 0.11818279 0.04263779 0.03834654 0.01577126 0.01116366 0.00572356 0.00797244] [0.36596107 0.27032757 0.14118775 0.12767427 0.0379417 0.0305599 0.01137569 0.00932305 0.00465975 0.00099059]]]] |
Explanation
- We define the input sequence length and create an input layer.
- A random input sequence is generated, and an attention mask is created with all ones (indicating that all elements are valid).
- We create a MultiHeadAttention layer with two heads and a key dimension of 8.
- The attention_layer is applied to the input sequence, providing both the attention output and the attention weights.
- The attention weights reflect how much each element in the input sequence contributes to the output.
Important Points
- The attention_mask should have the same shape as the input sequence.
- To mask out padding elements, set the corresponding elements in the attention_mask to zero.
- For causal masking, set the elements in the lower triangular part of the mask to zero.