멀티헤드 어텐션
여러 개의 "눈"으로 문장을 다각도로 바라보기
왜 멀티헤드가 필요할까?
하나의 관점으로는 부족하다
하나의 어텐션 헤드는 하나의 관점으로만 문장을 봅니다. 하지만 언어는 다양한 측면(문법, 의미, 위치 관계 등)을 동시에 가지고 있습니다. 멀티헤드 어텐션은 여러 관점을 동시에 학습하여 더 풍부한 표현을 만듭니다.
문장의 다양한 관계들
"The cat sat on the mat because it was tired." 문장에서:
인접한 단어들 간의 관계
"sat" → "on" (전치사와 동사)
주어-동사, 동사-목적어 관계
"sat" → "mat" (동사-목적어)
대명사가 가리키는 대상
문장 끝에서 시작까지 연결!
중요한 키워드에 집중
"mat" → 장소 특정
BERT와 GPT 모델을 분석한 연구에 따르면, 각 헤드는 실제로 다른 유형의 관계를 학습합니다:
- 위치 헤드 (Positional): 인접 토큰에 주목
- 문법 헤드 (Syntactic): 문법적 관계에 주목
- 희귀 단어 헤드: 빈도가 낮은 중요 단어에 주목
멀티헤드 어텐션의 동작 원리
분할하고 병합하기
멀티헤드 어텐션 아키텍처
차원 분할 시각화
멀티헤드 어텐션 수식
헤드별 어텐션 패턴
각 헤드가 학습하는 것들
각 헤드를 클릭하면 해당 헤드가 학습한 어텐션 패턴을 볼 수 있습니다:
Head 1: 위치 관계 패턴
인접한 단어들 사이의 관계에 집중합니다. 대부분의 어텐션이 바로 옆 단어에 분배됩니다.
실제 모델들의 멀티헤드 설정
모델별 헤드 수 비교
| 모델 | 헤드 수 | d_model | d_k (헤드당) | 레이어 수 |
|---|---|---|---|---|
| Transformer (원본) | 8 | 512 | 64 | 6 |
| BERT-base | 12 | 768 | 64 | 12 |
| BERT-large | 16 | 1024 | 64 | 24 |
| GPT-2 | 12 | 768 | 64 | 12 |
| GPT-3 (175B) | 96 | 12288 | 128 | 96 |
| Llama 2 (7B) | 32 | 4096 | 128 | 32 |
| Llama 2 (70B) | 64 | 8192 | 128 | 80 |
- 헤드당 차원(d_k)은 대체로 64~128로 유지
- 모델이 커지면 헤드 수를 늘려서 표현력 증가
- d_model = 헤드 수 × d_k 관계 유지
멀티헤드 어텐션의 장점
다양성과 효율성
다양한 표현 학습
각 헤드가 다른 측면(문법, 의미, 위치)을 학습하여 풍부한 표현 생성
병렬 처리
모든 헤드가 동시에 계산되어 효율적인 GPU 활용
강건성
하나의 패턴에 의존하지 않아 과적합 방지
유연한 설계
헤드 수 조절로 복잡도와 성능 균형 조절 가능
멀티헤드 어텐션은 CNN의 다중 필터와 유사한 역할을 합니다. CNN에서 여러 필터가 가장자리, 질감, 패턴 등 다양한 특징을 감지하듯이, 멀티헤드 어텐션의 각 헤드도 문법, 의미, 위치 등 다양한 관계를 포착합니다.
코드로 이해하기
PyTorch 구현
# PyTorch로 구현한 간소화된 Multi-Head Attention
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
def __init__(self, d_model=512, num_heads=8):
super().__init__()
self.num_heads = num_heads
self.d_k = d_model // num_heads # 헤드당 차원 (64)
# Q, K, V를 위한 Linear 변환
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
# 출력을 위한 Linear 변환
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x):
batch_size, seq_len, d_model = x.shape
# 1. Linear 변환
Q = self.W_q(x) # (batch, seq, d_model)
K = self.W_k(x)
V = self.W_v(x)
# 2. 여러 헤드로 분할
# (batch, seq, d_model) → (batch, num_heads, seq, d_k)
Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
# 3. Scaled Dot-Product Attention (각 헤드에서 병렬 계산)
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
attention_weights = F.softmax(scores, dim=-1)
context = torch.matmul(attention_weights, V)
# 4. 헤드 결합 (Concat)
# (batch, num_heads, seq, d_k) → (batch, seq, d_model)
context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
# 5. 출력 Linear 변환
output = self.W_o(context)
return output, attention_weights
# 사용 예시
mha = MultiHeadAttention(d_model=512, num_heads=8)
x = torch.randn(2, 10, 512) # batch=2, seq_len=10, d_model=512
output, weights = mha(x)
print(f"Output shape: {output.shape}") # (2, 10, 512)
print(f"Attention weights shape: {weights.shape}") # (2, 8, 10, 10)
- view + transpose: 텐서를 헤드별로 분할
- 병렬 계산: 모든 헤드가 한 번에 계산됨
- contiguous: 메모리 레이아웃 최적화
- W_o: 헤드들의 정보를 통합하는 학습 가능한 투영
핵심 요약
👁️ 멀티헤드의 필요성
- 하나의 헤드로는 부족
- 언어의 다양한 측면 학습
- 문법, 의미, 위치 등 다각도 분석
🔀 동작 원리
- 입력을 h개 헤드로 분할
- 각 헤드에서 독립적 어텐션
- 결과 Concat 후 Linear 변환
📊 실제 설계
- 헤드당 64~128 차원 유지
- 모델 크기 ↑ → 헤드 수 ↑
- d_model = h × d_k
다음 강의에서는 포지셔널 인코딩을 배웁니다. 트랜스포머는 순서 정보가 없는데 어떻게 단어의 위치를 알 수 있을까요? 사인/코사인 함수를 사용한 위치 인코딩의 원리를 알아봅니다.