brunch

Were RNNs All We Needed?

논문 톺아보기 16

by Qscar

| INTRO |


이번에 톺아볼 논문은 바로 'Were RNNs All We Needed?'입니다. 저는 개인적으로 minRNN으로 부르는 논문인데요, 아쉽게도 ICLR 2025에서는 Reject된 논문이기도 합니다. 그 이유는 기존의 유사한 연구들과 뚜렷한 차이가 확인되지 않고, 몇몇 논문들을 종합해 정리한 것에 가깝다(즉, 참신성이 부족하다)는 것이었습니다. 하지만 그럼에도 불구하고, 해당 논문을 통해 확인되는 기존의 재귀적 모델을 병렬화 가능하도록 구현하고, 이를 통해 보다 적은 파라미터로도 더 효율적으로 학습이 가능하다는 장점은 이 논문을 구현하고, 리뷰하기에 충분한 가치가 있다 판단하였습니다.


또한 기존 RNN의 한계를 극복하기 위해 사용한 병렬 스캔 및 로그 스페이스에서의 연산과 이 과정에서의 수치 불안정성을 해소하기 위해 음수로 연산하는 방법 등이 바로 직전에 포스팅한 Mamba 논문과 같은 논리를 지니고 있어, 연달아 리뷰하기에 좋은 논문이라 판단했습니다. 바로 본론으로 들어가보겠습니다.



| Minimalize RNN |


| #1 전통 RNN의 연산 흐름 |

전통 RNN 셀은 다음과 같이 순차적으로 계산을 수행했습니다.


1. 입력 변환: x_{t}로부터 중간값 계산

2. 게이트 연산: 이전 은닉 상태 h_{t-1}과 결합

3. 은닉 상태 업데이트: 현재 은닉 상태 h_{t} 생성


이러한 과정은 필연적으로 매 타임스탭마다 이전 결과를 기다려야 하므로, 시퀀스 길이 T에 비례해 순차적 종속성이 발생하게 됩니다. 이는 시퀀스 길이가 길어짐에 따라 처리 속도가 늘어날뿐 아니라, 역전파 과정에서 특정 타임스텝의 오차들이 누적돼 악영향을 끼치는 문제를 발생시키게 됩니다. 당연히 병렬화 또한 불가능하므로 병렬 연산에 특화된 GPU를 충분히 활용할 수 없다는 문제 또한 가지고 있습니다.


이러한 문제를 Transformer에선 Attention을 통해 병렬화함으로써 해결했지만, attention 연산 자체가 가진 O(n²) 복잡도가 새로운 문제로 대두되었으며, Mamba에서는 Selective SSM 방법을 이용해 동적으로 이산시간 간격 및 이에 종속적인 파라미터들을 결정하지만 모델 학습 과정에서 사용되는 batch size에 비례해 SSM 사이즈가 커지는 등 메모리 복잡도가 의외의 포인트에서 증가하고 별도로 복잡한 병렬 스캔 알고리즘을 구현해야 하는 등의 어려움이 있었습니다.


| #2 Intro of minRNN |

minRNN은 본 논문에서 공식적으로 제안된 표현은 아닙니다. 본 논문에선 minLSTM 및 minGRU로 명시돼있으나, 본 포스팅에선 이 둘을 통합적으로 묶어서 minRNN으로 지칭하겠습니다.


download.png Abstract of Paper[1]

본 논문에서는 기존 LSTM/GRU 구조에서 종속적인 게이트를 제거하거나 통합함으로써 연산량과 파라미터 수를 대폭 줄이면서 병렬 연산이 가능하도록 함으로써 기존 LSTM/GRU 대비 수십 배에서 수백 배까지도 빠르게 학습시킬 수 있다고 주장합니다. 실제로 시퀀스 길이가 길어짐에 따른 학습 시간과 속도 증가 수준, 그리고 GPU 메모리 사용량에 대한 그래프를 살펴보면 다음과 같습니다.


Comparison with normal and minimal[1]

