Pytorch 로 Attention 구현하기

2023. 7. 12. 19:46

 

Attention 을 Pytorch 를 통해 구현해보겠습니다.

Transformer 는 크게 Encoder 와 Decoder 로 나뉘고, 각각 내부에는 Attention 및 Feedforward Network 구조로 이루어져 있습니다. 그리고 가장 핵심은 Attention 구조입니다. Scaled Dot Product Attention 을 구현하는 것이 핵심입니다.

 


1. 필요 패키지 import

Pytorch 로 구현을 하기 때문에 다음과 같이 패키지를 import 합니다.

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

2. Scaled Dot-Product Attention (SDPA)

 

$$\text{Attention}(Q,K,V) = \text{softmax} \left( \frac{QK^T}{\sqrt{d_K}} \right)V \in \mathbb{R}^{n \times d_V} $$

 

유사도 계산을 하여 Attention value 를 계산하는 식인 위 식을 다음과 같이 구현합니다.

class ScaleDotProductAttention(nn.Module):
    
    def forward(self, Q, K, V, mask=None):
        d_K = K.shape[-1] # key dimesion
        # score 계산 : (Q @ K) / d_K
        scores = Q @ K.swapaxes(-1, -2) / d_K ** 0.5

        if mask is not None:
            scores = scores.masked_fill(mask==0, -1e9)
        
        weight = F.softmax(scores, dim=-1)
        out = weight @ V
        return out, weight

3. Attention

 

Scaled Dot-Product Attention Block 을 이용해 Attention Block 을 만들어보겠습니다.

 

지금 만들 Attention Block 으로 Self-Attention 과 Cross-Attention 을 만들때 활용할 수 있게 구현을 해보겠습니다.

context_dim = None 이면 Self Attention, context_dim 이 존재하면 Cross Attention 으로 작동하게 코드를 작성했습니다.

class Attention(nn.Module):
    def __init__(self, 
                 query_dim: int, 
                 context_dim: int = None, 
                 inner_dim: int = 128,
                 use_bias: bool = True,
                 dropout: float = 0.):
        super().__init__()
        self.context = context_dim

        if not context_dim:
            context_dim = query_dim
        
        self.to_q = nn.Linear(query_dim, inner_dim, bias=use_bias)
        self.to_k = nn.Linear(context_dim, inner_dim, bias=use_bias)
        self.to_v = nn.Linear(context_dim, inner_dim, bias=use_bias)
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, query_dim),
            nn.Dropout(dropout)
        )

        self.attention = ScaleDotProductAttention()
    
    def forward(self, x, context=None):
        q = self.to_q(x)
        if not self.context:
            context = x
        k = self.to_k(context)
        v = self.to_v(context)

        out = self.attention(q, k, v)
        out = self.to_out(x)
        return out

4. Mulit-Head Attention

멀티헤드인 경우를 구현합니다. inner dimmension 을 head_number 와 dimension of head 차원으로 나누고 attention 연산을 한 후 다시 합치는 과정이 추가됩니다.

 

inner dim 을 num_head 와 dim_head 로 나누는 것은 pytorch Tensor 의 차원을 조작할 수 있는 메소드(함수) 인 view, permute 를 사용하면 됩니다. 조작 도중에 텐서 차원과 메모리 차원이 일치하지 않기 떄문에 contiguous 를 사용합니다.

 

혹은 einops 패키지의 Rearrange 함수를 사용하여 차원을 쉽게 조작할 수도 있습니다.

여기서는 view, permute 를 사용하여 Tensor의 차원을 조작하겠습니다. map 함수를 사용하면 query, key, value 를 한번에 처리할 수 있습니다.

class MultiHeadAttention(nn.Module):
    def __init__(self, 
                 query_dim: int, 
                 context_dim: int = None, 
                 num_heads: int = 8,
                 dim_head: int = 64,
                 use_bias: bool = True,
                 dropout: float = 0.):
        super().__init__()
        
        if dim_head is None:
            dim_head = query_dim // num_heads

        inner_dim = dim_head * num_heads
        self.heads = num_heads
        self.dim_head = dim_head
        self.context = context_dim

        if not context_dim:
            context_dim = query_dim
        
        self.to_q = nn.Linear(query_dim, inner_dim, bias=use_bias)
        self.to_k = nn.Linear(context_dim, inner_dim, bias=use_bias)
        self.to_v = nn.Linear(context_dim, inner_dim, bias=use_bias)
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, query_dim),
            nn.Dropout(dropout)
        )

        self.attention = ScaleDotProductAttention()

    def forward(self, x, context=None):
        b = x.size()[0]
        h = self.heads
        d = self.dim_head
        
        q = self.to_q(x)
        if not self.context:
            context = x
        k = self.to_k(context)
        v = self.to_v(context)

        q, k, v = map(lambda t: t.view(b, -1, h, d).permute(0, 2, 1, 3).contiguous().view(b*h, -1, d), (q, k, v))

        out = self.attention(q, k, v)
        out = out.view(b, h, -1, d).permute(0, 2, 1, 3).contiguous().view(b, -1, h*d)
        out = self.to_out(x)       
        return out

 

einops 패키지를 사용한다면 좀 더 간단하게 코드를 짤 수 있습니다.

from einops import rearrange

class Attention(nn.Module):
    ...

    def forward(self, x, context=None)
        ...
        
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
        
        ...
        
        out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
        
        ...

'DeepLeaning > 구현' 카테고리의 다른 글

[논문구현]Pytorch로 ConvNeXt 구현  (0) 2023.08.10
[논문구현]Pytorch로 ResNet 구현  (0) 2023.07.27

BELATED ARTICLES

more