brunch

Differenctial Transformer

논문 톺아보기 14

by Qscar

| INTRO |

오늘 톺아볼 논문은 바로 2024년 10월에 발표된 Differential Transformer[1]입니다. 본 논문은 기존 Transformer 구조의 모델이 attention을 이용한 contextual learning을 하는 과정에서 정답과 무관하거나, 불필요한 토큰에도 높은 주의력을 할당하는 주의력 분산 문제(본 논문에선 attention noise라고 정의합니다)를 해결하기 위한 방법으로 Differential Transformer를 제시합니다.


구체적으로 우리가 사용하는 노이즈 캔슬링 헤드셋 등에서 사용되는 차등 소음 제거 방식에서 영감을 받아, attention weight를 구현하는 과정에서 두 개의 attention weight를 구성하고 이를 차등함으로써 정답과 관련성이 높은 토큰에 주의력을 집중시킴으로써 효율적인 모델 구현이 가능해지게 됩니다.


하지만 엄밀히 말해 본 논문의 내용은 우리가 이전에 리뷰했던 Transformer 혹은 Bert 정도의 구조에 단순히 Diff Attention만을 적용하는 것은 아니고, 최근에 나온 LLaMA[2]의 구조에 Diff Attention을 적용한 형태입니다. 때문에 본 포스팅에선 Differential Transformer을 구현하기 위해 Transformer의 Decoder를 독립적으로 사용한 형태에 가까운 GPT2 구조에서부터 LLaMA를 구현하고, 추가적으로 Differential Transformer를 구현한 뒤 각 모델들의 성능을 확인해보도록 하겠습니다.


| MODELING |

INTRO에서 소개했다시피 Diff Transformer는 Transformer 구조가 가진 주의력 분산(Attention Noise)를 해결하기 위한 것입니다. 이는 본 논문[1]에서 제시된 아래 그림으로 요약할 수 있습니다.

Transformer vs Diff Transformer (from [1])

위 그림은 document에 정답이 포함된 문제를 해결하는 과제를 수행할 때, 각 방식에 따른 attention score가 어떻게 할당되는지 모든 head들의 attention score를 평균낸 것입니다. 위 그림에서 볼 수 있듯, 일반적인 transformer는 정답에 0.03의 attention score를 할당하면서 동시에 과반에 이르는 attention score를 무의미한 토큰들에 할당하고 있는 것을 확인할 수 있습니다. 토큰별 스코어 비교에서 미약한 차이로 정답에 해당하는 토큰을 맞추는 식으로 진행되던 것이 기존 우리가 사용하던 transformer 방식이라는 것입니다.


반면 위 그림의 오른쪽에 해당하는 Differential Transformer의 경우, 무의미한 맥락(토큰들)에 아주 적은 attention score만을 할당함으로써 무의미한 집중력 분산을 방지하고 있습니다. 이러한 특징으로 인해 단일 쿼리로부터 여러 개의 항목을 동시에 검색하는 Multi-Needle Retrieval과 같은 task에선 기존 transformer 기반 모델이 55%의 정확도를 달성한 것에 비해, Differential Transformer는 85%의 정확도를 달성하고 있는 것을 확인할 수 있습니다.


또한 이러한 복잡한 문제를 푸는 것 외에도 데이터 자체를 효율적으로 활용하고, 불필요한 주의력을 최소화하기 때문에 65% 수준의 파라미터 혹은 더 적은 데이터셋만으로도 유사한 성능을 보인다고 합니다. 이를 살펴보기 위해선 우선 기존 Transformer의 수도 코드를 먼저 살펴본 뒤, Diff Transformer를 확인해보도록 하겠습니다.



|STEP 01| Pseudo Code

먼저 기존에 우리가 알고 있던 Transformer의 수도 코드입니다. 이 부분의 코드는 여기에서 확인할 수 있습니다.

transformer 수도 코드 (from [1])

Transformer에서는 Query, Key를 행렬곱한 뒤, root(dimension) 값으로 scaling 및 Softmax를 수행한 뒤, Value와의 행렬곱을 통해 attention score를 계산했었습니다. 이후 multi-head attention 및 feed forward network 등과 결합되며 비로소 하나의 transformer block이 완성되는 식이었습니다. 그렇다면 Diff Transformer는 무엇이 바뀌었을까요?


Diff Transformer의 수도 코드입니다.

Diff Transformer 수도 코드 (from [1])

Diff Transformer에서는 Query와 Key를 두 개의 attention map으로 분리하고, 첫 번째 값에 조정된(learnable parameter λ에 의해) 두 번째 값을 뺀 값을 Value와 행렬 곱함으로써 차등 주의 방법(Differential Transformer)를 구현합니다.


이를 실제 코드로 구현하면 다음과 같습니다. 먼저 일반적인 attention(이하 'classic attention'이라 임의 명명하겠습니다)입니다.

Standard Attention Code

다음으로는 diff attention입니다.

download.png Diff Attention Code

위 코드에 등장하는 하이퍼 파라미터에 대해 간략이 설명하면, b는 batch size, n은 토큰의 길이(이미지의 경우 1차원으로 flatten한 패치의 수), d는 모델의 dim을 의미합니다. [B, N, D₁] 차원으로 동일한 Query와 Key 간의 행렬곱을 통해 [B, N, N] 형태의 attention weight를 도출하고, 이후 [B, N, D₁] 차원의 Value와 행렬곱을 통해 다시 [B, N, D₁]의 형태가 되는 식입니다.


Diff Transformer에서는 [B, N, 2D₂]의 Query, Key, Value를 이용합니다. Query와 Key를 두 개로 나눠(Split) [B, N, D₂] 형상의 Q₁, Q₂ 및 K₁, K₂로 만들고, Q₁과 K₁, Q₂와 K₂로 만들어 두 개의 Attention Score를 만들어냅니다. 이를 각각 A₁, A₂라고 하며 첫 번째 score에서 두 번째 score를 뺌으로써 차등적 주의력을 구현하는 방식입니다.


이러한 구현 과정에서 learnable parameter인 λ를 이용하고, 이후 Transformer Block을 구현하는 과정에서 LLaMA[2]에서 제시한 pre-RMSNorm과 SwiGLU, 그리고 RoPE 등을 채택하기도 합니다. 이러한 내용은 이후의 내용에서 보다 상세하게 다루도록 하겠습니다.


|STEP 02| Transformer vs LLaMA

download.png
download.png
Transformer vs LLaMA (from [6])

위 구조도를 통해 알 수 있듯, Transformer에서 LLaMA로 현대화되는 과정에서 적용된 몇 가지 개선사항들이 있습니다. 우선 기존에는 embedding된 입력 x에 대해 position encoding을 적용했던 것에 비해 LLaMA에서는 Query, Key에 대해서만 RoPE라는 위치 인코딩 방식을 적용합니다. 이를 적용하는 것만으로도 모델의 성능이 좋아진다는 내용의 Roformer 모델 및 논문[3]에 대해서도 간단히 다루도록 하겠습니다.


다음으로 Attention 이후 진행되던 정규화(normalize) 과정이 attention 이전으로 옮겨져 pre-normalization이 되는 식으로 정규화가 적용되는 순서가 바뀌었고, 적용되는 정규화 방식도 LayerNorm에서 RMSNorm으로 바뀌었습니다. (이러한 변경을 통해 특히 입력 토큰의 수가 커질수록, 모델의 규모가 커질수록 늘어나는 연산량을 약간 줄일 수 있다는 소소한 이점을 챙길 수 있다고 합니다.)


Attention 방법은 일반적인 내적 연산에서 KV cache를 이용한 Grouped Attention으로 변경되었고, Feed Forward Network 내의 활성층도 SwiGLU를 이용한 방식으로 변경되었습니다.


우선 이렇게 LLaMA로 변경된 차이에 대해서 직접 구현한 후, 여기에 Diff Attention을 적용함으로써 본 포스팅의 모델링을 마치는 식으로 진행하도록 하겠습니다.


|STEP 03| 6 Differences

기본적으로 Diff Transformer의 전체적인 구조(layout)은 기존 Transformer의 Decoder 구조만을 따온 GPT-2와 6가지를 제외하곤 동일합니다. 6가지 중 3가지 다른 점은 아래의 본 논문의 색칠된 부분을 통해 확인할 수 있습니다.

