In an article recently submitted to the arxiv* server, researchers proposed training large language models (LLM) to simultaneously predict multiple future tokens using independent output heads. This approach improved sample efficiency and downstream capabilities, especially for larger models and generative tasks like coding.
*Important notice: arXiv publishes preliminary scientific reports that are not peer-reviewed and, therefore, should not be regarded as definitive, used to guide development decisions, or treated as established information in the field of artificial intelligence research.
The approach showed significant performance gains, solving more problems on benchmarks and improving algorithmic reasoning and induction heads. Models trained with four-token prediction achieved up to three times faster inference.
Background
LLMs have achieved impressive feats in capturing world knowledge and basic reasoning through the next-token prediction task. However, this method is inefficient, as it focuses on local patterns and requires significantly more data than human learning to reach similar fluency levels. Previous studies have explored multi-token prediction, which involves predicting multiple future tokens simultaneously. Despite its promise, this approach has not been extensively applied at scale.
The present research addressed this gap by proposing a straightforward multi-token prediction framework that did not increase training time or memory overhead. The authors provided experimental evidence demonstrating the benefits of this method, particularly for large models with up to 13 billion parameters, which solve approximately 15% more coding problems on average.
Additionally, multi-token prediction facilitated self-speculative decoding, enhancing inference speed by up to three times across various batch sizes. This work highlighted the potential of multi-token prediction to enhance LLM performance, coherence, and reasoning abilities beyond traditional next-token prediction.
Multi-Token Prediction Architecture and Efficient Training Methods
The proposed method generalized standard language modeling by implementing a multi-token prediction task. Instead of predicting the next token in sequence, the model predicted n future tokens simultaneously, minimizing the multi-token cross-entropy loss. The architecture comprised a shared transformer trunk that produced a latent representation of the observed context and n independent output heads that predicted each of the n future tokens in parallel. This factorized the multi-token prediction cross-entropy loss, enhancing the model's predictive capabilities.
To address the challenge of graphic processing unit (GPU) memory utilization in training multi-token predictors, the authors adapted the sequence of forward and backward operations. Instead of materializing all logits and their corresponding gradients, the method sequentially computed forward and backward passes for each independent output head, accumulating gradients at the trunk and freeing memory before moving to the next head. This reduced peak GPU memory usage without increasing runtime.
During inference, the architecture could perform vanilla next-token autoregressive prediction using the next-token prediction head. It leveraged other output heads to speed up decoding through self-speculative decoding methods like blockwise parallel decoding and Medusa-like tree attention, enhancing inference efficiency.
Experimental Results on Multi-Token Prediction Models
The researchers conducted seven large-scale experiments to evaluate the efficacy of multi-token prediction models. The findings indicated that multi-token prediction became increasingly beneficial as model size grew, significantly enhancing performance on code and natural language tasks.
- Model Size Scaling: Multi-token prediction models outperformed next-token models at larger scales, demonstrating better results on code benchmarks like mostly basic Python programming (MBPP) and HumanEval. Faster Inference: Using speculative decoding, multi-token prediction models achieved up to three times faster inference speeds on code and text.
- Global Pattern Learning: Multi-token prediction models excelled at learning long-term patterns, particularly with byte-level tokenization, showing a 67% improvement on MBPP pass@1.
- Optimal Token Prediction: Training with four future tokens consistently outperformed other configurations across various benchmarks.
- Multi-Epoch Training: The advantages of multi-token prediction persisted across multiple training epochs, maintaining an edge over next-token models.
- Finetuning: Pretrained multi-token models exhibited superior performance when finetuned on challenging tasks like CodeContests.
- Natural Language: While multi-token prediction models showed modest improvements on some natural language tasks, larger datasets might be necessary for significant gains.
Overall, multi-token prediction enhanced model capabilities sped up inference, and offered robust performance across diverse tasks.
Speculation on Why Multi-Token Prediction Works
Multi-token prediction improved performance by reducing the gap between training-time teacher forcing and inference-time autoregressive generation. It assigned higher implicit weights to critical "choice point" tokens that influenced subsequent text, ensuring better decision-making at these junctures.
An information-theoretic perspective revealed that multi-token prediction emphasized the mutual information between successive tokens, enhancing the model's ability to predict tokens that were crucial for the continuation of coherent and relevant text. This approach led to more accurate and effective language models, particularly for tasks requiring longer-term dependencies.
Conclusion
In conclusion, multi-token prediction presented a substantial advancement over next-token methods for training large language models, particularly enhancing performance in generative and reasoning tasks. By minimizing the gap between the training and inference phases, it optimized decision-making at critical points in text generation.
This approach, supported by efficient speculative decoding, significantly accelerated inference speeds. Future research aims to automate optimal token selection and refine vocabulary sizes, potentially further improving model efficiency and effectiveness across diverse applications.
*Important notice: arXiv publishes preliminary scientific reports that are not peer-reviewed and, therefore, should not be regarded as definitive, used to guide development decisions, or treated as established information in the field of artificial intelligence research.
Journal reference:
- Preliminary scientific report.
Gloeckle, F., Idrissi, B. Y., Rozière, B., Lopez-Paz, D., & Synnaeve, G. (2024, April 30). Better & Faster Large Language Models via Multi-token Prediction. ArXiv.org. DOI: 10.48550/arXiv.2404.19737, https://arxiv.org/abs/2404.19737