Megakernels
NVIDIA, AMD 같은 칩 제조사와 vLLM, Ollama 같은 추론 엔진들이 주목하는 기술이 바로 Megakernels입니다. 메가커널은 여러 개의 작은 GPU 커널 작업을 하나의 큰 커널로 합치는 방식인데, 이것이 배치 처리, 메모리 접근, 계산 오버헤드를 동시에 개선합니다.
문제 설정: 작은 커널들의 비효율성
전통적인 트랜스포머 추론은 여러 단계로 나뉩니다. 쿼리-키-값 변환, 어텐션 계산, 다음 토큰 샘플링 등이 각각 별개의 GPU 커널로 실행됩니다. 각 커널이 시작되고 종료될 때마다 GPU 메모리에서 중간 결과를 읽고 쓰고, 다음 커널로 데이터를 전달해야 합니다. 특히 생성 단계(decode phase)에서 배치 크기가 1~32 정도로 작을 때, 이런 오버헤드는 실제 계산 시간을 능가합니다.
예를 들어 bfloat16 추론에서 단일 토큰 생성이라면, 어텐션 계산에는 수 밀리초만 걸리지만 커널 호출, 메모리 전송, 동기화에는 그보다 더 오래 걸립니다. CUDA 프로파일러를 보면 "대기 시간"이 실제 계산 시간보다 많을 수 있습니다.
메가커널의 핵심 아이디어
메가커널은 이 문제를 "한 번에 처리"하는 방식으로 접근합니다. 쿼리 계산, 어텐션, 출력 투영을 하나의 커널로 융합(kernel fusion)하면, 중간 결과가 전부 GPU 로컬 메모리(shared memory)에 머물 수 있습니다. 메인 메모리 VRAM으로 왕복할 필요가 없어집니다.
이 구조의 장점:
- 메모리 대역폭 절약: 각 중간값이 VRAM을 왕복하지 않으므로 대역폭 경합이 줄어듭니다. prefill(긴 시퀀스 처리)에서는 10배 이상 대역폭 효율이 높아질 수 있습니다.
- 커널 호출 오버헤드 제거: CPU-GPU 동기화, 커널 시작 비용이 사라집니다. 디코드 단계 같은 작은 배치에서 20~30% 속도 향상.
- 레지스터·공유 메모리 활용: 여러 서브 작업이 같은 스레드 블록 내에서 데이터를 공유하므로, 캐시 히트율이 상승합니다.
실제 구현의 복잡성
메가커널은 간단해 보이지만, 실제 구현은 까다롭습니다.
첫째, 메모리 레이아웃을 맞춰야 합니다. 쿼리는 배치 × 차원, 키는 시퀀스 × 차원, 값도 마찬가지인데, 이들을 하나의 커널 내에서 효율적으로 접근하려면 메모리 배치를 정교하게 설계해야 합니다. Triton 같은 고수준 언어를 쓰면 조금 쉬워지지만, CUDA로 직접 짜면 상당한 튜닝이 필요합니다.
둘째, 배치 크기가 다양합니다. prefill과 decode는 배치 크기 특성이 완전히 다릅니다. prefill은 배치 크기가 크지만(예: 32) 각 요청의 시퀀스 길이가 다양합니다. decode는 배치 크기가 작지만(예: 1) 시퀀스 길이가 길고 고정입니다. 두 경우를 모두 지원하는 메가커널을 짜려면 조건부 분기가 많아집니다.
셋째, 정확도를 보장해야 합니다. 여러 작은 커널을 합치면서 수치 안정성을 잃을 수 있습니다. 특히 어텐션에서 softmax를 계산할 때, 로컬 메모리에만 데이터를 두고 감마 정규화(gamma normalization)를 하려면 각 워프(warp) 간 축약(reduction)을 정확히 구현해야 합니다.
실제 사례: FlashAttention과 메가커널의 관계
FlashAttention은 메가커널의 대표 사례입니다. 어텐션 계산을 한 커널로 융합해서, 외부 메모리(HBM) 접근을 줄이고 내부 SRAM을 활용하는 방식입니다. 원래 어텐션은 Q × K^T(시간 복잡도 O(n^2))를 먼저 계산한 다음, softmax를 취하고, V와 곱합니다. FlashAttention은 이 세 단계를 "블록 단위"로 조금씩 계산합니다. 즉, 메모리 접근 패턴을 I/O에 최적화합니다.
FlashAttention의 speedup은 실제로는 메모리 대역폭이 병목인 상황에서 나타납니다. H100 같은 고사양 GPU에서 FP8 추론을 할 때, FlashAttention 2는 기존 구현 대비 2~3배 빠릅니다.
메가커널이 모든 것을 해결하지 못하는 이유
메가커널은 강력하지만 만능은 아닙니다.
먼저, 메모리 크기 제약이 있습니다. 어텐션 계산에서 Q와 K를 행렬곱할 때 중간 결과(attention weights)가 나오는데, 이것이 배치 × 시퀀스 × 시퀀스 크기입니다. 4K 이상의 긴 시퀀스라면 이 행렬이 메모리에 안 들어갈 수 있습니다. 따라서 메가커널도 "블록 단위"로 재귀 호출합니다.
둘째, 모든 연산이 병합 가능한 것은 아닙니다. 예를 들어 MoE(Mixture of Experts) 모델에서는 라우터 네트워크가 어떤 전문가를 선택할지 동적으로 결정합니다. 메가커널이 이를 포함하려면 분기가 많아져 실제로는 느려질 수 있습니다.
셋째, 커스터마이제이션 비용이 높습니다. 새로운 어텐션 메커니즘(예: GQA, MQA)이 나올 때마다 메가커널을 다시 짜야 합니다. 표준화된 인터페이스가 없어서 각 프레임워크가 독립적으로 구현하고 있습니다.
현황과 트렌드
2024~2025년 현재 메가커널은 표준에 가까워지고 있습니다. vLLM은 Paged Attention이라는 메모리 관리 기법과 메가커널을 결합해서 높은 처리량을 달성합니다. TensorRT-LLM은 NVIDIA가 직접 제공하는 메가커널들의 집합입니다. 오픈소스로도 FlashAttention, Flash-2 같은 메가커널 구현들이 PyTorch 커뮤니티에 통합되고 있습니다.
다만 표준화 문제가 남아 있습니다. 같은 기능을 여러 프레임워크가 다르게 구현하고 있어서, 성능 비교가 어렵습니다. 또한 새로운 GPU 아키텍처(예: Hopper, Blackwell)가 나올 때마다 메가커널을 재최적화해야 합니다.
메가커널이 해결하려는 근본 문제는 "작은 배치에서의 레이턴시"입니다. 배치 크기가 작을수록, 커널 호출 오버헤드가 전체 시간의 더 큰 부분을 차지합니다. 따라서 실시간 추론 시스템에서 메가커널의 중요성은 계속 높아질 것 같습니다.