Diff Transformer layout (from [1])

첫 번째로 다른 점은 우리가 위의 수도 코드를 통해 살펴본 바와 같이 attention noise cancling을 위한 부분으로, 본 논문[1]에선 'differential attention을 conventional softmax attention으로 대체한다'라고 하고 있습니다. 참고로 여기서 conventional softmax attention이란 우리가 기존 transformer에서 구현했던 Query와 Key의 행렬곱 연산 후 d차원의 루트값으로 나눠줌으로써 scaling하고, 다시 Value와의 합성곱을 통해 attention score를 반환하는 과정을 지칭합니다.


두 번째와 세 번째로 다른 점은 'pre-RMSNorm''SwiGLU'를 적용한다는 점인데요, 이는 LLaMA 논문에 서술된 현대적 모델링 기법을 그대로 적용한 방식입니다. 또한 위 내용에선 언급하고 있지 않지만, 일반 Transformer와의 또 다른 네 번째 다른 점은 토큰의 위치에 대한 정보를 주는 방식에 있습니다. 기존 Transformer 논문에서는 ablation study를 통해 위치 정보를 주지 않는 경우와 position encoding을 적용하는 경우, 그리고 position embedding을 적용하는 경우로 나눠 성능 측정을 함으로써 최종적으로 position encoding을 채택하는 식으로 이뤄졌습니다.


하지만 본 논문에서는 RoPE(Rotary Position Encoding) 기법[3]을 적용하며, 이는 점차 긴 토큰을 효율적으로 다루기 위한 현대적 position embedding 기법 중 하나입니다. 이들에 대해 하나씩 구현하며 진행해보도록 하겠습니다.


| Difference #1 - PE(Position encoding/embedding) |

RoPE에 대해 다루기 전에 간단히 Position Encoding/Embedding의 변천사에 대해 알아보면 좋습니다. 본래 이전의 Transformer 논문에서 제시한 방식은 주기함수를 이용해 절대적 위치를 인코딩해 attention score에 더해주는 방식이었습니다. 이 부분의 코드는 여기에서 확인할 수 있습니다.


1. Absolute Position Encoding

아주 간단하게 모델의 dimension이 1일 때로 살펴보면, "I am a robot"은 다음과 같이 token으로 인코딩할 수 있습니다.

absolute position encoding

이전에 transformer를 리뷰할 때에는 absolute position encoding 대신 embedding 방식을 사용했지만 적용 원리는 위와 같습니다. 입력 시퀀스 x에 포지셔널 인코딩(혹은 임베딩) 값을 더해준 후 멀티헤드어텐션 레이어를 통과시키는 식이었습니다. 또한 그 과정에서 모델의 dimension이 늘어날수록 행렬곱 연산 과정에서 일부 토큰들에게 지나치게 적은 내적 연산 결과가 도출되고, 이러한 값들에 대해 PE 값이 오히려 더 큰 영향을 끼치는 문제 등을 방지하기 위해 모델 차원의 루트값으로 나눠주는 scaling을 처리하기도 했었습니다.


이에 대한 수식을 정리하면 아래와 같습니다. 우선 일반적인 attention의 개념입니다.

download.png
download.png
일반적인 attention 과정 (from [3])

위 왼쪽 수식에서 q, k, v를 정의하는데 사용되는 fucntion은 각 입력 토큰에 위치정보를 추가해주는 함수라고 이해하면 쉽습니다. 즉 이미 위치정보가 포함된 input x를 가지고 있으며, 이를 위 오른쪽 그림과 같이 scaled dot attention 및 softmax를 적용해 최종적으로 Om을 출력하게 되는 식입니다.


이때 absolute position encoding을 적용할 경우 수식은 다음과 같이 나타낼 수 있습니다.

absolute position encoding equation (from [3])

위 수식을 간단히 정리하면, 굳이 번거롭게 q, k, v에 각각 인코딩을 적용하지 말고 q, k, v의 공통적인 입력에 들어가는 입력 시퀀스 x에 pe 값을 더해서 주자는 식으로 해석할 수 있습니다.


2. Relative Position Embedding

각 토큰의 절대적 위치를 주는 방법 다음으로 제시된 것은 상대적 위치 정보를 기준으로 모델을 학습시키는 방법으로 대표적인 모델로 T5가 존재합니다. 이 접근법은 query와 key의 내적 연산식을 전개하는 것으로 부터 시작합니다. 우선 query @ key 연산을 전개하면 다음과 같습니다.

query key 내적 연산 (from [3])

이렇게 전개된 식에서 상대적 위치 임베딩을 적용하기 위해 두 가지의 변환이 적용됩니다. 우선 첫 번째로 두 번째와 네 번째 수식의 pn값을 상대적 위치 인코딩 값으로 대체해줍니다.

그리고 세 번째와 네 번째 식에 존재하는 pm 값을 독립적으로 학습가능한 u와 v로 대체함으로써 상대적 위치 정보를 줌과 동시에 embedding 값으로 신경망 자체의 보정을 적용하는 식입니다. 이때 Pm-n은 m 위치의 토큰 xm과 n 위치의 토큰 xn 간의 상대적 위치 차이를 의미하는 것으로, 이를 통해 너무 먼 위치에 있는 토큰의 영향력을 최소화하기 위함입니다.


물론 단순히 이러한 방법뿐 아니라 위 수식을 간소화한 아래 수식도 있으며,

간소화된 (6) 수식, 위의 b는 학습가능한 bias

별도의 학습가능한 u, v 대신 절대적 위치 인코딩값을 주거나, 애초에 이런 것도 필요없이 상대적 위치정보만으로 위치 임베딩을 주는 방식도 제시하고 있습니다. 하지만 이러한 상대적 위치 정보만을 이용하는 방식은 각 토큰 간의 상대적 위치 정보를 연산해야 하는 연산 복잡도가 추가되는 문제로 인해, 보다 긴 문장의 문맥을 파악하기 위해 도입된 방법임에도 불구하고 긴 문장일수록 정확도와 같은 성능 측면에서는 나아질지 몰라도 모델의 속도 측면에서는 불리해지는 문제를 야기했습니다. 이로 인해 상대적 위치 정보가 가진 장점을 유지하면서 동시에 단점을 보완하는 방법이 필요해지게 되었는데, 그것이 바로 RoPE입니다.


3. RoPE(Rotary Position Embedding)

이전까지의 내용을 다시 정리해봅시다. Transformer가 제안된 이후 그 구성요소들도 다양한 실험과 이론을 기반으로 현대화되어 가며 성능을 높여갔습니다. 절대적 위치 인코딩이 일정 길이 이상의 시퀀스를 처리하는데 효과가 떨어지거나, 아예 없기도 하며, 종종 오히려 문제를 일으키는 경우가 발생했습니다. 이에 대한 대안으로 초기에는 신경망에 위치정보를 맡기는 단순 임베딩 방식이 사용되기도 했지만, 이는 토큰 간 위치에 대한 명확한 정보를 주기 어려웠습니다.


이로 인해 다음으로 제시된 것이 토큰 간의 상대적 위치를 처리하는 방법이었고, 이는 모델의 성능 자체를 높였으나 별도로 상대적 거리를 모델링해야 해 attention 계산에 추가적인 복잡성을 야기했습니다. 이를 그림으로 확인하면 다음과 같습니다.

위치 임베딩 방식에 따른 학습/추론 효율 (from [4])

위 논문은 위치 임베딩을 Linear Bias로 제시함으로써 효율을 증대시키는 방식에 대한 논문에서 제시된 그림입니다. 위 그림에서 확인할 수 있듯이 절대적 위치와 선형 편향을 적용하는 방식이 속도와 메모리 측면을 모두 고려했을 때 가장 효율적입니다. 그리고 가장 최악인 것은 상대적 위치를 적용하는 것인데요, 속도 외의 성능적인 측면 때문에 상대적 위치 임베딩이 적용됐음을 고려해도 지나치게 낮은 성능이며, 상대적 위치 임베딩이 제시된 이유가 보다 더 긴 시퀀스를 다루기 위함임을 고려하면 본래의 목적성을 상실한 수준이라 봐도 무방합니다.


