Long Context에 대응하는 학습 및 추론 전략

챗봇과 같은 RAG 시스템에서 Long Context에 대한 니즈

by Dr Jenna

LLM 활용도가 올라가고 Reasoning이 추가되고 RAG를 기반으로 검색 결과 통합을 하게 되면서 점점 더 긴 토큰 시퀀스가 필요하게 되었다.


이러한 Long Context Task에서 LLM 모델을 Training 및 Inference 할 때 필수로 고려해야 하는 요소들을 정리 해 보았다.



1. 항목: 메모리 사용(특히 KV/Activation)



1. 학습(Training)

Gradient/Optimizer/Activation까지 올라와 전체 메모리 압박 극대화 **하단에 상세 내용 추가 설명

Activation checkpointing(=recompute), ZeRO/옵티마 분산, bf16/fp16, (가능하면) fp8 훈련

길이 버킷팅과 패킹으로 패딩 최소화


2. 추론(Inference)

Gradient/Optimizer 없음 → KV Cache가 메모리의 핵심

KV Cache fp16/bf16 기본, 필요 시 int8/fp8로 추가 절감

Prefix/Paged KV, 슬라이딩 윈도우로 메모리 선형 증가 제어


3. 실무 팁


학습: “토큰/디바이스”를 먼저 고정하고 이후 배치·축적 스텝으로 Throughput 맞추기

추론: max_model_len·max_num_seqs·max_num_batched_tokens 세 축으로 OOM 경계 잡기


────────────────

2. 항목: 어텐션 연산 최적화(복잡도 O(N^2))



1. 학습(Training)


FlashAttention(Backward 지원 버전), fused kernels 필수

(모델 설계가 허용하면) block/로컬/스파스 패턴으로 학습부터 일관 적용


2. 추론(Inference)


FlashAttention, GQA/MQA(=KV head 축소), Windowed/Sliding Attention 적극 활용

Speculative decoding로 디코드 병목 완화


3. 실무 팁


학습·추론 모두 “모델이 학습한 패턴”과 일치해야 품질 유지(학습 때 full, 추론 때만 sparse는 위험)


────────────────

3. 항목: 위치 인코딩/Long Context 일반화


1. 학습(Training)

RoPE scaling/ALiBi/YaRN 등 “확장 방식”을 학습 초기에 반영

Curriculum(짧은→긴 시퀀스)로 안정화


2. 추론(Inference)

학습 시 채택한 방식과 동일하게 로드(rope_theta/rope_scaling 일치)

학습 길이 초과 사용 시 품질 저하 가능 → 스케일링/가드 필요


3. 실무 팁

모델 config와 런타임 스케일링 값의 불일치 여부를 배포 자동 점검 항목에 포함


────────────────

4. 항목: Prefill 단계 비용


1. 학습(Training) **하단에 상세 내용 추가 설명

긴 컨텍스트 Forward가 그대로 비용 → 길이 버킷팅, 다중 패킹(패딩 제로화), 시퀀스 병렬/텐서 병렬

Activation checkpointing으로 메모리 절감, I/O 파이프라인 튜닝


2. 추론(Inference)

Chunked prefill로 디코드와 인터리브(Interleave), Continuous batching으로 GPU 유휴 최소화

Prefix KV로 공통 프리필 스킵


3. 실무 팁

학습은 “패딩을 없애고 재계산 선택적으로”, 추론은 “청크·배칭·재사용”이 키워드


────────────────

5. 항목: I/O·스케줄링


1. 학습(Training)

메모리맵/샤딩된 데이터, 프리토큰화, prefetch로 GPU 놀지 않게

통신(InfiniBand/NVLink)·동기화 최적화


2. 추론(Inference)

Continuous batching, Paged KV, (vLLM) 스케줄러로 프리필/디코드 공정 조정

레이턴시 목표면 eager 실행, 처리량 목표면 배치 상한 증대


3. 실무 팁

“GPU가 기다리지 않게”가 공통 목표, 구현 수단은 단계별로 다름


────────────────

6. 항목: 품질 검증/리스크


1. 학습(Training)

긴 시퀀스용 검증 세트, 길이 커리큘럼, 롱-레인저 과제 포함

스파스/윈도우 패턴 학습 시 다운스트림 영향 점검


2. 추론(Inference)

KV 양자화/int8·fp8, 윈도우 크기(W) 조정 시 태스크별 A/B 필수

롱 컨텍스트 RAG 파이프라인과 함께 종단 품질 평가


3. 실무 팁

“추가 절감(양자화/윈도우)”은 반드시 태스크 기준으로 승인


※ 부록1: 학습 단계에서 Activation 비용 증가


Q. 왜 토큰 수가 늘어나면 Activation 메모리가 커지는가?

A.
Transformer의 각 레이어는 Self-Attention과 FFN(Feed Forward Network) 등으로 구성되며,

