brunch

You can make anything
by writing

C.S.Lewis

by Qscar Jan 02. 2024

Swin Transformer V1

Paper Review 4 : 모델링, 학습, 튜닝, 평가

|Intro

이번에 리뷰할 논문은 'Swin Transformer: Hierarchical Vision Transformer using Shifted Windows'입니다. 이 논문은 읽을수록 저 이름을 아주 잘 지었다는 생각이 들 정도로 저 내용이 핵심입니다. 기존에 ViT의 문제점을 개선하고자 계층적인(Hierarchical) 구조를 윈도우를 움직여 가며(Shifted Window) 구현한 것입니다. 이를 통해 ViT 대비 계산 효율성과 부족한 데이터로도 더 좋은 효율성을 달성했고, 다양한 형태의 입력과 단순 분류 외 여타 태스크에도 활용할 수 있는 구조의 모델입니다.


또한 Swin Transformer는 V1, V2로 나뉘어 제시되는데요, 본 포스팅에선 V1에 대해서 리뷰를 진행한 후 이후의 포스팅에서 V2를 리뷰하도록 하겠습니다.



|Paper Summary

Swin Transformer의 구조



Swin Transformer의 핵심은 계층적인 구조라는 것입니다. 한 장의 이미지를 일정한 사이즈의 패치로 나눠, 패치 간의 어텐션을 계산하던 것이 이전에 리뷰한 ViT였습니다. 이는 이미지를 고정된 사이즈의 패치로 나눠야 하기 때문에, 이미지 사이즈가 달라지면 그에 따라 입력층, Patch Embedding, Positional Embedding과 같은 층이 통째로 교체되어야 하고, 이 과정에서 사전학습된 가중치까지 같이 사라지며 큰 데이터셋에 사전학습하며 성숙시킨 성능을 함께 잃어야 했습니다. 


또한 너무 작은 사이즈에서는 고유의 강점을 보이기가 어렵지만 반대로 너무 큰 사이즈의 이미지를 입력으로 주면, 어텐션 계산을 위한 패치가 너무 많아 비효율성이 발생하게 됩니다. 한 장의 이미지에서 분할된 패치의 개수가 n이라고 하면, ViT의 어텐션 연산 복잡도는 O(n^2)이 되기 때문입니다.


이를 개선하고자 Swin에서는 한 장의 이미지를 여러 개의 윈도우로 구분하고, 그 윈도우 내의 패치 간의 어텐션만 계산하도록 하였습니다. 이는 이미지의 사이즈가 달라져도 윈도우의 개수만 달라질뿐, 윈도우 내의 패치 수는 동일하기 때문에 다양한 사이즈의 이미지를 처리할 수 있습니다.

ViT와 Swin의 MSA 연산복잡도

이에 대해서 본 논문에서 제시한 연산 복잡도(computational complexity)는 위와 같습니다. 왼쪽의 붉은색으로 지운 복잡도선형 변환을 계산하는데 필요한 연산 복잡도로, 입력과 가중치 매트릭스를 곱하는 복잡도이며, 이 부분은 ViT와 Swin가 동일합니다(코드적으로 설명하면 nn.Linear로 q, k, v를 각각 projection한 후 최종적으로 한 번  더 Linear Projection해서 총 4번의 Linear Projection이 이뤄지는 것에 대한 계산 복잡도 입니다). 


주목할 부분은 오른쪽의 복잡도 부분입니다. 이 부분은 query와 key의 내적 및 value와의 내적을 의미하는 부분입니다. 위의 ViT에서는 시퀀스 길이(이미지의 사이즈를 패치로 나눈 값)의 제곱에 비례하고, 아래의 Swin에서는 시퀀스 길이의 사이즈에 비례하지만 그보다 작은 윈도우의 사이즈(M)의 제곱에 비례하는 식입니다. 또한 여기서는 간단하게 hw, M, C와 같이 동일한 값을  나타내는 듯한 기호로 작성돼 있지만 이 부분에서도 꽤 큰 차이가 납니다.


위 연산복잡도에 대한 공식 중 MSA와 관련된 사항은 여기를 참조하시고, W-MSA와 관련된 사항은 다음과 같습니다. 참고로 두 연산복잡도 계산에서 softmax에 대한 고려는 배제되었습니다.

W-MHA의 연산복잡도 설명


구체적으로 각 값을 대입해보면 ViT에서 h와 w는 이미지를 패치로 나눈 시퀀스의 길이로 h=w=14가 됩니다. 채널은 768로 12개의 트랜스포머 블록을 통해 반복되었습니다. 이를 수식적으로 계산하면 아래와 같습니다. 

12 x (4x196x768x768 + 76,832x768) = 12 x (462,422,016 + 59,006,976) = 6,257,147,904

반면 swin은 계층적인 구조로 이미지의 공간적 크기를 줄이고 채널을 늘려가며 진행하게 되는데요, 이에 대한 복잡도를 수식으로 표현하면 다음과 같습니다.

stage 1 (공간차원 56x56→hw: 총 패치 수=3,136, 2회 반복, dim=96)
: 2 x ( (4x3,136x96x96) + 98x3,136x96) = 2 x ( 115,605,504 + 29,503,488) = 290,217,984

stage 2 (공간차원 28x28 → hw=784, 2회 반복, dim=192)
: 2 x ( (4x784x192x192) + 98x784x192) = 2 x ( 115,605,504 + 14,751,744) = 260,714,496

stage 3 (공간차원 14x14 → hw=196, 6회 반복, dim=384)
: 6 x ( (4x196x384x384) + 98x196x384) = 6 x ( 115,605,504 + 7,375,872) = 737,888,256

stage 4 (공간차원 7x7 → hw=49, 2회 반복, dim=768)
: 2 x ( (4x49x768x768) + 98x49x768) = 2 x ( 115,605,504 + 3,687,936 ) = 238,586,880

Total  = 1,527,407,616


이렇게 계산한 두 모델의 MSA 연산량 차이는 약 4.01배에 이르게 됩니다.

참고로 이러한 복잡도의 차이는 실제로는 더 커지는데요, 여기선 복잡도 계산 규칙에 의해 ViT에서만 적용되는 class token으로 인한 차원수 증가를 반영하지 않았기 때문입니다.


연산 복잡도에 대한 설명

이에 대해서 본 논문에서는 기존 ViT는 이미지 사이즈에 대해 연산이 quadratic(제곱적)하지만, Swin은 linear(선형적)이며, Local Window를 통해 지역적 특성도 잘 반영한다라고 합니다. Swin이 Small Window를 통해 계산적 효율성과 지역적 특성을 반영할 있는 구조가 됨으로써 적은 컴퓨팅 리소스와 적은 데이터로 모델을 학습시킬 있다는 이 얼추 이해가 되었다면 조금 더 자세하게 들어가보도록 하겠습니다.