때문에 이러한 문제를 해결하고자 제시된 것이 바로 RoPE(Rotary Position Embedding)입니다. RoPE는 1)회전을 통해 위치 정보를 인코딩함으로써 절대적 위치 정보와 상대적 관계를 동시에 반영할 수 있으며, 2)회전 변환이라는 특성이 유사도 측정을 위한 내적(dot-product) 연산 과정에서 자연스럽게 위치 정보를 학습하기 용이하고, 3)상대적 위치 정보를 일일히 계산해야 했던 그 전에 비해 각 토큰의 위치에 대해 단일 회전 변환만 적용하면 되기 때문에 연산이 효율적이라는 장점이 있습니다.


그렇다면 RoPE는 구체적으로 어떻게 구현될 수 있을까요? RoPE가 적용되는 query, key 내적 연산을 다음과 같이 표현할 수 있습니다.

download.png RoPE가 적용된 query@key (from [3])

위 수식에서 위치 임베딩을 의미하는 수식 R은 다음과 같은 회전 행렬로 정의됩니다.

download.png RoPE matrix (from [3])

그리고 위 행렬을 도식화해 아래와 같이 직관적으로 나타낼 수 있습니다.

download.png RoPE equation (from [3])

이를 좀 더 직관적으로 나타내면 원본 벡터 x, 위치 m, 위치에 따른 회전 각도를 (θm, k)라고 할 때, ROPE 수식을 다음과 같이 삼각함수를 통해 나타낼 수 있습니다.

download.png ROPE를 이용한 x의 위치 인코딩 수식

하지만 이러한 수식을 보면 대충 무슨 뜻인지는 알 거 같은데, 구체적으로 뭘 의미하는 지, 어떻게 이런 수식이 도출되었는지 이해하기는 쉽지 않습니다. 때문에 RoPE가 고려된 query, key 내적 연산을 실제로 구현해보고, 그 결과를 먼저 살펴보겠습니다. 우선 RoPE가 적용된 내적 연산 수식은 다음과 같습니다.

query, key 내적 연산 과정에서 적용되는 위치 임베딩이 고려된 수식 (from [3])

위 수식에서 왼쪽의 cos과 sin으로 이뤄진 부분이 m번째 토큰에 대한 위치를 단순하게 인코딩하는 함수입니다. 해당 인코딩 수식에서 m은 토큰의 위치(몇 번째 토큰인지)이고, θ는 RoPE를 구현하기 위해 입력되는 시퀀스의 최대 길이와 모델의 dimension 등을 고려한 값입니다. 우선 이러한 인코딩 수식을 구현해 어떤 식으로 토큰의 순서가 인코딩되는지 확인해보겠습니다. 참고로 여기서 θ는 임의로 15도로 구현해 가시화해보겠습니다. 코드입니다.

encoding function

이러한 함수를 통해 단순히 입력된 시퀀스 내 특정 토큰의 순서를 단순히 0, 1, 2, ..., n으로 표현하는 것이 아닌 다음과 같이 2차원 벡터로 표현할 수 있습니다.

1dc6cc96-a724-4c75-bc3f-8515b14424f1.png

위 인코딩 함수에 입력 시퀀스의 임베딩 벡터를 고려하게 되면 다음과 같이 rope 인코딩의 형태를 확인할 수 있습니다. 이러한 형태를 통해 각 토큰 간의 관계를 두 위치 임베딩 값의 사이각으로 확인할 수 있으며, 이러한 벡터들 간의 내적 연산을 통해 자연스럽게 상대적, 절대적 위치 정보를 이해할 수 있도록 유도하는 형태인 것입니다.


0a38a9e6-97e0-4c6a-9900-be03905c9874.png

위와 같은 위치 인코딩이 고려된 임베딩 벡터는 모델의 차원 수 만큼이 됩니다. 보다 구체적인 단어와 예시를 통해 살펴보겠습니다. 각 토큰의 임베딩을 고려한 위치 인코딩을 적용한 후, 그 중 하나의 dimension을 살펴보면 다음과 같이 각 단어 간의 상대적 위치와 절대적 위치를 파악할 수 있습니다.

download.png RoPE의 원리 (from [5])

위 그림의 두 문장에서 pig와 dog의 절대적 위치는 다르기에, 오른쪽의 그림에서 표시된 회전의 정도는 다르지만, 두 단어의 거리는 3으로 동일하기에 두 단어 간의 사이각(세타, θ)는 동일함을 알 수 있습니다. 위에서 우리가 임의로 정한 단어 순서 간의 각도인 15도로 생각하면 위 예시에선 pig와 dog의 거리가 3이기 때문에 모델 dimension에 따라 임베딩된 값이 달라지더라도 두 단어 간의 차이는 45도로 일정할 것입니다. 즉, 각 단어들의 위치를 두 개의 벡터를 기반으로 회전시키는 일정한 규칙(간격, θ)이 있다면 이를 통해 각 단어들의 절대적 위치(vector)와 상대적 위치(θ의 차이)를 동시에 표현할 수 있음을 의미합니다.


여기까지의 내용을 통해 우리는 도대체 무슨 원리인지는 모르겠지만 저 삼각함수를 이용한 위치 임베딩 방식이 각 토큰의 임베딩 벡터를 절대 위치 정보와 상대 위치 정보를 포함한 벡터로 변환할 수 있음을 확인하였고, 제시된 RoPE의 수식에 의해 임베딩 벡터가 회전한다는 것을 확인할 수 있었습니다. 그렇다면 어떻게 이런 일이 가능했는지 내부적으로 살펴볼 차례입니다.


위 방법을 수식으로 다시 한 번 살펴볼까요? 우선 query, key의 내적 연산과 그 과정에서의 포지셔널 인코딩이 고려되는 내용을 포함한 수식은 다음과 같습니다.

RoPE equation (from [3])

위 수식의 첫 번째와 두 번째, 그러니까 query와 key에 대한 임베딩 및 위치 임베딩을 포함하는 함수는 e^imθ와 같이 표현됩니다. 이는 2차원 벡터 [x₁ x₂]를 [x₁+ix₂]와 같은 복소수로 표현한 후, 위치 p에 대해 회전시키면 허수 i의 특성인 i²=-1의 특징을 이용해 실수부와 허수부로 분리시킬 수 있으며, 이는 2x2 회전 행렬과 동일한 역할을 하게 됩니다. 즉, 우리가 원하는 회전 행렬을 자연스럽게 적용한 수식입니다. 이를 보다 자세히 알아봅시다.


간단하게 model의 dimension = 1인 경우로 살펴보겠습니다. 첫 번째 토큰은 모델의 차원에 따라 다음과 같이 embedding 및 reshape될 수 있습니다.

download.png ① 토큰의 embedding과 reshape 과정

이렇게 reshape된 벡터는 복소수를 이용해 다음과 같이 표현될 수 있습니다.

download.png ② 복소수를 이용한 표현

이렇게 임베딩된 벡터를 복소수를 이용한 표현으로 변환했다면, 여기에 위치 인코딩 함수를 적용합니다. 이때 각 주기 함수를 fn으로 표현할 수 있습니다.

download.png ③ 위치 인코딩 적용

이렇게 위치 인코딩 값이 적용된 임베딩 값을 실수부와 허수부로 분리하면 다음과 같습니다.

④ 실수부와 허수부로 분리

이 수식을 flatten해 (1, 4) 형태의 행렬로 변환할 수 있습니다.

download.png ⑤ flatten

이제 거의 끝났습니다. 위 수식에서 두 번째와 네 번째 수식을 단순히 순서만 바꿔주겠습니다.

download.png ⑥ rewrite

이렇게 도출된 행렬을 분배법칙에 근거해 분리해주고, fn으로 표현했던 위치 인코딩 함수를 다시 본래의 값으로 복원시켜 작성하면 다음과 같습니다.

download.png ⑦ split

그리고 이렇게 정리한 결과는 위에서 살펴본 수식과 일치하게 됩니다.

RoPE euqation (from [3])

즉, 여기까지의 과정을 통해 토큰의 위치 정보를 2차원 벡터로 표현하고, 각 단어 간의 거리를 각도로 측정할 수 있게 됨으로써 각 토큰의 절대적 위치 정보와 상대적 위치 정보를 모두 부여하는게 가능해졌습니다. 이제 RoPE에 대한 코드를 구현해보겠습니다.


