The Free Transformer

🏷️ 논문 딥러닝

F. Fleuret, "The Free Transformer", arXiv preprint arXiv:2510.17558, 2025.

요약

아키텍처: 표준 디코더 트랜스포머를 조건부 변분 오토인코더(VAE)로 확장했습니다. 중간 레이어에 랜덤 잠재 변수 \(Z\)를 주입하고, 인코더는 첫 번째 절반의 레이어와 비인과적 트랜스포머 블록 하나로 구성됩니다.

1-freetransformer.png

모델 크기: 1.5B 모델(28레이어)과 8B 모델(32레이어, Llama-3 구조)을 각각 47B, 200B, 1T 토큰으로 훈련했습니다. 인코더로 인한 오버헤드는 1.5B에서 3.6%, 8B에서 3.1%에 불과합니다.

핵심 메커니즘:

훈련 손실: 표준 교차 엔트로피 + 제어된 KL 발산 \[\mathcal{L} = \text{CE} + \frac{1}{T}\sum_{t=1}^T \max\left(0, D_{KL}(Q(Z_t|S) | P(Z_t)) - \kappa\right)\]

성능 향상 (8B, 1T 토큰):

평가 벤치마크: 코드/수학 생성(HumanEval+, MBPP, GSM8K), 다지선다 상식 추론(MMLU, CSQA, HellaSwag), 독해(RACE, BoolQ), 지식 검색(NQ, TriviaQA) 등 15개 태스크에서 평가했습니다.

논문 상세

1. Introduction

트랜스포머의 발명 이후 거의 10년이 지났지만, 자기회귀 모델링은 본질적으로 도전받지 않았습니다. 이 논문은 이 핵심 설계 측면을 재검토하여 더 풍부하고 자연스러운 밀도 모델이 나타날 수 있도록 합니다.

디코더 트랜스포머는 자기회귀 이산 밀도 근사기입니다. 토큰 시퀀스 \(S_1, \ldots, S_T\)를 모델링하여 각 토큰이 이전 토큰들이 주어졌을 때의 조건부 분포를 추정합니다. 이러한 모델이 구현하는 유일한 밀도 모델링과 샘플링은 생성된 토큰의 것입니다. 특히 디코더 트랜스포머는 생성할 토큰 스트림에 대한 추가적인 잠재 결정을 내리지 않습니다.

간단한 예를 살펴보겠습니다. \(Z \sim B(0.5)\)를 잠재 "동전 던지기"라고 하고, \(X_1, \ldots, X_T\)는 확률 \(\epsilon\)의 독립적인 플립으로 \(Z\)와 같다고 하겠습니다. \(X_t\)들은 \(Z\)가 주어졌을 때 조건부 독립이며:

\[P(X_{t+1} = 1 | Z = z) = \epsilon z + (1-\epsilon)(1-z)\]

하지만 \(Z\) 없이 자기회귀 모델로 표현하면:

\[P(X_{t+1} = 1 | X_1 = x_1, \ldots, X_t = x_t) = \frac{\left(\frac{\epsilon}{1-\epsilon}\right)^{\sum_{s=1}^t x_s}(1-\epsilon)^{t+1} + \left(\frac{1-\epsilon}{\epsilon}\right)^{\sum_{s=1}^t x_s}\epsilon^{t+1}}{\left(\frac{\epsilon}{1-\epsilon}\right)^{\sum_{s=1}^t x_s}(1-\epsilon)^t + \left(\frac{1-\epsilon}{\epsilon}\right)^{\sum_{s=1}^t x_s}\epsilon^t}\]

순수한 자기회귀 밀도 모델은 잠재적으로 여러 단점을 겪습니다:

2. Motivation

체인 룰로 인해 모든 밀도는 자기회귀로 모델링될 수 있습니다. 하지만 특히 "자연스러운" 구조가 잠재 변수에 대한 조건부를 포함할 때, 신호의 자기회귀 모델은 잠재 변수를 포함한 전체 결합 모델보다 훨씬 더 복잡할 수 있습니다.

