brunch

You can make anything
by writing

C.S.Lewis

by Chris송호연 Feb 23. 2020

ML 최적화 1. JIT & google JAX

머신러닝 성능 최적화 시리즈 1

안녕하세요- Chris 입니다 :)

번역 글만 작성하다가 오랜만에 제 글을 작성하네요.


이번에 시리즈로 소개해드릴 주제는 머신러닝 성능 최적화입니다.

관련된 토픽을 한번 나열해보자면 이 정도가 되겠네요. 


- Google JAX(XLA & AutoGrad) - JIT

- Google JAX - 분산 컴퓨팅 환경에서의 선형 대수학 연산

- tf.data를 활용한 Data pipeline 최적화

- Data / Model parallelism

- Multi-GPU 전략(All-Reduce, Async Parameter Server)

- Python 비동기 프로그래밍(Multi-Thread / Multi-Process)

- Distributed ML Computing(Kubernetes, Slurm, Cloud Solutions)


첫 번째로 소개해드릴 주제는 바로 JAX의 핵심 기능인 JIT(Just-In-Time) Compilation입니다.

https://github.com/google/jax


1. Interpreter & Compiler


Python은 Interpreter 언어입니다. C언어 처럼 미리 컴파일을 하고 실행하는 게 아니라, 실행될 때 한 줄씩 명령이 실행됩니다. 우선 컴파일러와 인터프리터의 차이를 살펴보겠습니다.


Interpreter

- 컴파일 시간이 필요 없음

- 런타임 실행 속도가 느림

- 필요한 변수를 그때 그때 로드해서 메모리 효율적임

- OS, 빌드 환경에 종속적이지 않음

- 한 statement 씩 실행시킴


Compiler

- 전체 코드를 컴파일해서 컴파일 시간이 오래 걸림

- 런타임 실행 속도가 빠름

- 런타임 실행 시에 관련된 라이브러리와 변수를 모두 메모리에 올려서 메모리가 많이 필요함

- (대게 어셈블리어로 변환시키면서) 프로그램이 OS 및 빌드 환경에 종속됨


2. Python, Interpreter


여러분이 만약 프로젝트에서 Data와 Machine Learning을 사용하기로 했다면, Python은 피할 수 없는 선택입니다. 모든 Data 관련 업무가 Python을 중심으로 이루어지고 있습니다. Python은 개발 생산성 면에서 탁월한 반면에, 단점들도 많이 갖고 있습니다. 


우선 가장 먼저 떠오르는 문제점은 명확합니다.

Python은 느립니다


Python의 속도를 빠르게 하기 위해선 어떤 방법들이 있을까요?


- Cython: Python 코드를 C 언어로 컴파일하여 속도를 향상

예를 들어, sklearn와 numpy 같은 유명한 라이브러리들은 Cython으로 속도를 향상시켰습니다. 

- C: 그냥 C언어로 플랫폼 별로 바이너리를 만들고 Python에서 가져다 씁니다.

주로 성능이 중요한 대규모 라이브러리의 경우 이런 접근법을 많이 씁니다. Tensorflow의 경우엔 기능을 먼저 Python으로 구현한 후 해당 모듈을 C언어로 변경하는 식으로 개발 업무를 진행합니다. Tensorflow에선 bazel이라는 통합 빌드 시스템으로 Tensorflow의 C언어와 Python을 한꺼번에 빌드해줍니다. 


https://bazel.build/


- JIT: 런타임 실행 시 코드의 일부분을 컴파일하여 속도를 향상

제가 오늘 소개해드리려고 하는 방식입니다. 함수에다가 JIT 컴파일러를 달아서 런타임 시에 해당 코드를 컴파일하는 방식입니다. Tensorflow 2.0에선 @tf.function 데코레이터로 JIT을 수행합니다. JAX에서는 @jit 데코레이터로 해당 함수를 JIT 컴파일시킬 수 있습니다. 



3. JIT(Just In Time) Compilation


JIT Compilation은 런타임 시에 코드를 컴파일해서 프로그램을 실행하는 방식입니다. JIT 컴파일은 자바 프로그래밍 언어가 널리 보급되면서 많이 쓰이기 시작했는데요- 자바 초기 버전에선 아주 심플한 연산만 최적화를 해주었지만, 외부 모듈로 JIT연산이 가능하게끔 해주는 hook를 공개한 후로는 JIT 컴파일을 많이 사용하게 되었습니다. 


JIT 컴파일은 런타임 시에 특정 프로그램을 기계어로 변환해주는 역할을 합니다. 물론 빨라집니다. 하지만 JIT 컴파일은 공짜가 아닙니다. JIT으로 프로그램 코드를 기계어로 변환해주는 데는 오버헤드가 있습니다. 


JIT의 장점: 

- 특정 연산의 속도가 급격히 빨라집니다

- 인터프리터 언어의 장점을 살려 개발한 후 JIT을 적용하면 되니 개발이 편합니다


JIT의 단점:

- JIT 컴파일 오버헤드가 있습니다.

