Understanding Self-Attention and Multi-Head Attention in Deep Learning

Naresh Nishad - Sep 28 - - Dev Community

Introduction

Self-attention and multi-head attention are fundamental concepts in modern deep learning, especially in Natural Language Processing (NLP) and Transformer-based models like BERT and GPT. These mechanisms enable models to focus on different parts of the input data effectively, improving their ability to handle complex tasks such as translation, summarization, and question-answering. In this article, we will explore self-attention and multi-head attention, their importance, and how they work.

1. What is Self-Attention?

Self-attention, also known as intra-attention, is a mechanism where different positions of a single sequence are related to each other to compute a representation for that sequence. In simpler terms, self-attention allows a model to focus on relevant parts of the input while processing a specific token, word, or element in the sequence.

Self-attention plays a crucial role in capturing dependencies between words that are far apart in a sentence. Instead of processing the sequence in order, self-attention enables the model to "attend" to the entire sequence at once and prioritize important elements for the task at hand.

The Self-Attention Formula

Self-attention computes a weighted sum of the values (V) where the weights are calculated by comparing the query (Q) with the corresponding keys (K). The formula for self-attention is:

Attention(Q, K, V) = softmax((Q * K^T) / sqrt(d_k)) * V

Where:

  • Q (Query): The token or word being processed.
  • K (Key): Other words in the sequence that the current word is compared against.
  • V (Value): The vector representation of the corresponding words.
  • d_k: Dimensionality of the keys, used for scaling to prevent large dot products.

2. How Self-Attention Works

To understand how self-attention operates, let’s break down the steps:

  1. Input Embedding: The input sequence (e.g., a sentence) is transformed into embeddings, which are dense vector representations of the words.

  2. Query, Key, Value Vectors: For each word in the sequence, we create three vectors: Query, Key, and Value. These vectors are derived from the input embeddings using learned weight matrices.

  3. Dot Product Attention: The query vector of the current word is compared with the key vectors of all words in the sequence to compute attention scores. The higher the score, the more relevant that word is to the current word.

  4. Softmax Normalization: The attention scores are normalized using the softmax function, resulting in weights between 0 and 1.

  5. Weighted Sum: The weighted sum of the value vectors is computed based on the attention weights, providing the output for that particular word.

This process is repeated for every word in the sequence, allowing the model to generate context-aware representations.

3. Multi-Head Attention

Motivation for Multi-Head Attention

While self-attention is powerful, relying on a single set of attention scores can limit the model's capacity to focus on different aspects of the input. This is where multi-head attention comes in. Instead of computing a single attention function, multi-head attention runs multiple attention operations (or heads) in parallel, allowing the model to attend to different parts of the input sequence simultaneously.

Multi-Head Attention Overview

Multi-head attention splits the query, key, and value vectors into multiple smaller vectors and performs self-attention independently on each of them. The results from these heads are then concatenated and linearly transformed to produce the final output.

MultiHead(Q, K, V) = Concat(head_1, head_2, ..., head_h) * W_O

Where each attention head is computed as:

head_i = Attention(Q * W_i^Q, K * W_i^K, V * W_i^V)

  • W_i^Q, W_i^K, W_i^V: Learned weight matrices for each head.
  • W^O: Output weight matrix.

The key idea is that each attention head can focus on different parts of the sequence, allowing the model to capture various relationships between words.

4. Why Multi-Head Attention Matters

The use of multiple heads allows the model to focus on different semantic aspects of the input. For example, one head might capture syntactic dependencies like subject-verb agreement, while another head might focus on the relationships between named entities in the text. This diversity in attention helps the model better understand the input and improves performance on downstream tasks.

5. Code Example: Multi-Head Attention in PyTorch

Here is a simple implementation of multi-head attention in PyTorch:

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.embed_dim = embed_dim
        self.head_dim = embed_dim // num_heads

        assert (
            self.head_dim * num_heads == self.embed_dim
        ), "Embedding dimension must be divisible by the number of heads"

        self.qkv_linear = nn.Linear(embed_dim, embed_dim * 3)
        self.fc_out = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        batch_size, seq_length, embed_dim = x.shape

        # Linear transformation for query, key, value
        qkv = self.qkv_linear(x)  # (batch_size, seq_length, 3 * embed_dim)
        qkv = qkv.reshape(batch_size, seq_length, self.num_heads, 3 * self.head_dim)
        q, k, v = torch.chunk(qkv, 3, dim=-1)

        # Scaled dot-product attention
        attn_scores = torch.einsum("bqhd,bkhd->bhqk", q, k)
        attn_scores = attn_scores / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
        attn_weights = F.softmax(attn_scores, dim=-1)

        # Weighted sum of values
        out = torch.einsum("bhqk,bkhd->bqhd", attn_weights, v)
        out = out.reshape(batch_size, seq_length, self.embed_dim)

        # Final linear layer
        out = self.fc_out(out)
        return out
Enter fullscreen mode Exit fullscreen mode

This code defines a MultiHeadAttention class where the input is split into multiple heads, self-attention is performed on each head, and the results are combined to generate the final output.

6. Applications of Self-Attention and Multi-Head Attention

Self-attention and multi-head attention are used in a wide variety of NLP and computer vision applications:

  • Machine Translation: These mechanisms allow models to focus on relevant parts of the input sentence while generating a translation.
  • Text Summarization: Self-attention helps identify the most important sentences or phrases in a document.
  • Question Answering: The model can "attend" to the part of the context that contains the answer to the question.
  • Vision Transformers (ViT): Multi-head attention is used to model relationships between different patches of an image.

7. Conclusion

Self-attention and multi-head attention are at the heart of modern deep learning models, especially transformers. They enable models to focus on the most relevant parts of the input data, improving performance across a wide range of tasks. Understanding these mechanisms is key to working with state-of-the-art NLP models and computer vision architectures.

. . . . . . . . . . . . . . . . . . . . . . . . . . . . . .