이 연구의 주요 목표는 훈련 예제에 의해 부과되지 않는 잠재 랜덤 양에 자기회귀 프로세스를 조건화할 자유를 모델에 제공하여 이러한 문제를 해결하는 것입니다.

3. Method

조건부 변분 오토인코더: 랜덤 변수 \(Z\)에 의존하는 모델로 처음부터 전체 시퀀스를 생성하는 것은 간단합니다. \(Z \sim P(Z)\)를 샘플링한 다음 표준 자기회귀 프로세스를 실행하면 됩니다.

그러나 모델을 훈련하는 것은 훨씬 더 복잡합니다. 훈련 샘플 \(S\)가 주어지면 목표는 다음을 최대화하는 것입니다:

\[P(S) = \int_z P(S | Z=z)P(Z=z)dz\]

VAE의 인코더 역할은 "좋은" 분포 \(Q(Z|S)\)에서 샘플링하여 샘플링된 \(Z\)가 디코더를 변조하여 \(S\)를 생성하도록 하는 것입니다.

모델 구조: Free Transformer는 중간 레이어에 노이즈 \(Z\)가 주입된 표준 디코더입니다. 이를 통해 트랜스포머 블록의 절반을 인코더와 공유하여 인코더에 특정하게 계산해야 하는 단일 트랜스포머 블록만 있으면 되므로 계산 오버헤드를 대폭 줄입니다.

\(1024 \times 1024\) 이미지를 입력한다고 가정하면 DeepEncoder는 이를 \(1024/16 \times 1024/16 = 4096\) 패치 토큰으로 분할합니다. 첫 번째 절반의 인코더가 윈도우 어텐션이 지배적이고 80M만 사용하므로 활성화가 허용 가능합니다. 글로벌 어텐션에 들어가기 전에 4096개의 토큰이 압축 모듈을 거쳐 \(4096/16 = 256\)개가 되므로 전체 활성화 메모리가 제어 가능합니다.

표준 디코더 트랜스포머로서 Free Transformer는 임베딩 테이블로 토큰 시퀀스를 인코딩하여 \(T \times D\) 형태의 텐서 \(X_0\)를 생성합니다. 그런 다음 첫 번째 \(L/2\) 트랜스포머 블록을 순차적으로 평가하여 동일한 형태의 \(X_{L/2}\)를 얻습니다.

이 시점에서 원-핫 벡터의 시퀀스 \(Z = (Z_1, \ldots, Z_t) \in {0,1}^{T \times C}\)를 샘플링합니다. 생성 중에는 각 \(Z_t\)에 대해 인덱스 \(c\)\({0, \ldots, C-1}\)에서 균일하게 샘플링한 다음 차원 \(C\)의 원-핫 벡터로 인코딩합니다.

인코더와 손실: 훈련 또는 KV 캐시 사전 채우기 중에 텐서 \(Z\)는 인코더로 샘플링됩니다. Free Transformer는 비인과적인 인코더 전용 트랜스포머 블록 하나를 가지고 있습니다. 이는 디코더의 조건화가 장거리 효과를 가질 수 있어 적절한 잠재 조건부 분포를 얻기 위해 전체 시퀀스를 고려해야 하기 때문에 필요합니다.

선형 판독은 인코더 블록의 출력에서 모든 토큰에 대해 \(H=16\) 차원의 벡터를 계산합니다. 이러한 구성 요소는 개별 비트의 로짓으로 해석되어 \({0, \ldots, 2^H - 1}\)에서 값을 샘플링하는 데 사용됩니다.

KL 발산은 개별 \(Z_t\)의 KL 발산을 임계값 \(\kappa\) 이상인 것만 합산하고 나머지는 무시하는 토큰별 free bits 방법으로 제어됩니다:

\[\frac{1}{T}\sum_{t=1}^T \max\left(0, D_{KL}(Q(Z_t|S_1, \ldots, S_T) | P(Z_t)) - \kappa\right)\]

Binary Mapper: 인코더의 마지막 선형 레이어는 처리 중인 시퀀스의 모든 인덱스 \(t\)에 대해 벡터 \(L_t = (L_{t,1}, \ldots, L_{t,H}) \in \mathbb{R}^H\)를 계산합니다. 이 구성 요소는 이진 인코딩의 개별 비트의 로짓으로 해석됩니다.

