brunch

You can make anything
by writing

C.S.Lewis

by Qscar Feb 28. 2023

[AI]Optimization Technique 03

Distillation

3. Distillation


What is Distillation

Distillation(지식정제)은 과거에는 모델 앙상블을 위해 처음 소개되고 사용됐지만, 최근에는 모델의 구조를 단순화, 경량화, 최적화하는 목적으로 많이 사용되고 있습니다(특히 가장 활발히 연구되고 있는 분야인 대규모 모델 등을 최적화하기 위해 많이 사용됩니다!).


Distillation은 teacher와 student 역할을 하는 두 모델의 상호작용이 핵심입니다. 더 크고 성능이 좋지만, 더 느린 teacher 모델의 동작을 모방할 수 있도록 student 모델을 학습시키는 것입니다. 이 과정을 불필요한 지식정보를 '정제Distillate'한다고 표현합니다.


Distillation Process

그렇다면 Distillation은 어떻게 진행될까요?

2015년 발표된 Geoof Hinton의 "Distilling the Knowledge in a Neural Network"에서는 이를 손글씨 분류 데이터셋으로 유명한 MNIST를 통해 설명하고 있습니다.


MNIST 3D PROJECTION - Image by Author

위 그림은 siam network를 통해 MNIST를 학습시키고, 이를 기반으로 다시 데이터를 해석한 값을 기준으로 3차원으로 투영한 결과입니다. 사후적으로 10개의 레이블을 legend로 하여 표시했습니다. 


참고로 여기서 siam network란 두 개의 head를 가진 모델로, 두 장의 input이 주어졌을 때 두 input이 같은 클래스인지 여부만을 판단하도록 학습하는 네트워크를 의미합니다. 


Siam Network Structure - Image by Kaggle

이를 통해 각 클래스 간의 거리는 최대화하고, 클래스 내의 거리는 최소화하도록 학습하게 되며, 이렇게 학습시킨 backbone에 추후 한 개의 이미지를 입력으로 받는 하나의 head와 여러 차원의 출력층(혹은 backbone의 마지막 출력차원이 적당하다면 이를 이용)을 통해 위 그림과 같은 차원축소가 가능해집니다.


보다 직관적으로 확인할 수 있도록 siam network로 해석한 결과를 tensorboard를 통해 출력하면 아래와 같습니다. siam network로 해석한 결과를 11개의 fully connected layers로 출력시켰고, 이를 텐서보드를 통해 분석한 결과입니다. 11개의 features 중에서 데이터들을 가장 잘 표현할 수 있는(분산을 제일 잘 유지할 수 있도록 하는) 주요 변수 3개로 데이터들을 표현한 결과입니다.


PCA by Tensorboard - Image by Author

위 그림에서 살펴보면 각각의 실제 작성한 데이터들이 3차원 맵에 투영돼 있습니다. 조금 더 자세히 살펴보면 숫자마다 서로 가깝고 먼 데이터가 있음을 알 수 있는데, 위 그림에서는 5와 6, 8이 가깝게 위치하는 것을 확인할 수 있습니다.


이는 표면적으론 저러한 데이터들이 서로 유사하구나라고 받아들일 수 있고, 5라고 적힌 손글씨는 5라는 레이블에 대해선 가장 높은 확률값을 가지지만, 6이나 8이라는 클래스에 대한 작은 확률도 가지겠구나라고 생각할 수도 있습니다.


Dark Knowledge

이를 Geoof Hinton의 "Distilling the Knowledge in a Neural Network"에서는 soft target들에 대한 아주 작은 확률이라고 표현했으며, 이후 다른 강연 등을 통해 Dark Knowledge라는 표현으로 정의했습니다. 이를 위해 지식정제에는 하이퍼파라미터 역할을 하는 T(Temperature, 온도)가 필요해집니다.