또한 추가적으로 minRNN의 성능은 Mamba 이전에 제시되었던 S4보다 뛰어나며, Mamba에 견줄만한 수준이고, 강화 학습까지 적용될 경우 Mamba보다도 뛰어나질 수도 있다고 합니다. 뿐만 아니라 일반화 성능 또한 Transformer에 견줄만하다고 하는데요, 아래 그림에서 볼 수 있다시피 (결국 과적합이 발생하지만) Transformer보다 적은 리소스로 성숙(과적합이 일어나기 직전의 최적화 지점)되는 것을 확인할 수 있습니다.

일반화 성능 비교[1]

즉, 여기까지 살펴보면 "minRNN은 기존 RNN보다 가볍고, 수십에서 수백 배까지도 빠르면서, Mamba에 견줄만하고(강화학습이 적용될 경우, 더 뛰어나질 수도 있으며), (과적합이 발생하지만) Transformer와 견줄만한 성능으로 Transformer보다 빠르게 일반화가 가능하다." 정도로 요약할 수 있습니다. (당장 구현해보자! 라는 마음이 들기에 충분한 내용입니다)


| #3 Intro of LSTM(Long Short-Term Memory) |

minRNN 계열을 구현하기에 앞서 기존의 LSTM과 GRU에 대해 간단히 살펴보고 어떻게 이들을 업데이트했는지 알아보며 단계적으로 진행해보도록 하겠습니다. GRU는 LSTM을 단순화한 모델이기 때문에 우선 LSTM 먼저 살펴보도록 하겠습니다. LSTM의 수식은 다음과 같습니다.

LSTM의 수식[1]

LSTM은 기본적으로 네 개의 게이트를 통해 정보를 선택적으로 유지하거나 제거합니다. 구체적으로 위 수식의 첫 번째 f_{t}는 Forget Gate라 하여 잊어야 할 정보를 결정하고, 두 번째 i_{t}는 Input Gate라 하여 새로 추가할 정보를 결정합니다. 세 번째는 Candidate memory cell이라 하고 새로 들어오는 정보를 저장하는 후보를 결정하고, 마지막 네 번째 o_{t}는 Output Gate라 하여 현재 히든 상태를 고려해 출력할 정보를 결정합니다.


참고로 위 수식에서 σ는 sigmoid 함수를 의미하고, ⊙는 요소별 곱셈(element-wise multiplication)을 의미합니다. 그리고 이러한 수식을 코드로 구현하면 다음과 같습니다.

LSTM Cell 코드

위 코드를 보면 네 개의 게이트를 우선 통합적으로 구현합니다. 이는 위 수식에서 네 개 게이트에 공통적으로 포함된 Linear([x_{t}, h_{t-1}])을 한 번에 구한 뒤, 네 개로 나누는 식으로 간단히 구현한 것입니다. 이후에는 각 게이트별로 sigmoid 혹은 tanh를 적용하고, c_t와 h_t를 구해 반환합니다.


이렇게 수식에 기반해 구현한 LSTM Cell을 이용해 LSTM 모델을 구현하면 다음과 같습니다.

LSTM code

혹시 모르니 이렇게 구현한 모델이 잘 구현됐는지 확인해보겠습니다. 가장 간단한 방법은 torch 등에서 공식적으로 제공하는 api 모델을 불러와 가중치를 복사한 다음 직접 구현한 모델과의 출력값을 비교하는 것입니다. 이러한 과정을 통해 모델의 구조가 같으며, 동일한 입력에 대해 동일한 파라미터 구조를 가진 모델의 출력 결과가 같은 수준임을 통해 제대로 구현되었음을 확인할 수 있습니다.


다만 출력 결과가 완전히 같지는 않은데요, 그 이유는 torch에서 제공하는 모델의 경우 최대한의 최적화가 적용되어(CPython 등으로 최적화, 기타 커널 연산 등) 연산 방식이 조금 달라지기 때문입니다. 이러한 이유로 아주 낮은 오차가 발생한다면 정상적으로 구현된 것으로 간주할 수 있습니다. 그 결과는 아래와 같습니다.