architecture of Swin Transformer

Swin의 구조를 살펴보면 위와 같습니다. 위 그림의 오른쪽 (b)를 먼저 살펴보면 이미 친숙한 LN과 MLP, 그리고 Residual Connection이 진행되고 있음을 확인할 수 있으며, 새로 등장한 개념인 W-MSASW-MSA가 있습니다. W-MSA(Window based Multihead Self Attention)은 이미지를 고정 크기(M)의 윈도우로 분할하고, 윈도우 내의 패치에 대해 MSA를 진행하는 것입니다.


 SW-MSA(Shifted Window based Multihead Self Attention)W-MSA와 유사하지만 윈도우의 위치를 조금씩 바꿔가며 MSA를 진행합니다. 구체적으로 W-MSA에서 하나의 윈도우의 경계 부분에 있던 픽셀들은 다음의 SW-MSA를 진행하며 다른 윈도우의 중심 부분에 위치하게 됩니다. 이를 통해 윈도우 경계를 넘는 객체나 패턴에 대한 인식이 가능해지며, 이는 각 윈도우가 독립적으로 연산&처리되며 고립되는 것이 아니라, 각 윈도우 간의 정보를 교환하고 관계성을 구축하는 효과를 만들어 내게 됩니다. SW-MSA를 통해 고정된 형태와 위치에 큰 영향을 받던 ViT의 한계를 극복하고, 지역적 특성을 반영하게 됩니다. 조금 더 쉽게 말하자면 이미지를 더 잘 이해하게 되는 것입니다. 


cyclic shift와 relative position bias에 대한 설명 그림

이들에 더해 본 논문에서는 순환 이동(Cyclic Shift)상대적 위치편향(Relative Position Bias)이 모델의 효율성과 성능을 증가시켰다고 합니다. 순환 이동(Cyclic Shift)은 이미지의 가장자리를 위 그림에서와 같이 반대쪽 이미지에 합치는 방법(본 논문에서 제시된 워딩은 'cyclic-shift toward the top-left direction'으로 오른쪽or아래에 있는 가장자리 영역을 왼쪽or위로 이동시키는 것입니다)으로 단순한 패딩으로 처리하던 기존 방식보다 효율적으로 연산하고, 이미지의 문맥 이해와 정보의 연속성 유지에 기여했습니다. 이에 대해선 아래의 Cyclic Shift를 구현하는 과정에서 자세하게 다루도록 하겠습니다.


상대적 위치편향은 윈도우 내에서 각 토큰들 사이의 상대적 거리에 따라 편향값을 부여하고, 이 학습가능한 편향값을 어텐션 계산에 추가하여, 데이터의 문맥과 패턴에 최적화되도록 하는 것입니다. 이는 기존에 Positional Encoding 혹은 Positional Embedding을 통해 패치간 절대적인 위치를 기준으로 파악하던 방식을 대신 이미지의 상대적 위치에 대한 이해를 돕고, 인접 영역과의 상관관계를 효과적으로 포착합니다. 결과적으로 모델이 입력 이미지의 사이즈 변화에 유연하면서도 성능이 향상되는 결과를 확인할 수 있었다고 합니다.


|Modeling

Swin Transformer의 개념 자체는 복잡하지 않은 모델이지만, 여러 레이어들이 교차적으로 들어가거나 순차적으로 들어가야하고 익숙치 않은 개념이 적용되기 때문에 구현을 위해선 명확히 짚고 넘어가야 합니다. 우선 각각의 레이어를 구축하기 전에 사전 구현된 모델을 통해 전체 구조를 살펴보고 넘어가보겠습니다. 이 부분에 대한 전체 코드는 제 github를 참조하시면 됩니다.


|Simple Structure

Swin Transformer 구조

위 코드는 Transformer library의 사전구현체를 불러와 해당 모델의 메인 모듈부터 하위 모듈을 순차적으로 가시화하는 코드입니다. 오른쪽의 출력된 결과를 살펴보면, Swin Transformer는 크게 Swin과 Classifier로 구분됩니다. Classifier는 불러온 모델이 이미지 분류용이기 때문에 붙은 것이고, 태스크별로 다른 메인 모듈이 붙습니다. Swin은 위 그림을 통해 보면 크게 임베딩과 인코더, 레이어 정규화, 풀러라는 이름의 하위 모듈로 구성돼 있는 것을 확인할 수 있습니다. 


Swin Structure

이전에 한 번 살펴봤던 Swin의 구조를 살펴보면 위와 같은데요. 크게 네 단계(Stage)에 걸쳐 진행되며, 각 단계를 통해 채널 수를 늘려가며 반대로 해상도는 낮춰가며(Down Sampling) 진행하게 됩니다. 첫 번째 단계에서는 선형 임베딩 이후 Swin Transformer Block을 지나는 과정을 두 번 거치고, 이후부턴 패치 병합(Patch Merging)이라는 과정을 거치며 Down Sampling되고, Swin Transformer 블록을 지나는 과정을 반복하게 됩니다.


다만 위 그림에서 다소 애매한 부분이 있습니다. 일반적으로 패치 분할(Patch Partition)은 단순히 이미지를 지정된 사이즈의 패치로 나누는 것을 의미합니다. 이는 단순히 이미지를 나누는 것에 불과하며 그 어떤 학습가능한 가중치가 있는 것이 아닙니다. 다만 이후 Linear Projection하는 것을 고려하면 이 과정은 이미지를 단순히 분할하는 게 아닌 패치 임베딩(Patch Embedding)을 의미하는 것입니다. 또한 패치 분할과 패치 병합은 Cyclic Shift가 구현되기 위한 전후처리 과정입니다. 때문에 이를 보다 사실적으로 표현하면 Swin의 구조는 아래와 같습니다.

Swin Structure

패치라는 용어가 너무 반복적으로 등장해 패치 파티션과 패치 리버스를 동일한 의미로 사용되는 윈도우 파티션과 윈도우 리버스로 대체하였습니다. 위 그림에 기초해 프로세스를 설명하면 아래와 같습니다.


[Swin Transformer Process]

1. Pre Stage : 이미지를 임베딩해 스테이지 안으로 전달
2. Resize : 임베딩 결과를 이미지 사이즈로 다시 리사이즈
3. Window Partition: 윈도우 사이즈로 분할
4. W-MHA : Window based MHA 진행
5. Window Reverse : 분할된 이미지를 하나의 이미지로 다시 합치기
6. Cyclic Shift + Window Partition : 윈도우의 위치를 조정하고, 그에 맞게 윈도우 분할
7. SW-MHA : Cyclic Shift 진행 후 W-MHA
8. Post Stage : 2n회의 반복이 끝나면 Patch Merging을 하며 다음 스테이지로


