논문 톺아보기 15
이번에 톺아볼 논문은 바로 'Mamba: Linear-Time Sequence Modeling with Selective State Spaces'입니다. 거의 전부라고 해도 좋을 정도로, 모델 구조에 대해 다룬 논문은 보다 큰 입출력 사이즈를 다루기 위해 제시됩니다. 이로 인해 Transformer 구조를 기반으로 하는 대다수의 모델들은 어떻게 하면 attention 연산이 가진, 시퀀스 길이에 quadratic하게 비례하여 증가하는 연산량 문제를 해소하고자 합니다. 하지만 종종 다른 관점을 제시하는 논문들도 있습니다. 바로 Transformer의 대안이 되는 구조를 통해 시퀀스 길이에 선형적으로 비례하는 수준의 구조를 제시하는 것입니다.
하지만 이러한 구조들은 지난 번 포스팅의 GRU와 같은 모델을 통해 볼 수 있었듯이, train loss와 val loss 간의 차이가 조기에 벌어지는 과적합 현상이 벌어집니다. 학습시키는 데이터셋이 실존하는 모든 케이스를 포함한 것이 아니라면, attention 연산을 통해 각 토큰 간의 상대적 관계를 파악해 해당 task 자체를 이해함으로써 보여주는 Transformer 기반 모델에 비해 뒤쳐질 수밖에 없는 것입니다.
이를 해결하기 위해 기존의 빠르고 가볍게(컴퓨팅 리소스적으로) 학습되는 RNN 계열의 모델들(RNN, LSTM, GRU 등)에 대한 대안으로 SSM이라는 방법이 제시되었습니다. SSM이란 State Space Model의 약자로 기존의 Recurrent 모델들이 암묵적으로 과거 토큰을 반영하는 과정을 명시적으로 만들어 토큰 간의 거리 제약을 줄이는 전략이었습니다. 하지만 이러한 SSM에도 여러 문제가 존재했고, 여러 단계에 걸쳐 발전해왔으며, 결국 Mamba에 이르렀습니다.
오늘 포스팅은 이러한 과정을 살펴보고, 각 과정에 담긴 개념을 이해한 뒤, 최종적으로 Mamba 모델을 구현해서 지난 포스팅에서 했던 것과 같은 식으로 GRU 및 현대적 트랜스포머 구조라 할 수 있는 Llama 모델과 비교해보겠습니다.
Mamba를 이해하기 위해 먼저 SSM이 무엇이고, 어떻게 발전해왔는지를 살펴보아야 합니다. 우선 RNN은 왜 근본적인 대안이 될 수 없었을까요? RNN의 구조를 그림으로 살펴보면 다음과 같습니다.
위 그림에서 볼 수 있듯, RNN 계열의 모델들은 그 이름(Recurrent Neural Network)처럼 하나의 구조를 재귀적으로 반복사용합니다. 이를 통해 모델의 사이즈가 작고, 빠르게 실행될 수 있지만 다음 토큰을 에측하는 과정에서 현재와 시점이 먼, 그러나 고려되어야 하는 토큰들과의 관계를 온전히 반영하기 어렵습니다. 예컨데 위 그림의 오른쪽에서 t+1 시점의 y를 예측할 때에는 t-1 시점의 x가 t시점의 x보다 온전히 고려되지 않을 것입니다. 비록 Hidden State의 형태로 어느정도 내포되긴 하겠지만, 그조차도 시퀀스의 길이가 길어지면 희석되다못해 사라지곤 합니다. 이를 그림으로 살펴보면 다음과 같습니다.
위 그림에서 마지막의 Maarten 토큰을 예측할 때에는, 처음의 Hello 토큰이 여전히 반영되지 않는 모습을 볼 수 있습니다. 즉, RNN은 일정 길이 이상의 시퀀스를 다루는 문제에는 적합하지 않았습니다.
이러한 문제를 해결하기 위해 나온 것이 SSM입니다. 이는 마치 이전 포스팅의 KV Cache처럼 재귀적으로 모델이 동작하는 가운데 이전 입력들의 정보를 압축해서 가지고 있다가 출력에 반영하는 방식입니다. 이를 수식으로 살펴보면 다음과 같습니다.
위 그림의 수식에서 h는 이전의 입력들이 고려된 latent space를 의미하며, 이를 고려해 아래의 t시점의 y를 예측하게 됨을 의미합니다. 또한 위 수식에서 A는 이전의 잠재 공간(latent space)이 현재의 잠재 공간으로 어떻게 반영되는지를, B는 현재의 입력에 대한 가중치를, C는 현재의 잠재공간이 출력에 어떻게 영향을 미치는지를 의미하는 학습가능한 파라미터입니다(이후엔 출력에 현재 입력을 고려하기 위한 D가 추가로 등장하며, 일종의 residual connection과 같이 feedthrough matrix처럼 작동합니다).
이때 잠재적 공간의 크기를 N, 모델의 차원을 D라고 할 때, 위 수식을 구조적으로 살펴보면 다음과 같이 시각화할 수 있습니다.
위와 같은 구조로 구현된 초기의 SSM은 이를 통해 시퀀스 길이가 길어져도 과거 토큰을 보다 잘 고려할 수 있게 되었지만, 추가적인 문제가 발생했습니다. 다름 아닌 위 과정을 수행하느라 RNN의 장점인 보다 빠르고, 가볍게 실행된다는 장점을 잃어버린 것이었습니다.
간단히 말해, 위 수식을 기반으로 모델을 구현하면 심지어 Transformer 구조보다도 많은 메모리와 긴 실행시간이 필요해지는 문제가 발생했으며, step이 반복되는 과정에서 A나 C와 같은 잠재적 공간을 고려하기 위한 파라미터들이 온전히 학습되지 않으며 그래디언트가 소실되는 문제가 발생하기도 했습니다.
학습 리소스가 더많이 필요해지는 문제는 현실적으로 비용을 높이면 해결되는 문제지만, 아예 그래디언트가 소실되며 학습 자체가 되지 않는 문제는 치명적입니다. 이를 해결하기 위해 이산화(Discretize)시키는 전략이 제안되었는데요. 이를 그림으로 살펴보면 다음과 같습니다.
이러한 이산화를 위해 선결되어야 하는 것은 하나의 시퀀스를 어떤 간격으로 자를 것인지를 결정하는 것이었습니다. 이렇게 결정된 시간간격을 Δ(delta)로 정의하며, 위 수식의 A와 B를 연산할 때 이를 고려하게 됩니다. 이를 통해 본래 이전 상태에 의존적이었기 때문에 순환적으로 연산되어야 했던 문제를, 전체 시퀀스에 대한 출력을 하나의 긴 convolution으로 표현이 가능해집니다. 이를 수식으로 간단히 풀어보면 다음과 같습니다.
우선 간단히 timestep 0에서 2까지를 풀어 상태공간 h와 출력 t를 풀어쓰면 다음과 같습니다.
위 연산 과정에서 우리는 k시점의 y는 현재의 입력들의 가중합으로 표현되며, 이때의 가중치는 시간간격을 고려한 A와 B로 표현할 수 있음을 알 수 있는데요. 정리하면 다음과 같습니다.
즉, k시점의 y를 convolution 연산으로 표현할 수 있음을 확인할 수 있습니다. 이를 수식이 아닌 단순한 그림으로 나타내면 이해가 쉬울 겁니다. 기존의 recurrent하게 처리되는 방식과 convolution으로 처리하게 되는 과정을 그림으로 간단히 나타내면 다음과 같습니다.
위 그림이 설명하는 바는 10개 길이의 시퀀스를 통해 다음 시퀀스를 하나 예측하는 과정을 묘사한 것입니다. 왼쪽의 재귀적 방식을로는 100번째 토큰을 예측하기 위해 이러한 과정을 90회 반복해야 하지만, 오른쪽의 convolution 방식을 통해 한 번에 허용된 사이즈만 크다면 한 번에도 연산해낼 수 있게 되는 것입니다(실제로는 여러 번에 나눠서 연산되며, 위 그림은 그저 이해를 돕기 위함입니다).
이를 보다 단순화한 수식으로 표현하면 다음과 같습니다.
여기까지의 연산을 통해 확인한 것은, 이산화 과정을 통해 더이상 recurrent하지 않게 문제를 풀 수 있게 된다는 것입니다. 이렇게 convolution 연산을 통해 기존에 재귀적으로 하나씩 수행하던 작업을 여러 개를 묶어서 한 번에 처리하는 식으로 수행가능해졌고, 이러한 과정은 재귀적으로 쌓이고 쌓이며 연산되던 gradient가 소실되는 문제를 완화했습니다.
추가적으로 convolution 연산에 적용될 수 있는 고속 푸리에 변환(FFT, Fast Fourier Transform)을 적용함으로써 기존의 O(L x State_Size²) 혹은 O(L x Dim²)의 복잡도를 가지던 연산을 O(L logL) 복잡도로 하드웨어에 따라 거의 선형에 가까운 복잡도로 연산이 가능해진다는 장점이 있습니다.
하지만 이때 정의되는 현재의 입력(u)에 대해 필터링하는 K에는 입력이 고려되지 않는다는 것이 문제인데요. 이는 간단히 말해 입력이 어떤 것이 되든, 해당 입력 시퀀스의 가운데나 양쪽 끝만 보도록 한다는 것입니다. 만약 잘 정제된 이미지 데이터(주요 대상이 반드시 가운데 있고, 해당 사진에 나온 대상이 누군지 분류하는 태스크에 사용되는)라면 이러한 방법이 잘 통할 수도 있으나, 그 외에 복잡한 상관관계를 주고받는 자연어나 신호처리에서는 원하는만큼의 성과를 얻을 수 없었습니다.
또한 convolution으로 처리하는 방식으로 최적화를 했음에도 여전히 막대한 수준의 연산량이 필요했습니다. 여전히 Transformer보다 몇 배는 많으면서도 성능은 부족했습니다. 추가적으로 convolution을 사용하는 방법조차 학습 과정에서는 이미 이후의 시퀀스에 대해 정답을 알고 있으니 상요할 수 있지만, 실제 현실에서 문제를 푸는 과정에서는 사용할 수 없었는데요. 즉, 단순히 학습 과정에서만 사용할 수 있는 반쪽짜리 성과에 불과했습니다.
이전 단계의 SSM이 가지고 있던 문제는 크게 세 가지였습니다. 하나는 입력 시퀀스가 무엇이든 고정된 필터(K)를 적용하게 되는 것이었고, 다른 하나는 convolution으로 대체했더라도 여전히 해결되지 않은 컴퓨팅 리소스 문제였고, 마지막 하나는 convolution 연산을 추론 시에는 사용할 수 없다는 것이었습니다. 새로 제시된 S4는 이 중에서 여전히 지나치게 많이 소모되는 컴퓨팅 자원을 줄이고자 하였습니다.
구체적으로 기존에는 상태 크기(N)을 고려해 A matrix는 NxN 사이즈의 Dense Matrix 형태였습니다. 이는 적절한 초기 시작지점을 찾거나 온전히 학습되는 게 어려워 종종 gradient 소실의 원인이 되곤 했습니다. 이러한 상황에서 시간이 지남에 따라 들어오는 연속적인 신호를 효과적으로 압축해 기억하는 방법에 대한 이론인 HIPPO에 대한 적용이 고려되었고, 이 적용을 통해 과거 정보를 다항식들의 조합을 통해 근사가능함이 확인되었습니다. 또한 이러한 다항식 계수들의 변화를 추적한 결과, 그 변화가 선형 상미분 방정식(Linear ODE)의 형태 h'(t) = A h(t) + B u(t)로 정확하게 표현됨을 발견하였다고 합니다.
그 결과로서 A와 B 행렬을 특정한 구조를 가진 형태로 유도할 수 있으며, 이때 유도된 A 행렬의 경우 대각행렬(Diagonal Matrix) 혹은 대각+저랭크 행렬(Diagonal Plus Low-Rank Matrix)이라는 특별한 형태를 가졌다고 합니다. 그렇다면 이러한 A 행렬을 구조화된 형태로 유도함으로써 얻는 이점은 무엇이 있을까요? 이를 다시 간단히 정리하며 살펴보겠습니다.
우선 S4에서는 '연속적인 신호를 효율적으로 압축하고 기억'할 수 있는 HIPPO 이론에 입각해 기존 Dense Matrix의 형태로 정의되던 A(과거 상태를 현재 상태에 반영하는 레이어)를 Diagonal Matrix로 정의합니다. 이를 수식으로 살펴보면, 기존 이산 시간 SSM에서는 A와 B를 다음과 같이 정의하였습니다.
time step Δ를 고려해 이전 상태 행렬을 반영하는 A와 현재 입력을 고려하는 B 행렬을 구현한 것입니다. 이를 HIPPO 이론에 입각한 대각행렬로 변환할 경우, 아래와 같이 전개될 수 있습니다.
즉, 구조화된 A 행렬을 통해 상태 공간을 업데이트하는 연산량을 대폭 줄일 수 있다는 장점을 가지게 되었으며, 이를 통해 보다 긴 sequence에 대해서도 도전해볼만해졌다는 것입니다.
하지만 그럼에도 우리는 아직 큰 문제를 남겨두고 있습니다. 바로 추론에 사용되기 힘든 convolution 구조와 입력 시퀀스와 무관하게 적용되는 필터 역할을 하는 K에 대한 것입니다. S6 혹은 Mamba로 불리는 이번 포스팅의 대상이 되는 구조에서는 이를 해결하고자, selection mechanism을 통해 입력 토큰의 중요도에 따라 관련 정보를 흘려내거나, 어떻게 업데이트될지 등을 제어하고자 합니다.
예컨데 S4까지의 진화 방식은 이전 정보를 최대한 효율적으로 압축해서 현재의 출력에 고려하기 위한 방법이었습니다. 하지만 이 과정에서 중요성이 높은 토큰과 낮은 토큰을 동일하게 처리하는 한계가 존재했고, 이는 convolution 연산을 통해 시퀀스 길이에 따른 연산 효율성을 담보하더라도 성능적인 측면에서 여전히 상당한 비효율성을 야기했습니다. 때문에 불필요한 정보는 잊고, 더 중요한 정보가 들어오면 기존 정보를 업데이트할 수 있는 선택적인 정보처리 방식이 필요해졌고, 이에 내용이 바로 Mamba인 것입니다.
이를 본 논문에서 제시된 수도 코드를 통해 확인해보면 다음과 같습니다.
위 수도 코드를 통해 확인할 수 있듯이, Mamba에서는 각 매트릭스를 입력에 의존적인 형태로 변화시킵니다. 구체적으로 Mamba에서는 이산 시간 스텝 Δ를 입력 의존적으로 업데이트하며, 이는 이산화된 상태 전이 행렬 Ā(기존 상태 h_{t-1}이 다음 상태 h_t에 미치는 영향을 조절)와 입력 행렬 B(현재 입력 x_t가 상태 h_t에 미치는 영향을 조절)에 영향을 줍니다. 또한, 출력 행렬 C(상태 h_t가 현재 출력 y_t에 미치는 영향 조절) 역시 직접적으로 입력에 의존하여 업데이트됩니다.
결과적으로 기존의 시불변(LTI, Linear Time-Invariant) 시스템이 아닌 선형 시변(LTV, Linear Time-Varying) 시스템이 되게 되며, 이를 통해 각 타임스텝마다 이산화된 Ā와 B, 그리고 C까지도 달라지게 하는 식입니다. 조금 더 구체적으로 입력 토큰 t시점의 입력을 x_t라고 할 때, 이 입력에 의존해 상태 공간 모델의 핵심 파라미터라 할 수 있는 Δ(discretization step size)와 B, C 등을 '동적으로' 계산합니다. 일반적으로 작은 선형 레이어를 통해 타임스텝에 사용할 Δ_t, B_t, C_t를 '실시간으로' 생성하는 방식입니다.
하지만 이러한 LTV 방식을 도입함으로써 각 타임스텝별로 파라미터(Δ_t, B_t, C_t)가 달라져, 컨볼루션 연산의 기본 가정인 시불변성(Time-Invariance)이 깨지므로 더 이상 컨볼루션 연산을 사용할 수 없게 됩니다. 이는 원칙적으로 각 타임스텝의 상태를 순차적으로 계산해야 하는 재귀 방식으로 돌아가는 것을 의미하며, 초기 SSM 모델에서 대두되었던 시퀀스 길이에 따른 연산 비효율성 문제를 다시 야기할 수 있습니다.
이를 해결하기 위해 Mamba에서는 Scan 알고리즘을 사용합니다. Scan 알고리즘은 (a * b) * c = a * (b * c)
와 같은 연관 법칙(associative property)이 성립하는 연산에 대해 계산 순서를 바꿔도 결과가 같은 성질을 이용합니다. 이 성질을 활용하여 상태 공간 모델의 순차적인 상태 전이(state transition) 과정을 효율적으로 계산할 수 있습니다. 구체적으로 하나의 긴 시퀀스를 여러 개의 작은 chunk로 나눠 chunk 내부의 계산을 병렬적으로 수행합니다. 이렇게 각 chunk에서 병렬적으로 계산된 중간 결과들을 다시 효율적으로 결합해 최종 결과를 얻는 식이며, 이는 Parallel Prefix Sum과 유사한 원리를 적용합니다.
이렇게 병렬 스캔으로 시퀀스를 여러 chunk로 잘라 효율적으로 수행하는 과정에서 추가적으로 하드웨어 최적화가 요구됩니다. 병렬 스캔 알고리즘 (고전적인 예로 Blelloch 알고리즘 등이 있음)을 활용하고 GPU 아키텍처(SRAM/HBM)에 맞게 최적화할 경우, 이론적으로 로그 시간 복잡도(병렬 깊이)로 수행될 수 있으며, GPU의 많은 코어를 활용하면 전체 작업량은 시퀀스 길이에 대해 선형 복잡도에 가깝게 수행될 수 있다고 합니다. 사실 위 부분과 관련해 본 논문[1]을 읽고, 이해하고, 구현하는데 참 많은 애를 먹었는데요. 결과적으로 CUDA 및 triton을 이용한 기법과 관련 알고리즘에 대해서는 이해하지는 못했습니다. 만일 이에 대해 궁금하다면 이 포스팅을 읽어보시는 것을 추천합니다.
다시 본래의 주제로 돌아와서, Mamba에서는 선택적 Scan 알고리즘을 통해 LTV라는 동적 시스템을 도입했습니다. LTV 시스템은 컨볼루션 연산을 사용할 수 없지만, Scan 알고리즘 덕분에 효율적인 계산이 가능합니다. 이 Scan 방식은 학습과 추론 모두에 동일하게 적용됩니다. 결과적으로 Mamba는 (LTI 시스템의 한계인) 어떤 시퀀스가 들어오든 동일하게 처리하는 문제를 해결하고 입력 내용에 따라 동적으로 반응하는 Content-Awareness를 확보했습니다. 하지만 이러한 방법에도 단점이 있는데요, 바로 구현이 복잡하다는 것입니다.
예컨데 Mamba를 구현하기 위해 적용되는 Scan 알고리즘은 필연적으로 하드웨어에 따른 최적화를 요구하게 됩니다. 그렇지 않을 경우, LTV의 표현력은 얻을 수 있지만 Transformer 대비 실행 시간 등에서 경쟁 우위를 확보하기 어려울 수 있습니다. 때문에 공식 코드에서는 이를 triton 등을 이용해 해결했습니다. 이러한 점에 대해서도 아래의 코드 구현 부분에서 직접 구현한 Mamba와 triton 구현체를 활용하는 경우를 비교해 살펴보도록 하겠습니다.
위에서 이론적으로 설명한 Mamba의 모델링 부분입니다. 이 부분에 대한 코드는 여기를 통해 확인할 수 있습니다. 우선 Mamba를 구성하는 Seletive SSM 구조를 그림으로 살펴보면 다음과 같습니다.
위 그림을 통해 먼저 직관적으로 알 수 있는 것은 입력 x_t를 통해 Δ_t, B_t, C_t가 업데이트된다는 것입니다. 두 번째로는 이전의 상태공간 h_{t-1}이 구조화된 대각행렬 A가 적용되고, 입력 x_t가 업데이트된 B_t 행렬을 통과해 최종적으로 h_t를 구성한다는 것이죠. 그리고 이렇게 구성된 h_t는 C_t 행렬을 통해 최종적으로 반환됩니다.
그리고 이때 각 행렬에 대한 수식은 다음과 같습니다.
위 수식을 기반으로 하나씩 구현해보겠습니다.
처음으로 구현할 부분은 이산 시간 간격 Δ에 직접적으로 영향을 받지 않는 A와 D를 정의하는 것입니다. 참고로 Mamba는 재귀적으로 작업이 수행되는 과정에서 여러 수치적 불안정성이 발생하고, 이를 제어하기 위한 다양한 방법이 적용되는데요. 구체적으로 각 변수의 최대최소 범위를 clamp로 제어하고, 학습 과정에서도 낮은 학습율과 스케줄러뿐 아니라 gradient clip 등을 통해 안정성을 높이고자 합니다. 이 때문에 낮은 수준의 정밀도로 연산이 진행되는 amp 등으로 인한 불안정성 문제도 심화되곤 하기 때문에, normalization 등의 과정에서는 좀 더 높은 수준의 정밀도(ex. FP32)로 진행해야 합니다.
이때 시스템 안정성을 위해 A는 음수의 범위로 초기화하고, log를 통해 연산한 뒤 추후에 -exp을 적용해 음수의 범위를 유지합니다. 이를 이해하기 위해선 A의 역할(이전 상태공간을 현재 상태 공간으로 업데이트)을 이해해야 합니다. A만을 고려한 관점에서, h(t) = h(0) x At입니다. 이때 A > 0 범위에서 연산되면, 재귀적으로 연산되는 과정에서 기하급수적으로 커져 발산하는 문제가 발생합니다.
하지만 반대로 A < 0 범위에서 연산되면 시간이 지날수록 0에 수렴하며 h(t)가 안정적으로 수행되는 것입니다. 또한 학습 과정에서 A를 음수로 제어하기 위해 log를 이용해 연산한 뒤, 추후에 exp 연산을 수행하면 반드시 양수의 범위로 제한되고 -1을 곱함으로써 음수의 범위로 유지되도록 유도합니다. 또한 D 행렬은 현재의 입력 x에 대해 학습가능한 행렬로 단순합니다. 이를 코드로 살펴보면 다음과 같습니다.
다음으로는 이렇게 구현한 A를 기반으로 B를 구하고, A와 B를 이산화시켜야 합니다. 이를 위해 discretization 메소드를 구현합니다. 이를 위해 참고해야하는 수식은 다음과 같습니다.
위 수식에서 볼 수 있다시피 이산화를 위해선 ΔA를 이용해 Ā(delta_A)와 delta_B를 연산합니다. 이를 코드로 구현하면 다음과 같습니다.
위까지의 과정들을 통해 delta_A와 delta_B, 그리고 D 행렬을 정의했다면, 이제는 이들을 이용해 위의 수식대로 구현할 차례입니다. 바로 코드로 살펴보면 다음과 같습니다.
위 코드에서 주목할 부분은 바로 연속 시간 시스템을 이산화시키기 위한 시간 간격 Δ(delta)입니다. selective mechanism으로 인해 시간간격은 일정하지 않을 수 있지만, 결코 음수로 갈수는 없습니다. 따라서 delta를 정의할 때, softplus 함수를 이용해 정의해주어야 하는 것입니다. 참고로 softplus 함수는 softplus(x) = log(1 + exp(x))와 같이 정의되는 함수입니다.
이렇게 Mamba를 구성하는 SSM(S6)를 구현했다면 전체 구조를 구현해 하나의 Mamba 모델로 구현해야 합니다. 구조에 대한 그림을 살펴보면 아래와 같습니다.
위 그림의 실질적인 구조는 이전 포스팅에서 살펴봤던 현대식 Transformer 구조와 같습니다. GLU 계열의 활성화 함수와 RMSNorm을 사용하고, 잔차합을 사용할 때 덧셈이 아닌 곱셈으로 처리하는 식입니다. 이때 Mamba만의 특이점이라면 SSM 모듈에 넣기 전에 Causal Convolution 층을 통과시킨다는 것인데요, 이는 재귀적으로 수행되는 작업 과정에서 미래 정보 유출을 막기 위해 적용된 방식입니다.
이를 코드로 살펴보면 다음과 같습니다. 디테일한 구조는 이전 포스팅에서와 크게 다르지 않습니다. Attention 대신 SSM이 적용되었고, 그 전에 Causal Conv가 적용된 차이가 있습니다.
이렇게 구현된 Mamba 모델을 학습하며 성능 차이를 확인해보겠습니다. 이전 포스팅에서와 거의 유사하게 next token prediction task를 수행할 것이고, 비교군은 대표적인 RNN 모델인 GRU와 현대적인 Transformer인 Llama, 그리고 직접 구현한 Mamba 모델이며, 모델의 파라미터는 20만 정도로 작게 했습니다. (이전 포스팅에서는 각 모델의 장단점을 뚜렷하게 보이기 위해 70만 이상의 파라미터를 사용했지만, 이번 포스팅에서는 다소 작은 길이와 사이즈의 데이터셋으로 테스트하기 위해 더 작은 모델을 이용함으로써, 긴 시퀀스에 대한 성능을 테스트하는 컨셉입니다)
이 부분에 대한 코드는 크게 네 개로 구성돼있습니다.
첫 번째 코드는 여기에서 확인할 수 있으며, Mamba의 delta_B를 구하는 과정에서 ZOH의 공식 B̄ = (exp(ΔA) - I)A⁻¹B 에서 exp(ΔA)를 테일러 급수의 첫 두 항 I + ΔA로 근사할 때 유도된 간단한 1차 근사 공식을 적용한 것입니다. 이는 간단하게 동작할 수 있으나 delta 값이 클 경우 성능이 하락할 수 있습니다.
두 번째 코드는 여기에서 확인할 수 있으며, 정석적으로 ZOH의 근사공식을 사용한 경우입니다. 세 번째 코드는 여기에서 확인할 수 있으며, 공식 github의 triton 모듈을 이용해 하드웨어 최적화한 학습 코드입니다. 네 번째 코드는 여기에서 확인할 수 있으며, 추론 시의 vram 사용량 및 모델 실행 속도를 측정하기 위한 것으로 batch size를 1로 하고, 아주 작은 데이터셋으로 간단히 결과만 확인할 수 있도록 한 것입니다.
우선 torch를 이용해 구현한 결과입니다. 다만 이때는 triton 등으로 구현하지 않은 형태이기 때문에 병렬 scan이 제대로 구현되지 않아 다소 비효율적이긴 합니다. 전체 코드를 확인하시려면 여기를 통해 확인할 수 있으며, 학습 결과는 다음과 같습니다.
Llama 모델이 근소하게 GRU보다 좋았으며, Mamba은 둘보다 학습과 추론 모두에서 확연하게 뛰어났습니다. 심지어 수천 개 차이긴 해도 Mamba 모델이 가장 적은 파라미터를 가지고 있음에도 말이죠.
다음으로는 triton을 이용해 구현된 하드웨어 최적화된 코드를 구현한 결과입니다. 아쉽지만 이 부분은 직접 구현하지 못했고 공식 repo의 구현체를 사용했습니다. 이에 대한 코드는 여기를 통해 확인할 수 있습니다. (참고로 이때 해당 코드를 구현하기 위해서는 공식 mamba 공식 github를 clone한 뒤, setup.py를 이용해 빌드 후 설치해야 합니다. pip install mamba-ssm 을 이용해 설치할 경우, 온전히 설치되지 않고 서버가 다운되기도 했습니다.)
성능은 torch로만 구현됐을 때와 거의 같습니다. 다만 이때는 말 그대로 하드웨어 최적화된 연산으로 인해 실행 속도에서 차이가 발생했는데요, 이를 표로 정리하면 다음과 같습니다.
위 표를 통해 확인할 수 있다시피 triton을 이용한 최적화를 수행했을 때, Transformer 기반의 모델보다도 훨씬 빠르게 동작했습니다. 다만 학습 시 GPU 메모리(VRAM) 사용량은 여전히 막대한 수준인데요. 이는 SSM의 구성요소인 상태 공간이 입력 데이터의 배치 사이즈에도 영향을 받아 커진 영향이기도 합니다. 위 학습은 배치 사이즈 512인 경우이며, 추론 및 서비스 상태와의 비교를 위해 배치 사이즈가 1인 경우로 테스트한 아래 두 개의 결과를 살펴보면 GRU보다 가볍고, Llama보다 빠르게 수행되는 것을 확인할 수 있습니다.
[1] Albert Gu, Tri Dao. "Mamba: Linear-Time Sequence Modeling with Selective State Spaces".
https://arxiv.org/abs/2312.00752.
[2] Maarten Grootendorst. "A Visual Guide to Mamba and State Space Models". https://newsletter.maartengrootendorst.com/p/a-visual-guide-to-mamba-and-state.
[3] Vivian C. "Mamba, SSMs & S4s Explained in 16 Minutes". https://www.youtube.com/watch?v=SUQPeQNy1mE.