직접 구현한 LSTM 체크

오차가 아주 작은 것을 보니 정상적으로 구현된 것을 확인할 수 있습니다. 그렇다면 이를 기반으로 어떻게 Minimalizing을 해 minLSTM을 구현할 수 있는지 확인해보도록 하겠습니다.


| #3 Minimalize LSTM |

LSTM을 Minimalize하는 과정은 세 단계를 통해 구현되는데요, 이를 수식으로 먼저 살펴보면 아래와 같습니다.

minimalizing LSTM[1]

LSTM을 간소화(Minimalizing)하기 위한 첫 번째 단계는 실험적으로 그 기능이 유의미하지 않다고 증명된 출력 게이트 o_{t}를 제거하는 것입니다. 이에 따라 기존에 hidden_size를 네 개의 게이트를 고려해 4배만큼 만들어뒀던 것을 3배만큼만 생성하는 식으로 변경해주면 됩니다. 이에 대한 코드는 다음과 같습니다.


두 번째 단계는 이전 은닉 상태의 의존성을 제거하는 것입니다. 정확히는 수식상에서의 h_{t-1}, 위 코드 상에서 h_prev를 제거하는 것입니다. 이는 모든 게이트에 공통적으로 고려되어 한 번에 연산했던 gates를 구현할 때 적용되어 아래와 같이 최소화됩니다.

마지막으로는 Input Gate와 Forget Gate를 통합하는 것입니다. 구체적으로 Input Gate i_{t}와 Forget Gate f_{t}를 분리 계산 후 비율로 합쳐 f'_{t}와 i'_{t}를 구하게 됩니다. 코드를 통해 살펴보면 다음과 같습니다.


이러한 과정을 통해 기존의 LSTM 수식을 어떻게 간소화할 수 있는지 살펴보았습니다. 요약하면 제대로 사용되지 않는 게이트를 제거하거나 줄이고, 이전 상태 의존성을 제거함으로써 병렬 연산이 가능토록 했다 정도로 요약할 수 있을 것 같습니다. 이러한 과정을 아래의 GRU에 대해서도 동일하게 진행하며 복습해보겠습니다.


| #4 Intro of GRU(Gated Recurrent Unit) |

GRU는 LSTM을 간소화한 버저입니다. 기존 LSTM의 네 개나 되는 복잡한 게이트 구조를 줄여 간결하게 만들면서도 성능은 유사하게 유지합니다.

GRU에 대한 설명[1]

구체적으로 LSTM 모델에서 사용됐던 네 개의 게이트 중 제거해도 큰 지장이 없던 출력 게이트 o_{t}와 셀 상태를 저장하는 c_{t}를 없애고 은닉 상태(h_{t})만으로 시퀀스 정보를 유지합니다. 결국 남는 것은 현재의 입력을 고려해 이전 정보를 얼마나 유지할지 결정하는 Update Gate(z_{t})와 이전 상태를 얼마나 잊을지 결정하는 Reset Gate(r_{t})뿐입니다.


이러한 간소화를 통해 동일한 작업을 수행하는데 적은 파라미터와 연산량을 통해 유사한 성능을 달성한 것이 바로 GRU 모델입니다. 이를 코드로 살펴보면 다음과 같습니다.

GRU 코드

코드의 구현 방식은 LSTM과 동일하게 수식 기반으로 작성했습니다. 이렇게 구현한 코드를 실제 torch 구현체와 비교하면 아래와 같습니다.


torch 구현체와의 비교

torch의 최적화된 연산으로 인해 아주 작은 차이가 있지만 무시할만한 수준으로, 잘 구현된 것을 확인할 수 있습니다.


| #5 Minimalize GRU |

이렇게 구현한 GRU를 마찬가지로 간소화해보겠습니다. 이에 대한 본 논문에서 제시하는 수식은 다음과 같습니다.

how to minimalize GRU