본 논문을 구현하면서 가장 헷갈렸던 부분이기도 했는데요, 위 프로세스를 구성하는 핵심은 cyclic shift입니다. 아래 그림의 제일 왼쪽을 보면 cyclic shift에서 원본 이미지는 총 9개의 조각으로 나뉘게 됩니다. 하지만 작은 좌상단의 윈도우들을 우하단으로 이동시킴으로써 9개의 조각은 4개의 윈도우만으로 처리할 수 있어 연산효율성을 얻을 수 있습니다. 

cyclic shift structure

이러한 효과를 누리기 위해 각 스테이지는 쉬프트되지 않은 이미지를 분할하고 W-MSA를 진행한 뒤, SW-MSA를 위해 이미지를 다시 합쳐놓을 필요성이 있습니다. 이를 위해서 각 스테이지에서는 Window Partition→(Cyclic Shift)→Window Reverse 과정이 반복적으로 발생하는 것입니다. 또한 각 스테이지는 W-MSA 이후 cyclic shift를 통해 SW-MSA가 진행되기 때문에, 각 transformer block은 2의 배수로 적용됩니다.


|Embeddings

embedding layer에 대한 설명

embeddings의 구현은 기존 ViT와 크게 다르지 않습니다. 우선 위의 Simple Structure에서 봤다시피 Embeddings에는 이미지를 여러 개의 패치로 분할하는 patch_embedding과 Layer Normalization, 그리고 dropout이 적용됩니다. 내부 파라미터로는 윈도우 사이즈(default:7)보다 작은 패치 사이즈(default:4)와 이를 고려한 더 적은 모델 차원수(embed_dim:default=96)를 가지고 있습니다. 이때 ViT와의 차이라면 크게 두 가지가 있을텐데요, 하나는 ViT보다 많은 패치 수와 현격히 적은 채널 수입니다. 


Swin Variants Structure

ViT의 경우, 가장 작은 모델이었던 ViT-Base가 224 사이즈의 이미지를 대상으로 16 사이즈의 패치로 분할했을 경우, 196개의 패치와 768개의 embed dim을 가졌었습니다. 반면 Swin의 경우, 가장 작은 모델인 Swin-Tiny가 96개의 채널(embed_dim)과 3,136개의 패치를 가지게 됩니다.


ViT와 Swin의 다른 점 또 한 가지는 입력 이미지의 사이즈가 고정됐던 ViT와 달리, Swin은 Window를 이용한 구조를 통해 입력되는 이미지 사이즈에 유연한 구조여야 한다는 것입니다. 이를 고려해 코드를 작성해보도록 하겠습니다. 

embeddings code

embeddings는 크게 patch embedding, layer normalization, dropout의 세 개 레이어들로 구성돼 있고, img_size는 정수형 혹은 튜플 타입으로 받을 수 있으며, 정수형으로 받을 경우, 이를 2차원 튜플로 확장하는 코드입니다.


|Swin Transformer Block

Swin의 핵심이 되는 Swin Transformer Block은 아래의 수식과 같이 구현됩니다.

Transformer block 수식

크게 네 단계로 구성되는 swin transformer block은 매 단계마다 Layer Normalization과 Residual Connection이 진행되고, 그 외에는 W-MSA→MLP→SW-MSA→MLP의 순으로 진행되게 됩니다. 이를 간단히 요약하면, 1)윈도우 내의 패치에 대해서 상관관계 분석 후 2)정리, 3)윈도우 위치 바꿔서 다시 상관관계 분석 후 4)정리 정도의 프로세스입니다.


다만 이 블록에는 위의 수식 외에도 추가적으로 언급된 개념들이 함께 포함되어야 하는데요, 그 대표적인 것들이 바로 순환 이동(Cyclic Shift)와 상대적 위치 편향(Relative Position Bias), 그리고 패치 병합을 통한 다운 샘플링(Down Sampling)입니다. 


|순환 이동(Cyclic Shift)

이전의 전체 구조를 설명하면서 언급한 바와 같이 순환이동은 W-MSA가 진행되고, 윈도우로 분할된 이미지를 다시 되돌린 후 진행됩니다. 이를 간단히 생성한 예시 이미지로 살펴보면 아래와 같습니다.


Cyclic Shift의 결


이렇게 cyclic shift가 적용된 이미지에 window partition을 적용하면 아래와 같습니다.

Window Partition의 결과

왼쪽은 Cyclic Shift가 적용되지 않은 W-MSA가 적용된 것이며, 오른쪽은 Cyclic Shift가 적용되어 SW-MSA가 적용된 것으로 이해할 수 있습니다. 이러한 이유로 우리는 아래에서 하나의 W-MSA 클래스를 통해 W-MSA와 SW-MSA를 구현할 것입니다.


|상대적 위치 편향(Relative Position Bias)

상대적 위치편향은 토큰 간의 상대적인 거리를 고려할 수 있도록 함으로써 패치 분할 혹은 윈도우 분할로 제대로 고려되지 않는 영역정보를 고려하기 위한 방법입니다. 이를 위해 패치 간 상대적인 위치 정보를 x축과 y축을 기준으로 위치행렬을 생성해 두 개의 좌표를 통해 상대적인 좌표를 구현하는 식입니다. 


본 논문에서는 상대적 위치편향을 구현하는 과정에서 다음과 같이 M^2 사이즈의 상대적 위치 편향 B를 (2M-1)x(2M-1) 사이즈의 B hat으로부터 추출한다고 되어 있습니다. 참고로 2M-1이라는 단위는 상대적 위치편향의 범위(-M+1, M-1)을 고려한 최대 범위입니다.

이를 위해 각 패치간 위치를 인코딩할 수 있는 방법을 제시합니다. 구체적으로 상대적 위치편향을 고려할 수 있는 위치행렬을 생성한 뒤, 윈도우 사이즈를 고려해 보정하고, 종합적으로 고려할 수 있는 단일 행렬을 생성합니다. 이 단일 행렬은 각 패치의 상대적 위치정보가 담겨있으며, attention 계산 시 함께 고려되게 됩니다. 이를 하나씩 살펴보면 아래와 같습니다.


STEP 01. 이미지의 X, Y축에 대한 위치행렬 계산

전체적인 맥락 이해가 쉽도록 window size는 2로 작게 설정하였습니다. 이렇게 할 경우, 각 윈도우는 4개의 패치로 구성됩니다. 이를 왼쪽 상단에서부터 우하단의 순서로 flatten했다고 가정하고, 각 패치 간의 가로축과 세로축 기준 위치를 단순 인코딩한 것입니다. 

Image를 flatten