일반적으로 딥러닝 모델이 학습하는 과정에서 정답 레이블을 제외한 다른 레이블에 대한 softmax 결과를 0에 수렴하도록 학습되는데, 이 과정에서 Dark Knowledge를 잃기 때문에 softmax 함수에 T값을 이용해 로짓 스케일을 조정해줍니다. 이를 수식으로 살펴보면 다음과 같습니다.


Distillation Softmax with T - Image by author

위와 같이 기존의 softmax에 하이퍼파라미터 T를 적용했을 경우, T가 1일 때는 (당연하게도) 본래 softmax 분포와 동일하며 T값이 커질수록 완만한 확률분포를 생성합니다. 이러한 개념을 DeepMind의 Chris J. Maddison 외 2명의 논문에서는 Gumbel softmax라 정의했고, 이를 발전시켜 보다 명확히 정의된 결과를 그래프로 정리한 Google Brain의 논문에 실린 그래프를 살펴보면 아래와 같습니다.


T값에 따른 Gumbel softmax의 결과 및 학습 후 예측 - Image by Paper[3]

위 그림을 살펴보면 상단의 a)는 모델이 학습한 후 추론한 결과이며, b)는 T값에 따른 Gumbel softmax distribution의 데이터셋으로부터 추출한 샘플입니다. 제일 왼쪽의 Categorical이라고 표시된 것은 일반적으로 우리가 알고 있는 one-hot encoding의 결과이며, 점차 T를 늘려갈수록 전체 분포가 완만(uniform)해지는 것을 확인할 수 있습니다. 


다만 실무적으로 T값이 무작정 높아야 좋은 것은 아니었으며, 데이터셋의 성질에 따라 T가 높아질수록 선형적으로 높아지는 경우도 있었으나, 그렇지 않은 경우도 존재했습니다. T를 높일수록 student 모델은 teacher 모델에 더 큰 영향을 받게 되며, 때문에 teacher 모델의 성능이 이미 충분히 높아 이를 기반으로 조금 더 성능을 개선시키고자 하거나 최적화, 경량화를 위한 목적으로 사용할 수 있습니다.




How to Distillate

Distillation은 보통 fine tuning을 위한 과정에서 많이 사용됩니다. 이미 특정 데이터셋에 학습시켜 일정 수준 이상으로 잘 동작하는 모델이 있고, 이를 정제해 더 나은 모델을 학습시키는 방식입니다(물론 처음부터 진행하는 학습과정에서도 사용할 수 있습니다). 이를 코드를 통해 하나씩 살펴보도록 하겠습니다.


Trainer

Distillation은 복수의 모델을 사용합니다. 모델의 구조가 상이하며, teacher 모델의 경우 이미 학습된 모델인 경우도 있습니다. 일반적으로 student 모델은 teacher 모델과 동일한 종류지만 더 가벼운 모델을 사용하지만, 간혹 다른 모델을 사용하기도 합니다. 이러한 것들을 가능하게 해주는 것이 바로 Trainer인데요, 이러한 기능 또한 transformers, keras, torch 등 다양한 라이브러리에서 지원하고 있습니다. 


Distillation에서 student loss - Image by author

여기서는 transformers 라이브러리와 BERT 모델을 사용하도록 하겠습니다. 지식정제모델을 만들기 위해서는 새로운 하이퍼파라미터 a, T가 추가됩니다. T는 이전에 봤던 softmax를 uniform하게 만들어줌으로써 dark knowledge에 대한 손실을 막아주기 위한 가중치이며 1보다 커야 합니다(1이라면 softmax와 동일하므로), a는 정제 손실의 상대적인 가중치를 제어하는 값입니다. 위 그림에서처럼 a를 통해 student의 크로스 엔트로피 손실과 teacher 모델과의 분포 차이를 최종적인 loss에 반영합니다. 코드는 아래와 같습니다.

Distilation을 위한 Trainer class - Image by author

