Discover how NAMMs streamline transformer models with cutting-edge memory management, unlocking unprecedented performance across tasks and modalities.
Qualitative examples comparing the ground produced responses by Llama3 with and without our NAMM memory, together with GPT4, on two prompts from the En.Sum task part of InfiniteBench.
*Important notice: arXiv publishes preliminary scientific reports that are not peer-reviewed and, therefore, should not be regarded as definitive, used to guide development decisions, or treated as established information in the field of artificial intelligence research.
In an article recently submitted to the arXiv preprint server, researchers at Sakana AI, Japan, introduced neural attention memory models (NAMMs), a novel method to enhance the efficiency and performance of transformers.
NAMMs manage memory dynamically by focusing on relevant latent contexts within attention layers, reducing input context size while improving results.
Applicable to any self-attention model, NAMMs demonstrated strong performance across benchmarks and transferability across modalities, including vision and reinforcement learning.
Using a combination of spectrogram-based feature extraction and evolutionary optimization, NAMMs provide a scalable framework for optimizing pre-trained transformers.
NAMMs use evolution to optimize the performance of LMs by pruning their KV cache memory. Evolved NAMMs can be zero-shot transferred to other transformers, even across input modalities and task domains.
Background
Transformer architectures have become a cornerstone in deep learning, underpinning modern foundation models due to their exceptional scalability and performance.
These models rely on a context window of input tokens, which poses challenges in addressing long-range tasks efficiently. Extending this context window often increases computational costs, making transformers resource-intensive.
Existing approaches, such as heuristics for identifying and evicting less important tokens from the key-value (KV) cache, have achieved partial success in reducing memory size while limiting performance degradation. However, these hand-designed strategies inherently trade off efficiency for accuracy.
This paper introduced NAMMs, a novel framework that redefined memory management in transformers. NAMMs employ the Short-Time Fourier Transform (STFT) to extract spectrogram representations of attention matrices, enabling universal applicability across transformer architectures without altering base model parameters.
Unlike heuristic methods, NAMMs utilized evolutionary strategies (Covariance Matrix Adaptation Evolution Strategy, CMA-ES) to overcome non-differentiable memory operations, resulting in significant performance gains across tasks.
By conditioning only on attention matrix features, NAMMs were universally applicable to various transformer architectures. The study demonstrated NAMMs' efficiency and performance benefits across diverse benchmarks and modalities, including vision and reinforcement learning, marking a substantial advancement in transformer optimization.
Optimizing Transformer Memory with NAMMs
NAMMs addressed the inefficiencies of transformer models, particularly the quadratic cost of computing attention matrices during autoregressive tasks.
NAMMs focus on optimizing the KV cache, a memory structure used for efficient processing of token sequences, by introducing a model-agnostic feature extraction framework using STFT with a Hann window. This compressed token information into spectrograms, enabling universal applicability across transformer architectures without altering base model parameters.
The NAMMs framework employed a backward attention memory (BAM) model to parameterize memory selection. This model uses counter-causal masking, enabling a purposeful distinction between older and newer tokens to prioritize relevant information. Tokens in the KV cache were selectively retained based on scalar scores generated by a lightweight neural network, optimizing memory usage without compromising task performance.
NAMMs were trained through incremental evolution using CMA-ES, which eliminates the need for backpropagation. This approach facilitated efficient zero-shot transfer across larger models and varied applications. Empirical results demonstrated that NAMMs enhanced both performance and efficiency across tasks, significantly reducing KV cache size while outperforming baseline models.
Experimental Results
The researchers evaluated NAMMs across various benchmarks against full-context transformers and recent hand-designed KV cache management methods—H2O, L2, and FastGen.
NAMMs demonstrated superior performance and efficiency by learning to discard unhelpful information, avoiding the lossy heuristics of prior methods.
In benchmarks like LongBench and InfiniteBench, NAMMs achieved significant improvements over the Llama 3 8B base transformer, with 11% and 15% performance gains, respectively, while reducing KV cache size by 75%.
Unlike hand-designed methods, NAMMs reduced KV cache sizes while maintaining accuracy, scaling effectively to long contexts by forgetting redundant data. NAMMs also excelled in adapting memory allocation based on layer depth, retaining more tokens in early-middle layers, and discarding redundancy in later stages.
In reinforcement learning, NAMMs improved decision transformers across eight tasks by discarding context that could propagate errors. NAMMs trained with the BAM architecture outperformed simpler multilayer perceptron (MLP)-based NAMMs.
Incremental evolution enhanced task-specific adaptations, including learning to prioritize critical tokens for complex tasks such as code completion while pruning redundant elements like whitespace and boilerplate code.
The findings underscored NAMMs' potential to advance long-context memory systems through end-to-end optimization, overcoming the limitations of manually designed approaches.
Conclusion
In conclusion, NAMMs offer a novel framework to enhance transformer efficiency and performance while reducing memory usage.
By dynamically managing latent memory, NAMMs optimized KV cache usage, surpassing traditional hand-designed methods. Their evolutionary approach enabled NAMMs to outperform baselines across diverse benchmarks, including LongBench and InfiniteBench, and transfer effectively across architectures and modalities.
The approach introduced innovative memory selection strategies, leveraging evolutionary training to prioritize relevant tokens, significantly improving long-context tasks without full-context conditioning. NAMMs' ability to scale efficiently to tasks with extended context windows highlights their transformative potential for future advancements in transformer optimization.
*Important notice: arXiv publishes preliminary scientific reports that are not peer-reviewed and, therefore, should not be regarded as definitive, used to guide development decisions, or treated as established information in the field of artificial intelligence research.
Source:
Journal reference:
- Preliminary scientific report.
Cetin, E., Sun, Q., Zhao, T., & Tang, Y. (2024). An Evolved Universal Transformer Memory. ArXiv. https://arxiv.org/abs/2410.13166