이해를 돕기 위해 위 그림과 인코딩된 값을 살펴보며 설명하겠습니다. 위치행렬 계산 결과 중 첫 번째 행렬, 그 중에서도 첫행을 살펴보면 그  값이 [0, 0, -1, -1]로 되어 있는데요. 이는 1번 패치와 1, 2, 3, 4번 패치들 간의 X축에 대한 위치를 표현한 것입니다. 이를 풀어서 쓰면 [1번 패치와 같은 행에 있음, 1번 패치와 같은 행에 있음, 1번 패치보다 아래 행에 있음, 1번 패치보다 아래 행에 있음] 정도로 이해할 수 있습니다. 이런 식으로 각 패치 간의 관계를 가로세로축을 기준으로 인코딩하는 것이 첫 번째 단계입니다.


STEP 02. Window Size를 고려한 위치행렬 보정

윈도우 사이즈를 고려해 보정된

두 번째 단계는 가로축과 세로축의 값을 두 개의 행렬이 아닌 단일 행렬로 나타내기 위해 범위를 조정하고(+Window Size-1), 가로축을 표현하는 값을 스케일링 해줌으로써 단일 행렬만으로 모든 값을 표현할 수 있도록 하는 것입니다. 위 과정을 통해 윈도우 내 모든 범위(거리와 방향)를 표현할 수 있게 됩니다.


STEP 03. 종합적으로 고려된 위치행렬 계산

최종적으로 생성된 상대적 위치행렬

위 그림들을 통해 생성된 결과물을 어떤 식으로 상대적 위치편향을 위한 인덱스 테이블이 최종적으로 생성되었는지 알 수 있습니다. 이는 마치 Transformer를 처음으로 구현할 때, sin cos 주기함수를 통해 구현하는 것과 비슷한 원리이나 절대적인 위치정보를 제시하느냐, 상대적인 위치정보를 제시하느냐의 차이가 있습니다.


참고로 이렇게 생성한 위치행렬을 해석하면 다음과 같습니다.

위의 3단계는 이러한 규칙의 Bias Index를 구하기 위한 공식과도 같은 것이고, 실제 모델링을 위해선 이렇게 구한 Bias Index를 학습가능한 파라미터로 구성된 Bias Table에 곱해 일정한 패턴을 띈 상대적 위치 테이블을 구성해 사용하도록 합니다.


상대적 위치편향은 위 수식과 같이 어텐션 스코어를 계산하는 과정에서 추가적인 정보를 제공하는 역할을 하게 됩니다. 또한 이를 통해 위치정보를 전달했기 때문에 이전의 트랜스포머 구조와 달리 위치 인코딩이나 위치 임베딩을 사용하지 않습니다. 또한 이미지 사이즈가 커지거나, 패치 사이즈가 줄어 패치 수가 늘어나면 이전에 진행했던 방법으로 상대적 위치정보를 전달하는 위치행렬 내부의 값 일부가 커져 제대로 연산되지 않는 문제가 생길 수 있습니다. 


위 과정까지 진행했다면 남은 것은 각 head의 수와 표현가능한 모든 범위에 걸쳐 인코딩된 값을 투영하는 것입니다. 이를 위해 학습가능한 파라미터 Relative Position Bias Table을 만들고, 여기에 기존에 만들어둔 Relative Position Bias를 통해 각 head 별로 상대적 위치 편향을 이해하면서도 학습가능하도록 해 최적화할 수 있도록 합니다.


STEP 04. 위치행렬 인덱스를 위치행렬 테이블에 적용

상대적 위치 인덱스를 조정하기 위한 편향 테이블

위치행렬 인덱스를 테이블에 적용하는 코드와 그 결과는 위와 같습니다. 이를 통해 우리는 토큰의 상대적 위치에 다른 어텐션 매커니즘 조정에 긍정적인 역할을 기대할 수 있습니다. 위 계산식은 본 논문의 아래 이미지 강조된 부분을 참고하여 작성하었습니다.


relative position bias matrix(table)에 대한 설명


|W-MSA & SW-MSA

남은 것은 위에서 구현한 코드들을 잘 합쳐 하나의 클래스로 구현하는 것입니다. 이에 대한 구조도를 그리면 아래와 같습니다.

Window Attention Structure

위 구조도에서 다루지 않은 부분은 Mask입니다. 여기서 마스크가 사용되는 이유는 윈도우의 경계를 정확하게 처리하기 위함입니다. 이는 Cyclic shift의 결과로 인해 하나의 윈도우 내에 본래라면 같이 붙어있을 수 없는 이미지 쌍들이 존재하게 되기 때문에 적용된 것입니다. 이때 마스크를 적용하지 않으면 본래라면 붙어있지 않았을 픽셀/패치 정보 교환이 일어날 있습니다. 이는 Swin의 W-MSA 기능을 저해하게 되기 때문에, 본래 붙어있는 윈도우끼리연산이 일어나도록 처리하는 것입니다.


이에 대한 코드는 아래와 같습니다.

Window Attention Code

코드는 좀 길지만 복잡하진 않습니다. 하나씩 살펴보면, 이미지에 대한 q, k, v를 생성하고 스케일링을 적용하고, 상대적 위치 편향을 고려할 준비를 합니다. 그 이후에는 전달받은 mask가 있다면 적용(SW-MSA인 경우)하고, 그 외에는 그대로 Attention Score Matrix를 구성하는 식입니다. 이 과정에서 적절히 dropout 등을 넣어주고, 그 다음 Transformer Block으로 넘겨줄 수 있도록 선형변환을 진행해야 합니다.


|etc - MLP, Layer Scale

Swin Transformer Block은 크게 W-MSA, SW-MSA와 함께 Layer Normalization, 그리고 MLP, 세부적으로는 dropout과 residual connection, 추가적으로 이전 ViT에서 소개한 Drop Path와 Layer Scale을 적용할 것입니다. MLP 코드는 아래와 같습니다.

mlp 코드

mlp 레이어는 그다지 특별한 것이 없습니다. ViT와 동일합니다. Drop Path는 timm 라이브러리 구현체를 사용하며, Layer Scale은 이전에 사용했던 구현체를 그대로 사용합니다. 코드는 아래와 같습니다.


Layer Scale 코드


|Swin Transformer Block

여태까지 구현한 클래스들을 이용해 Transformer Block을 구현해보겠습니다. 이 코드에서는 cyclic shift와 이를 고려한 SW-MSA, 마스크 생성 등이 추가됩니다. 전체 코드가 길기 때문에 크게 세 개로 나눠서 살펴보도록 보겠습니다.

Swin Trasformer Block - init

