Paper Review 13
KAN(Kolmogorov Arnold Network)는 모델의 해석가능성을 높이는 측면에서 기존 MLP보다 뛰어나지만 결국 tree 알고리즘에 MLP를 섞은 결과물 정도로 보이며, 최초 논문[1]에서 제안한 바와 달리 MLP를 대체하기엔 어려워 보입니다. 이는 추가적인 개선이나 KAN 구조에 적합한 하드웨어의 등장 등이 기반되지 않는다면 아주 일부 영역에서만 모델의 해석가능성을 높이기 위해 사용될 것으로 예상됩니다.
Kolmogorov Arnold Network, 약칭 KAN은 Kolmogorov-Arnold representation theorem에 기반합니다. 이는 어떤 복잡한 함수를 간단한 함수들의 합으로 표현할 수 있다는 것으로 수식으로 나타내면 아래와 같습니다. 어떤 복잡한 함수 f(x,y,z)가 있을 때, 이 함수를 다음과 같이 표현할 수 있습니다[1].
위 수식에서 ϕq와 ψq는 간단한 함수를 의미하며, 위 수식의 의미는 복잡한 함수를 여러 개의 간단한 함수로 나누어 더하고, 합쳐서 만들 수 있다는 것입니다. 이를 조금 더 전문적으로 말하면 모든 복잡한 다변수 함수를 간단한 일변수 함수들의 합으로 표현할 수 있다라고 할 수 있으며, 위 수식도 아래와 같이 정리할 수 있습니다.
이를 간단히 구현해볼까요? x1, x2 두 개의 입력을 받는 sin 함수를 목표하고자 하는 함수로 정의하겠습니다. 이를 수식으로 나타내면 다음과 같습니다.
이에 대해서 x1, x2를 다음과 같이 정의하고 위 함수를 구현한 f_true에 통과시키는 코드와 최종적으로 생성된 Y true 값을 3차원에서 시각화하면 다음과 같습니다.
다음으로는 위 복잡한 함수에 근사할 수 있는 작은 함수들의 집합을 구현할 것인데요, 이를 통해 결과적으로 위의 오른쪽 그림에 근사하는 결과물을 만드는 것이 목표가 됩니다. 기본적으로 KAN 논문에서는 이를 구현하기 위해 spline 함수 등을 이용해 근사하도록 제시했지만, 해당 논문의 공식 코드 및 이후 KAN 구현체들에서 공통적으로 일반 선형 레이어로 대체해 사용하는 것을 고려해 간단한 선형층으로 구성된 모델을 구현해 보겠습니다.
369개의 파라미터를 가진 위 모델을 이용해 학습한 결과를 시각화하면 아래와 같습니다.
위 그림을 통해 알 수 있는 건, 간단한 함수들 간의 연산 등을 통해서 복잡한 함수에 근사할 수 있다는 것입니다. 다만 위 실험은 그저 Kolmogorov-Arnold representation theorem에 대한 실험일뿐, KAN이라고 부르긴 어려운데요. 차근차근 KAN이 되기 위한 구조를 살펴보도록 하겠습니다.
KAN은 기본적으로 MLP를 대체하고자 나온 개념입니다. 때문에 기존의 MLP와 유사하면서도 다릅니다. 이 때문에 KAN을 이해하기 위해선 MLP에서 시작하는 게 좋습니다. KAN 논문에서 제시된 그림을 살펴보면 다음과 같습니다.
위 그림에서 확인할 수 있듯 MLP와 KAN은 기본적으로 유사한 구조이지만, 내부 개념에 차이가 있습니다. 간단히 살펴보면 다음과 같습니다.
1. Theorum
MLP는 하나의 은닉층을 갖는 인공 신경망이 임의의 연속인 다변수 함수를 근사할 수 있다는 보편 근사 정리(Universal Approximation Theorem)에 기반하지만, KAN은 임의의 연속인 다변수 함수를 복수의 일변수 함수들의 조합으로 근사할 수 있다는 콜모고로프-아놀드 표현 정리(Kolmogorov-Arnold Representation Theorum)에 기반합니다.
2. Model
모델 관점에서는 크게 활성화 함수와 가중치, 비선형성의 처리에 대한 부분에서 차이가 있습니다. 우선 활성화 함수의 위치와 학습 가능성 관점에서 MLP에서는 각 노드(뉴런)에 고정된 활성화 함수(ReLU와 같은)를 사용하지만, KAN에서는 각 엣지(가중치)에 학습 가능한 활성화 함수(Spline)를 사용합니다.
이는 본 논문[1]에서는 유연한 학습을 가능하게 한다고 하지만, 이로 인해 탐색 범위(grid)와 몇 차(degree) 활성화 함수를 적용할지 등에 대한 추가적인 하이퍼 파라미터 설정과 탐색이 요구되며 이는 모델의 적절한 최적화가 매번 달라지고, 최적의 모델을 구현하기 어렵거나 많은 리소스를 요구하는 문제를 발생시킬 수도 있습니다.
모델링 관점에서 두 번째 차이는 가중치 표현 방식입니다. 이는 선형 가중치 행렬을 사용해 입력을 변환하는 MLP와 달리, 각 가중치를 스플라인으로 표현된 1차원 함수로 대체하는 것인데요. 이는 더 복잡한 변환이 가능하지만, 계산 비용이 증가할 수 있다는 단점이 있습니다.
마지막으로 비선형성의 처리 관점에서 각 노드에서 비선형 활성화 함수를 이용해 입력을 출력으로 변환하는 MLP와 달리, KAN에서는 노드 단에서는 비선형성을 적용하지 않고 단순히 신호를 합산합니다. 즉 KAN에서는 노드가 아닌 엣지에서의 1차원 스플라인 함수를 통해 비선형성을 확보하는 것입니다.
다만 이러한 비효율성에도 불구하고 KAN을 이용해 구현한 모델이 MLP를 이용해 구현한 모델보다 더 적은 연산 그래프를 지니기 때문에, 즉 더 적은 파라미터로 유사한 성능을 보일 수 있기 때문에 괜찮다라는 어필(위 그림에서 파란색 강조 부분)이 있습니다. 하지만 이에 대해서 다각도로 실험했을 때, 이를 뒷받침하는 근거를 얻을 수는 없었습니다 - 정확히는 100배 적은 파라미터를 적용하더라도 학습 비용이 일반 NN 모델에 비해 아주 높으며, 추론 과정에서는 본 논문에서 제시하는 이득을 확인하긴 어려웠습니다.
KAN 논문이나 여타 블로그 포스팅, 논문 리뷰를 종합적으로 살펴보고, 수차례 공식 코드를 검토하며 든 생각은 KAN은 단순히 어떤 구조의 모델이라기 보다는 전략적 개념에 가깝다는 것입니다. 본래의 의도는 지난 번에 리뷰한 FFN을 대체할 것으로 예상되는 MoE처럼 MLP를 특정 영역에서 대체할 수 있는 KAN을 익히고자 한 것이 본 포스팅의 시작이었는데, 실제로는 많은 문제가 있었습니다. 이를 알아보기 위해 논문의 내용에 기반해 KAN의 개념을 보다 자세히 살펴보며 진행해보도록 하겠습니다.
KAN에 대해 간략히 전체적인 구조를 살펴보겠습니다.
KAN은 기본적으로 B-Spline Basis Function으로 구성돼 있으며, 이들을 각각 미분가능하도록 배치하고, 학습 루프를 통해 최적화하며 표현 가능한 수식으로 구성된 레이어를 구축합니다. 이렇게 구현한 레이어를 여러 개의 층으로 쌓고, 학습시킨 뒤 수식에 기반에 중요 입력 요소들의 가중치를 평가합니다. 이러한 평가를 기반으로 가지치기(pruning) 혹은 각 요소들의 중요성 평가가 가능해지며, 해당 레이어를 수식으로 표현하는 것 또한 가능해지는 것입니다.
KAN의 시작이자 끝은 결국 B-spline 기저함수입니다. 이는 여러 점들을 부드럽게 이어주는 선을 만드는 방법으로, 기저함수를 기반으로 구현돼 있습니다. 정확히는 여러 점들을 몇 개의 구간으로 나누고, 이 구간 사이에 존재하는 점들을 기저함수에 기반한 곡선으로 표현한 뒤 이 곡선들이 연속되도록 잇는 방식입니다. 이 표현 방식에 대해 궁금하다면 Bézier curve를 참조하는 것을 추천합니다.
이러한 B-spline 기저함수로 주어진 점을 설명하기 위해서는 몇 가지 사전에 설정해야 하는 중요한 요소들이 있습니다. 이를 하이퍼 파라미터처럼 설정해야 하며, 바로 제어점(Control Points), 차수(Degree), 그리고 매듭 벡터(Knot Vector)입니다.
매듭 벡터(Knot Vector)는 각 구간의 경계점을 지정해주는 벡터로, 구간별로 기저함수에 기반한 곡선을 형성할 수 있게 해줍니다. 이때 곡선이 얼마나 매끄럽고 복잡하게 휘어질지를 결정하는 것이 차수(Degree)이며, 제어점(Control Points)은 곡선의 형태를 정의하는 기준이 되는 점들로, 곡선이 꼭 지나가지는 않지만 그 곡선의 모양을 결정하는 데 중요한 역할을 합니다.
이를 그림으로 살펴보면 아래와 같습니다. 가장 일반적인 degree=3, control points=6인 케이스를 살펴보겠습니다.
6개의 점을 차수가 3인 곡선으로 표현하고자 하면 위와 같이 표현됩니다. 설명하자면 차수가 3이기 때문에 6개의 점을 표현하기 위해 3개의 곡선으로 표시됩니다. 간단히 말해 n개의 점을 k차수로 표현할 경우 이를 표현하기 위한 곡선의 수는 n-k개가 되기 때문입니다(이는 1차 곡선=직선에는 두 개의 점이 필요하고, 2차 곡선에는 세 개의 점이 필요한 이유를 생각하면 이해하기 쉽습니다). 그렇다면 만약 k가 1일 때는 어떻게 되는지 살펴볼까요?
차수(k)가 1인 곡선은, 직선을 의미합니다. 두 점을 잇는 선이며, 그렇기 때문에 6개의 점을 잇는 5개의 선으로 구현되는 결과물은 위와 같은 식인 것입니다. KAN에서는 이러한 B-Spline 기저함수들을 조합해 아주 복잡한 수식도 표현할 수 있다고 주장합니다. 이러한 과정에서 요구되는 파라미터가 적기 때문에 학습 과정이 다소 비효율적이더라도 충분히 시도할만하며, 최종적으로 해당 task를 해결하기 위한 수식을 직접 확인할 수 있기 때문에 설명 가능성 측면에서 더욱 강점이 있다는 것입니다.
하지만 직접 구현했던 코드 및 공식 코드를 통해 테스트한 결과, 정확히는 적은 수의 파라미터로도 충분히 효율적이다라기 보다는 작지 않으면 제대로 성능을 보이지 못했습니다. 조금의 파라미터를 향상시킨 것만으로 더 많은 학습 epoch가 요구되거나, 오히려 작은 모델일 때보다 성능이 떨어진 것입니다. 아마 해당 논문에서 밝히지 않았거나, 아직 외면하고 있는 문제가 있는듯 해 보입니다.
참고로 위 작업을 직접 해보고 싶으시다면 여기로 가서 그림 하단의 확대 버튼을 눌러 진행해볼 수 있습니다.
KAN이 가진 가장 큰 장점은 단언컨데 수식화입니다. 일반적으로 기존 ML 모델은 기존에 사람이 찾아낸 알고리즘에 기반해 문제를 풀어내고, DL 모델은 문제를 해결할 수 있는 복잡한 수식을 만드는 것 자체도 알아서 하는 것을 지칭합니다. 하지만 Kolmogorov-Arnold Representation Theorem에서는 이 복잡한 수식도 결국 우리가 이미 알아낸 수식의 조합들로도 충분히 풀어낼 수 있음을 시사하기에, KAN에서는 지수와 로그를 비롯한 삼각함수 등을 통해 DL 모델을 통해 찾아낸 방법을 수식화해낼 수 있게 된 것입니다.
또한 이러한 과정에서 등장하는 개념이 바로 grid인데요. 간단히 말해 각 활성화 함수를 통해 한 번에 몇 개의 점들을 고려해 B-Spline 기저함수를 최적화할지 결정하는 것입니다. 당연히 처음부터 많은 지점을 살펴보게 되면 모델의 성능이 충분히 성숙되기도 어려울 뿐더러, 학습에 많은 시간과 비용이 소모됩니다. 때문에 본 논문에서는 학습 과정에서 epoch가 진행됨에 따라 grid를 조금씩 넓히는 전략을 취함으로써 모델을 자연스레 성숙시킴과 동시에 고도화하는 전략을 추천합니다.
즉, 학습 초기 단계에서는 각 활성화 함수가 한 번에 충분히 작은 수의 점만을 설명하도록 함으로써 모델을 빠르게 초기화시킨 뒤, epoch가 진행됨에 따라 grid의 범위를 넓혀가면서 세밀한 최적화를 수행하는 방식입니다. 또한 이 과정에서 지수와 로그, 삼각함수 등을 이용해 각 과정들을 설명해야 하다보니 충분히 깊지 않은 KAN 모델은 설명력이 떨어질 수 있으며, 이러한 과정 자체가 학습비용 증가와 직결된다는 문제가 있기도 합니다. (때문에 GNN 등과 같이 상품 추천이나 통계적 개념이 적용되는 영역과 결합함으로써 모델의 설명가능성이라는 장점을 극대화시키는 방향으로 개발하는 것이 나을 것 같기도 합니다)
또한 이 과정에서 중요성이 낮은 활성화 함수(B-Spline 기저함수)를 가지치기함으로써 모델을 보다 강건하게 만들 수 있는 효과를 가지기도 합니다. 이러한 과정을 확인하기 위해 공식 코드를 이용해 우리에게 익숙한 iris dataset을 학습시켜 그 과정과 결과를 살펴보도록 하겠습니다.
Iris dataset은 꽃잎의 길이와 너비, 꽃받침의 길이와 너비를 기반으로 setosa, versicolor, virginica의 3종으로 분류하는 데이터셋입니다. 본 테스트를 위해 사용한 KAN 모델은 이번에는 직접 구현에 실패해(이번 달 야근과 외부일정이 너무 많아 실패했습니다...특히 수식 구현하는 거 너무 복잡...), 공식 코드를 참조하였습니다.
공식 코드를 활용해 학습하고, 시각화하는 코드는 여기를 통해 확인할 수 있습니다. KAN 모델에 iris dataset을 적용해 초기화하면 초기 모습은 다음과 같이 시각화할 수 있습니다. 참고로 이때 모델의 shape은 (4,5,3)으로, 4개의 입력 변수(SL, SW, PL, PW)를 받아, 5개의 히든 노드를 통과해 최종적으로 3개(Set, Ver, Vir)의 iris 종으로 분류하는 아주 간단한 구조입니다. 2번째 층의 크기가 5인 이유는 본 논문[1]에서 이전에 있던 대부분의 Kolmogorov-Arnold 이론을 응용한 연구에서는 3개의 층으로 구성되며, 두 번째 층의 너비는 '2n+1로 한다'를 따른 결과입니다.
위 그림에서 각 점이 노드를 의미하며, 네모난 그림 안에 있는 선이 활성화 함수, 즉 초기화된 B-Spline 기저함수를 의미합니다. 아무래도 초기이다보니 간단한 형태입니다. 각 노드의 활성화 함수는 다음 노드의 숫자와 동일합니다.
참고로 본 논문[1]에서는 역전파에 의한 최적화가 가능하기 때문에 복수의 층으로 구성된 KAN에서는 굳이 이를 따를 필요가 없으며 더 적은 너비로도 길이가 길어지면(예컨데, 4>5>3 대신 4>2>2>3) 보다 높은 성능 달성도 가능하다고 하였으나 여기선 3개 층으로 구현하는게 더 시각화에 장점이 있어 이렇게 진행하였습니다.
학습 과정을 시각화하면 다음과 같습니다.
위 학습 과정에서 확인할 수 있듯 학습이 진행되며 각 B-Spline 기저함수가 최적화됨과 동시에 중요하지 않은 변수들에 대해서는 그 영향력을 줄여나가는 것을 확인할 수 있습니다. 최종적인 학습 결과를 시각화하면 다음과 같습니다.
이 다음 스탭으로는 두 가지를 선택할 수 있습니다. 하나는 grid를 넓혀 추가적인 학습을 하는 것이고, 다른 하나는 가지치기를 하는 것입니다. 하지만 이미 여기서 모델의 성능이 학습 데이터 정확도 98.33%와 테스트 데이터 정확도 100%를 달성했기에 추가적인 학습은 하지 않고, 가지치기를 진행해주도록 하겠습니다. 가지치기는 공식 코드 상의 prune function을 이용해 적용할 수 있으며, 그 코드와 결과는 아래와 같습니다.
이렇게 가지치기를 한 뒤에는 그대로 사용하는 것도 방법이지만, 본 논문에서는 간단히 파인튜닝함으로써 모델의 성능이 떨어지는 것을 최소화하는 것을 권장하고 있기에 그에 맞춰 진행해보겠습니다.
간단하게 50에포크만 학습했고, 모델이 더욱 가벼워진 덕분에 최초 학습 시에는 100에포크 학습을 위해 6분 정도가 걸렸던 것과 달리 거의 1초만에 파인튜닝이 완료됐으며, 성능은 가지치기를 하기 전과 동일한 수준입니다. 여기까지 진행한 모델의 구조는 다음과 같습니다.
파인튜닝의 과정에서 활성화 함수가 보다 최적화 됐고 각 가중치도 조금 조정된 것을 확인할 수 있습니다. 이러한 모델의 결과를 수식으로 확인하면 다음과 같습니다.
위 수식은 모델의 결과를 구현하기 위해, 즉 각 iris 품종으로 분류하기 위한 수식을 의미합니다. 예컨데 formula1은 첫 번째 품종인 Setosa일 가능성(값)이 49.7737206361179−14.2755109129876 * x3(PL)을 해야한다는 의미로, 직관적으로 해석하면 꽃받침의 길이(PL)가 길수록 Setosa일 확률은 약 14.275의 가중치만큼 감소한다 라고 할 수 있을 것 같습니다.
마지막으로 본 논문에서 제시한 바와 같이 MLP 모델과 비교해볼 것인데요, 위 그림에서와 같이 본 논문에서는 KAN이 학습에는 일반적으로 10배 정도 느리지만, 100배 더 파라미터 효율적이며, 100더 정확하다고 하였는데요. 과연 그럴지 직접 비교해보겠습니다.
우선 약 100배 큰 MLP를 간단히 구현합니다. (4, 5, 3)이었던 KAN을 고려해 (4, 500, 3)의 MLP를 정의하며, 이를 살펴보면 다음과 같습니다.
이 모델을 동일한 loss와 optimizer, 학습율로 100에포크 학습시킨 결과는 다음과 같습니다.
KAN 모델의 경우, 학습 데이터셋에 대한 성능은 약 95~99% 범위에 속하며 테스트 데이터셋은 대부분의 경우 100% 맞춥니다. 반면 MLP 모델의 경우, 학습 데이터셋에 대한 성능은 93~98%이며, 테스트 데이터셋에 대한 성능은 그보다 살짝 낮거나 비슷한 수준을 유지합니다.
때문에 평균적으로 KAN 모델의 성능이 더 높았습니다만 반드시 그런 것은 아니었습니다. 다만 KAN 모델의 100에포크 학습에는 약 360초가 소요됐지만, 100배 큰 mlp 모델은 2초 이내에 학습이 완료돼 약 180배의 속도 차이를 보였습니다. 또한 추론 속도에 있어서도 약 14배의 속도 차이를 보여, 실질적으로 수식이 제시됨으로써 모델의 투명성이 높아지는 것 외의 장점은 아직까진 없어 보입니다.
[1] Ziming Liu, et al. "KAN: Kolmogorov–Arnold Networks." https://arxiv.org/pdf/2404.19756.
[2] Runpeng Yu, Weihao Yu, Xinchao Wang. "KAN or MLP: A Fairer Comparison." https://arxiv.org/pdf/2407.16674.
[3] Van Duy Tran, et al. "Exploring the Limitations of Kolmogorov-Arnold Networks in Classification: Insights to Software Training and Hardware Implementation." https://arxiv.org/pdf/2407.17790v1.
[4] Daniel Bethell. "Demystifying Kolmogorov-Arnold Networks: A Beginner-Friendly Guide with Code." https://daniel-bethell.co.uk/posts/kan/.
[5] Fei Cheung. "Dissecting Kolmogorov-Arnold Network." https://feicheung2016.medium.com/dissecting-kolmogorov-arnold-network-f1bee719d949.