논문 톺아보기 17
이번에 톺아볼 논문은 바로 'Hymba: A Hybrid-head Architecture for Small Language Models'[1]입니다. 본 논문은 2023년 제시된 'Efficient Streaming Language Models with Attention Sinks'[2] 논문에서 제시된 attention sink 문제를 개선하기 위한 방법으로써 제시됐습니다.
이번 포스팅에선 Hymba를 구성하는 고유 구성 요소들을 단계적으로 구현하고(llama block과 mamba는 이전에 구현했던 내용을 그대로 사용합니다. llama block에 대한 내용은 Differential Transformer에 대한 포스팅을, mamba block에 대한 내용은 Mamba에 대한 포스팅을 각각 참조하세요), 이전에 했던 것과 동일하게 작은 셰익스피어 희곡 데이터셋에 학습시켜 성능을 확인해보겠습니다.
본 논문[1]은 마치 이전의 ConvNeXt 논문 때와 같이 꽤나 친절하게 쓰여진 편에 속하는 논문인데요, 특히 아래와 같이 Hymba 구조를 step-by-step으로 구현하는 과정 및 그 과정으로 인해 얻을 수 있는 효과를 잘 정리해두었습니다.
뿐만 아니라 본 모델을 학습 시킬 때 적용한 Training Process도 아래와 같이 공개되어 있습니다. (일반적으로 사용되는 방법론 - 대규모 비지도 학습 > 지도 학습 > 정제 지도 학습 > 강화 학습 - 이지만, 그 데이터의 규모나 특징이 잘 정리돼 있어 유용합니다)
본 리뷰에서는 위와 같은 수준의 학습 단계를 직접 실행해보긴 현실적으로 어려우므로, 모델 구조를 온전히 구현하고, 각 구성 요소별로 어떤 효과를 얻을 수 있는지 확인하는 정도로 마치도록 하겠습니다.
우선 처음으로 살펴볼 것은 왜 해당 구조가 제시되었나 하는 것입니다. 위의 서론 부분에서 언급한 바와 같이 본 논문의 Hymba 구조는 attention sink 문제를 해결하기 위해 제시되었습니다. attention sink란 attention이 적용되는 모델들에게 일반적으로 적용되는 causal mask로 인해 가장 많이 참조되는 토큰, 즉 제일 앞부분의 토큰들에 대해 거의 대부분의 attention score를 할당하는 문제입니다. 이를 그림으로 살펴보면 다음과 같습니다.
참고로 이러한 문제는 합이 항상 1이 된다는 softmax의 특징과 attention이 가진 causal mask의 특성상 앞부분의 토큰은 항상 참조하도록 하며 가장 많이 노출되기 때문에 발생하는 문제이며, 이는 다음 그림과 같이 학습이 길어질수록 심화됩니다.
이러한 문제는 이전에 다뤘던 Differential Transformer에서 다뤘던 attention noise와도 유사한 현상입니다. 즉, 중요하지 않은 token에 대해 과분한 attention score를 할당함으로써 연산이 낭비되고, 성능 향상을 저해하는 문제가 발생한다는 것입니다. Differential Transformer에선 이를 attention noise canceling 기법을 통해 query와 key를 두 개로 나누고, 한쪽에서 다른 한쪽을 뺌으로써 noise를 최소화하는 전략을 적용했었습니다.
본 논문[1]에선 다른 전략을 채택합니다. 어차피 공통적으로 그리고 반복적으로 참조하는 토큰에 대해 과도하게 많은 attention score가 할당된다면 이곳에 의미있는 토큰을 둬보는 건 어떨까? 하는 방식으로 접근하는 것이죠. 이를 Meta Token이라고 합니다. learnable한 parameter를 생성해 전체 시퀀스의 앞부분에 붙여서 모든 attention 과정에서 반드시 참조하게 만드는 식입니다.
하지만 이럴 경우 발생하는 문제가 있습니다. 우선 첫 번째로 단순히 앞부분에 learnable token을 만들어 prepend하는 방식만으로는 시퀀스 앞부분의 토큰에 대한 의존성을 떨어뜨리는데 한계가 있다는 것이며, 시퀀스 길이에 quadratic하게 비례해 연산량이 늘어나는 transformer의 특징으로 인해 연산량 부하 문제가 심화된다는 것입니다.
그래서 이를 해결하기 위해 처음과 중간, 끝 부분의 레이어에 대해서만 global attention을 적용하고 그 외의 레이어에 대해서는 local attention을 적용하는 hybrid attention 방식을 채택하며, 단순히 transformer만을 활용하는게 아니라 mamba까지 함께 병렬로 사용함으로써 연산량 부하를 최소화시키려 합니다. 자세한 사항은 이후에 하나씩 구현해가며 다뤄보도록 하겠습니다.
우선 전체적인 구조를 살펴보고 가는 것이 좋습니다.
가장 핵심은 당연히 위 그림에서도 볼 수 있듯 SSM과 Attention을 병렬로 연결해 Hybrid Head를 구성한다는 것입니다. 이전에도 Zamba[3]와 같이 mamba와 transformer 구조를 혼합해 사용하는 방식은 존재했습니다. 하지만 기존에는 두 개의 모델을 직렬로 연결함으로써, mamba가 가진 고질적 문제인 정보의 압축 과정에서 발생하는 정보 손실 문제가 transformer로 이어지는 문제가 있었습니다.
이로 인해 두 구조가 가진 단점은 해소하면서도 장점을 결합하기 위해 병렬 구조가 더 근본적인 문제를 해결할 수 있는 구조로 제안되었습니다. 이를 통해 Mamba의 전체 시퀀스에 대한 흐름과 방향성을 빠르게 파악하는 능력과, Transformer의 느리지만 특정 정보에 대한 정밀한 작업 능력을 합쳐 성능은 향상시키면서도 실행 속도는 그 중간에 위치하게끔 한다는 전략적 구조가 탄생한 것입니다.
그리고 추가적으로 Meta Token의 추가 등으로 인해 더욱 심화된 Transformer의 고질적인 문제를 해결하기 위해 크게 세 가지 대안이 추가되기도 하는데요, 각각 'Meta Token', ' Hybrid Attention', 그리고 'KV Cache sharing'입니다. 이를 단계적으로 설명한 후, 코드로 구현해 그 성능을 살펴보도록 하겠습니다.
처음으로 살펴볼 것은 바로 Meta Token입니다. Streamling LLM[2]과 같은 논문 및 기타 연구들에서 Transformer 기반 모델들이 모든 인덱스에서 반복적으로 참조하는 제일 앞의 토큰에 큰 attention score를 분배하는 문제, 즉 attention sink를 확인했습니다. 어느 정도냐면, 시퀀스의 앞에 추가되는 BOS 토큰('Begin of Sentence', 즉 시퀀스의 시작 부분을 의미하는 무의미한 토큰)에 과반의 attention score가 집중되는 문제가 있었습니다.
즉, 큰 의미가 없는 토큰에 attention score를 낭비하는 문제가 발생하고 있고 이는 모델의 성능이 온전하게 발휘되지 못하는 문제를 야기하는데요, 아래 그림에서 확인할 수 있듯 현대 Transformer 구조라 칭해지는 Llama에서는 56%가, 2024년 기준 가장 일반적으로 사용되는 Foundation Model Structure 중 하나인 Jamba(attention과 mamba를 직렬로 연결한 hybrid 구조)에서의 attention에서는 무려 62%의 attention score가 할당된 것을 아래 그림을 통해 확인할 수 있습니다.
이전에 리뷰했던 Differential Transformer에서는 이러한 문제 - 의미없는 토큰에 attention score를 할당하는 Attention Noise 문제 - 를 해결하기 위해 Differential Attention이라는 차등주의 어텐션을 통해 무의미한 attention을 제거했는데요. Hymba에서는 조금 다르게 이 문제를 해결합니다. 바로 Meta Token이라는 전체 시퀀스에 영향을 주는 토큰을 임의로 제일 앞 시퀀스에 추가하는 것입니다. 이를 그림으로 나타내면 아래와 같습니다.
위 그림에서 볼 수 있듯이, 세 개의 메타 토큰이 <BOS> 토큰 앞에 추가되었습니다. 즉, 제일 앞에 있다는 이유만으로 <BOS> 토큰에 집중되던 attention score를 메타 토큰에 집중시킬 수 있게 되었는데요. 이는 몇 가지 장점이 있습니다.
자연어 모델이라면 단순 언어능력뿐 아니라 이를 기반으로 하는 특정 도메인(ex. 법률, 의학, 수학, 과학, 코드 등) 별로 상이한 입출력 스타일 및 결과를 요구받게 됩니다. 이러한 문제는 단순히 자연어 모델뿐 아니라 컴퓨터 비전이나 시계열, 이상탐지와 같은 영역에도 해당할 수 있는 문제인데요, 이때 본 논문에서는 각기 다른 도메인에서 프람프트가 입력됐을 때 각기 다른 메타토큰들이 활성화되는 현상이 관찰되었으며 이는 메타토큰들이 각 도메인 별 정보 등을 캡슐화한 것으로 보인다라고 말합니다. 이를 그림으로 살펴보면 다음과 같습니다.
즉, 제일 앞에 있는 무의미한 토큰에 집중되던 attention score를 활용하기 위한 발상의 전환으로, 제일 앞에 의미있는 토큰을 두는 방식을 채택함으로써 단순히 일반적인 성능을 향상시킬뿐 아니라 여러 도메인에 걸쳐 특화될 수 있는 능력까지도 향상시킨 것입니다. 다만 이러한 전략적 선택은 한 가지 큰 문제를 발생시키는데요. 바로 시퀀스의 길이가 길어짐에 따라 가뜩이나 고질적인 문제로 지적받던 Transformer의 시퀀스 길이(n)에 quadratic하게 증가하던 연산 복잡도를 (n+len_meta_token)에 quadratic하게 증가시키는 문제를 발생시킨 것입니다.
때문에 다음의 두 가지 방법이 더 적용되었습니다.
첫 번째로 적용된 방식은 바로 기존의 Global Attention(시퀀스의 모든 토큰들과의 attention) 방식과 Local Attention(Sliding Window Attention, 특정 길이의 윈도우 길이 내의 토큰들과만 attention) 방식을 혼합해 사용하는 방식입니다.
이를 통해 attention의 대상이 되는 sequence 길이를 meta token의 길이와 sliding window의 크기를 더한 값 이하로 하도록 함으로써 meta token의 추가로 인한 연산량 증가를 줄이는 전략입니다. 이를 시각화하면 다음과 같습니다.
위 그림은 두 개의 메타토큰이 추가된, window size = 3의 Local Attention을 시각화한 그림입니다. 위 그림에서 알 수 있듯 모든 토큰에 대해서 보라색으로 표시된 메타토큰이 attention score 계산 대상임을 확인할 수 있으며, 현재 토큰 인덱스 이전의 것들 중 window size의 범위 내에서 attention이 이뤄지고 있는 것을 확인할 수 있습니다.
이를 통해 기존의 첫 번째 토큰이었을 idx=3 토큰은 메타토큰과 자기 자신까지 3개의 토큰에 대해 attention score를 계산하게 되며, idx=4 토큰은 메타토큰 2개와 자기 자신 및 그 이전 토큰의 2개를 합쳐 4개 토큰에 대해, 그 다음은 최대 윈도우 크기인 3이 적용돼어 5개 토큰에 대해 attention score를 계산하게 되는 식입니다.
당연히 이러한 attention 방식은 정보 손실을 야기할 수 있으며, 이로 인해 성능의 하락을 피할 수 없는데요. 때문에 본 논문[1]에서는 모든 레이어에 SWA를 쓰는 게 아니라, 제일 처음과 중간, 그리고 마지막에는 Global Attention을 적용해 정보 손실을 막거나 복원하도록 해 성능 하락은 최소화하거나 없애면서도 실행 속도는 빠르게 만들 수 있었다고 합니다.
다음으로는 KV Cache Sharing(정확히는 'Cross-layer KV Cache Sharing'이 본 논문에서의 명칭)입니다. 이는 간단히 말해 Local Attention이 적용되는 모든 레이어마다 별도의 KV Cache를 생성해 활용하는 기존의 방식 대신 두 개 이상의 레이어마다 하나의 Cache를 활용하는 방법입니다. 이를 시각화해 나타내면 다음과 같습니다.
위 그림에서 볼 수 있다시피 처음과 중간, 마지막에는 global attention이 적용되고, 그 사이사이에 local attention이 적용됩니다. 위에서는 2개의 local attention 레이어를 하나로 묶어 사용하는 방식을 시각화한 것이며, 따라서 layer1이 owner 혹은 producer로서 kv cache를 만들어 저장하고, 이를 그 다음 레이어와 공유하는 식으로 반복됩니다. 기존에는 11개의 kv cache를 저장해야 했다면, 이제는 global attention 3개와 local attention 4개 총 7개의 kv cache만을 저장하는 식으로 줄어들게 됩니다. 이러한 방법론은 당연하게도 layer가 길어질수록 유리하게 되며, 하나로 묶는 local attention이 많아질수록 연산은 효율적으로 이뤄지게 됩니다(다만, 너무 많은 local attention을 한 번에 묶으면 속도는 빨라져도 성능 자체는 떨어질 수 있습니다). 이를 그림으로 살펴보면 다음과 같습니다.
이에 대한 근거는 인접한 레이어 간의 KV Cache가 거의 유사했다는 기존의 실험 결과에 기반합니다 (본 논문과는 관련성이 떨어지지만 이와 관련해서 최근 여러 모델 간의 소통을 이러한 중간 cache를 통해 정보를 주고 받는 방식에 대한 연구가 꽤 활발하게 이뤄지고 있으며, 효과가 어느정도 입증되기도 했습니다). 개인적으로 추측컨데, 이는 residual connection으로 인한 레이어 간 유사성이라는 특징으로 인한 것으로 보여집니다. 잔차 연결이라는 방법은 딥러닝 모델이 그 이름답게 모델이 깊고, 복잡해질수록 성능이 향상되는 것을 담보하는 역할을 했지만, 반대로 각 레이어들이 이전 레이어와 아주 미세한 차이만을 만들어내도록 학습되는 문제가 발생하기도 했습니다. (이로 인해 2026년 현재에는 이러한 residual connection이 딥러닝 모델들의 성능을 비효율화한다는 주장이 제기되었으며, residual connection을 대체하기 위한 구조가 제안되고 있기도 합니다)
즉, 모델의 깊이와 크기에 비해 그 성능을 온전히 내고 있지 못한 것이기도 한 것입니다. 하지만 그렇다고 잔차 연결이라는 방식을 무턱대고 제거하기엔 작금의 딥러닝 모델 구조에서 오는 이점을 제대로 살리지 못하는 문제가 있어, 아직은 실험적으로 일부 잔차 연결을 제거하거나 학습 레시피 상의 후반 단계에서 일부 연결 제거, 띄엄띄엄 적용하는 등의 실험이 이뤄지고 있는듯 합니다.
여튼 다시 본래의 주제로 돌아와서 본 논문[1]에서는 이러한 점을 고려해 기존의 재귀적으로 수행되는 next token prediction의 과정에 필수적으로 자리잡은 KV Cache를 굳이 모든 레이어마다 저장하지 않고, 인접 레이어들끼리 공유하는 전략을 취함으로써 KV Cache를 만들고, 저장하는 과정을 최적화하였습니다. 이는 아래 표를 통해서 확인할 수 있습니다.
위 표에서 볼 수 있듯이 KV Cache Sharing을 적용함으로써 처리 속도 및 캐시 사이즈는 줄어든 반면, 성능 하락은 뚜렷하지 않습니다(정확히는 어느정도 trade off가 있는 것 같지만, meta token을 통해 무시할 수 있는 수준이 되는듯 합니다).
우선 hymba의 transformer 부분을 먼저 구현해야 합니다. 이때 다음의 사항을 고려해야 합니다.
1. 학습가능한 meta token을 생성하고, prepend
2. global attention과 local attention을 각각 구현하고, 이를 위한 mask 생성3. kv cache sharing을 위한 producer와 consumer 레이어 구분 및 적용
이를 하나의 코드로 살펴보면 다음과 같습니다. 다소 길지만 하나하나 읽다보면 그리 어렵진 않습니다. 참고로 모델을 구현하고 학습시킨 코드 일체는 여기를 통해 확인할 수 있으며, 본 포스팅에 쓰인 코드 및 시각화/정리 등만을 확인하길 원하신다면 해당 경로의 5번 노트북을 확인하시면 됩니다.
사실 공식 코드에선 flex attention 모듈을 이용해 마스크 생성 및 local attention을 최적화하고, 더 간단히 작성해 구현하지만 여기선 직관적인 구현과 검토를 위해 직접 코드를 작성해 구현했습니다. 이렇게 구현한 attention의 마스크 영역을 시각화하면 다음과 같습니다.
참고로 실제 학습시킨 모델의 attention map을 시각화하면 다음과 같습니다.
여기에 이전에 구현했던 mamba 블록 등을 결합해서 모델을 학습시켰고, 학습시킨 모델을 불러와 global attention과 local attention의 attention score 분포를 살펴보면 다음과 같습니다. (모델의 구현과 학습 코드는 여기를 통해 확인할 수 있습니다. Mamba의 구현과 결합 등은 그리 어렵지 않아, 본 포스팅에선 따로 다루지 않겠습니다. - 사실 엄밀히 말해 해당 hymba 모델로 기존의 다른 모델들보다 뚜렷하게 높은 성능을 확인하지 못했는데요, 아무래도 간단한 데이터셋과 학습 레시피로는 충분히 성능 측정을 하기 어려운듯 합니다. 본 포스팅에선 hymba의 특징을 제대로 드러냈는지에 초점을 맞춰 진행하겠습니다)
위 그림에서 주목할 부분은 meta token 이후의 token들입니다. 기본적으로 모든 토큰에 걸쳐 meta token에 대한 할당이 높게 이뤄지고 있으나 이는 기존의 과반에 해당하던 수준보다는 낮습니다(local attention으로 인해 희석). 즉 일반 토큰에 대한 attention 할당이 자연스레 높아지며, 기존의 attention sink 문제를 완화함과 동시에 meta token에 유의미한 정보를 저장하고 활용하는 knowledge capsule화를 통해 단일 모델을 통해 다양한 태스크와 도메인에 대해서도 처리할 수 있게 되는 것입니다.
Hymba는 결국 가장 많이 반복적으로 살펴보게 되는 시퀀스의 앞부분에 위치한 토큰에 대해 무의미하게 많은 attention score가 할당되는 attention sink 문제를 해결하기 위한 것이며, 이는 항상 모든 합이 1이 되는 softmax와 causal mask가 복합적으로 작용한 결과로 보이며, 조금 낭만적으로 표현하면 자주 보는 것에 관심이 가고(more attention), 결국 사랑에 빠져버렸다(attention sink) 정도로 풀어낼 수 있습니다.
이를 해결하기 위해 본 논문[1]에선 meta token을 앞부분에 배치함으로써 이러한 attention sink를 오히려 유용하게 사용할 방법을 제시하였으며, 이 과정에서 발생하는 연산량 증가를 hybrid attention 및 Transformer와 Mamba를 병렬로 연결한 hybrid architecture, kv-cache sharing이라는 방법을 통해 효율화함과 동시에 성능 향상도 노리는 전략을 구현한 것입니다.
다만 직접 구현해본 결과, 이는 일정 규모 이상의 데이터셋과 다양한 태스크, 여러 단계에 걸친 학습 레시피를 통해 논문에서 제시된 성능을 검증해볼 수 있을 것 같습니다.
[1] Xin Dong, et al. “Hymba: A Hybrid-head Architecture for Small Language Models”. https://arxiv.org/abs/2411.13676.
[2] Guangxuan Xiao, et al. "Efficient Streaming Language Models with Attention Sinks". https://arxiv.org/abs/2309.17453.
[3] Paolo Glorioso, et al. "Zamba: A Compact 7B SSM Hybrid Model". https://arxiv.org/abs/2405.16712.