첫 번째는 init입니다. [row30~35] 다양한 이미지의 사이즈가 입력될 수 있기 때문에 이를 고려해, 윈도우 사이즈보다 작은 이미지가 입력될 경우 등을 대비한 코드를 추가해줍니다. [row37~53] 그 외에는 순서대로 layer normalization, attention layer, drop path, mlp 등을 정의합니다. [row56] 마지막으로 cyclic shift가 진행될 경우를 대비해 어텐션 마스크를 계산하는 함수를 호출합니다. 다음으로 Attention Mask를 작성하는 코드를 살펴보면 아래와 같습니다.


Swin Transformer Block - calculate attn mask

위 코드는 입력된 이미지 사이즈, 윈도우 사이즈, 쉬프트 사이즈에 따라 마스크를 생성합니다. 이때 -100으로 변환된 곳들은 어텐션 계산시 제외되고 0인 곳들끼리만 계산되는 식입니다. 이러한 마스크 연산은 shift가 적용될 때만 계산되며, shift size와 window size가 동일한 경우 cyclic shift가 의미없기 때문에 사이즈가 다를 경우에만 계산됩니다. 또한 마지막으로 이러한 attn mask를 buffer로 등록하며, 이렇게 등록된 buffer는 현재의 이미지 사이즈와 윈도우 사이즈, 쉬프트 사이즈를 유지한다면 반복적으로 계산하지 않고 한 번 계산한 것을 재사용해 계산 효율성을 높일 수 있습니다. 


마지막으로 forward입니다. foward는 한 번 본 적이 있는 아래 구조도의 Stage 부분을 참고하면서 진행하면 이해가 쉽습니다.

transformer block structure
Swin Transformer Block - forward


|Patch Merging

Patch merging에 대한 설명

본 논문에서 제시된 사항에 따르면 네트워크가 깊어질수록 Patch Merging을 통해 4개의 인근 패치(2x2 neighboring patches)를 하나로 합치면서 채널 수를 두 배로 늘려나갑니다. 이를 통해 계층적 표현(hierarchical representation)이 가능해게 됩니다. 구체적으로 채널 수가 두 배로 늘어나지만 해상도가 4배로 줄어들기 때문에 계산효율성이 늘어나고, 점차 수용영역(Receptive Field)이 넓어짐으로써 다양한 해상도에서의 특징을 파악하는 것이 가능해집니다.


Patch Merging Code

이에 대한 코드는 위와 같습니다. [row 20~23] Patch Merging을 진행하려면 입력된 이미지가 지정된 사이즈로 변형가능해야하며, 다운샘플링할 수 있는 사이즈(2의 배수)여야 합니다. 그 뒤로는 인근 네 개의 패치 그룹으로 구분하고, 이를 연결합니다. [row 32] 이를 연결하고 형상을 바꾸는 과정에서 4배로 늘어난 채널을 [row 34~35] 1/2로 줄여 4배로 늘어난 채널의 수를 2배 수준으로 조정합니다.


|Stage Layer 구성

이제 모든 구성요소가 정의되었습니다. 이제 이를 이용해 스테이지 레이어를 구축한 뒤, 다시 이를 이용해 Swin Transformer 전체를 구현해보도록 하겠습니다.

Stage를 구성하는 코드

이미 사전에 필요한 것들을 모두 구성해뒀기 때문에 여기서는 이것들을 적절히 배치하면 됩니다. [row 32] 여기서 주의해야할 것은 depth가 깊어질수록 이에 따라 shift size를 조정하는 코드를 작성해야 한다는 것입니다. 이를 위해 단순하게 포함된 모듈을 순서대로 실행하는 nn.Sequential이 아닌 명시적으로 호출해서 사용해야 하는 nn.ModuleList를 적용합니다.


|Swin Transformer

Stage를 구성하는 Layer까지의 구현이 완료됐다면 이젠 그 앞부분의 embeddings와 뒤의 classifier와 결합하는 것만 남았습니다. 다만 이때 추가적으로 고려할 사항이 있습니다.

gap에 대한 설명

우리는 swin을 구성하면서 class token을 추가하지 않았는데요, 그 이유는 바로 여기에 있습니다. 위 그림에서 볼 수 있듯 본 논문에서는 마지막 스테이지 레이어의 결과물을 global average pooling을 적용해 class token을 대체했고, 추가적인 class token을 사용하는 것만큼의 성능을 보였다고 합니다. 이는 class token만큼의 연산을 아낌과 동시에 단순한 객체 인식 외의 태스크에도 적용가능하다는 장점이 있습니다. 때문에 우리는 이전까지 구현한 embeddings, stage layers, patch merging 클래스를 이용해 최종 피처맵의 출력을 만들고, 이에 대해 gap을 적용한 후 마지막 클래스 분류를 위한 분류기(head)를 만들어주도록 하겠습니다.


stage를 거치며 변화하는 x의 형상

참고로 classifier 역할을 하는 head는 별 다를 것 없이 최종 피처맵의 출력 사이즈를 분류할 class의 수로 연결하는 Linear Layer이기 때문에 전체 코드를 통해 간단히 소개하고, 여기선 먼저 gap에 대해서 알아보도록 하겠습니다. 우선 gap층을 지나기 전의 x의 shape는 (batch, HxW÷32÷32, dim x 8)입니다. 이는 의도적으로 (batch, window size, dim x 8)의 사이즈가 되도록 유도함으로써, 마지막 스테이지에서는 이미지 전체를 하나의 윈도우만큼의 사이즈로 줄여 전체적인 맥락을 살펴보기 위함으로 생각됩니다. 이러한 x의 형상에 gap를 적용하는 방법은 크게 두 가지가 있습니다.


gap code 두 가지

왼쪽의 첫 번째 방법은 Adaptive Avg Pool1d 매서드를 사용하는 것이고, 오른쪽 방법은 그냥 간단하게 첫 번째 차원에 대해서 평균을 취하는 것으로 둘의 연산 결과 차이는 없습니다. 속도 차이도 거의 비슷한 수준입니다. 엄밀히 말하면 배치사이즈나 차원 사이즈가 작을 때에는 인스턴스 매서드인 mean을 그냥 사용하는게 소폭 빨랐고, 배치 사이즈나 차원 사이즈가 커질수록 Adaptive Avg Pool1d 매서드를 사용하는 게 미세하게 빠르거나 비슷해졌습니다. 본 논문에선 Global Average Pooling을 의미하는 문구만 있을뿐 구체적으로 어떻게 구현했다는 언급은 없습니다.


성능적 차이가 없다면 코드의 간결성이나 가독성을 따져봐야 합니다. 구현하는 코드로는 인스턴스 매서드인 mean을 사용하는 것이 편하지만, 본 포스팅의 목적인 모델의 구조를 제대로 구현했는가를 생각하면 pooler를 제대로 구현하는 것이 더욱 좋은 선택일 것 같습니다. 때문에 여기서는 Adaptive Avg Pool1d 매서드를 사용하겠습니다. 전체 코드는 아래와 같습니다.

