Paper Review 1 : Transformer 이해, 코드 구현
Transformer가 나온 배경은 간단하다. 기존에 자연어 처리를 위한 방법들(RNN, LSTM, CNN) 등이 가진 한계가 있었기 때문. 이러한 한계는 크게 두 가지로 다음과 같았다.
1. 입력의 차원이 길어질수록 최초 입력에 대한 정보가 희석된다.
2. 입력의 차원이 길어질수록 결국 마지막 입력에 대한 가중치가 높아진다.
어찌보면 하나의 문제처럼 보이는 이것을 소위 inductive bias라고 정의하기도 한다. 기존에 사용되던 RNN과 LSTM과 같은 모델들은 결국 순차적으로 처리되기 때문에, 결국 데이터의 순서에 의존성을 지니게 된다. 이는 특정 시퀀스 패턴에 대한 Inductive bias가 높다는 것을 의미하고, 자연어 처리에 한계가 존재한다는 인식을 퍼뜨리는 주요 원인이 되기도 했다.
(이를 해소하고자 CNN을 결합할 수 있는 형태가 나오기도 했지만, 결국 이조차 기계적으로 단락을 나누는 과정에 불과해 본질적인 한계를 극복하지 못했다. 이는 추후 Transformer를 Computer Vision Task에 적용한 ViT 모델들 중 Swin Transformer가 당면한 문제와 유사하다)
이러한 문제를 해결하기 위해 Transformer, 정확히는 Attention Mechanism이 제시됐다. 이를 통해 입력 시퀀스 내의 모든 요소 간의 관계를 학습하고, 고려할 수 있게 되었다. 다만 이러한 방식은 귀납적 추론 방식과 크게 다르지 않기 때문에 기존에 사용되던 방식들에 비해 최소 몇 배는 많은 데이터를 요구하는 문제를 낳기도 했다.
본 논문의 핵심은 두 말할 나위없이 [Attention Mechanism]이다. 이는 입력 데이터의 순서에 의한 편향을 제거하고, 귀납적 논리에 의거해 입력 값들 간의 상대적 관계성을 파악하는 것에 초점이 맞춰져 있다. 다만 이러한 방법론을 제대로 정확하기 위해 함께 제시되는 개념이 Multi-head Attention, Positional Encoding, Cross Attention과 같은 방법들이며, 이에 대해 하나씩 정리해보자.
본 논문에서 제시된 핵심개념인 Attention의 공식이다. 일반적으로 입력값은 임베딩된 벡터의 형태로 전달되며, 이러한 입력값에 대해 쿼리(Q:Query), 키(K, Key), 밸류(V:Value)를 연산한다. 이러한 Q,K,V들은 가중치 행렬(Wq, Wk, Wv)들이 각각 곱해지며, 이 가중치들은 학습 과정에서 최적화된다.
(Attention Mechanism에서 유일하게 학습 중 최적화되는 부분이라고 봐야한다.)
어텐션을 구현하는 순서는 아래와 같다.
① Q, K, V 연산
② Query, Key 내적
③ Attention Score 연산
④ Attention Weight 연산
⑤ 최종 출력
①입력값에 대해 가중치 행렬을 구해 Q, K, V를 구했다면, ②Q와 K의 내적을 계산하고, ③이를 벡터 차원의 제곱근으로 나눠, 지나치게 큰 내적값으로 인해 softmax의 결과가 극단적으로 치우치는 것을 방지한다. (때문에 수식 이해를 위해선 우선 root(dk)를 빼고 생각해보면 더 쉽다.) 이 값을 Attention Score라고 하며, 입력 요소들 간의 상대적 중요도를 표현한다.
이후에는 ④Attention Score에 대해 softmax 함수를 적용해 Attention Weight를 계산하며, 이는 각 입력요소가 다른 요소에 얼마나 주목(attention)하는지를 표현한다. ⑤마지막으로 Attention Weight에 Value를 곱해 최종 출력을 생성하며, 이는 입력 시퀀스를 새롭게 표현한 형태로 간주한다.
(입력값을 단순 임베딩 벡터로 표현한 것에 입력 요소들 간의 관계가 반영된 것)
위 과정까지 진행했다면 위 attention 구조를 나타낸 그림에서 옵션으로 제시한 Mask를 제외한 부분을 완료했다. 위 그림에서 Mask는 task에 따라 선택적으로 적용하게 된다. 일반적으로 Transformer 구조는 복수의 attention layer의 결합으로 구성되며, 각각 조금씩 다른 옵션이나 형태를 취할 수 있다. Mask가 그 예시 중 하나이며, 일반적으로 Data Leakage를 방지하기 위한 목적으로 사용된다.
그 외에는 task에 따라 달라질 수도 있는데, 만일 chatGPT와 같이 특정 값을 입력하면 그 뒤의 값을 생성하는 task라면 정답 부분에 해당하는 뒷 내용에 대해 Mask를 생성해 진행한다. 반대로 본 구조를 채택해 진행하는 task가 전체 문장에 대한 분류나 번역 등의 문제라면 Mask없이 진행할 수 있다. (이러한 Mask에 대한 옵션은 같은 Task라도 구현하는 목적이나 의도, 다른 방법론과의 결합 등으로 변동할 수 있다.)
Scaled Dot-Product Attention에 대한 코드는 아래와 같다.
쉬운 이해를 위해 이미 벡터화된 데이터를 가정하고 진행한다. 다만 위 코드는 numpy만을 이용해 구현하고자 했기에, 가중치 행렬을 임의로 설정해 사용한다. 실제 코드에선 간단한 Linear Module을 이용해 초기화하고, 학습에 따라 업데이트가 진행된다. 위 코드를 통해 볼 것은 Scaled Dot-Product Attention의 수식이 어떤 과정을 거쳐 진행되는지를 확인하는 것이다.
실제 코드에선 보다 빠른 연산 및 gpu활용을 위해 tensor형태로 작동하는 tensorflow나 torch 등의 라이브러리를 활용하며, 이에 대해선 이하의 코드 설명 부분에 포함돼 있다. 위 코드의 입출력을 살펴보면 아래와 같다.
참고로 위의 코드에서는 scipy의 softmax 함수를 이용해 softmax를 바로 구현했다. softmax는 간단히 설명하자면 여러 값들로 구성된 그룹이 있을 때, 각 요소들을 전체에 대한 비중으로 나누는 것이다. 이러한 이유로 softmax 함수가 적용된 값들은 0과 1사이에 위치하며 합하면 항상 1이 된다. 때문에 너무 작거나 큰 값은 0 또는 1로 수렴하며, 이는 동시에 다른 값들에 대한 노이즈를 생성시킬 수 있어 Attention의 공식에서 root(dk)로 나눠주는 것이다.
위 코드에서 scipy의 softmax 함수를 numpy로 구현하면 아래와 같다.
Attention을 통해 각 입력값들 간의 상대적 관계성까지 고려한 시퀀스를 생성했다. 하지만 우리의 언어는 다른 말이지만 같은 뜻을 나타내는 동의어라거나, 같은 단어임에도 여러 뜻을 지니고 있기도 하고, 과거의 역사나 유행에 의해 새로운 단어를 만들거나 여러 단어가 합쳐진 결과 전혀 다른 뜻을 내포하기도 한다.
이는 단일 Attention, 정확히는 단일 Wq, Wk, Wv로는 우리가 실제 사용하는 언어의 상관관계를 명확하게 표현하기 어렵게 만드는 요소이고, 이를 해결하고자 하는 방법론이 추가로 적용되는데 이것이 바로 Multi-Head Attention이다.
위 구조도에서 보라색 부분까지가 이전에 진행한 Scaled Dot-Product Attention이고, 이 과정을 여러 부분으로 나눠 병렬적으로 처리하고, 나눠서 처리했기 때문에 최종적으로 합치는(Concat) 방법까지가 Multi-Head Attention이다. 참고로 하나의 Attention을 Head라고 칭한다.
만일 입력된 값의 Sequence Length가 10이고 배치 사이즈는 2, head의 개수를 4, 모델의 차원(입력차원을 새롭게 표현하는 벡터길이)를 64라고 한다면 입력차원은 (2, 10)이다. 이를 Multi-Head Attention이 적용되지 않은 Attention에 통과시키면, Q, K, V의 차원은 (2, 10, 64)이 된다. 하지만 Multi-Head Attention이 적용된다면 Q, K, V의 차원은 (2, 10, 4, 16)이 된다.
(이러한 연산을 위해 head의 개수는 모델의 차원, 흔히 d_model로 표현하는 값의 약수로 해야 한다)
이는 단순히 하나의 Attention을 작은 단위로 자른 것 이상의 효과를 보이게 되는데, 가장 직관적으로 확인할 수 있는 것은 속도 부분이다. head의 수만큼 나뉘어 병렬로 연산되기 때문이다. 또 다른 강점은 이렇게 나뉘어 연산되는 과정 속에서 각각의 head가 별개의 attention을 계산해 가중치를 지니게 되기 때문에 서로 다른 '관점'을 지니게 된다는 것이다. 이는 위에서 언급한 다양한 표현이나 예외적 사용 등에 대해서도 대응할 수 있는 근거를 마련해준다.
위 그림은 본 논문에 소개된 Multihead attention의 Attention Score를 가시화한 것이다. 오른쪽의 making이라는 단어와 높은 관계성을 지닌 단어들을 표현하고 있는데, 왼쪽에서 색칠된 블록들은 각기 8개의 서로 다른 head의 주목도를 나타낸 것이다.
코드를 살펴보면 아래와 같다. 일반적으로는 하나의 forward에 넣는 Multi-Head 부분을 split_heads라는 이름의 별도 함수로, 별도로 입력받거나 forward 내에서 처리하는 mask 영역 생성 부분을 create_look_ahead_mask라는 별도의 함수로 생성해 조금이라도 이해가 쉽도록 구현했다.
본래라면 디코더에만 Mask가 사용되기 때문에, 해당 부분을 조건문을 통해 제한해야 하지만 여기선 Mask의 효과를 확인하기 위해 한 번에 연산되도록 했다. 참고로 여기서 사용한 Mask는 look-ahead-mask라는 것으로 현재보다 미래의 정보가 반영되지 않도록 하는 방식이다.
이해가 쉽도록 위에서 언급한 예시와 동일한 파라미터를 적용해 클래스를 활성화하면 아래와 같다.
위 클래스는 마스크 이전의 Attention Score Matrix와 마스크가 적용된 Attention Score Matrix, Mask Area와 최종 결과를 출력한다. 이에 대한 결과는 다음과 같다.
일반적으로 Multihead Attention에 적용되는 계산복잡도는 위와 같이 표현된다. 위 복잡도는 softmax를 제외한 query, key, value 및 최종 출력에 대한 projection(각 값과 weight끼리의 내적) 복잡도와 Query, Key, Value 간 내적을 표현한 것이다. 이를 좀 더 자세하게 수식적으로 표현하면 아래와 같다.
transformer 구조를 통해 inductive bias를 없애 먼저 들어온 값과 나중에 들어온 값을 동등하게 분석할 수 있었지만, 우리들이 사용하는 언어는 그 위치에 따라 뜻이 정반대로 달라질 수도 있다. 간단히 보면,
1. 나는 고양이보다 강아지가 좋아
2. 나는 강아지보다 고양이가 좋아
단순히 강아지와 고양이의 순서가 뒤바뀌는 것만으로, 문장의 뜻은 정반대가 되는 것을 알 수 있다. 이러한 문제를 해결하고자 도입된 개념이 positional encoding으로, 간단히 말해 각 단어의 순서를 유지하는 것이다. 일반적으로 주기적인 함수인 cos과 sin 함수를 사용한다.
이렇게 표현된 positional encoding은 위 그림의 구조에서와 같이 embedding된 입력값에 대해 더해짐으로써 고려된다.
positional encoding의 코드는 아래와 같다.
코드 자체는 그리 어렵지 않고, 간단하게 docstring을 작성했다.
위 코드는 sin, cos 함수를 이용해 홀수 인덱스의 위치는 cos함수를 사용하고, 짝수 인덱스의 위치는 sin함수를 사용해 표현하는 식이다. positional encoding된 값의 형태는 당연히 transformer에서 출력되는 형태와 일치해야 하며, 여기선 위의 예시에서 사용했던 최대 길이(max_len) 10, 모델 차원(d_model) 64를 사용했다.
위 코드의 결과 출력되는 pos_encoding의 shape와 출력값은 아래와 같다.
출력형태는 (max_len, d_model), 즉 (10,64)이고, 주기 함수인 cos, sin 함수를 사용했기 때문에 출력된 array의 첫 번째 줄은 0과 1이 반복되는 것이 특징이다.
본 논문에서는 Encoder-Decoder Attention으로 표현하는, Encoder의 출력이 Key와 Value로, Decoder에서의 출력이 Query로 사용돼, Encoder 출력과 Decoder 출력 간의 상관관계를 변수화하는 Attention Layer이다. 일반적으로 Cross Attention으로 표현하지만, Cross Attention이 좀 더 넓은 개념으로 사용된다. (멀티모달 모델 등에서 다른 모델의 출력 시퀀스와 현재 시퀀스 간의 관계성을 정의하는 경우에도 cross attention이 적용된다고 표기하지만, 엄밀히 말해 encoder decoder 관계는 아니라 그렇다)
Attention이란 결국 주어진 시퀀스 내에서 상대적인 관계를 통해 각 요소들의 특징을 규정하는 것이다. 예컨데, '삶은 계란은 맛있다', '닭은 계란을 낳는다', '계란이 먼저일까, 닭이 먼저일까?' 등등 '계란'이라는 단어가 등장하는 수많은 문장들 속에서 '계란'이라는 게 무엇인지 이해하는 방식이다. 이를 통해 계란과 닭의 상관관계나, 선후관계, 좀 더 나아가서는 철학적인 개념까지도 이해하게 되는 것이다. 당연히 이 과정을 위해서는 각 단어들의 뜻을 충분히 이해할 수 있도록 충분히 많은 데이터가 필요하며, 일반적으로 Transformer가 적용된 알고리즘을 사용하려면 최소 수십 만 단위의 데이터셋을 구축해야 한다.
Cross Attention은 이러한 Attention을 변형한 개념으로, 예컨데 한글을 영어로 번역하는 인공지능 번역기를 만들고 싶다면 한글 문장 '닭은 계란을 낳는다'와 영어 문장 'Chickens lay eggs'의 관계성까지 고려해야 한다. 이를 위해 일반적으로 Encoder에 한글 문장을 넣어 출력한 값을 Decoder로 넘겨, Decoder의 영어 문장과의 관계성을 추가적으로 고려해야 하는 것이다. 이러한 Cross Attention은 Encoder와 Decoder 간의 관계성을 표현하는 근거가 되며, 이를 통해 번역이나 요약, 생성 등의 task에 적용된다.
cross attention의 코드 자체는 특별할 것이 없다. 일반적인 Attention Mechanism에서는 Encoder에 사용된 Attention이라면 Encoder에서 전달받은 값만을 사용하고, Decoder쪽에서 사용된 Attention이라면 Decoder에서 전달받은 값만을 사용하는 것이 특징이라면, Query는 본래의 프로세스(Decoder)에서, Key, Value에 해당하는 값은 다른 프로세스(Encoder)에서 가져온다는 차이가 있을 뿐이다.
이를 코드로 살펴보면 아래와 같다.
이해가 쉽도록 msa 등을 제외하고, 최대한 간단하게 구현했다.
잘 보일지 모르겠지만 15~17번째 줄이 cross attention이 일반적인 attention과 다른 부분이며, 나머지는 일반적인 attention과 동일하다.
Transformer 알고리즘은 현재까지도 가장 각광받는 알고리즘이다. 다양한 태스크나 도메인에 적용가능하며, 성능 또한 뛰어나다. 간혹 일부 태스크에 대해서는 Transformer를 사용하지 않는 방식이 채택되거나 추천되기도 하지만, 결국 AGI라는 인공지능의 목표를 고려할 때 Transformer를 완전히 포기할 순 없는 노릇.
특히 최근 LM(Large Model) 혹은 FM(Foundation Model) 등이 대두되면서, 간단한 Fine Tuning 방법들이 제시되고 있으며, 그 중 Attention Layer의 가중치 행렬만을 업데이트하는 방식인 Lora(Low-Rank Adaptation)와 같은 방법들이 주목받고 있음을 고려하면 Transformer의 중요성은 굳이 논할 필요가 없을 정도로 기본적인 사항이 되었다 생각한다.
본문에서 작성한 코드가 궁금하다면 저자의 github를 방문하거나, 여기를 클릭하면 된다.
[1] Ashish Vaswani, Noam Shazeer, Niki Parmer, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, and Illia Polosukhin. Attention Is All You Need. https://arxiv.org/pdf/1706.03762.pdf
[2] pytorch documentation, https://pytorch.org/docs/stable/index.html
2023.11.21 | 초고 작성 및 Attention, MSA 코드 작성
2023.11.22 | 1차 검수 및 PE, Cross Attention 코드 작성, 코드에 대한 링크 추가
2024.03.04 | MSA 계산복잡도 내용 추가