Binary Mapper는 다음과 같이 독립적으로 비트 \(B_{t,1}, \ldots, B_{t,H}\)를 샘플링합니다:

\[P(B_{t,h} = 1) = \frac{1}{1 + e^{-L_{t,h}}}\]

그리고 결과 값에 해당하는 \(2^H\) 차원의 원-핫 벡터 \(Y_t\)를 출력합니다.

4. Experiments

합성 데이터셋: Free Transformer가 실제로 \(Z\)를 사용하여 생성 프로세스를 조건화하는지 확인하기 위해 합성 데이터셋을 설계했습니다. 각 시퀀스는 64개의 밑줄로 시작하고, 대문자와 시퀀스의 위치를 무작위로 선택하여 선택한 문자가 8번 반복되는 "타겟"으로 밑줄을 교체합니다.

매우 낮은 KL 발산 값의 경우 모델은 바닐라 모델처럼 동작하며, 값이 증가하면 모델은 처음에 잠재 상태에 타겟의 위치만 인코딩하고, 그 다음 타겟 위치와 노이즈를 모두 인코딩하고, 마지막으로 전체 시퀀스를 인코딩하여 부정확한 생성을 초래합니다.

탐색적 결과: 1.5B 모델(47B 토큰)과 8B 모델(200B 토큰)을 다양한 KL 발산 임계값으로 훈련하여 여러 벤치마크에서 성능을 비교했습니다.

추론을 필요로 하는 벤치마크인 HumanEval+, MBPP, GSM8K에서 상당한 성능 향상을 관찰했습니다. 8B 모델의 경우 1/2 비트 KL 발산으로 다지선다 질문인 MMLU와 CSQA에서도 명확한 개선이 있었습니다.

1T 토큰 훈련 결과: 더 현실적인 설정에서 개선을 측정하기 위해 8B 모델을 1T 토큰으로 훈련했습니다. 200B 토큰 결과를 고려하여 토큰당 최대 절반 비트의 정보에 해당하는 값 \(\kappa = \log(2)/2\)를 선택했습니다.

핵심 결과는 HumanEval+, MBPP, GSM8K, MMLU, CSQA에서의 성능 향상이며, 이는 더 작은 설정에서 관찰한 것을 확인하고 다른 작업에서 더 큰 안정성을 보여줍니다.

5. Previous work

VAE와 디코더 트랜스포머를 결합하려는 여러 시도가 있었습니다. OPTIMUS 모델은 사전 훈련된 BERT를 텍스트 임베딩/인코더로, GPT-2를 디코더로 결합하여 VAE와 유사한 손실로 미세 조정합니다.

Fang 등의 CVAE는 두 개의 사전 훈련된 GPT-2를 결합하며, 하나는 인과적 마스킹 없이 인코더로 사용됩니다. AdaVAE는 유사하게 두 개의 사전 훈련된 GPT-2의 조합이며, 첫 번째는 인과적 마스킹 없이 인코더 역할을 합니다.

6. Conclusion

Free Transformer는 표준 디코더 트랜스포머의 직접적인 확장이며 조건부 VAE의 추상적 구조를 가지고 있습니다. 단일 추가 비인과적 트랜스포머 블록으로 구현되며 몇 퍼센트의 계산 및 메모리 사용 오버헤드가 필요합니다.

이 구조는 비지도 학습 잠재 랜덤 변수를 학습하고 생성 프로세스를 조건화할 수 있게 합니다. 어떤 면에서 이 접근법은 추론 모델이 토큰 공간에서 생각 체인과 RL 절차로 수행하는 것을 잠재 공간에서 오토인코더로 달성하는 것을 목표로 합니다.

최적화 하이퍼파라미터를 조정하지 않고도 여러 벤치마크와 두 가지 크기의 모델에서 성능 향상을 보인 것은 전체 접근법이 실제로 바닐라 트랜스포머의 귀납적 편향을 개선한다는 강력한 신호입니다.