최종적으로 구현된 Swin Transformer Code

forward 부분을 살펴보면 입력된 이미지 (B, C, W, H)에 대해 embedding을 진행하고, 차례대로 Layer를 지난 뒤, layernorm, global average pooling, 그리고 classifer를 지나며 Swin Transformer가 완성됩니다. 위의 클래스의 기본값은 가장 작은 모델인 tiny를 기준으로 지정하였습니다. 이 모델의 출력결과를 살펴보면 다음과 같이 정상적으로 출력되는 것을 볼 수 있습니다.

Swin Transformer를 통과한 결과

또한 본 포스팅의 첫 부분에서 timm 구현체를 통해 확인했던 바와 같이 동일하게 모델의 구조를 간략히 살펴보면 아래와 같습니다.

Swin Transformer 구조 비교. 직접구현(좌), timm 구현체(우)

timm 구현체에서는 backbone으로서 swin 구조를 구분하고, classifier를 붙인 구조입니다. 반명 우리가 직접 구현한 구현체는 하나씩 설명을 진행하며 이해하기 쉽도록 구현한 것이고, 굳이 timm 구현체와 같은 구조로 구현할 순 있지만 큰 의미가 있는 것은 아니기에 그렇게 진행하진 않았습니다. (내부 모델 구조는 동일합니다.)


구현한 swin tiny의 parameter 수

이렇게 구현된 모델의 파라미터 수는 약 28M입니다. 이는 본 논문에서 제시하는 29M과는 조금 차이가 있는 수인데요, 한참을 제가 무언갈 빼먹은 것이 있나 싶어 모델의 상세 구조 등을 샅샅이 검토해봤지만 찾을 수 없어 공식 구현체를 찾아보았습니다. 아래는 각각 본 논문에서 제시하는 파라미터 수와 공식 github에서 제시한 파라미터 수입니다.

논문과 공식 github에서 나온 Swin-T의 파라미터 수

본 논문에서와 달리 공식 github에서는 28M으로 돼있고, 좀 더 구체적으로 파라미터를 확인하기 위해 github의 학습 로그를 살펴보면 아래와 같이 약 2,829만 개의 파라미터를 가지고 있어 직접 구현한 것과 다소 차이는 있지만 거의 유사한 것을 확인할 수 있습니다.


공식 github에 올라와있는 학습로그

모델 구현은 이 정도로 마치고, 이제 구현한 모델을 통해 이전에 ViT를 통해 학습시켰던 sports 데이터셋에 학습시켜보도록 하겠습니다.



|Train

다음으로는 모델 학습입니다. 기본적으로 대부분의 세팅이나 파라미터는 이전에 포스팅한 ViT와 동일하게 진행했습니다. 데이터셋은 kaggle의 sports dataset을 사용했고, 이에 대한 데이터 증강도 동일하게 처리했습니다(본 논문에선 별도로 제시하는 증강방법이 있으나 데이터 증강에 따른 모델 성능 차이를 최소화하고자 했습니다). Sports Dataset에 대해서 간략히 다시 설명하면, Sports Dataset은 224 사이즈의 이미지를 통해 100종에 달하는 스포츠 중 어디에 해당하는지를 분류하는 것으로 총 13,493장으로 구성돼 있습니다. 이는 Transformer model을 학습시키기엔 충분하지 않은 숫자이지만, 대신 학습 결과를 빠르게 확인할 수 있고, 최소한의 성능이 나올 수 있도록 여러 하이퍼 파라미터 등을 조정해 학습시킬 수 있습니다.


|Train Option & Parameters


우선 전체적인 모델의 구조는 본 논문의 아래 부분을 참고하였습니다.

Detail architecture of Swins

이러한 파라미터는 아래 코드와 같이 기본값으로 지정해두었습니다.

swin-T Setting


|Opitimizer : AdamW

본 논문에서는 Adam을 개선한 AdamW Optimizer를 사용했습니다. 이는 Adam이 가진 weight decay가 의도와는 다르게 적용되는 문제 때문입니다. 구체적으로 Adam과 같이 적응형(Adaptive)으로 각 가중치에 대한 개별적인 학습률을 계산하는 알고리즘의 경우, weight decay를 적용할 때 학습률이 크게 조정되는 가중치는 덜 감소하고, 학습률이 작게 조정되는 가중치는 과도하게 감소되는 문제가 발생할 수 있습니다. 이는 일반화가 아닌 모델의 성능 자체의 하락을 야기할 수 있기에 이에 대한 대안으로 AdamW가 제시되었습니다.


AdamW는 이를 해결하기 위해 가중치 감소를 그라디언트 업데이트 단계와 분리해 처리합니다. 즉, Adam의 weight decay가 적용되는 방식이 각각의 가중치에 대해 (1-lr*weight decay)와 같은 방식이었다면, AdamW는 1-weight decay를 적용한 뒤에 학습율을 적용함으로써 각 weight 별 학습율의 영향을 배제하는 식으로 진행되는 것입니다. 

Adam vs AdamW, 하이퍼 파라미터 튜닝의 영향 시각화[4]

특히 이러한 AdamW의 장점은 lr가 유동적으로 변화하는 scheduler가 함께 적용되었을 때 장점을 발휘합니다. 위 그림은 Adam과 AdamW의 learning rate와 schduler별 성능(test loss)을 히트맵으로 시각화한 것으로, 동그라미는 가장 좋은 성능을 보인 lr와 weight decay 지점으로 해석할 수 있습니다. 위 그림을 통해 살펴볼 수 있는 첫 번째는 AdamW의 성능이 Adam보다 높다는 것입니다. 특히 cosine annealing을 적용했을 경우, Adam에선 등장하지 않았던 완연한 파란색(낮은 test loss)이 보여지고 있습니다.


두 번째로 확인할 수 있는 것은 최적의 성능을 내는 조합(동그라미 위치)들이 Adam의 경우엔 가로로 우하향하고 있지만, AdamW의 경우엔 다양한 조합으로 매트릭스를 그리듯 나타난다는 것입니다. 이러한 경향은 해당 논문에서 제시한 다른 그림을 살펴보면 더욱 뚜렷하게 확인할 수 있습니다.

Effect of Decoupled weight decay regularization[4]

