## GrokAdamW: A PyTorch Optimizer for Accelerated Grokking
by Eric Hartford
**GrokAdamW** is a novel optimizer designed to enhance AI training by combining the strengths of Grokfast (a technique for accelerating "grokking" in deep learning models) with the robustness and efficiency of the AdamW optimizer. It's particularly useful for models exhibiting delayed generalization, where performance on validation data improves significantly after a period of overfitting to the training data.
**Update:** This optimizer was used to train the awesome tiny model [nisten/Biggie-SmoLlm-0.15B-Base](https://huggingface.co/nisten/Biggie-SmoLlm-0.15B-Base)
This implementation was inspired by the following papers:
- **Grokfast: Accelerated Grokking by Amplifying Slow Gradients**
Lee, J., Kang, B. G., Kim, K., & Lee, K. M. (2024).
*arXiv:2405.20233 [cs.LG]*.
[https://doi.org/10.48550/arXiv.2405.20233](https://doi.org/10.48550/arXiv.2405.20233)
- **Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets**
Power, A., Burda, Y., Edwards, H., Babuschkin, I., & Misra, V. (2022).
*arXiv:2201.02177 [cs.LG]*.
[https://doi.org/10.48550/arXiv.2201.02177](https://doi.org/10.48550/arXiv.2201.02177)
- **Decoupled Weight Decay Regularization**
Loshchilov, I., & Hutter, F. (2019).
*arXiv:1711.05101 [cs.LG]*.
[https://doi.org/10.48550/arXiv.1711.05101](https://doi.org/10.48550/arXiv.1711.05101)
## Table of Contents
1. [Overview](#overview)
2. [Theory](#theory)
3. [Mathematical Explanation](#mathematical-explanation)
4. [Installation](#installation)
5. [Usage](#usage)
6. [Configuration](#configuration)
7. [Common Pitfalls and Debugging Tips](#common-pitfalls-and-debugging-tips)
8. [Contribution](#contribution)
9. [License](#license)
## Overview
**Grokking** is a phenomenon where deep learning models achieve sudden generalization after a long period of overfitting. Research suggests that this delayed generalization is related to the slow-varying components of gradients during training. **Grokfast**, inspired by this research, accelerates grokking by amplifying these slow-varying gradients.
**GrokAdamW** builds upon this concept by integrating Grokfast's adaptive frequency amplification into the AdamW optimization algorithm. It introduces several key innovations:
1. **Adaptive Momentum:** The momentum of the Grokfast component (which controls the emphasis on slow-varying gradients) dynamically adjusts based on a "Grokking Signal" that reflects the model's generalization progress.
2. **Layer-Wise Momentum Decay:** Recognizing that different layers learn at different rates, GrokAdamW implements a gradual decay of the AdamW momentum parameter (β1) from earlier to later layers, promoting faster generalization in early layers while preventing overfitting in later layers.
3. **Multiple Grokking Signals:** Allows for flexibility in defining the Grokking Signal by supporting multiple signal functions, which can be combined to capture different aspects of generalization performance.
4. **Optional Gradient Clipping:** Provides the option to clip gradients, enhancing training stability and preventing exploding gradients, a common issue in deep learning.
## Theory:
### Mathematical Explanation:
**Core AdamW Updates:**
For each layer *l*, parameter *p*, and training step *t*:
* First Moment Estimate:
* m_t[l, p] = β1_l * m_(t-1)[l, p] + (1 - β1_l) * ĝ_t[l, p]
* Where β1_l = β1_init * (1 - γ)^l (layer-wise momentum decay)
* Second Moment Estimate:
* v_t[l, p] = β2 * v_(t-1)[l, p] + (1 - β2) * ĝ_t[l, p]²
* Bias Correction:
* m̂_t[l, p] = m_t[l, p] / (1 - β1^t)
* v̂_t[l, p] = v_t[l, p] / (1 - β2^t)
* Parameter Update:
* θ_t[l, p] = θ_(t-1)[l, p] - η * (m̂_t[l, p] / (sqrt(v̂_t[l, p]) + ε) + wd * θ_(t-1)[l, p])
**Grokfast Integration:**
* Grokking Signal:
* GS_t = Combine(signal_1(t), signal_2(t), ..., signal_n(t)) (using the provided `grokking_signal_fns`)
* EMA Filter Momentum:
* α_t = α_init * exp(-κ * GS_t)
* EMA Filter Update:
* μ_t[l, p] = α_t * μ_(t-1)[l, p] + (1 - α_t) * g_t[l, p]
* Grokfast-Amplified Gradient:
* ĝ_t[l, p] = g_t[l, p] + λ * μ_t[l, p]
**Optional Gradient Clipping:**
* If `gradient_clipping` > 0:
* `torch.nn.utils.clip_grad_norm_(parameters, gradient_clipping)`
## Installation
You can easily install GrokAdamW using pip:
```bash
pip install grokadamw
```
## Usage:
```python
import torch
import torch.nn as nn
from grokadamw import GrokAdamW
# Define your model
model = nn.Linear(10, 1)
# Define your grokking signal function(s)
def grokking_signal_fn(training_loss: float, validation_loss: float) -> float:
if training_loss == 0:
return 0.0 # Avoid division by zero
return (validation_loss - training_loss) / training_loss
# Initialize GrokAdamW optimizer
optimizer = GrokAdamW(model.parameters(), lr=1e-3, grokking_signal_fn=grokking_signal_fn)
# Training loop
for epoch in range(num_epochs):
# ... [Your training code] ...
# Calculate validation loss (val_loss)
# Perform optimization step
loss = optimizer.step(closure=lambda: your_loss_function(model, data))
```
## Configuration:
GrokAdamW supports standard AdamW parameters (`lr`, `betas`, `eps`, `weight_decay`) and additional parameters for Grokfast:
* `alpha_init`: Initial momentum for the EMA filter (default: 0.98)
* `lamb`: Amplification factor for the filtered gradients (default: 2.0)
* `gamma`: Layer-wise momentum decay rate (default: 0.1)
* `grokking_signal_fns`: A list of functions that each return a scalar grokking signal (optional)
* `grokking_signal_decay_rate`: Decay rate for adjusting alpha based on the grokking signal (default: 0.1)
* `gradient_clipping`: Maximum norm for gradient clipping (default: 1.0, set to 0 to disable)
## Common Pitfalls and Debugging Tips
1. **Grokking Signal Functions Not Providing Useful Signals:**
- Ensure that the functions return meaningful values, reflecting aspects like validation vs. training loss differences.
- Consider normalizing the output of signal functions.
2. **Issues with Gradient Clipping:**
- If gradients are frequently being clipped, it may indicate a need to adjust the learning rate or other hyperparameters.
3. **Unexpected Behavior with Layer-wise Momentum Decay:**
- Monitor the learning dynamics for different layers. If some layers are learning too slowly or too quickly, adjust `gamma` or individual layer hyperparameters accordingly.
4. **Monitoring Grokking Signal and Alpha Values:**
- Use tools like TensorBoard or custom logging to track the grokking signal, alpha values, and gradient norms. This can help in understanding the optimizer's behavior and making necessary adjustments.
## Contribution
GrokAdamW is an ongoing research project. Your feedback and contributions are welcome! Please feel free to submit issues, feature requests, or pull requests. For more details, see our [CONTRIBUTING.md](CONTRIBUTING.md) file.
## License
GrokAdamW is licensed under the Apache 2.0 License. See the LICENSE file for more details.
Raw data
{
"_id": null,
"home_page": null,
"name": "grokadamw",
"maintainer": null,
"docs_url": null,
"requires_python": ">=3.7",
"maintainer_email": null,
"keywords": "pytorch, optimizer, machine learning",
"author": null,
"author_email": "Your Name <your.email@example.com>",
"download_url": "https://files.pythonhosted.org/packages/9d/80/e18018fc0de7cf68821e1502794512716d059c6ad08d7d2c9e43ec5e65f0/grokadamw-0.1.2.tar.gz",
"platform": null,
"description": "## GrokAdamW: A PyTorch Optimizer for Accelerated Grokking\n\nby Eric Hartford\n\n**GrokAdamW** is a novel optimizer designed to enhance AI training by combining the strengths of Grokfast (a technique for accelerating \"grokking\" in deep learning models) with the robustness and efficiency of the AdamW optimizer. It's particularly useful for models exhibiting delayed generalization, where performance on validation data improves significantly after a period of overfitting to the training data.\n\n**Update:** This optimizer was used to train the awesome tiny model [nisten/Biggie-SmoLlm-0.15B-Base](https://huggingface.co/nisten/Biggie-SmoLlm-0.15B-Base)\n\nThis implementation was inspired by the following papers:\n\n- **Grokfast: Accelerated Grokking by Amplifying Slow Gradients** \n Lee, J., Kang, B. G., Kim, K., & Lee, K. M. (2024). \n *arXiv:2405.20233 [cs.LG]*. \n [https://doi.org/10.48550/arXiv.2405.20233](https://doi.org/10.48550/arXiv.2405.20233)\n\n- **Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets** \n Power, A., Burda, Y., Edwards, H., Babuschkin, I., & Misra, V. (2022). \n *arXiv:2201.02177 [cs.LG]*. \n [https://doi.org/10.48550/arXiv.2201.02177](https://doi.org/10.48550/arXiv.2201.02177)\n\n- **Decoupled Weight Decay Regularization** \n Loshchilov, I., & Hutter, F. (2019). \n *arXiv:1711.05101 [cs.LG]*. \n [https://doi.org/10.48550/arXiv.1711.05101](https://doi.org/10.48550/arXiv.1711.05101)\n\n\n## Table of Contents\n1. [Overview](#overview)\n2. [Theory](#theory)\n3. [Mathematical Explanation](#mathematical-explanation)\n4. [Installation](#installation)\n5. [Usage](#usage)\n6. [Configuration](#configuration)\n7. [Common Pitfalls and Debugging Tips](#common-pitfalls-and-debugging-tips)\n8. [Contribution](#contribution)\n9. [License](#license)\n\n## Overview\n\n**Grokking** is a phenomenon where deep learning models achieve sudden generalization after a long period of overfitting. Research suggests that this delayed generalization is related to the slow-varying components of gradients during training. **Grokfast**, inspired by this research, accelerates grokking by amplifying these slow-varying gradients.\n\n**GrokAdamW** builds upon this concept by integrating Grokfast's adaptive frequency amplification into the AdamW optimization algorithm. It introduces several key innovations:\n\n1. **Adaptive Momentum:** The momentum of the Grokfast component (which controls the emphasis on slow-varying gradients) dynamically adjusts based on a \"Grokking Signal\" that reflects the model's generalization progress.\n2. **Layer-Wise Momentum Decay:** Recognizing that different layers learn at different rates, GrokAdamW implements a gradual decay of the AdamW momentum parameter (\u03b21) from earlier to later layers, promoting faster generalization in early layers while preventing overfitting in later layers.\n3. **Multiple Grokking Signals:** Allows for flexibility in defining the Grokking Signal by supporting multiple signal functions, which can be combined to capture different aspects of generalization performance.\n4. **Optional Gradient Clipping:** Provides the option to clip gradients, enhancing training stability and preventing exploding gradients, a common issue in deep learning.\n\n## Theory:\n\n### Mathematical Explanation:\n\n**Core AdamW Updates:**\nFor each layer *l*, parameter *p*, and training step *t*:\n\n* First Moment Estimate: \n * m_t[l, p] = \u03b21_l * m_(t-1)[l, p] + (1 - \u03b21_l) * \u011d_t[l, p] \n * Where \u03b21_l = \u03b21_init * (1 - \u03b3)^l (layer-wise momentum decay)\n* Second Moment Estimate: \n * v_t[l, p] = \u03b22 * v_(t-1)[l, p] + (1 - \u03b22) * \u011d_t[l, p]\u00b2\n* Bias Correction: \n * m\u0302_t[l, p] = m_t[l, p] / (1 - \u03b21^t)\n * v\u0302_t[l, p] = v_t[l, p] / (1 - \u03b22^t)\n* Parameter Update: \n * \u03b8_t[l, p] = \u03b8_(t-1)[l, p] - \u03b7 * (m\u0302_t[l, p] / (sqrt(v\u0302_t[l, p]) + \u03b5) + wd * \u03b8_(t-1)[l, p])\n\n**Grokfast Integration:**\n\n* Grokking Signal:\n * GS_t = Combine(signal_1(t), signal_2(t), ..., signal_n(t)) (using the provided `grokking_signal_fns`)\n* EMA Filter Momentum:\n * \u03b1_t = \u03b1_init * exp(-\u03ba * GS_t) \n* EMA Filter Update:\n * \u03bc_t[l, p] = \u03b1_t * \u03bc_(t-1)[l, p] + (1 - \u03b1_t) * g_t[l, p]\n* Grokfast-Amplified Gradient:\n * \u011d_t[l, p] = g_t[l, p] + \u03bb * \u03bc_t[l, p]\n\n**Optional Gradient Clipping:**\n\n* If `gradient_clipping` > 0:\n * `torch.nn.utils.clip_grad_norm_(parameters, gradient_clipping)` \n\n## Installation\n\nYou can easily install GrokAdamW using pip:\n\n```bash\npip install grokadamw\n```\n\n## Usage:\n\n```python\nimport torch\nimport torch.nn as nn\nfrom grokadamw import GrokAdamW\n\n# Define your model\nmodel = nn.Linear(10, 1)\n\n# Define your grokking signal function(s)\ndef grokking_signal_fn(training_loss: float, validation_loss: float) -> float:\n if training_loss == 0:\n return 0.0 # Avoid division by zero\n return (validation_loss - training_loss) / training_loss\n\n# Initialize GrokAdamW optimizer\noptimizer = GrokAdamW(model.parameters(), lr=1e-3, grokking_signal_fn=grokking_signal_fn)\n\n# Training loop\nfor epoch in range(num_epochs):\n # ... [Your training code] ...\n\n # Calculate validation loss (val_loss)\n\n # Perform optimization step\n loss = optimizer.step(closure=lambda: your_loss_function(model, data)) \n```\n\n## Configuration:\n\nGrokAdamW supports standard AdamW parameters (`lr`, `betas`, `eps`, `weight_decay`) and additional parameters for Grokfast:\n\n* `alpha_init`: Initial momentum for the EMA filter (default: 0.98)\n* `lamb`: Amplification factor for the filtered gradients (default: 2.0)\n* `gamma`: Layer-wise momentum decay rate (default: 0.1)\n* `grokking_signal_fns`: A list of functions that each return a scalar grokking signal (optional)\n* `grokking_signal_decay_rate`: Decay rate for adjusting alpha based on the grokking signal (default: 0.1)\n* `gradient_clipping`: Maximum norm for gradient clipping (default: 1.0, set to 0 to disable)\n\n## Common Pitfalls and Debugging Tips\n\n1. **Grokking Signal Functions Not Providing Useful Signals:** \n - Ensure that the functions return meaningful values, reflecting aspects like validation vs. training loss differences.\n - Consider normalizing the output of signal functions.\n\n2. **Issues with Gradient Clipping:**\n - If gradients are frequently being clipped, it may indicate a need to adjust the learning rate or other hyperparameters.\n\n3. **Unexpected Behavior with Layer-wise Momentum Decay:**\n - Monitor the learning dynamics for different layers. If some layers are learning too slowly or too quickly, adjust `gamma` or individual layer hyperparameters accordingly.\n\n4. **Monitoring Grokking Signal and Alpha Values:**\n - Use tools like TensorBoard or custom logging to track the grokking signal, alpha values, and gradient norms. This can help in understanding the optimizer's behavior and making necessary adjustments.\n\n## Contribution\n\nGrokAdamW is an ongoing research project. Your feedback and contributions are welcome! Please feel free to submit issues, feature requests, or pull requests. For more details, see our [CONTRIBUTING.md](CONTRIBUTING.md) file.\n\n## License\n\nGrokAdamW is licensed under the Apache 2.0 License. See the LICENSE file for more details.\n",
"bugtrack_url": null,
"license": null,
"summary": "GrokAdamW optimizer for PyTorch",
"version": "0.1.2",
"project_urls": {
"Bug Tracker": "https://github.com/yourusername/grokadamw/issues",
"Homepage": "https://github.com/yourusername/grokadamw"
},
"split_keywords": [
"pytorch",
" optimizer",
" machine learning"
],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "c783751b7f170517e7fae3924ab9bfc07a16436bac1f2ae06372e6569673972c",
"md5": "1dda75fe3f3b6c8f3302279e607c4c37",
"sha256": "a7f0b6e3d5f9875d72dcab465004a44837c5656cbae6112cd682157bcee08f1a"
},
"downloads": -1,
"filename": "grokadamw-0.1.2-py3-none-any.whl",
"has_sig": false,
"md5_digest": "1dda75fe3f3b6c8f3302279e607c4c37",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": ">=3.7",
"size": 6624,
"upload_time": "2024-08-12T06:37:50",
"upload_time_iso_8601": "2024-08-12T06:37:50.649732Z",
"url": "https://files.pythonhosted.org/packages/c7/83/751b7f170517e7fae3924ab9bfc07a16436bac1f2ae06372e6569673972c/grokadamw-0.1.2-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": "",
"digests": {
"blake2b_256": "9d80e18018fc0de7cf68821e1502794512716d059c6ad08d7d2c9e43ec5e65f0",
"md5": "c05fbc406e73647469d1e66b4ba7cd60",
"sha256": "dbd7b983712fe5c2b372dfd69153fa1621803e6c4aa94564b6f2e9eef5a33290"
},
"downloads": -1,
"filename": "grokadamw-0.1.2.tar.gz",
"has_sig": false,
"md5_digest": "c05fbc406e73647469d1e66b4ba7cd60",
"packagetype": "sdist",
"python_version": "source",
"requires_python": ">=3.7",
"size": 6566,
"upload_time": "2024-08-12T06:37:52",
"upload_time_iso_8601": "2024-08-12T06:37:52.551353Z",
"url": "https://files.pythonhosted.org/packages/9d/80/e18018fc0de7cf68821e1502794512716d059c6ad08d7d2c9e43ec5e65f0/grokadamw-0.1.2.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2024-08-12 06:37:52",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "yourusername",
"github_project": "grokadamw",
"github_not_found": true,
"lcname": "grokadamw"
}