우선 위치 인코딩을 하는 함수(fn)입니다. 이를 구현하는 과정에서 RoPE 논문에서는 위치 인코딩 함수에 적용되는 사이각인 θ를 다음과 같이 사전 정의하고 있습니다.

download.png θ에 대한 정의(from [3])

이를 고려해 코드를 작성하면 다음과 같습니다.

position encoding function for RoPE

위 코드에서 torch.polar 함수를 통해 편리하게 복소수와 주기함수를 이용한 위치 변환을 적용할 수 있습니다. 이를 통해 위치 인코딩된 값을 살펴보면 다음과 같습니다.

24개의 토큰에 대한 위치 인코딩

직관적으로 확인하기 위해 이를 그래프로 표현하면 다음과 같습니다.

download.png

아직은 임베딩된 토큰값이 곱해지지 않은 상태임을 확인할 수 있습니다. 이후엔 이 함수를 이용해 입력된 임베딩 벡터를 위치 임베딩이 적용된 값으로 변환하는 함수를 작성해보겠습니다. 이때에는 실제 Transformer 모델 구현 구조에 따라 multi head attention을 고려해야 합니다.

MSA이 고려된 RoPE 코드

이렇게 RoPE의 필요성부터 수식, 실제 코드에 이르기까지의 과정을 살펴보았습니다.


| Difference #2 - RMSNorm |

PE 다음으로 살펴볼 것은 바로 새로운 정규화 방식인 RMSNorm입니다. 이 부분의 코드는 여기에서 확인할 수 있습니다. 이는 긴 sequence를 한 번에 처리할 때 기존의 LayerNorm이 가진 연산비용을 약간 줄여주기 위해 적용됐습니다. 무슨 의미냐하면, LayerNorm의 수식은 다음과 같습니다.

LayerNorm의 수식

LayerNorm을 구현하기 위해서는 입력 벡터의 배치 전체에 대한 평균과 분산을 계산해야 합니다. 따라서 입력 벡터 x의 배치 크기나 차원이 커질수록 연산 속도가 느려지는 단점이 있습니다. 이러한 복잡도를 줄이기 위해 제안된 방법이 바로 RMSNorm(Root Mean Squared Norm)입니다. RMSNorm은 이름에서 나타나듯 입력 벡터의 각 원소를 제곱한 후 평균을 내고, 그 값에 제곱근을 취해 정규화를 수행하는 일종의 스케일링입니다. 따라서 입력 벡터 x의 크기가 커지더라도 평균이나 분산을 별도로 계산할 필요가 없어 연산이 더욱 효율적입니다. 이에 대한 수식은 다음과 같습니다.

RMSNorm 수식

물론 이러한 RMSNorm에도 단점이 있습니다. 바로 전체적인 평균과 분산을 통해 특정 배치로 구성된 데이터 내부에 존재할 일부 노이즈 데이터로 인해 편향이 발생할 수 있다는 것입니다. 이러한 이유로 이전에는 LayerNorm이 사용되었지만, LLaMA[2] 논문에서는 이러한 영향이 대규모 데이터셋과 모델을 통해 필터링되며 결과적으로 미미한 영향을 끼치거나 영향이 없는 것으로 관측된다고 하였으며, 이를 위해 FFN의 잔차합을 더하기가 아닌 곱하기로 변환해 적용합니다. 이에 대해서는 아래의 SwiGLU를 통한 FFN을 다룰 때 다시 확인하도록 하겠습니다.


(참고로 위 수식에서 γ는 학습가능한 파라미터이고, β는 학습가능한 편향(bias), ε은 0으로 나눠지는 경우를 방지하기 위해 추가해주는 아주 작은 수입니다.)


이러한 RMSNorm을 코드로 구현하면 다음과 같습니다.

RMS Norm 코드


| Difference #3 - KV Cache |

LLM을 비롯한 대부분의 transformer 모델은 자기회귀적인 형태로 작동합니다. 이게 무슨 의미냐 하면, 만약 '나는 사과를 좋아해'라는 원문을 'I like an apple'로 번역해야 한다면, 최근의 단일 블록으로 구성된 구조에서는 위 두 문장을 한 번에 입력으로 넣어 학습시키게 됩니다. 다만 입력 토큰과 출력 토큰을 구분하기 위한 구분자로 ' => '와 같은 문자를 넣을 수 있죠. 때문에 모델의 입력으로 들어가는 문자열은 다음과 같게 됩니다.


'나는 사과를 좋아해 => I like an apple'


그리고 모델은 위 입력 텍스트를 토큰화하고, 추가적으로 특별한 토큰을 추가해 받아들여 이해하는 과정을 거치며, 소위 자기회귀적인 방법을 통해 위 토큰을 이해하게 됩니다. 이를 간략해 요약하면 모델이 각 단어를 자기회귀적인 방법으로 생성되는 과정은 다음과 같습니다.


[STEP 1] 나는 나는 사과를

[STEP 2] 나는 사과를 나는 사과를 좋아해

[STEP 3] 나는 사과를 좋아해 나는 사과를 좋아해 =>

[STEP 4] 나는 사과를 좋아해 =>나는 사과를 좋아해 => I

[STEP 5] 나는 사과를 좋아해 => I 나는 사과를 좋아해 => I like

[STEP 6] 나는 사과를 좋아해 => I like나는 사과를 좋아해 => I like an

[STEP 7] 나는 사과를 좋아해 => I like an나는 사과를 좋아해 => I like an apple


위 학습 과정을 보면 현재 스탭의 출력 결과가 다음 스탭의 입력으로 돌아가는 것을 확인할 수 있습니다. 즉 모델에 입력된 결과를 전처리(마스킹)하여, 자체적으로 문맥과 자신의 역할을 이해하도록 하는 구조가 바로 자기회귀적인 구조인 셈입니다.


이러한 학습 과정을 통해 모델은 한국어와 영어, 두 언어 간의 관계, 번역이라는 작업에 대한 이해를 진행하게 되며, 추후 번역기로 사용할 때에는 위 과정에서 [STEP 4]에 해당하는 작업부터 진행하게 됩니다.


하지만 추론이 아닌 학습 과정에서 실제 입력되는 시퀀스는 어떻게 처리될까요? 실제로는 다음과 같습니다.


[STEP 4] 나는 사과를 좋아해 => [Mask] [Mask] [Mask] [Mask]

나는 사과를 좋아해 => I [Mask] [Mask] [Mask]

[STEP 5] 나는 사과를 좋아해 => I [Mask] [Mask] [Mask]

나는 사과를 좋아해 => I like [Mask] [Mask]

[STEP 6] 나는 사과를 좋아해 => I like [Mask] [Mask]

나는 사과를 좋아해 => I like an [Mask]

[STEP 7] 나는 사과를 좋아해 => I like an [Mask]

나는 사과를 좋아해 => I like an apple


모델의 각 입력과 출력의 최대 길이는 늘 같게 유지됩니다. 이를 사후적인 처리 등으로 최종 출력에서 제외할 수 있을뿐입니다. 이러한 문제로 인해 입출력의 시퀀스 길이가 길어질수록 자기회귀적인 구조의 모델은 의미없는 [Mask]를 고려한 합성곱 연산을 수행해야 한다는 것입니다. 이를 그림으로 살펴보면 다음과 같습니다.


attention 연산에 대한 그림 ① (from [6])

위 그림은 시퀀스 길이 9, 모델의 dim이 4096인 경우를 도식화한 것입니다. 만약 우리가 번역하고자 하는 토큰이 앞의 4개, 이후의 번역된 결과가 5개라고 할 때, 자기회귀적 구조를 반복하며 4번째로 연산된 결과물은 다음과 같을 것입니다.


attention 연산에 대한 그림 ② (from [6])

위 그림의 query@key 연산 결과인 16개의 결과물 중 mask가 적용되는 6개의 연산은 결국 이후의 연산에서 고려되지 않게 됩니다. 당연히 이러한 불필요한 값을 연산하는 것과 이후의 연산에 고려되는 것 모두 불필요한 연산이 되며, 이러한 문제는 시퀀스의 길이가 길어질수록 quadratic하게 상승하게 됩니다(불필요한 연산량 = 2*(seq_len ^ 2)/2 - seq_len/2)).