해당 논문[4]에서 제시한 위 그림은 SGD와 Adam에 그들이 제시한 방법론인 'Decoupled Weight Decay Regularization'을 적용하기 전후의 결과를 시각화한 것입니다. 위 그림에서는 기존의 SGD와 Adam이 뚜렷하게 우하향하며 lr과 weight decay factor가 상호 영향을 주고 받는 것으로 보이지만, 새로운 방법을 적용했을 경우엔 상호 독립적인 모습을 보여주고 있습니다. 이는 기존의 lr가 업데이트되는 과정에서 weight decay가 함께 고려되었던 것을 개선했음을 의미하며, 동시에 이를 통해 더욱 높은 성능을 달설항 수 있음을 의미하는 것입니다.


최근에는 AdamP라고 하여, 네이버의 클로바AI 팀에서 제시한 새로운 옵티마이저를 사용하고 있기도 한데, 이에 대한 논문은 제가 아직 읽어보지 못했기에 추후에 별도로 포스팅하거나, 이를 사용해 학습한 논문이 있다면 그때 다뤄보도록 하겠습니다.


adamW 적용방법

본 논문에서는 AdamW를 사용하며 0.05의 weight decay를 적용했다고 합니다. 하지만 이전의 ViT 실험에서와 마찬가지로 우리가 가진 작은 데이터셋에서는 이러한 weight decay가 부정적으로 작용하는 경우가 많았습니다. 또한 이번에도 마찬가지로 weight decay를 0.05보다는 0.01로 세팅하는 것이 최종적인 성능이 더 좋았습니다. 이는 간단히 아래와 같이 한 줄의 코드로 적용할 수 있습니다. 참고로 해당 구현체의 기본 weight decay는 0.01입니다.

torch.adamW


|Scheduler : Cosine Warmup Scheduler

다음으로는 스케줄러인데요, 이는 이전에 구현했던 cosine warmup scheduler와 차이가 없습니다. 이미 transformer 구현과 ViT에서 직접 구현하며 그 그래프를 살펴보기도 했기에, 여기선 transformers 라이브러리를 통해 구현된 구현체를 활요하도록 하겠습니다.

cosine warmup scheduler code

위 코드에서 warmup steps는 초기학습(warmup)을 진행할 step이고, training_steps는 전체 학습 steps를 의미합니다.


|Gradient Clipping

gradient Clipping에 대한 적용

Gradient Clipping은 모델이 깊어질수록 역전파로 전달되는 파라미터의 기울기가 중첩돼 발산하고, 이로 인해 한 번의 업데이트 스탭이 너무 커져버리는 문제를 해결하기 위한 방법입니다. 이에 대한 수식은 아래와 같은데요, 간단히 요약하면 기울기가 hyper parameter로 전달받은 norm의 최대값(threshold)보다 크면, 기울기 벡터를 최대값보다 큰만큼의 비율로 나눠주는 것입니다. 이를 통해 기울기 벡터는 항상 역치보다 작은 값이 되며, 기울기 벡터의 방향은 유지하면서도 한 번에 너무 많은 step을 건너뛰지 않도록 제어할 수 있습니다.

Gradient Clipping에 대한 수식

본 논문에서는 threshold로 적용할 max_norm을 1로 두었습니다. 이를 코드로 적용하려면 아래와 같이 학습 코드를 작성해야 합니다. 이때 주의할 점은 optimizer.step() 이전에 적용해야 한다는 것입니다.

일반적인 gradient clip 적용 방법

다만 우리는 이번에도 이전의 ViT에서 사용했던 것과 같이 AMP(Automatic Mixed Precision)을 적용한 학습 코드를 적용할 것인데요, 이를 통해 모델의 성능은 해치지 않으면서도 학습 효율성을 높일 수 있었습니다. 이를 위한 코드는 아래와 같습니다.

gradient clipping + amp


참고로 테스트 결과 gradient clipping이 적용됐을 때, 레이어의 깊이가 12인 tiny 모델의 성능은 오히려 소폭 하락했습니다. 하지만 레이어의 깊이가 24인 small은 더 높았으며, 동일한 깊이인 base나 large에서는 긍정적인 영향을 줄 것으로 보입니다. (이는 학습시킨 데이터셋에 종속적인 분석일 수 있습니다)


|Training & FineTuning

여기서 구현된 Swin-Tiny의 파라미터 수는 약 28M 수준으로 86M 수준이었던 ViT-Base보다 작습니다. 동일한 12 Depth를 가졌음에도 이렇게 파라미터 수의 차이가 발생하는 것은 Patch Merging을 통한 Downsampling 덕분입니다. 이러한 이유로 Swin-Tiny를 학습시킬 때에는 에포크당 소요시간이 10초 정도(AMP 적용기준) 단축되었습니다. 이 부분에 대한 코드는 제 github를 참조하시면 됩니다.


|Augmentation

본 논문에서 적용하는 데이터 증강에 관한 내용

본 논문에서는 다양한 데이터 증강 기법을 활용합니다. 여기서 굳이 '다양하다'고 표현한 이유는 일반적으로 transforms를 통해 input으로 사용되는 이미지를 변형하는 것이지만, 본 논문에선 이에 한정되지 않기 때문입니다. 그나마 첫 번째로 적용되는 RandAugment를 제외하면, Mixup과 Cutmix는 학습 과정에서 배치에 들어간 이미지들끼리 섞는 것이고, Stochastic Depth는 우리가 이미 모델링 과정에서 적용한 drop path를 의미하는 것입니다. 


두 개의 원본 이미지와 증강된 이미지

Random Augmentation은 기존에 있던 RandomCropandResize와 Random Horizontal Flip 정도로만 적용하겠습니다. Drop Path는 이미 적용돼있고, Mixup과 Cutmix는 학습 코드에 추가로 적용해야 합니다. Mixup과 Cutmix에 대한 이해를 돕기 위한 그림을 살펴보면 위와 같습니다. 이를 코드로 적용하기 위해선 아래와 같이 작성할 수 있습니다.


mixup과 cutmix가 적용된 코드

참고로 이러한 증강방식이 적용됐을 때의 특징은 간단한데요, 바로 val_loss가 상당 기간 동안 train loss보다 낮은 상태로 학습이 진행된다는 것입니다. 즉, 일반화 성능이 높아진 상태로 학습이 진행됨을 의미합니다. 추가적으로 mixup_fn에서 사용되는 label smoothing은 적용된 mixup&cutmix된 이미지들 간의 label smoothing을 적용하는 것입니다. 이러한 방식이 Loss Function을 통해 적용되는 Label Smoothing보다 항상 좋지는 않지만 실험적으로는 이편이 성능이 더 좋았습니다.


|Train Result