위 코드를 통해 student의 예측결과와 logits를 계산하고(12~16), 이를 teacher의 logits와 비교해 분포 차이를 측정합니다(18~25). 이때 정규화인자인 T를 제곱한 값(소프트 레이블이 생성한 그레디언트의 크기가 1/T^2으로 스케일 조정되기 때문에)과 쿨백-라이블러 발산(Kullback Leibler Divergence) 값을 이용해, teacher와 student 간 확률분포 차이를 형상화합니다.

이후에는 student의 크로스 엔트로피 로스에 a를 곱하고, 그 여분(1-a)을 확률분포 차이 로스에 곱해 student의 loss를 정의합니다(28~).


또한 위 코드를 온전히 사용하기 위해선 새로운 파라미터인 a와 T를 추가할 수 있는 클래스가 필요합니다. 이에 대한 코드는 아래와 같습니다.

a, T를 추가하기 위한 class 정의 - Image by author


Model Training

Trainer를 정의한 이후에는 적절한 student를 고르는 과정이 진행됩니다. trainer 클래스는 범용적으로 사용할 수 있지만 student는 데이터셋이나 task에 따라 달라지곤 합니다. 여기선 huggingface에 올라와있는 BERT를 teacher로 할 것이기 때문에, 이를 Distillation한 DistilBERT를 student로 해보겠습니다.


1) Load Models(Teacher & Student)

이전의 첫 시간에 사용했던 것과 동일한 BERT 베이스의 의도 분류 모델을 불러오도록 하겠습니다. teacher로 사용할 모델은 BERT, student로 사용할 모델은 이를 distillation한 DistilBERT입니다.

Bert모델과 DistilBert모델의 parameter 비교 코드 - Image by author
출력결과 - Image by author


위 코드의 실행결과에서 확인할 수 있듯, DistilBERT는 BERT 대비 약 40% 가량 적은 파라미터를 지니고 있습니다.


2) Tokenizer 초기화 & Simple Preprocessing & Define Score Function

본격적으로 모델을 불러오고, 학습시킬 준비를 하도록 하겠습니다. 아래 코드를 통해 토크나이저를 초기화하고, 필요없는 text 칼럼을 삭제한 후, Trainer가 자동으로 인식할 수 있도록 기존의 'intent' 칼럼을 'labels' 칼럼으로 바꿔줍니다.

토크나이저 초기화 및 사전처리 작업 - Image by author


이후에는 성능을 측정하기 위한 함수를 정의합니다.

정확도 측정 함수 정의 - Image by author


3) Setting Trainer Args

훈련을 위한 하이퍼 파라미터를 지정하고, 사전에 정의한 trainer class로 전달합니다. 가지고 있는 머신 성능에 따라 batch size를 조정합니다. 충분한 학습이 이뤄질 수 있도록 100epochs 동안 학습시켜보도록 하겠습니다. 학습과 관련된 파라미터들은 다음과 같습니다.

trainer args - Image by author

위 그림에서 alpha가 1에 가까울수록 teacher의 영향으로부터 자유롭습니다. 즉, alpha=1로 설정하면 student가 독립적으로 학습하는 것과 동일합니다. logging_steps를 100으로 지정했기 때문에 100 steps, 여기선 약 21번째 epoch부터 학습 로스가 출력되게 됩니다.

이는 대규모 모델 학습, 전이학습, 지식전이와 같은 학습에서는 초반의 epoch는 데이터 자체를 이해하는 과정이기 때문에 성능이 낮고, 다소 의미가 떨어져 로그를 출력하지 않는 식으로 진행할 수 있습니다. 다만 이 과정에서도 테스트 로스 및 추가적으로 부여한 성능지표(metrics)에 대한 평가는 이뤄지기 때문에 학습과정을 이해하는데에는 충분합니다.


4) configuration

student 모델에 intent와 label id의 맵핑을 부여해야 합니다. 이를 위해 transformers의 AutoConfig 기능을 사용할 수 있습니다. 또한 새로운 하이퍼파라미터가 입력됐을 때마다 새로운 모델을 정상적으로 만들 수 있도록 student를 초기화하는 함수(student_init)를 정의합니다. 이를 통해 train 메서드가 호출될 때마다 새로운 모델이 학습될 것입니다.