위 수식 변화를 보면 잊어야할 정보를 관리하던 reset gate가 사라지고, 이전 상태에 대한 의존성 h_{t-1}을 제거한 것을 확인할 수 있습니다. 이를 순차적으로 적용해보겠습니다.


우선 reset gate를 제거합니다.

LSTM 때와 동일하게 게이트가 하나 줄었으니 이를 고려해 이외 코드들을 업데이트해주어야 합니다. 이렇게 reset gate를 제거했다면, 이후에는 이전 상태에 대한 의존성을 제거할 차례입니다. 코드로 살펴보면 다음과 같습니다.


이전 상태(h_prev)에 대한 의존성을 제거해, 모두 현재 입력에 대한 연산만으로 수정한 코드입니다.


여기까지의 과정을 통해 수식 기반으로 기존의 LSTM과 GRU 셀들을 간소화(Minimalize)하는 과정을 수행해봤습니다. 그렇다면 이렇게 구현한 셀들의 파라미터 수와 시간을 측정해보면 왼쪽과 같은 결과를 얻을 수 있었습니다.


하나의 작업을 수행하기 위한 셀 단위의 파라미터에 대해 파라미터 수는 약 1/4로 줄었고, 실행속도 또한 1/2로 줄어든 것을 확인할 수 있습니다. 물론 이는 긍정적인 결과이지만, 본 논문에서 제시한 수십~수백 배 이상의 성능 향상과는 다소 거리가 있는 결과인데요, 그 이유는 우리가 아직 해당 논문에서 제시한 병렬 스캔과 관련된 내용을 적용하지 않았기 때문입니다.


때문에 우선 여기선 동일한 작업을 수행하는 셀 단위의 간소화를 이뤄냈고, 이전 상태에 대한 의존성을 제거함으로써 이후에 추가되는 병렬화가 가능하도록 만들었다 정도로 정리하면 될 것 같습니다.


| #6 Parallel Scan |

이전의 수식 간소화를 통해 우리는 파라미터 수를 25%로 줄이고, 실행속도는 두 배로 늘릴 수 있었습니다. 하지만 기존의 알고리즘이 가진 이전 상태에 대한 의존성으로 인해, 병렬화되지 못하고 순차적으로 계산되어야 했던 구조가 여전히 남아있습니다. 즉, 이전까지의 과정을 통해 우리는 병렬화가 가능하도록 구조를 개선한 것이지, 아직 병렬화를 적용한 게 아닌 것입니다.


그렇다면 남은 것은 병렬화 연산을 적용하는 것입니다. 우리가 위에서 구현한 minRNN에서 t시점의 h를 구하는 수식은 다음과 같았습니다.



이를 본 논문에서 제시한 수도 코드를 통해 살펴보면 더 직관적으로 이해라 수 있습니다. 가독성을 위해 minGRU 버전으로 살펴보겠습니다.

Parallel Mode

위 수식을 통해 알 수 있듯, 가장 중요한 변화는 마지막 줄의 h_t를 구하는 부분입니다. 기존에는 이를 순차적으로 하나씩 구현하는 Sequential 한 방법이었다면, 이제는 이를 병렬 스캔으로 구현할 수 있게 되는 것입니다.


이는 아래와 같은 과정을 거쳐 정리가 가능합니다(brunch에선 수식 작성이 어려워, 제가 작성한 jupyter notebook 내용을 그대로 가져왔습니다).


이를 코드로 구현하면 다음과 같습니다.

Parallel Scan Code

이렇게 구현한 병렬 스캔을 적용함으로써 본 논문에서 제시된 실행속도와 같은 효율성을 충분히 달성할 수 있습니다. 이렇게 구현한 병렬 스캔을 적용한 모델의 구조를 코드로 살펴보면 아래와 같습니다.

minGRU for raw space

