RNN의 효율적 학습을 위한 고급 접근법

by 송동훈 Hoon Song

딥러닝 강의에서 배운 Recurrent Neural Network의 고급 학습 방법에 대한 인사이트를 정리해보았다. 특히 긴 시퀀스를 효율적으로 학습하는 방법과 상태 유지의 중요성이 인상적이었다.


1. Truncated BPTT의 등장 배경

Screenshot 2025-05-01 at 4.11.43 PM.png

일반적인 Backpropagation Through Time은 시퀀스 길이가 길어질수록 두 가지 심각한 문제가 발생한다. 첫째, 계산량이 너무 많아져 학습 속도가 느려진다. 둘째, 그라디언트가 너무 커지거나(exploding gradient) 너무 작아지는(vanishing gradient) 현상이 발생한다. 이런 문제들을 해결하기 위해 Truncated BPTT가 개발되었다.


2. Weight Sharing의 효율성과 도전. RNN의 핵심 특징은 모든 시간 단계에서 동일한 Weight를 공유한다는 점이다. 이는 학습해야 할 파라미터 수를 크게 줄여준다. 예를 들어, 시퀀스 길이가 1,000이어도 학습 파라미터는 3개의 Weight Matrix(Wxh, Whh, Why)로 제한된다. 이는 학습 가능성을 높여주는 중요한 요소다.


3. 계산량 vs 파라미터 수의 균형. 딥러닝에서는 두 가지 리소스 제약이 있다. 계산량과 학습 파라미터 수다. 계산량이 많아지면 학습이 느려지지만, 파라미터 수가 너무 많아지면 학습 자체가 불가능해질 수 있다. 일반적으로 파라미터 수를 줄이는 것이 더 중요한 최적화 방향이다.


4. Truncated BPTT의 핵심 메커니즘

Screenshot 2025-05-01 at 4.19.59 PM.png

이 방법은 전체 시퀀스를 다 보지 않고, 현재 시점에서 일정 길이(k 단계)만큼만 거슬러 올라가 그라디언트를 계산한다. 이렇게 하면 계산량이 줄어들고, 그라디언트 소실/폭발 문제도 완화된다. k값은 하이퍼파라미터로, 과거 정보를 충분히 활용할 만큼 크되 학습이 어려워지지 않을 정도로 작게 설정해야 한다.


5. 서브시퀀스의 개념과 배치 학습. Truncated BPTT에서는 긴 시퀀스를 고정 길이 k의 서브시퀀스로 나누어 처리한다. 각 서브시퀀스는 독립적인 샘플로 취급되며, 배치 학습 시 여러 서브시퀀스를 모아 처리한다. 중요한 점은 서브시퀀스 내부의 시간적 순서는 유지해야 하지만, 서로 다른 서브시퀀스 간에는 순서가 없어 셔플이 가능하다는 것이다.


6. Stateful 모드의 장점

Screenshot 2025-05-01 at 4.22.22 PM.png

RNN의 첫 타임스텝에서는 Hidden State가 없어 보통 랜덤 값으로 초기화한다. 그러나 Stateful 모드에서는 이전 배치의 마지막 Hidden State를 다음 배치의 첫 Hidden State로 전달한다. 이 방식은 배치 간에도 시간적 연속성을 유지해 더 긴 시간 의존성을 학습할 수 있게 해준다.


7. Stateful 모드 사용 시 주의점. Stateful 모드에서는 배치 간 순서와 샘플 수의 일관성이 중요하다. 각 배치에 동일한 수의 샘플이 포함되어야 하며, i번째 샘플은 다음 배치의 i번째 샘플로 정확히 연결되어야 한다. 따라서 Stateful 모드에서는 데이터 셔플링을 할 수 없다.


8. CNN과 RNN의 Weight 공유 비교. CNN은 공간적 차원에서, RNN은 시간적 차원에서 Weight를 공유한다. CNN에서는 커널 필터를 이미지 전체에 적용하여 파라미터 수를 줄이고, RNN에서는 동일한 Weight를 모든 시간 단계에 적용한다. 둘 다 입력 크기에 관계없이 일정한 수의 파라미터만 학습하는 효율적인 방법이다.


RNN 학습의 핵심은 시간적 의존성과 효율성 사이의 균형을 맞추는 것이다. Truncated BPTT는 계산 효율성을, Stateful 모드는 시간적 연속성을 향상시켜 더 좋은 성능을 얻게 해준다. 이러한 방법들을 적절히 조합하면 긴 시퀀스 데이터도 효과적으로 학습할 수 있다.

keyword
일요일 연재
이전 27화RNN에서 시간을 거슬러 학습하는 방법