이러한 문제를 해결하기 위해 제시된 방법이 KV-Cache이며, 이를 그림으로 살펴보면 다음과 같습니다.

KV-cache에 대한 그림 (from [6])

위 연산의 요지는 굳이 불필요한 연산 자체를 하지 않고 스킵하는 것입니다. Query의 토큰을 하나씩 처리하며, Key와 Value에서는 이를 고려해 현재 query token의 인덱스 이하의 시퀀스에 대해서만 연산을 처리하는 식입니다. 이러한 방법을 통해 cache를 보유함에 따른 VRAM 메모리 부담은 다소 늘어나지만, 추론 속도의 향상을 꾀할 수 있습니다. 한 마디로, 트랜스포머 구조의 가장 큰 문제인 seqeunce의 길이가 늘어남에 따라 연산량이 quadratic하게 증가하는 문제를 어느정도 해소할 수 있습니다.


이를 구현하기 위해선 Q, K, V에 대해 각각의 토큰을 sequential하게 처리합니다. Sequential하게 처리하는 이유는 자기회귀적으로 next token prediction을 진행하기 위해선 다음 토큰을 예측하고, 예측한 토큰을 포함해 다시 입력으로 넣는 과정을 반복하는 과정을 구현하기 위한 것입니다.


즉, 학습 과정에선 다음 토큰을 이미 알고 있으니 당연히 matrix 연산을 통해 진행하는게 효율적이나, 이는 다음 토큰을 미리 알고 있지 못한 가정/환경에서 이뤄지는 추론/서비스 과정에선 사용할 수 없는 방법입니다. 때문에 Sequential한 구조를 이용하면서 불필요한 연산을 최소화하기 위해 이전 K, V를 기억해뒀다가 쓰는 방식이 적용되는 것이 효율적입니다. 이에 대한 실험 결과는 다음과 같습니다.

KV-Cache가 적용된 attention code


이를 간단히 코딩해 연산 복잡도를 측정한 결과는 다음과 같습니다. 연산복잡도는 직관적으로 연산에 들어간 시간으로 측정하였습니다.

download.png 입력 시퀀스 길이에 따른 attention 연산 복잡도 비교

위 그림을 통해 볼 수 있다시피, 시퀀스 길이가 길어질수록 일반적인 attention 연산이 sequence 길이가 길어질수록 quadratic하게 연산 복잡도가 늘어나는 반면, KV-Cache를 이용할 경우에는 이를 상쇄할 수 있는 것을 확인할 수 있습니다. 위 실험 결과의 코드는 여기서 확인할 수 있습니다.


| Difference #4 - GQA(Grouped Query Attention) |

KV cache 외에도 attention 자체의 연산 방법을 개선시키기 위한 전략이 사용되는데요. 기존에 Transformer에 사용되던 방식을 일반적으로 Batched Multi-Head Attention이라고 부릅니다. 배치 단위로 처리함으로써 효율성을 높이고, Multi-Head를 적용함으로써 다양한 관점에서 접근할 수 있도록 유도한 결과물입니다. 하지만 이러한 방법론은 여전히 seqence의 길이가 길어질수록 quadratic하게 연산량이 증가하는 문제를 해결하지는 못했는데요, 이를 개선하기 위해 새로운 어텐션 연산 방식들이 제안됩니다.


overview of various attention method (from [7])


기존에 우리가 사용하는 방식은 위 그림의 왼쪽 방식이었습니다. 즉, 동일한 shape의 Q, K, V를 사용하여 각 헤드별로 연산을 수행하는 방식입니다. 하지만 이 방식은 모든 head에 대해 동일한 크기의 key와 value를 반복적으로 계산하게 되어 불필요한 연산이 발생할 수 있어 비효율적입니다.


이에 대한 대안으로 제시된 방식이 위 그림의 오른쪽에 해당하는 Multi-Query Attention 방식입니다. 이 방식은 query에 대해서만 각 head별로 분리하여 계산하고, key와 value는 단일 헤드의 형태로 관리됩니다. 이후 attention 연산 시, key와 value는 명시적 복제를 하지 않고 내부적으로 broadcast를 통해 query와 동일한 shape로 확장되어 연산됩니다.


이러한 연산 과정을 shape으로 나타내면 다음과 같습니다.


1. Q, K, V 정의

Q = (B, n_heads, seq_len, head_dim)
K = V = (B, 1, seq_len, head_dim)

2. K, V broadcast - (연산시 자동으로 broadcast 됨)

K.repeat(n_heads) -> (B, n_heads, seq_len, head_dim)
V.repeat(n_heads) -> (B, n_heads, seq_len, head_dim)

3. Attention

Score: Q@K.T -> (B, n_heads, seq_len, seq_len)
Out: score @ V -> (B, n_heads, seq_len, head_dim)


비록 동일한 key와 value가 broadcast를 통해 연산되더라도, query의 multi-head가 가진 다양한 관점에 의해 투영되는 것이 같아도 출력(out)의 다양성을 어느정도 유지할 수 있다는 가정에 기반한 전략인 셈입니다. 하지만 이전의 Transformer 리뷰에서 보았다시피 모든 head가 각기 다른 관점을 취하는 것은 아니며, 어느정도 유사한 관점을 취하는 이들도 꽤나 많습니다. 즉, 모든 헤드에 걸쳐 동일한 KV를 투영하는 것은 다양성의 한계가 있다는 의미이며, 결과적으로 성능 자체가 떨어지는 문제를 야기했습니다.


이러한 한계를 보완하고자 제안된 것이 Grouped Multi-Query Attention(GQA)입니다. 이 방식은 모든 query의 head에 대해 동일한 key와 value를 공유하는 대신, 여러 head를 그룹으로 묶어 각 그룹 내에서는 key와 value를 공유하지만, 그룹 간에는 서로 다른 값을 사용함으로써 메모리 사용량과 연산 비용은 줄이면서도 어느 정도의 표현력 차별화를 유지할 수 있도록 설계되었습니다.


이러한 연산 과정을 마찬가지로 나타내면 다음과 같습니다.


1. Q, K, V 정의 - (2개의 Query Head씩 묶는다고 가정 n_heads/2)

Q = (B, n_heads, seq_len, head_dim)
K = V = (B, n_heads/2, seq_len, head_dim)

2. K, V broadcast - (연산시 자동으로 broadcast 됨)

K.repeat(2) -> (B, n_heads, seq_len, head_dim)
V.repeat(2) -> (B, n_heads, seq_len, head_dim)

3. Attention

Score: Q@K.T -> (B, n_heads, seq_len, seq_len)
Out: score @ V -> (B, n_heads, seq_len, head_dim)


이를 통해 key와 value를 작게 유지하면서 연산이 필요한 경우, broadcast를 통해 형상을 맞추는 식으로 적용함으로써 attention 연산 과정을 최적화할 수 있다는 것입니다. 이를 numpy를 이용한 간단한 코드로 살펴보면 다음과 같습니다.

GQA with numpy

참고로 본 논문에서는 이러한 KV cache와 GQA에 대한 언급은 없었는데요, 다만 official code를 살펴봤을 때 GQA는 적용돼있는 형태로 구현된 것을 확인할 수 있었습니다. 또한 우리가 테스트할 모델은 파라미터가 작고, 입출력 시퀀스의 길이도 짧은 편이라 KV cache로 인한 효율성을 체크하기엔 무리가 있어 해당 부분을 제외하였습니다.


torch로 작성한 attention 코드는 다음과 같습니다. (KV cache는 코드의 가독성을 위해 제외하였고, 최종 코드에서만 구현하였습니다만 별도의 generate 함수를 구현하지 않으면 사용되지는 않습니다)

Grouped Multi-Head Self Attention w/o KV cache

위 전체 코드에서 확인할 수 있다시피 query의 shape가 (Batch_size, Sequence_length, num_heads, head_dim)인 반면, key와 value의 shape은 query의 head 수를 group 수로 나눈 값으로 지정하게 됩니다. 여기서는 모델 파라미터를 별도의 Model_Args로 지정했는데, 그 일부를 살펴보면 다음과 같습니다.