위 코드들을 통해 알 수 있다시피 시퀀스 길이 T에 대해 누적합용 텐서 B를 계산하기 위해 Z와 H를 미리 계산해둡니다. 하지만 이전 포스팅인 Mamba에서와 같이 위 방식으로 병렬 스캔을 적용할 경우 시퀀스의 길이가 길어지고, 데이터 사이즈가 커짐에 따라 오차가 누적되거나 본래 의도와 다르게 연산되는 등, 이로 인해 수치 불안정성을 야기할 수 있습니다.


위 수식을 기준으로 설명하면 위 수식에서 0~1 사이의 값으로 결정될 At가 반복적으로 곱해짐에 따라 0으로 떨어지거나 밑이 언더플로우하는 문제가 발생할 수 있으며, 이 부분은 Bt에 가중합되며 Bt에도 오버플로우 문제를 야기할 수 있습니다. 이러한 이유로 해당 연산과정을 로그로 변환해 진행하는 것이 안정성 측면에서 낫습니다.


이를 위해 로그 공간에서 연산하기 위한 수식을 다음과 같이 정의할 수 있습니다.

이를 코드로 살펴보면 다음과 같습니다.

mingru for log space

그리고 이러한 코드들에 적용되는 병렬 스캔 코드를 살펴보면 아래와 같습니다.

parallel scan for log space

여기까지 구현한 모델들의 실행속도를 살펴볼 수 있는데요, 그 결과는 아래와 같습니다.

참고로 이러한 성능 향상을 위해선 시퀀스 길이가 충분히 커야 하는데요, 논문에서는 그 기준을 512로 잡고 있습니다. 실제로 시퀀스 길이가 줄어들수록 이러한 속도 향상 효과가 줄어드는 것을 확인할 수 있었습니다.


| #7 Model Structure |

그렇다면 이렇게 구현한 minGRU 셀을 어떻게 wrapping해서 하나의 모델로 만드는 것이 좋을까요?


minGRU 셀 자체를 구현하는 방법은 본 논문에서도 자세히 설명하고 있었기에 큰 어려움은 없었지만, 문제는 이 셀을 활용하는 모델 구조를 어떻게 구현하느냐에 대한 것이었습니다. 이를 위해 크게 세 개의 선택지가 있었는데요, 하나씩 살펴보겠습니다. 참고로 이 부분에 대한 코드는 여기서 확인할 수 있습니다.


첫 번째는 minGRU 셀을 하나의 Transformer Block처럼 적용하는 것이었습니다. 그러니까 이전의 Llama 구조에서 사용한 Transformer Block을 대체하는 식입니다. 이를 구조로 살펴보면 아래와 같습니다. (구조를 간단히 확인할 수 있도록 num_layers=2로 구현했습니다.)

Structure Like Llama

RMSNorm을 이용한 Pre-Norm 전략 및 SwiGLU FFN을 이후에 통과시키는 이중 구조로 하나의 블록을 구현했습니다. 성능은 이후에 세 개의 구조를 모두 살펴본 뒤에 일괄적으로 확인해보고, 다음으로 넘어가겠습니다.


두 번째는 Mamba의 구조를 차용하는 것이었습니다. 바로 살펴보면 다음과 같습니다.

Structure Like Mamba

Mamba의 구조와 동일하게 Pre-Norm > Linear > Causal Conv 1d > [Main Cell] > Linear > Dropout의 순서로 구현한 결과입니다.


마지막으로 세 번째로는 이들을 융합하는 구조인데요. 이 구조를 생각한 이유는 위의 Mamba 구조를 차용한 모델의 성능이 예상과 달리 매우 부족했기 때문입니다. 저는 그 원인을 정규화없이 연속된 레이어들의 연결에 있다고 보았고, 이를 최소로 줄이는 방식을 채택하고자 했습니다. 이를 살펴보면 아래와 같습니다.

Hybrid Structure

위 구조는 기존의 Mamba Block 안에 있던 Casual Conv1d를 초기 단계에서 1회만 수행하고, 블록을 별도로 구성하진 않고 minGRU 셀을 연속으로 반복해 자체적으로 블록이 되도록 구현했습니다. 이렇게 구현한 모델까지 총 세 개 모델들의 성능을 확인하면 다음과 같았습니다.