student configuration code - Image by author


5) train

모든 준비가 끝났습니다. 우선 의도 분류 task에 사전학습된 student와 teacher를 불러오고, 파라미터를 다시 비교해보겠습니다. 만약 해당 코드를 실행하는 과정에서 huggingface 계정 및 권한과 관련한 에러가 있다면 위 3번의 TrainingArguments의 마지막 'push_to_hub'를 False로 수정하거나, huggingface login을 통해 권한 설정을 해줘야 합니다. 만약 hugginface login을 통해 이후 과정을 진행하고 싶다면, write 권한을 가진 토큰을 등록해줘야 합니다.

distillation code - Image by author


위 과정은 GPU 머신의 성능에 따라 꽤 긴 시간이 소요될 수 있습니다. NVIDIA A100 머신 4개를 이용했을 경우엔 약 15분 정도가 소요됐습니다. 결과 코드는 아래와 같습니다.

a=1일때, 학습결과 - Image by author


위는 a가 1인 경우로, 약 85번째 epoch에서 0.937097로 가장 높은 성능(정확도)을 보였습니다. 로스는 지속적으로 떨어지고 있지만 정확도 측면에서는 오히려 떨어지고 있으며 학습과 평가의 로스값이 커지기 시작하는 등 과적합의 의심되는 정확이 포착되고 있습니다(실제로 200epochs까지 학습시킨 결과 과적합이 확인되었습니다). 만약 단일 모델로 사용한다면 85번째 epoch의 checkpoint를 사용하는 것이 적절해 보입니다.


그렇다면 teacher 모델로부터의 지식정제 효과를 발생시키기 위해 a를 조금 낮춰보도록 하겠습니다. a를 0.7 정도로 낮춘 결과입니다.

a=0.7인 경우, 100epoch 학습 - Image by author


정확도 기준의 성능은 유사하지만 차이라면 아직 성능이 선형적으로 높아지고 있다는 것입니다. 또한 로스값을 기준으로 했을 경우엔, 특히 평가 데이터셋에 대한 성능이 아주 준수한 것을 확인할 수 있습니다. 추가적인 실험결과 약 20 에포크 정도는 더 학습시킬 수 있었고, 최고 정확도는 a가 1인 경우와 유사한 수준이었지만 loss값은 훨씬 낮았습니다.


그렇다면 T를 같이 조정해보면 어떨까요? a를 0.7로 둔 상태에서, T를 기존의 2에서 10으로 높여보겠습니다.

a=0.7, T=10인 경우, 100epoch 학습 - Image by author

확인 결과 로스와 정확도 모두 개선되는 것을 확인할 수 있었습니다!

이처럼 데이터셋이나 task에 따라 최적의 a와 T를 찾아야 하는 경우가 발생하는데요, optuna와 같은 라이브러리를 이용해 찾는 것이 일반적입니다. 이와 관련해서는 추후 기회가 있다면 다뤄보도록 하겠습니다.


마지막으로 a는 0.5, T는 5로 두고 과적합이 확실히 드러날 때까지 200 에포크를 학습시킨 결과를 살펴보겠습니다.

a=0.5, T=5인 경우, 200epoch 학습 - Image by author

중간평가입니다, 이전의 학습들과 비교하기 위해 위의 왼쪽 그림을 살펴보면 100에포크 즈음에는 훨씬 낮은 로스와 0.94를 넘는 성능을 보이고 있는 것을 확인할 수 있습니다. teacher로 사용한 BERT 모델의 성능이 0.94 초반에 머물고 있다는 것을 생각하면 스승보다 나은 제자가 나온 셈입니다. 심지어 위의 오른쪽 그림을 살펴보면 126에포크에서는 0.946774로 거의 0.005정도 높은 성능을 보이고 있는 것을 확인할 수 있습니다.


