PART 4 · 강의 2/4

멀티헤드 어텐션

여러 개의 "눈"으로 문장을 다각도로 바라보기

01

왜 멀티헤드가 필요할까?

하나의 관점으로는 부족하다

💡 핵심 통찰

하나의 어텐션 헤드는 하나의 관점으로만 문장을 봅니다. 하지만 언어는 다양한 측면(문법, 의미, 위치 관계 등)을 동시에 가지고 있습니다. 멀티헤드 어텐션은 여러 관점을 동시에 학습하여 더 풍부한 표현을 만듭니다.

문장의 다양한 관계들

"The cat sat on the mat because it was tired." 문장에서:

📍
위치 관계

인접한 단어들 간의 관계

"The" → "cat" (바로 다음 단어)
"sat" → "on" (전치사와 동사)
🔗
문법 관계

주어-동사, 동사-목적어 관계

"cat" → "sat" (주어-동사)
"sat" → "mat" (동사-목적어)
🎯
대명사 참조

대명사가 가리키는 대상

"it" → "cat" (대명사 해결)
문장 끝에서 시작까지 연결!
희귀 단어

중요한 키워드에 집중

"tired" → 핵심 의미
"mat" → 장소 특정
🔬 연구 결과

BERT와 GPT 모델을 분석한 연구에 따르면, 각 헤드는 실제로 다른 유형의 관계를 학습합니다:

  • 위치 헤드 (Positional): 인접 토큰에 주목
  • 문법 헤드 (Syntactic): 문법적 관계에 주목
  • 희귀 단어 헤드: 빈도가 낮은 중요 단어에 주목
02

멀티헤드 어텐션의 동작 원리

분할하고 병합하기

멀티헤드 어텐션 아키텍처

입력 (d_model = 512)
Linear 변환 후 8개 헤드로 분할
Head 1 (64d)
Head 2 (64d)
Head 3 (64d)
Head 4 (64d)
Head 5 (64d)
Head 6 (64d)
Head 7 (64d)
Head 8 (64d)
각 헤드에서 독립적으로 Attention 계산
Concat (8 × 64 = 512)
Linear → 출력 (512)

차원 분할 시각화

입력 차원
d_model = 512
512
÷ 8 헤드
헤드별 차원
64
64
64
64
64
64
64
64
64 × 8

멀티헤드 어텐션 수식

MultiHead(Q, K, V) = Concat(head₁, ..., headₕ) · WO
where headᵢ = Attention(Q·WiQ, K·WiK, V·WiV)
h: 헤드 수 WO: 출력 투영 행렬 Wi: 각 헤드의 가중치
03

헤드별 어텐션 패턴

각 헤드가 학습하는 것들

각 헤드를 클릭하면 해당 헤드가 학습한 어텐션 패턴을 볼 수 있습니다:

📍
Head 1
위치 관계
🔗
Head 2
주어-동사
🎯
Head 3
대명사 참조
📝
Head 4
구문 구조
Head 5
핵심 단어
🔄
Head 6
자기 참조
➡️
Head 7
다음 단어
⬅️
Head 8
이전 단어

Head 1: 위치 관계 패턴

인접한 단어들 사이의 관계에 집중합니다. 대부분의 어텐션이 바로 옆 단어에 분배됩니다.

The cat sat on the mat
어텐션 패턴 설명
Head 1은 각 단어가 바로 다음 단어에 높은 어텐션을 부여합니다. "The" → "cat", "cat" → "sat" 처럼 순차적 관계를 학습합니다.
04

실제 모델들의 멀티헤드 설정

모델별 헤드 수 비교

모델 헤드 수 d_model d_k (헤드당) 레이어 수
Transformer (원본) 8 512 64 6
BERT-base 12 768 64 12
BERT-large 16 1024 64 24
GPT-2 12 768 64 12
GPT-3 (175B) 96 12288 128 96
Llama 2 (7B) 32 4096 128 32
Llama 2 (70B) 64 8192 128 80
📌 패턴 분석
  • 헤드당 차원(d_k)은 대체로 64~128로 유지
  • 모델이 커지면 헤드 수를 늘려서 표현력 증가
  • d_model = 헤드 수 × d_k 관계 유지
05

멀티헤드 어텐션의 장점

다양성과 효율성

🎨

다양한 표현 학습

각 헤드가 다른 측면(문법, 의미, 위치)을 학습하여 풍부한 표현 생성

병렬 처리

모든 헤드가 동시에 계산되어 효율적인 GPU 활용

🛡️

강건성

하나의 패턴에 의존하지 않아 과적합 방지

🔧

유연한 설계

헤드 수 조절로 복잡도와 성능 균형 조절 가능

🔬 CNN과의 유사성

멀티헤드 어텐션은 CNN의 다중 필터와 유사한 역할을 합니다. CNN에서 여러 필터가 가장자리, 질감, 패턴 등 다양한 특징을 감지하듯이, 멀티헤드 어텐션의 각 헤드도 문법, 의미, 위치 등 다양한 관계를 포착합니다.

06

코드로 이해하기

PyTorch 구현

# PyTorch로 구현한 간소화된 Multi-Head Attention
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model=512, num_heads=8):
        super().__init__()
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # 헤드당 차원 (64)

        # Q, K, V를 위한 Linear 변환
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)

        # 출력을 위한 Linear 변환
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, x):
        batch_size, seq_len, d_model = x.shape

        # 1. Linear 변환
        Q = self.W_q(x)  # (batch, seq, d_model)
        K = self.W_k(x)
        V = self.W_v(x)

        # 2. 여러 헤드로 분할
        # (batch, seq, d_model) → (batch, num_heads, seq, d_k)
        Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)

        # 3. Scaled Dot-Product Attention (각 헤드에서 병렬 계산)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
        attention_weights = F.softmax(scores, dim=-1)
        context = torch.matmul(attention_weights, V)

        # 4. 헤드 결합 (Concat)
        # (batch, num_heads, seq, d_k) → (batch, seq, d_model)
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)

        # 5. 출력 Linear 변환
        output = self.W_o(context)

        return output, attention_weights

# 사용 예시
mha = MultiHeadAttention(d_model=512, num_heads=8)
x = torch.randn(2, 10, 512)  # batch=2, seq_len=10, d_model=512
output, weights = mha(x)
print(f"Output shape: {output.shape}")  # (2, 10, 512)
print(f"Attention weights shape: {weights.shape}")  # (2, 8, 10, 10)
📝 코드 핵심 포인트
  • view + transpose: 텐서를 헤드별로 분할
  • 병렬 계산: 모든 헤드가 한 번에 계산됨
  • contiguous: 메모리 레이아웃 최적화
  • W_o: 헤드들의 정보를 통합하는 학습 가능한 투영
SUMMARY

핵심 요약

👁️ 멀티헤드의 필요성

  • 하나의 헤드로는 부족
  • 언어의 다양한 측면 학습
  • 문법, 의미, 위치 등 다각도 분석

🔀 동작 원리

  • 입력을 h개 헤드로 분할
  • 각 헤드에서 독립적 어텐션
  • 결과 Concat 후 Linear 변환

📊 실제 설계

  • 헤드당 64~128 차원 유지
  • 모델 크기 ↑ → 헤드 수 ↑
  • d_model = h × d_k
🎓 다음 강의 예고

다음 강의에서는 포지셔널 인코딩을 배웁니다. 트랜스포머는 순서 정보가 없는데 어떻게 단어의 위치를 알 수 있을까요? 사인/코사인 함수를 사용한 위치 인코딩의 원리를 알아봅니다.

REF

참고 자료