Researchers at NVIDIA unveil the groundbreaking nGPT architecture, normalizing transformer networks and enabling faster, more efficient AI training that outperforms traditional models across multiple tasks.
Research: nGPT: Normalized Transformer with Representation Learning on the Hypersphere
*Important notice: arXiv publishes preliminary scientific reports that are not peer-reviewed and, therefore, should not be regarded as definitive, used to guide development decisions, or treated as established information in the field of artificial intelligence research.
In an article submitted to the arXiv preprint* server, researchers at NVIDIA proposed a novel neural network architecture called normalized generative pre-trained transformer (nGPT) with representation learning on the hypersphere. In nGPT, the unit norm normalized all vectors, including embeddings, multi-layer perceptron (MLP), attention matrices, and hidden states.
The input tokens moved across the surface of a hypersphere, with each layer contributing a displacement toward output predictions. Experiments showed that nGPT accelerated learning, reducing training steps by a factor of 4 to 20, depending on sequence length.
A key innovation in nGPT is the normalization of embeddings, which improves the conditioning of matrices and ensures more stable model performance.
Related Work
Past work provided an overview of representation learning on the hypersphere, highlighting its association with stable training in variational autoencoders and face verification.
Another study unified classification and regression through prototype embeddings on a hypersphere, achieving separation with large margins. However, challenges remain in ensuring consistent performance across varying data distributions and embedding configurations.
Normalized Transformer Evolution
The evolution from GPT to nGPT involves modifying the baseline transformer to create a normalized version, which is illustrated with a focus on the transformer decoder and self-attention. These changes can be extended to the encoder-decoder and cross-attention settings. The key modification involves constraining all embedding vectors to have unit norms and normalizing weights during training.
It ensures that predictions are made by calculating the dot products between normalized vectors, resulting in more accurate similarity estimations. Token embeddings and output logits are handled through learnable matrices, with adjustments made for prediction confidence by introducing a trainable scaling parameter.
In the baseline transformer, alternating layers of self-attention and MLP blocks update the hidden states. The norms of embeddings fluctuate due to the addition of outputs, leading to the introduction of normalization steps. To optimize the movement along the hypersphere, nGPT employs Spherical Linear Interpolation (SLERP), which is sometimes approximated by Linear Interpolation (LERP), controlling the transitions between layers.
The model updates hidden states based on gradients, controlling updates through learnable eigen-learning rates for attention and MLP blocks. These rates are applied after normalization, keeping the hidden states on the hypersphere.
Self-attention in the baseline transformer uses unconstrained query, key, and value vectors, allowing unbounded values. nGPT normalizes these vectors along their embedding dimensions, ensuring they represent cosine similarity bounded between [-1,1]. Rotary position embeddings (RoPE) inject positional information, and trainable scaling factors help normalize queries and keys, maintaining variance control during the softmax attention mechanism.
The softmax scaling factor is adjusted by normalizing these vectors to ensure accurate variance control in the attention mechanism. Each head's multi-head attention is computed independently, but all attention matrices are normalized for consistency.
In nGPT, MLP block projections are normalized with scaling factors to control their impact, while the output remains invariant to scaling changes.
The Adam optimizer adjusts learning rates for normalized parameters, ensuring effective updates. Key changes include removing standard normalization, normalizing matrices after each step, and adjusting softmax scaling.
nGPT Outperforms GPT
The experiments compare the performance of the GPT and the nGPT on the OpenWebText dataset, evaluating them on downstream tasks.
The team trained models with 0.5 billion and 1 billion parameters, reporting results using the best initial learning rate settings. The focus was on examining how nGPT accelerates training compared to GPT.
Notably, nGPT achieved a 10x speedup in terms of iterations and tokens used, demonstrating that it reached the same validation loss as GPT much faster.
Further analysis confirmed that nGPT significantly reduced the number of tokens required for training across various context lengths (1k, 4k, 8k) and model sizes.
The figures illustrate that nGPT is 4x, 10x, and 20x faster in token efficiency compared to GPT. In addition to faster training, nGPT also showed superior performance across a range of downstream tasks.
The network's inspection revealed that nGPT maintained better-conditioned embeddings and attention matrices, while GPT's embeddings exhibited significant variation and higher condition numbers, potentially leading to computational inefficiencies.
The decoupling of predictions from attention and MLP blocks, along with the use of eigen-learning rates, allows nGPT to take more modest steps when updating hidden states, further contributing to its stability. Additionally, scaling factors applied to MLP and attention blocks compensated for the normalization processes in nGPT, contributing to its overall stability and performance improvements.
Conclusion
In summary, the experiments compared GPT and nGPT on the OpenWebText dataset, revealing that nGPT achieved a 10x speedup in iterations and tokens, reaching the same validation loss much faster. nGPT also outperformed GPT in token efficiency and downstream tasks, showing faster training across context lengths and model sizes.
Additionally, nGPT's embeddings and attention matrices were better conditioned, improving computational efficiency. Eigen-learning rates and scaling factors, combined with the novel hyperspherical representation, contributed to its stability and overall performance enhancements.
*Important notice: arXiv publishes preliminary scientific reports that are not peer-reviewed and, therefore, should not be regarded as definitive, used to guide development decisions, or treated as established information in the field of artificial intelligence research.
Journal reference:
- Preliminary scientific report.
Loshchilov, I., Hsieh, C., Sun, S., & Ginsburg, B. (2024). NGPT: Normalized Transformer with Representation Learning on the Hypersphere. ArXiv. https://arxiv.org/abs/2410.01131