Model Params and etc

위 그림에서 확인할 수 있다시피 kv의 head는 q의 head를 group 수로 나눈 값이 되게 됩니다.


다만 이렇게 GQA의 적용은 명확한 장단점을 만든다는 것을 알아야 합니다. 단점으로는 당연하게도 Key와 Value의 표현 다양성이 줄어들어 모델의 문제 해결 능력 자체는 떨어질 수 있다는 것입니다. 하지만 그럼에도 (특히 거대 모델의 경우) 저장/관리하는 파라미터의 수를 줄일 수 있다는 것은 큰 장점입니다. 비록 모델이 실행되는 과정에서의 파라미터가 증가할지라도 저장/관리되는 모델의 파라미터가 줄어들고, 이를 단순히 broadcasting함으로써 사용할 수 있다는 장점이 있습니다.


본 논문에선 제시되지 않았지만 최근 화재가 되고 있는 DeepSeek의 기술 보고서나 논문[9]을 보면, 이러한 Attention의 개선된 형태로서 MLA(Multi-Head Latent Attention)을 제시하기도 합니다. 이는 Q, K, V를 모두 낮은 차원의 Latent Vector로 변환해 적용함으로써 MHA에 필적하거나, 유사한 성능을 내면서도 보다 빠른 추론이 가능하도록 한 구조입니다. 이를 그림으로 살펴보면 다음과 같습니다.

download.png MLA Structure (from [9])

이에 대한 개념은 본 포스팅에선 자세히 다루진 않고, 추후에 별도의 포스팅으로 다루겠습니다. 해당 논문[9]에서 제시된 다양한 attention들의 성능을 살펴보는 정도로만 하겠습니다. 우선 위에서 살펴본 MQA(Multi-Query Attention), GQA(Grouped-Query Attention), 그리고 MHA(Multi-Head Latent Attention) 간의 성능 차이입니다.

download.png Comparison among 7B dense models with different attentions(from [9])

위 그림에서 알 수 있듯 모델의 규모가 더 클지라도, 모든 경우에 있어서 MHA가 가장 높은 성능을 보여줍니다. 그렇다면 MLA와는 어떨까요?

download.png Comparison between MLA and MHA with various MoE (from [9])

MLA 모델은 특히 FFN의 역할이 분업화되는 MoE와 결합할 경우 보다 높은 성능을 보이는데요, 작은 MoE와 결합될 경우, MHA와 비슷하거나 근소하게 앞선 성능을 보이지만 큰 MoE와 결합될 경우 오히려 보다 높은 성능을 내는 것을 확인할 수 있습니다. 또한 이러한 모델들의 경우 대부분 KV-Cache를 통해 추론 과정을 최적화하기 마련인데요, 아래 그림에서 추론 과정에서 쓰이는 각 attention 별 토큰 당 KV-Cache 사이즈가 도식화 돼있습니다.

Comparsion of the KV cache per token among different attentions (from [9])

위 그림에서 dh는 head_dim을 의미하고, l은 num_layers를 의미합니다. 이들은 모두 공통적이니 배제하고 고려한다면 MQA가 토큰 당 KV Cache가 2로 제일 낮고, 다음으로 GQA가 그룹수의 배수이고 MLA가 4.5, MHA가 num_heads(nh)의 배수입니다. 즉 MLA는 그룹수가 2.25인 GQA 혹은 head의 개수가 2.25개인 MHA와 같은 수준의 KV cache를 가짐으로써 아주 작은 모델이 아닐 경우 가장 빠르게 실행되면서도 가장 적은 수준의 리소스를 소모한다는 것입니다.


| Difference #5 - SwiGLU |

마지막으로 일반 Transformer와 LLaMA의 구조와의 차이는 바로 활성함수의 차이입니다. LLaMA에서는 SwiGLU라는 이름의 activation layer를 사용합니다. 기본적으로 우리가 사용하던 GLU 계열 활성화 함수는 다음 수식과 같이 Bilinear 함수에 추가적인 함수(σ)를 추가로 적용한 것입니다.

GLU 활성화 함수 수식 (from [8])


이러한 활성화 함수를 처음으로 제시한 논문[8]에서는 다양한 버전의 GLU 활성화 함수 및 이를 이용한 FFN를 제안하였으며, 이때 SwiGLU의 수식은 다음과 같습니다.

download.png 다양한 GLU 변주 수식 (from [8])

위 그림에서 볼 수 있다시피 어떤 함수를 적용하는가에 따라 뒤에 'GLU'가 붙는 식으로 이름이 정해집니다. 즉, SwiGLU란 swish(β) 함수가 적용된 GLU 활성 함수를 의미하는 식입니다. 이러한 수식에 기반한 성능 평가는 다음과 같습니다.

GLU 들의 성능 평가 (from [8])

위 그림에서 알 수 있듯 GEGLU를 적용한 경우와, SwiGLU를 적용한 경우에 성능이 제일 높았으며, 이후의 다른 벤치마크를 고려했을 때에도 GLU 계열들이 뛰어난 성능을 발휘했습니다. 이때 많은 경우에서 SwiGLU가 높은 스코어를 달성해서 최근 모델들에서 SwiGLU를 사용하는 것이 아닐까 싶습니다. 왜 이런 식으로 얼버무리듯이 성능이 뛰어난 이유를 설명하느냐하면,

GLU 계열이 뛰어난 이유를 알 수 없는 이유? (from [8])

해당 논문에서도 별다른 설명이나 근거없이, 신의 자비로 결론내렸기 때문입니다. 때문에 우리는 그저 실험적으로 성능이 뛰어났다는 내용을 믿고, 수식에 기반에 SwiGLU를 구현해보겠습니다. 우선 수식입니다. LLaMA 논문에서는 swish 함수의 굴곡 정도를 결정짓는 β를 1로 정의하였는데요, 이는 일반적으로 많이 사용되며 SiLU(Sigmoid Linear Unit)이라고도 부릅니다. 이에 대한 수식과 그래프는 다음과 같습니다.

download.png SwiGLU function (from [6])

그리고 이를 이용한 FFN를 구현하면 다음과 같이 작성할 수 있습니다.

FFN w/ SwiGLU

위 코드에서 특이한 점은 17번째 줄의 residual connection, 즉 잔차합을 더하는게 아닌 곱하기로 적용한다는 것입니다. 이는 LLaMA 구조에서 사용되는 RMSNorm의 한계를 해소하기 위한 것인데요, 기존 LayerNorm의 경우 평균과 분산 계산을 통해 기존과 같이 x+f(x)의 형태로 잔차합을 적용하더라도 벡터값의 크기가 크게 변화하지 않았습니다. 반면 RMSNorm의 경우, 평균을 빼는 과정없이 벡터의 크기 자체만을 정규화하기에 이를 기존과 같은 방식으로 residual connection을 적용할 경우, 입력 벡터가 과도하게 커지거나 작아질 수 있는 것입니다.


이로 인해 덧셈이 아닌 곱셈(x⊙f(x))의 방식으로 residual connection을 적용함으로써 입력 벡터 x에 스케일링을 적용한 형태로만 보정됨으로써 평균과 분산 계산이 없는 RMSNorm의 한계를 극복할 수 있게 되는 것입니다. 여기까지 구현이 완료됐다면 남은 것은 LLaMA[2]에서 제시한 구조에 맞게 결합하는 것뿐입니다. 즉, 여기까지의 과정을 통해 우리는 LLaMA를 구현하는데 성공했습니다!



| Difference #6 - Diff Attention |

from LLaMA to Diff Attention

이제 Diff Transformer를 구현하기 위한 마지막 과정입니다. 우리가 여태까지 구현한 현대적 트랜스포머 구조에 Diff Attention을 추가하는 것입니다. 다시 기존의 attention code를 살펴보면 다음과 같습니다.

Attention of LLaMA

GQA(Grouped Query Attention)을 적용하였고, Query와 Key에 대해 RoPE를 적용하였으며, Causal Mask를 적용한 형태의 attention code입니다. 여기에 Differential Attention을 적용해야 합니다. 이를 구현하기 위해 본 논문[1]에서 제시한 구조를 살펴보면 다음과 같습니다.