본 논문에서는 별도의 데이터 증강방법에 대한 언급이 존재합니다. 랜덤 데이터 증강기법 외에는 cutmix와 mixup 그리고 stochastic depth가 그것들이며, stochastic depth는 사실 이전의 ViT에서도 이미 적용한 DropPath를 의미합니다. 때문에 여기서는 위에서 설명한 cutmix와 mixup을 추가로 적용한 결과를 확인해볼 것입니다. 그리고 일반 또한 직접 구현한 모델의 성능을 비교하기 위해 timm 라이브러리의 사전구현된 모델과 동일한 조건으로 학습시켜 성능을 비교했습니다.

timm 구현체로 학습한 결과
직접 구현한 모델로 학습한 결과

이전에 ViT로 학습했을 경우엔 직접 구현한 모델이 0.54, timm 구현체가 0.45의 f1-score를 보였었던 것과 대조적으로 높은 성능을 발휘하고 있습니다. 가장 큰 차이는 training loss는 엇비슷하지만 val loss와는 큰 차이를 보였던 ViT와는 달리 Swin은 그래도 비슷한 수준으로 수렴했다는 것인데요, swin에서 사용된 cyclic shift가 기존의 작은 데이터셋에 대해 효율적인 학습을 위해 설계된 목적을 긍정적으로 달성한 결과로 보입니다.

이러한 성능은 현재 학습시킨 Swin-T의 파라미터(28M)가 ViT-Base의 파라미터(86M)의 1/3 수준이라는 것을 고려하면 알고리즘적으로 우수하며, 파라미터 효율적이라는 평가를 내릴 수 있을 것입니다.



|FineTuning Result

두 번째로 살펴볼 것은 FineTuning입니다. 이 또한 ViT를 포스팅했던 때와 동일합니다. Sports Dataset에 학습시킨 모델을 CIFAR10 데이터셋에 파인튜닝해볼 것입니다. 32x32 사이즈의 CIFAR10 데이터셋을 224 사이즈로 늘리고, ①사전학습되지 않은 직접 구현체와 ②timm 구현체, ③사전학습된 직접 구현체와 ④timm 구현체의 성능을 비교하며 확인해보도록 하겠습니다. 이 부분의 전체 코드는 여기를 참조하시면 됩니다.

finetuning detail

우선 본 논문에서 제시한 파인튜닝 방법을 살펴보면 위와 같습니다. 기본적으로 본 논문에서 제시하는 파인튜닝은 보다 큰 사이즈로의 파인튜닝에 국한돼 있습니다. 때문에 여기선 cifar10 본래의 이미지 사이즈를 활용하는 파인튜닝은 제외하고, 224 사이즈로 키워서 적용하는 방식만을 테스트해볼 것입니다. 파인튜닝을 위한 optimizer는 여전히 adamW이지만, 학습율은 1e-5, weight decay 또한 1e-8로 아주 낮은 수준으로 진행하며, 학습을 위한 에포크는 30, 데이터 증강이나 규제는 동일하게 적용하지만 Drop Path(Stochastic Depth)를 0.1로 낮춥니다. 


이를 적용해 finetuning을 해보도록 하겠습니다. 첫 번째로 살펴볼 것은 사전학습되지 않은 모델의 성능입니다.

직접 구현체(上)와 timm 구현체(下)의 cifar10 학습 결과

두 모델 모두 성능 향상이 어느정도 이뤄지긴 했지만 충분한 수준은 아닙니다. 다만 ViT 모델로 학습시켰을 때와 비교하면 훨씬 성능이 높아진 것을 확인할 수 있는데요, 직접 구현체 기준 Val Loss가 1.27→1.23로 향상됐고, timm 구현체 기준으로도 1.88→1.23 성능이 향상됐습니다. 다음으로는 사전학습한 모델을 이용한 파인튜닝 결과입니다. 이를 위해 다음과 같이 직접 구현한 모델의 classifier를 교체하였습니다.

모델의 drop path 조절 및 classifier 교체

이렇게 학습시킨 모델들의 성능을 평가하면 아래와 같습니다.

직접 구현한 Swin 모델을 Sports 데이터셋에 학습시킨 뒤 cifar10에 파인튜닝한 결과
timm 구현체 Swin 모델을 Sports 데이터셋에 학습시킨 뒤 cifar10에 파인튜닝한 결과

ViT로 학습시켰을 때는 0.56 정도의 정확도와 0.55 정도의 F1 Score가 나왔었던 것을 생각하면 Swin을 사용함으로써 약 30%의 성능 향상이 이뤄졌습니다. 만일 처음부터 학습시킨 모델로 성능을 평가하면 약 0.58~0.6 정도의 F1 Score가 나오게 되며, 반대로 ImageNet과 같이 더 크고 다양한 이미지가 포함된 데이터셋에 학습시킨 모델로 파인튜닝할 경우엔 거의 1에 가까운 스코어 나오게 됩니다. 이는 아래 테스트 결과 이미지를 확인하거나 finetuning 코드를 참조하시면 됩니다. 


ImageNet에 학습시킨 Swin-T를 FineTuning했을 때의 결과

위의 왼쪽 이미지를 통해 알 수 있듯 거의 1에 가까운 스코어임에도 30에포크보다도 적은 10 에포크만에 달성한 성과입니다.



|Reference

[1] Ze LiuYutong LinYue CaoHan HuYixuan WeiZheng ZhangStephen LinBaining GuoSwin Transformer: Hierarchical Vision Transformer using Shifted Windows. https://arxiv.org/pdf/2103.14030.pdf

[2] 太阳花的小绿豆. Swin-Transformer网络结构详解. https://blog.csdn.net/qq_37541097/article/details/121119988

[3] apodx. swin-transformerhttps://aistudio.baidu.com/projectdetail/3735708?channelType=0&channel=0

[4] Ilya Loshchilov & Frank Hutter, Decoupled Weight Decay Regularization, https://ar5iv.labs.arxiv.org/html/1711.05101



|Log

2023.12.17 | Paper Reading

2023.12.18 | Blog Reading

2023.12.19 | Sample Code 작성 ① - timm research

2023.12.20 | Sample Code 작성 ② - torch, transformers modeling

2023.12.21 | Modeling 초고 작성 ①

2023.12.22 | Modeling 초고 작성 ②, 학습코드 구현 및 학습 진행

2023.12.26 | 학습 결과 확인 및 추가 수정사항 적용(gradient clipping)

2023.12.27 | 파인튜닝 및 1차 퇴고

2023.12.28 | 2차 퇴고, 코드 정리

2023.12.29 | 3차 퇴고, 코드 주소 추가

2023.12.30 | 최종 검토, 학습 및 파인튜닝 파이프라인 구축 및 실

2024.01.02 | 발행

2024.02.02 | RPB index 의미 추가

2024.02.05 | ViT와의 MSA 연산량 비교 상세 추가

2024.03.04 |계산복잡도 재계산


 

브런치는 최신 브라우저에 최적화 되어있습니다. IE chrome safari