과적합이 관찰되며 학습 종료 - Image by author

물론 이 이후로는 아쉽게도 추가적인 성능의 향상은 관찰되지 않습니다. loss의 하락은 어느정도 있으나 학습 로스만이 독단적으로 낮아져 평가로스와의 격차가 점점 벌어지고, 정확도 성능도 100에포크 즈음과 유사한 수준입니다.


6) validation

첫 시간에 사용했던 Benchmark 함수를 통해 학습에 사용됐던 train, validation셋 외에 test셋을 통한 성능을 확인해보도록 하겠습니다. 이처럼 인공지능 학습 과정에서 학습용 데이터셋, 학습 시 평가용 데이터셋, 학습 종료 후 최종 평가를 위한 데이터셋을 구분하는 과정은 달리 Holdout이라고도 하며, 모델의 성능을 객관적으로 확인하기 위해 필수적인 과정입니다. Benchmark class입니다.

성능평가를 위한 benchmark class - Image by author

위 클래스는 모델의 사이즈(mb), 평균 추론 시간, 그리고 정확도 스코어를 출력하는 클래스였습니다. 그렇다면 이를 이용해 지식정제의 효과를 최종적으로 점검해보도록 하겠습니다.


우선은 Bert original 모델의 학습 시 사용한 평가 데이터셋에 대한 성능입니다.

teacher model의 validation set에 대한 성능 - Image by author


Bert 모델의 사이즈는 418mb, 평균 추론 속도는 21.8초이며, validation dataset에 대한 성능은 0.9429입니다. 그렇다면 test dataset에 대한 성능은 어떨까요? 위 코드에서 clinc['validation']을 clinc['test']로만 바꿔주면 됩니다.


teacher model의 test set에 대한 성능 - Image by author

test set에 대해선 일반적으로 성능이 낮아지기 마련이며, 이 지표를 통해 모델이 학습 시 과적합됐는지를 평가할 수 있습니다. 그렇다면 우리가 학습시킨 정제모델(student)을 확인해볼까요? 아쉽게도 자동으로 저장된 모델은 마지막 학습 과정의 500번째와 1,000번째 스탭에서 저장된 것으로 최고 성능보다는 살짝 낮은 0.946의 수준이었습니다. 하지만 저것조차 BERT 모델보다는 높았으니 test set에 대해서도 확인해보도록 하겠습니다.


local에 저장된 1,000 step의 student model의 성능 측정 코드 - Image by author
student model의 성능

위 그림에서 확인할 수 있듯, student는 teacher 대비 약 60% 사이즈이며, 속도는 두 배 가량 빠르고, 성능은 오히려 조금 더 높습니다. 이처럼 지식정제가 적용된 student 모델은 일반적으로 teacher 모델과 비슷하거나 더 높은 성능을 보이기도 하는데요, 이렇게 학습시킨 student 모델에 이전 시간들을 통해 진행했던 양자화와 가지치기를 추가로 적용할 경우 모델은 더 가볍고, 더 빠른 모델을 구현할 수 있습니다.


어쩌면 성능도 더 빠를수도 있구요.

다음 포스팅에서는 전체 정리의 느낌으로, 이번 포스팅을 통해 정제시킨 모델을 양자화하고, ONNX라는 라이브러리를 이용해 추가적인 성능 최적화를 하는 과정을 다뤄보도록 하겠습니다.




REFERENCE

[1] G. Hinton, "Distilling the Knowledge in a Neural Network" (https://arxiv.org/abs/1503.02531), (2015).

[2] C. J. Maddison, "The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables" (https://arxiv.org/pdf/1611.00712.pdf), (2016).

[3] E. Jang, "Categorical Reparameterization with Gumbel-Softmax" (https://arxiv.org/pdf/1611.01144.pdf), (2017).

브런치는 최신 브라우저에 최적화 되어있습니다. IE chrome safari