# [Paper Implementation] AN EVOLVED UNIVERSAL TRANSFORMER MEMORY
[![Join our Discord](https://img.shields.io/badge/Discord-Join%20our%20server-5865F2?style=for-the-badge&logo=discord&logoColor=white)](https://discord.gg/agora-999382051935506503) [![Subscribe on YouTube](https://img.shields.io/badge/YouTube-Subscribe-red?style=for-the-badge&logo=youtube&logoColor=white)](https://www.youtube.com/@kyegomez3242) [![Connect on LinkedIn](https://img.shields.io/badge/LinkedIn-Connect-blue?style=for-the-badge&logo=linkedin&logoColor=white)](https://www.linkedin.com/in/kye-g-38759a207/) [![Follow on X.com](https://img.shields.io/badge/X.com-Follow-1DA1F2?style=for-the-badge&logo=x&logoColor=white)](https://x.com/kyegomezb)
An open source implementation of the paper: "AN EVOLVED UNIVERSAL TRANSFORMER MEMORY"
Abstract:
Prior methods propose to offset the escalating costs of modern foundation models by dropping specific parts of their contexts with hand-designed rules, while attempting to preserve their original performance. We overcome this trade-off with Neural Attention Memory Models (NAMMs), introducing a learned network for memory management that improves both the performance and efficiency of transformers. We evolve NAMMs atop pre-trained transformers to provide different latent contexts focusing on the most relevant information for individual layers and attention heads. NAMMs are universally applicable to any model using selfattention as they condition exclusively on the values in the produced attention matrices. Learning NAMMs on a small set of problems, we achieve substantial performance improvements across multiple long-context benchmarks while cutting the model’s input contexts up to a fraction of the original sizes. We show the generality of our conditioning enables zero-shot transfer of NAMMs trained only on language to entirely new transformer architectures even across input modalities, with their benefits carrying over to vision and reinforcement learning.
## Install
```bash
$ pip3 install -U
```
## Usage
```python
def create_sample_inputs(
batch_size: int = 2,
seq_len: int = 1024,
n_queries: int = 512,
d_model: int = 256,
device: str = "cuda" if torch.cuda.is_available() else "cpu"
) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
"""Create sample inputs for NAMM testing.
Args:
batch_size: Batch size
seq_len: Sequence length (number of tokens in KV cache)
n_queries: Number of recent queries
d_model: Model dimension
device: Device to create tensors on
Returns:
Tuple of (kv_cache, attention_matrix)
"""
logger.info(f"Creating sample inputs on device: {device}")
# Create sample KV cache
# In practice, these would be the key and value tensors from transformer layers
kv_cache = {
"key": torch.randn(batch_size, seq_len, d_model, device=device),
"value": torch.randn(batch_size, seq_len, d_model, device=device)
}
# Create sample attention matrix
# In practice, this would be the recent attention scores from transformer layers
attention_matrix = torch.randn(batch_size, seq_len, n_queries, device=device)
# Apply softmax to make it look like real attention scores
attention_matrix = torch.softmax(attention_matrix, dim=1)
logger.info(
f"Created inputs - KV cache size: {kv_cache['key'].shape}, "
f"Attention matrix size: {attention_matrix.shape}"
)
return kv_cache, attention_matrix
def main():
"""Main function demonstrating NAMM usage."""
# Setup logging
logger.remove()
logger.add(lambda msg: print(msg, flush=True), colorize=True, level="INFO")
# Set random seed for reproducibility
torch.manual_seed(42)
# Create NAMM instance with custom config
config = NAMMConfig(
update_interval=256, # More frequent updates for demonstration
stride_size=16,
window_size=64,
d_model=256,
n_head=4,
gamma=0.95,
dropout=0.1
)
namm = create_namm(config)
device = "cuda" if torch.cuda.is_available() else "cpu"
namm = namm.to(device)
logger.info(f"Created NAMM model on device: {device}")
# Create sample inputs
kv_cache, attention_matrix = create_sample_inputs(
batch_size=2,
seq_len=1024,
n_queries=512,
d_model=config.d_model,
device=device
)
# Simulate multiple steps of processing
n_steps = 1000
retention_stats = []
logger.info(f"Starting simulation for {n_steps} steps")
for step in range(n_steps):
# Process the KV cache
updated_cache, _ = namm(kv_cache, attention_matrix)
# Every few steps, evaluate retention
if step % 100 == 0:
stats = namm.evaluate_retention(kv_cache, attention_matrix)
if stats: # Only store if we got stats (remember NAMM only updates every update_interval)
retention_stats.append(stats)
logger.info(
f"Step {step}: Retention rate = {stats['retention_rate']:.2%}, "
f"Mean score = {stats['mean_score']:.3f}"
)
# Update KV cache and attention matrix for next step
if updated_cache: # If NAMM made updates
kv_cache = updated_cache
# Create new attention matrix for reduced sequence length
_, new_seq_len, _ = kv_cache['key'].shape
attention_matrix = torch.randn(
2, new_seq_len, 512, device=device
)
attention_matrix = torch.softmax(attention_matrix, dim=1)
# Print final statistics
if retention_stats:
avg_retention = sum(s['retention_rate'] for s in retention_stats) / len(retention_stats)
logger.info(f"Average retention rate over simulation: {avg_retention:.2%}")
if __name__ == "__main__":
main()
```
# License
MIT
Raw data
{
"_id": null,
"home_page": "https://github.com/kyegomez/Open-NAMM",
"name": "open-namm",
"maintainer": null,
"docs_url": null,
"requires_python": "<4.0,>=3.10",
"maintainer_email": null,
"keywords": "artificial intelligence, deep learning, optimizers, Prompt Engineering",
"author": "Kye Gomez",
"author_email": "kye@apac.ai",
"download_url": "https://files.pythonhosted.org/packages/1b/9f/dca755f3284ac77db179faa8e73d82c43f7e5d8a6c4d8a4251dee8501f0d/open_namm-0.0.2.tar.gz",
"platform": null,
"description": "\n# [Paper Implementation] AN EVOLVED UNIVERSAL TRANSFORMER MEMORY\n\n[![Join our Discord](https://img.shields.io/badge/Discord-Join%20our%20server-5865F2?style=for-the-badge&logo=discord&logoColor=white)](https://discord.gg/agora-999382051935506503) [![Subscribe on YouTube](https://img.shields.io/badge/YouTube-Subscribe-red?style=for-the-badge&logo=youtube&logoColor=white)](https://www.youtube.com/@kyegomez3242) [![Connect on LinkedIn](https://img.shields.io/badge/LinkedIn-Connect-blue?style=for-the-badge&logo=linkedin&logoColor=white)](https://www.linkedin.com/in/kye-g-38759a207/) [![Follow on X.com](https://img.shields.io/badge/X.com-Follow-1DA1F2?style=for-the-badge&logo=x&logoColor=white)](https://x.com/kyegomezb)\n\nAn open source implementation of the paper: \"AN EVOLVED UNIVERSAL TRANSFORMER MEMORY\"\n\n\nAbstract:\n\nPrior methods propose to offset the escalating costs of modern foundation models by dropping specific parts of their contexts with hand-designed rules, while attempting to preserve their original performance. We overcome this trade-off with Neural Attention Memory Models (NAMMs), introducing a learned network for memory management that improves both the performance and efficiency of transformers. We evolve NAMMs atop pre-trained transformers to provide different latent contexts focusing on the most relevant information for individual layers and attention heads. NAMMs are universally applicable to any model using selfattention as they condition exclusively on the values in the produced attention matrices. Learning NAMMs on a small set of problems, we achieve substantial performance improvements across multiple long-context benchmarks while cutting the model\u2019s input contexts up to a fraction of the original sizes. We show the generality of our conditioning enables zero-shot transfer of NAMMs trained only on language to entirely new transformer architectures even across input modalities, with their benefits carrying over to vision and reinforcement learning.\n\n\n## Install\n\n```bash\n$ pip3 install -U \n```\n\n\n## Usage\n\n```python\n\ndef create_sample_inputs(\n batch_size: int = 2,\n seq_len: int = 1024,\n n_queries: int = 512,\n d_model: int = 256,\n device: str = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n) -> tuple[dict[str, torch.Tensor], torch.Tensor]:\n \"\"\"Create sample inputs for NAMM testing.\n \n Args:\n batch_size: Batch size\n seq_len: Sequence length (number of tokens in KV cache)\n n_queries: Number of recent queries\n d_model: Model dimension\n device: Device to create tensors on\n \n Returns:\n Tuple of (kv_cache, attention_matrix)\n \"\"\"\n logger.info(f\"Creating sample inputs on device: {device}\")\n \n # Create sample KV cache\n # In practice, these would be the key and value tensors from transformer layers\n kv_cache = {\n \"key\": torch.randn(batch_size, seq_len, d_model, device=device),\n \"value\": torch.randn(batch_size, seq_len, d_model, device=device)\n }\n \n # Create sample attention matrix\n # In practice, this would be the recent attention scores from transformer layers\n attention_matrix = torch.randn(batch_size, seq_len, n_queries, device=device)\n \n # Apply softmax to make it look like real attention scores\n attention_matrix = torch.softmax(attention_matrix, dim=1)\n \n logger.info(\n f\"Created inputs - KV cache size: {kv_cache['key'].shape}, \"\n f\"Attention matrix size: {attention_matrix.shape}\"\n )\n \n return kv_cache, attention_matrix\n\ndef main():\n \"\"\"Main function demonstrating NAMM usage.\"\"\"\n # Setup logging\n logger.remove()\n logger.add(lambda msg: print(msg, flush=True), colorize=True, level=\"INFO\")\n \n # Set random seed for reproducibility\n torch.manual_seed(42)\n \n # Create NAMM instance with custom config\n config = NAMMConfig(\n update_interval=256, # More frequent updates for demonstration\n stride_size=16,\n window_size=64,\n d_model=256,\n n_head=4,\n gamma=0.95,\n dropout=0.1\n )\n \n namm = create_namm(config)\n device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n namm = namm.to(device)\n \n logger.info(f\"Created NAMM model on device: {device}\")\n \n # Create sample inputs\n kv_cache, attention_matrix = create_sample_inputs(\n batch_size=2,\n seq_len=1024,\n n_queries=512,\n d_model=config.d_model,\n device=device\n )\n \n # Simulate multiple steps of processing\n n_steps = 1000\n retention_stats = []\n \n logger.info(f\"Starting simulation for {n_steps} steps\")\n \n for step in range(n_steps):\n # Process the KV cache\n updated_cache, _ = namm(kv_cache, attention_matrix)\n \n # Every few steps, evaluate retention\n if step % 100 == 0:\n stats = namm.evaluate_retention(kv_cache, attention_matrix)\n if stats: # Only store if we got stats (remember NAMM only updates every update_interval)\n retention_stats.append(stats)\n logger.info(\n f\"Step {step}: Retention rate = {stats['retention_rate']:.2%}, \"\n f\"Mean score = {stats['mean_score']:.3f}\"\n )\n \n # Update KV cache and attention matrix for next step\n if updated_cache: # If NAMM made updates\n kv_cache = updated_cache\n # Create new attention matrix for reduced sequence length\n _, new_seq_len, _ = kv_cache['key'].shape\n attention_matrix = torch.randn(\n 2, new_seq_len, 512, device=device\n )\n attention_matrix = torch.softmax(attention_matrix, dim=1)\n \n # Print final statistics\n if retention_stats:\n avg_retention = sum(s['retention_rate'] for s in retention_stats) / len(retention_stats)\n logger.info(f\"Average retention rate over simulation: {avg_retention:.2%}\")\n\nif __name__ == \"__main__\":\n main()\n```\n\n\n# License\nMIT\n",
"bugtrack_url": null,
"license": "MIT",
"summary": "Paper - Pytorch",
"version": "0.0.2",
"project_urls": {
"Documentation": "https://github.com/kyegomez/Open-NAMM",
"Homepage": "https://github.com/kyegomez/Open-NAMM",
"Repository": "https://github.com/kyegomez/Open-NAMM"
},
"split_keywords": [
"artificial intelligence",
" deep learning",
" optimizers",
" prompt engineering"
],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "23676f182a43cd9388cba3248163a413450799773995c41f853b70209c955313",
"md5": "22d7728a6336b231444fd881e114f945",
"sha256": "b366a39ad9a2b2256ad7eca1d847abdfa3360a3873b57833d70c1ea6120bb413"
},
"downloads": -1,
"filename": "open_namm-0.0.2-py3-none-any.whl",
"has_sig": false,
"md5_digest": "22d7728a6336b231444fd881e114f945",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": "<4.0,>=3.10",
"size": 9175,
"upload_time": "2024-12-10T02:06:10",
"upload_time_iso_8601": "2024-12-10T02:06:10.121831Z",
"url": "https://files.pythonhosted.org/packages/23/67/6f182a43cd9388cba3248163a413450799773995c41f853b70209c955313/open_namm-0.0.2-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": "",
"digests": {
"blake2b_256": "1b9fdca755f3284ac77db179faa8e73d82c43f7e5d8a6c4d8a4251dee8501f0d",
"md5": "3912e2858bbcb75811ac9e18018b78cb",
"sha256": "2fa7e57f66642d9ee9ad182cd71a86ffc7f7b44a0e93f6af9c0465ad709fea79"
},
"downloads": -1,
"filename": "open_namm-0.0.2.tar.gz",
"has_sig": false,
"md5_digest": "3912e2858bbcb75811ac9e18018b78cb",
"packagetype": "sdist",
"python_version": "source",
"requires_python": "<4.0,>=3.10",
"size": 8731,
"upload_time": "2024-12-10T02:06:11",
"upload_time_iso_8601": "2024-12-10T02:06:11.233988Z",
"url": "https://files.pythonhosted.org/packages/1b/9f/dca755f3284ac77db179faa8e73d82c43f7e5d8a6c4d8a4251dee8501f0d/open_namm-0.0.2.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2024-12-10 02:06:11",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "kyegomez",
"github_project": "Open-NAMM",
"travis_ci": false,
"coveralls": false,
"github_actions": true,
"requirements": [
{
"name": "torch",
"specs": []
},
{
"name": "zetascale",
"specs": []
},
{
"name": "swarms",
"specs": []
}
],
"lcname": "open-namm"
}