학습 시에는 Backpropagation을 위해 각 토큰의 **중간 계산 결과(Activation)**를 모두 저장해야 한다.

Activation이란, 모델의 모든 Layer 내부의 중간 출력값을 통칭하는 개념이다.
ㄴ FFN의 중간 출력
ㄴ Self-Attention의 Q/K/V 텐서, Attention Score, Context Vector 등

Long Context에서 메모리 부담이 폭발하는 이유는 FFN 자체보다는 **Attention Score가 O(seq_len²)**로 증가하기 때문이다.

Memory Complexity
ㄴ FFN Activation 메모리: O(seq_len) (토큰 수에 선형 비례)
ㄴ Attention Activation 메모리: O(seq_len²) (토큰 수의 제곱에 비례)

최적화 방안
ㄴ Activation Checkpointing (= recompute): 메모리를 절약하고, 필요한 시점에만 중간값을 다시 계산
ㄴ 분산 학습(ZeRO, Optimizer Sharding): GPU 간 메모리 분산
ㄴ 저정밀도 사용(bf16 / fp16): Activation 저장 시 메모리 절감


※ 부록2: 학습 단계에서 Prefill 비용 증가


학습에서 “Prefill”은 곧 긴 시퀀스를 한 번에 Forward하는 과정 전체를 뜻한다.

Long Context 학습의 병목이 바로 여기서 생기므로, 아래를 순서대로 챙기는 게 좋다.

1. 길이 버킷팅(length bucketing)과 다중 패킹(multi-packing)

문제: 서로 다른 길이 샘플을 섞으면 패딩이 폭증 → 쓸모없는 계산/메모리 낭비

해법 A(버킷팅): 길이가 비슷한 샘플끼리 배치 구성 → 패딩 최소화

해법 B(패킹): 여러 짧은 시퀀스를 한 시퀀스로 이어붙여 패딩을 0에 가깝게; 서로 다른 샘플 간 어텐션이 섞이지 않도록 attention mask/segment id로 차단

실무 팁: “Loss는 각 샘플의 유효 구간에만” 계산(라벨 시프팅/마스킹)


2. FlashAttention(Backward 지원) + 커널 퓨전

Forward/Backward 모두에서 메모리 I/O를 줄여 O(N^2)의 벽을 완화

Fused RMSNorm/SiLU, fused softmax 등 커널 퓨전으로 메모리 왕복 최소화

실무 팁: 커널/드라이버 호환 매트릭스 확인(환경이 미세하게 다르면 속도/안정성 편차 큼)


3. Activation checkpointing(=recompute)와 선택적 적용

모든 중간 활성값(activations)을 저장하지 않고, Backward 때 재계산

메모리는 크게 줄지만 시간은 증가 → “어텐션/MLP에만” 선택 적용 등으로 타협

실무 팁: 레이어별 메모리 프로파일링 후 상위 기여 구간에 우선 적용


4. 분산·병렬 전략: TP/PP/시퀀스 병렬(Sequence Parallelism)

TP: 거대한 행렬을 헤드/기둥 축으로 쪼개 병렬; Long Context에서도 기본축

PP: 레이어를 스테이지로 나눠 마이크로배치로 파이프라인 채우기(버블 최소화)

시퀀스 병렬: 시퀀스 차원 기준으로 분할(일부 스택에서 지원)해 메모리 분담

실무 팁: 통신 오버헤드 vs 메모리 이득을 벤치마크로 결정


5. Optimizer/메모리: ZeRO/8-bit 옵티마/오프로딩

ZeRO-1/2/3로 옵티마/그라디언트/파라미터 분산, 8-bit Adam/Adafactor로 상태 축소

CPU/NVMe 오프로딩은 “마지막 수단”; I/O 병목이 생길 수 있으니 균형 필요


6. 혼합 정밀 및(가능 시) FP8 훈련

bf16 기본, 환경이 허락하면 일부 경로의 fp8로 HBM 대역폭/메모리 절약

안정화: grad scaling, norm clipping, warmup 스케줄로 발산 방지


7. 데이터/로더 I/O 파이프라인

대용량 토큰화 결과를 메모리맵(mmap)/샤딩하고, 프리패치/다중 워커

시퀀스 샘플링 전략: sliding-window 샘플러(겹치기), 문서 경계 유지/랜덤 시프트

실무 팁: GPU 유휴(Idle) 시간 0% 목표로 로더 모니터링


8. Long Context 일반화

RoPE/ALiBi/YaRN 등 “확장 방식”을 학습부터 반영

Curriculum: 4k→8k→32k→… 점진 증가 스케줄

검증셋: 긴 문서 QA, 코드베이스 질의, 장문 요약 등 “실전형”으로 구성

keyword
작가의 이전글복잡하고 어려운 LLM Serving 핵심 요약 버전