학습 손실만 따질 경우, Mamba 구졸르 차용한 모델이 여러 조건 하에서 항상 뛰어났습니다. 다만 평가 손실의 경우 위 오른쪽 그림처럼 빠르게 과적합되는 경향을 보이고, 과적합 직전까지의 성능도 다른 모델 미치지 못했습니다.


그리고 Transformer 구조를 차용한 모델에 비해 Hybrid 구조를 차용한 모델이 다양한 조건 하에서도 미세하게 더 성능이 좋았고, 이는 모델의 최고 성능 및 일반화 등 제가 확인한 모든 지표에서 앞서는 결과를 확인할 수 있었습니다. (두 모델 모두 충분히 길게 학습시킬 경우, 둘의 일반화 성능은 유사하긴 했습니다)


위 결과를 기반으로 이후의 다른 모델들(Original GRU, Llama, Mamba)과의 비교 작업은 세 번째 Hybrid 구조를 기반으로 진행하겠습니다.


| #8 Model Comparison |

그렇다면 마지막으로 이렇게 구현한 모델과 torch 내부적으로 구현된 gru 기반 모델 및 현대 Transformer(llama), 그리고 이전에 구현한 Mamba와 비교했을 때 어떤 결과를 얻을 수 있을까요? 시퀀스 길이별로 성능을 확인해보겠습니다. 이에 대한 코드는 여기에서 확인할 수 있습니다.


우선 시퀀스 길이가 64인 경우이며, 모델의 파라미터 수준은 약 5만 개로 맞추었습니다.

T=64 comparsion

왼쪽이 학습 손실, 오른쪽이 평가 손실을 나타내는 곡선입니다. Transformer 모델이 상대적으로 낮은 성능을 보이고 있고, 그 외 3개(GRU, Mamba, minGRU) 모델은 유사한 수준의 성능을 내고 있습니다.


다음으로는 시퀀스 길이 128, 모델 파라미터 수준은 50만 개의 경우입니다.

T=128 comparison

시퀀스 길이가 길어짐에 따라 Transformer 모델은 경쟁력을 갖추기 시작하고, 그 외 모델들은 작든크든 과적합의 경향을 보기에 됩니다. 다만 val_loss 기준으로 살펴보면 최고 성능은 GRU가 차지했고, 학습 로스와 평가 로스를 모두 고려했을 경우 Transformer와 minGRU과 비슷한 수준이었습니다.


마지막으로 시퀀스 길이가 256개인 경우입니다. 모델 파라미터는 90만개 수준입니다.

T=256 Comparison

시퀀스 길이가 길어짐에 따라 GRU와 Mamba는 이전보다 더 빠르게 과적합 경향을 드러냈습니다. 물론 여전히 평가 손실을 기준으로 하면 GRU가 가장 높은 성능을 나타냈고, 학습 로스와 종합적으로 평가하면 초중반부에선 minGRU가, 그 뒤로 학습이 진행될수록 Transformer 점차 우위를 얻는 결과를 확인할 수 있었습니다.


| #9 Model Optimization |

마지막으로 하나 더 수행하고 싶은 것은 바로 모델의 최적화입니다. 이는 이전에 리뷰했던 Mamba와 같이 병렬스캔 및 그 외 가능한 부분들을 trition 모듈로 대체하면 어떨까하는 의문으로 시작한 것입니다. 이를 위해 적용할 수 있는 방법은 크게 두 가지입니다. 하나는 병렬 스캔 등을 Mamba 때와 동일하게 triton을 이용한 모듈로 구현해 적용하는 것이고, 다른 하나는 torch.compile을 이용하는 것입니다. 또한 이 둘을 동시에 적용할 수도 있습니다.


이에 대한 코드는 두 개로 나뉘어있는데요, triton 및 torch compiler를 이용하는 과정에서 커널 간 충돌 등이 발생해 두 개의 코드로 분할했습니다. 기존의 모델과 이를 triton 적용이 가능하도록 dynamic한 부분을 수정한 모델, 그리고 이를 compile한 결과를 먼저 확인해보겠습니다. 이에 대한 코드는 여기에서 확인할 수 있습니다.