- JIT 컴파일을 적용할 코드가 복잡하면 성능향상이 제대로 이루어지지 않을 수 있습니다.


위 단점에서 말했 듯이, 코드가 복잡하면 JIT이 제대로 소화를 못할 수 있습니다. 그러니 JIT 적용하시려고 한다면, 코드가 명확하고 깔끔하게 정리되있어야 JIT의 성능 향상을 제대로 경험할 수 있을 것입니다.


Python 프로그램의 속도를 빠르게 만들기 위해 JIT은 상당히 유용합니다. 우선, 빌드 과정이 Cython에 비해 간편하며, 유저가 커스터마이징하기 쉽기 때문에 코딩도 상당히 직관적입니다. 특히나 지속적으로 새로운 구조의 아키텍쳐를 시험해보셔야 하는 연구자분들에게는 이런 확장성이 중요합니다.


4. Speed up numpy with JAX


간단하게 numpy 연산을 JAX numpy로 변경했을 때 속도의 차이를 한번 비교해보겠습니다. 

따라할 수 있는 튜토리얼 코드는 github + colab으로 만들어두었으니, 확인해보셔요!


github star를 누르면 여자친구가 생깁니다


github

https://github.com/chris-chris/jax-tutorial/blob/master/JAX_Tutorial.ipynb

Colab Notebook

https://colab.research.google.com/github/chris-chris/jax-tutorial/blob/master/JAX_Tutorial.ipynb


JAX 튜토리얼 1: numpy 성능 압살


첫 번 째로 비교해볼 연산입니다. cos(x) sin(y)를 더하는 연산입니다. 단순한 연산인데 그냥 numpy로 했을 때 401 ms 걸리던 연산이 2.15ms로 속도가 압도적으로 빨라졌습니다. 참고로 numpy도 Cython을 사용해서 C언어로 구현되었다고 볼 수 있지만, JAX의 경우에는 numpy가 최적화하지 못한 loop 등을 최적화하고 GPU 자원까지 활용해서 속도를 향상시켰습니다.

JAX 튜토리얼 2: JAX JIT & @tf.function JIT 역시 성능 압살


JAX 기능은 사실 Tensorflow에서 있던 XLA 기능을 빼와서 독립적인 모듈로 만들어낸 것입니다. 즉 Tensorflow에도 있긴 합니다. Tensorflow 2.0에선 단지 @tf.function 데코레이터만 함수에 붙여주면 JIT 연산이 가능해집니다. (정말 쉽죠?)


그래서 JAX JIT 돌려보는 김에 같이 해봤습니다. 

3가지 함수를 선언했습니다. 마찬가지로 sin(x) + cos(y) 연산입니다.

1. fn: 그냥 함수

2. fn_jit: JAX JIT 적용한 함수

3. fn_tf2: @tf.function 적용한 함수

벤치마크 결과를 비교해보겠습니다.

1. fn: 그냥 함수 - 780ms

2. fn_jit: JAX JIT 적용한 함수 - 2.12ms

3. fn_tf2: @tf.function 적용한 함수 - 3.36 ms


JIT이 제일 빠르고 그 다음 @tf.function이 빠르고 numpy는 답이 없게 느리네요.

위 연산에서 만큼은 JAX JIT이 numpy에 비해 약 367배 빨랐습니다.



5. 결론


JAX를 쓰면 Tensorflow나 Pytorch를 안써도 된다는 말이 아닙니다. 프로젝트의 목적 자체가 다릅니다. 만약에 여러분이 Tensorflow나 Pytorch가 아직 제대로 지원하지 않는 새로운 방식의 아키텍쳐를 구상하고 있으시거나 특별한 대규모의 머신러닝 시스템을 구축하고자 한다면 이 기법은 상당히 유용할 것이라 믿습니다. 특히나 아직 제대로 자리잡지 않은 Neural Architecture Search 같은 특별한 연구 분야에도 잘 어울릴 것이라 생각합니다.


중요한 사실은 JAX는 딥러닝 플랫폼에 독립적이라는 점입니다. 
Pytorch, Caffe 등 어디서나 가져다 쓰시면 됩니다. 


참고로, Deepmind는 최근에 전체 코드를 JAX를 기반으로 리팩토링해서 라이브러리들을 오픈소스로 공개했습니다.


https://github.com/deepmind/rlax

https://github.com/deepmind/dm-haiku

6. 참고 자료


https://www.tensorflow.org/xla


예전 Google Tensorflow 2017 컨퍼런스에서 XLA 기능이 소개되었습니다. 초기에는 Tensorflow  XLA 기능은 사용이 조금 어려웠지만, Tensorflow 2.0에 와서는 @tf.function 이라는 아주 직관적인 방식으로 XLA를 수행할 수 있게 되었습니다. 


https://github.com/google/jax

https://hub.packtpub.com/google-researchers-introduce-jax-a-tensorflow-like-framework-for-generating-high-performance-code-from-python-and-numpy-machine-learning-programs/

https://iaml.it/blog/jax-intro-english


브런치는 최신 브라우저에 최적화 되어있습니다. IE chrome safari