Diff Transformer Structure (from [1])

위 그림에서 살펴보면 Q, K를 2개로 나누고, V는 나누기 전의 Q, K와 같습니다. 즉 [B, L, 2D']의 형태입니다. 여기서 2D'은 일반적인 트랜스포머 구조의 embed dim이 2로 나눠질 수 있어야 함을 의미합니다. 조금 더 나아가 Multi-Head Attention의 구조를 고려하면 Q, K, V는 [B, 2H, L, H_D]의 형태여야 합니다. 여기서 H는 헤드의 수, H_D은 헤드당 차원수입니다.


이때 사전에 Q와 K를 나누고 두 개의 attention weight를 구하는 방법도 있지만, 한 번에 구한 다음 나누는 식으로도 진행가능합니다. 따라서 Q@K 연산을 수행하면 [B, 2H, L, L]과 같은 형상이 되며, 형상변환(reshape)을 통해 [B, H, 2, L, L]과 같이 바꾸고 세 번째 차원을 기준으로 분리해줍니다. 이때 나눠진 attention weight를 각각 attn1, attn2로 두고, 본 논문[1]에서 제시한 attn1 - λ*attn2를 연산해야 하는데, 이때의 λ는 다음과 같은 수식에 의해 결정됩니다.


Differential Attention Equation (from [1])

위 그림의 2번 수식과 같이 scalar λ는 학습가능한 파라미터들에 의해 re-parameterization됩니다. 이러한 내용을 코드로 살펴보면 아래와 같습니다.

Core Code of Diff Transformer (from official github)

즉 λ는 학습가능한 파라미터 네 개와 경험적 최적값인 λinit의 결합으로 구현되며, 이를 차등적 attention weight를 구할 때 고려되고, 추후 GroupNorm을 진행한 후 λinit 값을 고려해 scaling 되는 식입니다. 이러한 내용을 전체 코드로 살펴보면 다음과 같습니다.


Attention of Diff

GQA와 RoPE를 적용하기 위해 shape 변환 과정이 조금 복잡할 수 있지만, 그 부분 외에는 위에서 살펴본 수도 코드에 따라 구현된 Attention 코드입니다. 코드의 가독성과 이해가 쉽도록 추론에 필요한 KV cache와 Mask에 대한 코드는 제외하였습니다. 이렇게 구현한 코드를 모델의 구조에 맞게 조립해 사용하면 모델링은 마무리됩니다.


|Experiments

위에서 구현한 모델들을 직접 학습해 검증해보는 단계입니다. 이 내용을 위해 Hugman(SangKeun Jung) 님의 medium 포스팅[10]의 구조를 참조하였습니다. 참고한 코드는 참조 [10]에서 코드를 다운받거나 제가 조금 수정한 버전을 여기에서 확인할 수 있습니다. 또한 이를 기반으로 본 포스팅을 위해 전체적인 테스트 구조만을 가져와 직접 실험해본 코드는 여기서 확인할 수 있습니다.


구체적으로 해당 포스팅에서는 작은 사이즈의 shakespear 희곡 데이터셋을 이용해 text generation을 학습하며, 아주 작은(파라미터 수천 개 사이즈의) LSTM 모델과 단층 Transformer에서부터 멀티헤드, 멀티레이어에 본 포스팅에서 설명한 LLaMA의 구조들을 하나씩 이식하며 결과를 제시합니다.


다만 해당 포스팅 내용에서 데이터셋이 작은 것을 고려해도 너무 단순하고 작은 순환신경망(RNN)을 사용해 단층&단일헤드 Transformer와 비교해도 train loss가 낮은 결과가 제시되는데, 이는 적어도 train loss에 대해선 최신 모델도 RNN 계열을 넘어서기 힘들어한다는 점을 고려했을 때 비상식적입니다. 때문에 정상적인 모델 성능 측정 및 비교를 위해, 더 큰 사이즈의 GRU 모델에 layernorm을 적용한 비교적 현대적인 구조의 순환신경망을 구현해 비교하였습니다. 모델의 파라미터 사이즈는 약 80만 개로 기존의 4,000개 미만의 수준에서 200배 정도를 늘려서 테스트하며(수만 개에서 1억 개 수준으로 늘리며 테스트했을 때, 70~80만 이상이 되면 전체적인 학습 그래프 패턴이 유사했습니다), 굳이 단층&단일헤드 Transformer부터 시작하지 않고 Transformer의 Decoder만 분리한 GPT-2 구조와 LLaMA, 그리고 Diff Transoformer를 비교하였습니다. 추가적으로 AMP(Automatic Mixed Precision)을 적용해 경량화해 학습하였습니다.


|Description - Dataset

우선 데이터셋입니다. 이는 참고한 블로그 포스팅[10]의 내용을 그대로 가져온 것으로 셰익스피어의 희곡 데이터셋입니다. 위에서 설명한 바와 같은 next token generation을 위해 구축된 형태이며, 간단하게 sliding window와 같은 형태로 구축됩니다.

download.png

여기선 별도의 토크나이저 없이, 하나의 글자를 하나의 토큰으로 사용하는 식으로 간단히 구현하였습니다.


|Model 1 - GRU

RNN 계열의 모델 중 하나인 GRU에 layernorm을 추가하고, activation을 relu에서 gelu로 바꾼 모델을 성능 비교를 위한 base model로 구현합니다. 이에 대한 코드는 다음과 같습니다.

download.png

이렇게 구현한 모델 및 이하의 모든 모델들은 Cross Entropy Loss와 AdamW Optimizer를 통해 학습시키며, 그 외 별도의 학습 레시피(scheduler, gradient clipping, EMA, etc)는 사용하지 않습니다. 이렇게 구현한 모델의 파라미터 수는 다음과 같습니다.


download.png

이렇게 구현한 모델의 파라미터 수는 약 78만 개이며, 이 모델로 20 epoch(약 10,000 step) 학습시킨 결과는 다음과 같습니다.

download.png

위 그림에서 알 수 있듯 약 5,000step 이후부터 val loss가 더이상 하락하지 않고 커지며 과적합 현상이 발생하고 있는 것을 알 수 있습니다. 즉, 학습되지 않은 데이터 유형에 대해 취약하다는 RNN 기반 모델의 단점이 발생하고 있는 것입니다. 이를 Transformer 기반 모델들은 내적 연산을 통한 상대적 관계를 파악함으로써 이를 보완하고자 하였습니다.


|Model 2 - GPT2(Transformer Decoder Only)

GRU와 비교할 두 번째 모델은 GPT-2입니다. 정확히는 GPT-2를 구현한 것은 아니고 그와 유사한 구조로 구현한 모델입니다. 이 또한 Transformer 비교군이기 때문에 classic이라고 명명하였습니다. 가장 근본적인 차이는 LLaMA 모델의 주요 요소를 배제한 것입니다. (즉, LLaMA 모델을 먼저 구현하고, LLaMA의 구성 요소를 역순으로 빼서 해당 모델을 구현하였습니다)

참고로 위와 같이 모델을 구현할 때, 추후 attention weight를 시각화하기 위한 코드를 추가하였습니다. 이렇게 구현한 모델의 파라미터 수는 다음과 같습니다.


download.png

총 네 개의 layer로 구성하였으며, 전체 파라미터 수는 GRU와 같은 수준인 78만 개 가량입니다. 이에 대한 학습 결과를 살펴보면 다음과 같습니다.

위 그림에서 볼 수 있듯 왼쪽의 train loss는 GRU 기반 모델보다는 성능이 낮지만, 오른쪽의 validation loss의 경우엔 오히려 transformer 기반 모델의 경우 여전히 지속적으로 성능이 향상되는데 반해 GRU 기반 모델은 오히려 성능이 하락하며 과적합을 일으키고 있는 것을 확인할 수 있습니다. 이러한 경향은 이후 100 epoch 이상까지도 유지(GRU의 val_loss는 높아지고, Transformer의 val_loss는 지속적으로 하락하며 수렴)됩니다.


|Model 3 - LLaMA

다음으로는 LLaMA입니다. Attention까지의 코드는 위에서 설명했으니 넘어가고, 전체적인 모델 구조만 살펴보고 넘어가겠습니다.

위 코드로 구현한 모델의 구조 및 파라미터는 다음과 같습니다.


download.png

모델의 파라미터는 약 77만 개로, 이전의 두 모델보다는 약간 작게 구현했습니다. 성능이 뛰어나다면 약간 부족한 파라미터로도 더 높은 성능을 낼 수 있어야 하기 때문입니다. 이에 대한 성능은 다음과 같습니다.


위의 왼쪽 그림을 통해 볼 수 있듯, LLaMA 구조로 구현된 모델은 Transformer 보단 빠르게, GRU보단 느린 수준으로 성능이 향상됐습니다. 오른쪽 그림에선 Transformer보다 빠르게 초기에 수렴하면서 지속적으로 보다 높은 성능을 보이고 있는 것을 확인할 수 있습니다.


|Model 4 - Differential Transformer

마지막으로 Differential Transformer인데요, 이 모델은 attention 코드만 바뀌었을 뿐 Transformer Block이나 전체 모델 구조는 바뀌지 않았습니다. 바로 전체 구조와 파라미터를 확인해보겠습니다.

download.png

Diff Transformer도 LLaMA와 같은 논리로 77만 개의 파라미터로 구현했습니다. 이에 대한 성능은 다음과 같습니다.

위 그림을 살펴보면 학습 손실의 경우, LLaMA와 거의 같은 수준이었으며, 평가 손실의 경우 LLaMA보다 살짝 높았습니다. 이는 실험에 사용한 규모의 데이터셋은 규모가 작을뿐더라 별다른 노이즈 데이터도 적어 Differential Attention이 긍정적으로 개입될 여지가 적어서 발생한 문제로 보입니다.


|Attention Map

그렇다면 이러한 모델들의 Attention Map은 어떻게 나타났을까요? 우선 GPT-2와 LLaMA의 Attention Map입니다.

두 모델 모두 0에서 1까지의 attention weight 범위 안에서 분포돼있는 일반적인 attention map을 그리고 있는 것을 확인할 수 있습니다. 그렇다면 diff transformer는 어떨까요?

Diff Transformer의 경우, 첫 번째 맵과 두 번째 맵, 그리고 람다를 곱한 두 번째 맵과 이를 첫 번째 맵에서 뺀 맵을 시각화하면 위와 같습니다. 직관적으로 확인할 수 있는 것은 Diff Transformer의 attention weight는 음수와 양수의 범위 안에서 이뤄지고 있는 것을 확인할 수 있는데요. 이를 좀 더 큰 이미지로 수치와 함께 확인하면 다음과 같습니다.


LLaMA의 attention weight를 확대해서 살펴보면 위와 같습니다. 표시되지 않은 부분은 0이거나 아주 작은 attention이 이뤄지는 곳입니다. 가독성을 위해 100을 곱해 정수 자리까지만 시각화하였습니다. 그렇다면 다음으로는 Diff Transformer의 attention weight도 살펴보죠.


위 그림을 통해 알 수 있는 부분은 크게 두 개 입니다. 첫 번째로는 위에서 이미 살펴봤던 것과 같이 Diff Transformer의 attention map은 음수의 범위까지 분포한다는 것입니다. 이를 통해 단순히 양의 관계성이 얼마나 있냐를 넘어 음의 상관관계를 갖는 토큰 간의 관계까지도 고려할 수 있는 식으로 노이즈를 처리하게 되는 것으로 보입니다.


두 번째로는 상대적으로 LLaMA와 같은 일반 Transformer 모델에 비해 하나의 토큰에 대해 큰 attention value를 지닌 토큰이 소수라는 것입니다. 즉, 이러한 특징은 이는 람다(λ)로 스케일된 두 번째 attention weight에서 scaling 됐음에도 attn1보다 큰 경우엔 음수로 바꿔 음의 상관관계를 갖게 하고, 그렇지 않은 경우엔 절대치를 줄여 중요하지 않게 하는 효과를 유도하는 것으로 보입니다.


|Experiements 1 - GQA

위에서 진행했던 모델 학습 및 실험 결과는 GQA 그룹수를 1로 한 결과물이었습니다. 정확히 표현하면 MHA로 진행한 결과였죠. 때문에 추가적인 실험을 통해 GQA를 적용하기 전후의 결과를 간단히 살펴보도록 하겠습니다. 이에 대해서 GQA와 같은 압축적으로 attention을 수행하는 모델의 성능과 상호보완적으로 작동하는 MoE를 적용해보았습니다. 모델의 사이즈는 약 180만으로 이에 대한 코드 및 결과는 여기서 확인할 수 있습니다.

위 실험 결과에서는 GQA를 적용했을 경우 학습 성능은 유사하나, val_loss의 경우 초반엔 조금 부족하다가 이후에 오히려 앞지르는 모습을 보여주고 있습니다. 그렇다면 GQA와 같은 현대적 attention 방식에 더 친화적이라는 MoE를 적용하면 다음과 같으며, 이에 대한 코드는 여기서 확인할 수 있습니다.

MoE를 적용할 경우엔 초반부터 GQA의 성능이 MHA의 성능을 앞서는 것을 확인할 수 있습니다. 비록 현재 테스트에 사용한 데이터셋이 작고, 노이즈가 적은 형태이며 모델의 사이즈 또한 그리 크지 않다는 점을 고려해야 하지만, 그럼에도 작은 태스크에 어울리는 작은 모델에 활용하는 경우에도 GQA가 자원 활용 측면에서 효율적일뿐 아니라 성능 측면에서도 괜찮다 정도의 결론은 낼 수 있을듯 합니다.


|Experiements 2 - θ of RoPE

마지막 실험은 위치 인코딩 방식인 RoPE에서 상대적 위치에 대한 단위인 세타를 조정하는 것입니다. RoPE를 적용하는 논문이나 모델들을 보면 초기엔 10,000이며, 이는 본래 Transformer의 주기함수에서 유래된 값이며, 경험적으로 가장 최적의 값인 것으로 알려져 있습니다만, 이는 입력 시퀀스 길이에 따라 조정가능하지 않을까 해서 실험해본 결과입니다

위 그림에서 Diff1, Diff2, Diff3는 각각 세타값을 100, 1000, 10000으로 적용한 결과입니다. 현재 구현한 모델의 입력 시퀀스 길이가 짧음에도 여전히 10,000에서 높은 성능을 보이는 것을 확인할 수 있습니다. 위 결과는 여기 코드를 통해 확인할 수 있습니다.


|Reference

[1] Tianzhu Ye, et al. "DIFFERENTIAL TRANSFORMER".

https://arxiv.org/pdf/2410.05258.

[2] Hugo Touvron, Thibaut Lavril, Gautier Izacard, et al. "Llama: Open and efficient foundation language models". https://arxiv.org/pdf/2302.13971.

[3] Jianlin Su, Yu Lu, Shengfeng Pan, Bo Wen, and Yunfeng Liu. "Roformer: Enhanced transformer with rotary position embedding". https://arxiv.org/abs/2104.09864.

[4] Ofir Press, Noah A. Smith, Mike Lewis. "Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation". https://arxiv.org/abs/2108.12409.

[5] Efficient NLP. "Rotary Positional Embeddings: Combining Absolute and Relative". https://www.youtube.com/watch?v=o29P0Kpobz0

[6] Umar Jamil. "Coding LLaMA 2 from scratch in PyTorch - KV Cache, Grouped Query Attention, Rotary PE, RMSNorm". https://www.youtube.com/watch?v=Mn_9W1nCFLo.

[7] Joshua Ainslie, James Lee-Thorp, Michiel de Jong, Yury Zemlyanskiy, Federico Lebrón, Sumit Sanghai. "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints". https://arxiv.org/abs/2305.13245.

[8] Noam Shazeer. "GLU Variants Improve Transformer". https://arxiv.org/pdf/2002.05202.

[9] DeepSeek-AI. "DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model". https://arxiv.org/pdf/2405.04434.

[10] Hugman Sangkeun Jung. "[Hands-On] Mastering LLaMA - Implementing LLama1 from Scratch". https://medium.com/@hugmanskj/hands-on-mastering-llama-implementing-llama1-from-scratch-1-3-8ba4b9e8da0e.

keyword
작가의 이전글Kolmogorov Arnold Network