기존에 위에서 우리가 정의한 모델은 dynamic한 구조로 인해 triton 등의 적용이 어렵고, 모델 사이즈에 비해 실행속도가 느리다는 단점이 있었는데요. 대신 그만큼 파라미터 수 대비 성능이 높다는 장점이 있었습니다. 하지만 여기선 그러한 장점을 포기하는 대신, 파라미터수를 훨씬 늘려서 컴퓨팅 자원 측면에선 비슷한 수준의 소모를 유지하면서 성능은 그 이상이 나오도록 구조를 수정했고, 이를 torch compiler까지 적용한 결과입니다.


이를 위해 기존에 사용했던 모델의 파라미터보다 두 배 많은 파라미터를 가지도록 했고, 위 결과에서 확인할 수 있다시피 loss 측면의 성능은 충분히 달성한 것을 확인할 수 있습니다. 또한 실행속도를 기준으로 살펴보면, 기존의 5만 개 파라미터를 가진 모델이 에포크 당 약 31초의 실행시간과 65.19mb의 vram 사용이 이뤄졌고, 구조를 개선한 10만 개 파라미터를 가진 모델이 에포크 당 약 19초의 실행시간과 28.35mb의 vram 사용이 이뤄져 파라미터가 두 배로 늘었음에도 실행시간과 vram 소모를 기존 대비 61.29%, 43.49% 수준으로만 사용하였습니다.


또한 여기에 torch.compiler를 추가로 적용할 경우, 에포크 당 실행속도는 약 12초, vram 사용은 3.53mb로 기존 대비 38.71%의 실행시간과 5.41%의 vram 사용만이 이뤄졌습니다. 그렇다면 여기에 triton을 이용한 병렬 스캔을 적용하면 어떻게 될까요? 이에 대한 실험 결과 및 코드는 여기에서 확인할 수 있으며, 그 결과는 다음과 같습니다.


결론을 먼저 말하자면 triton을 이용한 속도 향상과 자원 소모 감소는 가능하나, 수치적 불안정성을 개선하기 위한 추가적인 방법이 필요하다고 정리할 수 있습니다. triton을 적용한 모델은 에포크 당 약 14초의 실행시간과 27.64mb의 vram 사용을 보여, 적용하기 전 대비 73.68%의 실행 시간과 97.50%의 vram 사용량을 보였습니다. 또한 이렇게 triton이 적용된 모델에 torch compiler를 적용할 경우, 첫 에포크를 통해 수행되는 compile 작업이 포함된 에포크 시간이 22.67초에서 19.07초로 16% 감소했고, 이후 에포크 당 실행 속도는 12초에서 10.5초로 12.5% 감소했습니다. 이는 기존의 5만 개 파라미터 모델 대비 에포크 당 실행속도는 33.87%, vram 사용량은 5.65%만이 사용된 결과입니다.


하지만 이런 실험 결과를 확인하기 위해 여러 차례 코드를 수정하며 진행한 결과, triton을 이용해 일부 모듈을 대체할 경우, 이후 torch compiler를 적용할 때 compile에 걸리는 시간을 줄일 수 있고 추가적인 실행속도 개선을 얻을 순 있지만, 수치 불안정성으로 인한 성능 문제 및 커널 이슈가 종종 발생해 특수한 경우가 아니라면 torch compiler만을 적용하는 정도에서 만족하는게 대부분의 경우 최선의 선택지가 되지 않을까 싶습니다.



|Reference

[1] Leo Feng, et al. “Were RNNs All We Needed?”. https://arxiv.org/html/2410.01201v1.

[2] Rick Fritschek, Rafael F. Schaefer. "MinGRU-Based Encoder for Turbo Autoencoder Frameworks". https://arxiv.org/abs/2503.08451.




keyword
작가의 이전글Mamba