adams-torch


Nameadams-torch JSON
Version 0.0.1 PyPI version JSON
download
home_pageNone
SummaryAdams optimizer: next-generation optimizer blending element-wise methods with matrix-aware regularization
upload_time2025-09-13 07:51:57
maintainerNone
docs_urlNone
authorOne
requires_python>=3.8
licenseNone
keywords pytorch optimizer deep learning machine learning spectral regularization
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # Adams Optimizer

Adams is a next-generation optimizer that blends the simplicity of element-wise methods with the stability benefits of matrix-aware regularization. It updates neural network parameters in both **1D (per-element)** and **2D (per-matrix)** ways, staying fast and easy to parallelize like Adam/AdamW while improving stability and generalization.

* **Stable:** no loss spikes observed; no gradient clipping required.
* **Fast & scalable:** element-wise updates + one rank-1 spectral decay step per matrix; easily parallelizable.
* **Simple:** no `epsilon` hyperparameter; truly scale-invariant per-parameter update.

## Definition 📝

![Adams pseudocode](./assets/adams_pseudocode.png)

## How Adams Works 🌟

### 1) Bounded, element-wise update (1D)

Small second-moment estimates are a major source of instability and loss spikes in Adam-like methods. Adams replaces the usual preconditioned step with a **bounded** update using `atan2`:

$$
\Delta \theta \propto \text{atan2}\big(\hat m_t,\sqrt{\hat n_t}\big),
$$

which:

* naturally bounds the step size,
* removes the need for the `epsilon` hyperparameter,
* yields true scale invariance of the update.

### 2) Spectral weight decay (2D)

For matrix parameters $W \in \mathbb{R}^{M \times N}$, spectral norm better reflects the scale relevant to activations than the Frobenius norm. Adams therefore applies **decoupled spectral weight decay** (akin to AdamW’s decoupling), replacing the usual $\tfrac{1}{2}\|W\|_F^2$ with the spectral norm $\tfrac{1}{2}\sigma_1^2$:

* We compute a one-step **power iteration** with persistent state (same idea as PyTorch’s `spectral_norm`) to approximate the top singular triplet $(u_1, \sigma_1, v_1)$.
* The decay term is applied as $\sqrt{M} u_1 \sigma_1 v^\top_1$ (the gradient of $\tfrac{1}{2}\sigma_1^2$, scaled by $\sqrt{M}$ to match the RMS of $W$) per update step.
* This helps control activation scales and mitigates instabilities tied to large spectral norms.

**Efficiency:** the spectral step adds only two GEMV operations per matrix per update, comparable to a handful of extra element-wise ops. In typical FSDP/ZeRO setups the full weight matrix is available during forward/backward, so this integrates cleanly at scale.

## Design Motivation 💡

Recent reports suggest that fully matrix-based optimizers (e.g., Muon) can be hard to implement/parallelize broadly and often show modest end-to-end benefits on large models (~1.1x or less), despite strong stability. Meanwhile, the dominant optimizer Adam is simple and fast but prone to instability and loss spikes.

**Adams** asks: *Can we keep Adam’s speed and simplicity while gaining matrix-level stability?*

## Installation

```bash
pip install adams-torch
```

## Quick Start 📈

You don’t need to manually broadcast parameters or all-reduce gradients—multi-GPU usage matches single-GPU usage. Fully compatible with `torch.compile`.

> FSDP is not supported yet. Contributions welcome.

```python
import os
import torch
import torch.distributed as dist
from adams import Adams_ZeRO  # main optimizer

def init():
    # Initialize distributed training if launched via torchrun/torch.distributed
    if "LOCAL_RANK" in os.environ:
        dist.init_process_group(backend="nccl")
        torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = Model().to(device)

    # Spectral decay applies to matrix-shaped params.
    # scalar_vector_weight_decay applies standard decoupled L2 to 0D/1D params.
    optimizer = Adams_ZeRO(
        model.parameters(),
        lr=3e-4,
        weight_decay=0.1,                 # spectral decay for matrices
        scalar_vector_weight_decay=0.1,   # L2 for scalars/vectors
        betas=(0.9, 0.95)
    )

    # Sync any internal buffers across ranks if required by your setup.
    optimizer.broadcast_buffers(model.buffers())

    return model, optimizer

@torch.compile  # Optional: works with torch.compile
def train_step(model, optimizer, batch):
    loss = model(batch)        # forward; compute your loss
    loss.backward()            # backward
    optimizer.step()           # no gradient clipping needed
    optimizer.zero_grad(set_to_none=True)
    return loss
```

## Notes ⚠️

Care should be taken as matrix-based optimizers (e.g. Muon).

1. **Non‑matrix parameters.** Disable the matrix‑based part (spectral decay) for parameters that are scalars, vectors, or collections of vectors (e.g. LayerNorm, Embedding, Output Head, etc.) by setting `param.use_spectral_decay = False`. Adams uses a separate decoupled L2 term, controlled by `scalar_vector_weight_decay` (default `0.1`).
2. **Batched matrices.** Parameters that are conceptually multiple matrices concatenated along leading dimensions (e.g., attention QKV projections) should be expressed with shape `(B, M, N)`. Adams treats all dimensions except the last two as batch dimensions. (P.S. In our experiments, we treat each Attention head q,k,v as separate projection matrices. E.g. there are 24 (8 head * 3) matrices in QKV proj for 8 MHA heads)

## Practical Tips ✏️

* **Hyperparameters:** start with AdamW-like settings; the bounded update removes `epsilon`. Adams can handle much larger weight decay to improve generalization, e.g. `1.0`.
* **Stability:** the bounded step and spectral decay together target sources of spikes linked to tiny second moments and large spectral norms.
* **Generalization & adversarial robustness:** spectral regularization is widely observed to improve both, and Adams adopts a lightweight decoupled form.

## References

1. [Scaling Exponents Across Parameterizations and Optimizers](https://arxiv.org/pdf/2407.05872)
2. [Adaptive Preconditioners Trigger Loss Spikes in Adam](https://arxiv.org/pdf/2506.04805)
3. [Muon: An optimizer for the hidden layers of neural networks](https://github.com/KellerJordan/Muon)
4. [Spectral Norm Regularization for Improving the
Generalizability of Deep Learning](https://arxiv.org/pdf/1705.10941)
5. [Thinking from spectral norm gradient to new weight decay](https://kexue.fm/archives/10648)

## License

Apache-2.0

            

Raw data

            {
    "_id": null,
    "home_page": null,
    "name": "adams-torch",
    "maintainer": null,
    "docs_url": null,
    "requires_python": ">=3.8",
    "maintainer_email": null,
    "keywords": "pytorch, optimizer, deep learning, machine learning, spectral regularization",
    "author": "One",
    "author_email": null,
    "download_url": "https://files.pythonhosted.org/packages/9b/ca/e85c7ef7b0af57ecc3c9322721459242c7b44c6412e74feae1e08ea15779/adams_torch-0.0.1.tar.gz",
    "platform": null,
    "description": "# Adams Optimizer\n\nAdams is a next-generation optimizer that blends the simplicity of element-wise methods with the stability benefits of matrix-aware regularization. It updates neural network parameters in both **1D (per-element)** and **2D (per-matrix)** ways, staying fast and easy to parallelize like Adam/AdamW while improving stability and generalization.\n\n* **Stable:** no loss spikes observed; no gradient clipping required.\n* **Fast & scalable:** element-wise updates + one rank-1 spectral decay step per matrix; easily parallelizable.\n* **Simple:** no `epsilon` hyperparameter; truly scale-invariant per-parameter update.\n\n## Definition \ud83d\udcdd\n\n![Adams pseudocode](./assets/adams_pseudocode.png)\n\n## How Adams Works \ud83c\udf1f\n\n### 1) Bounded, element-wise update (1D)\n\nSmall second-moment estimates are a major source of instability and loss spikes in Adam-like methods. Adams replaces the usual preconditioned step with a **bounded** update using `atan2`:\n\n$$\n\\Delta \\theta \\propto \\text{atan2}\\big(\\hat m_t,\\sqrt{\\hat n_t}\\big),\n$$\n\nwhich:\n\n* naturally bounds the step size,\n* removes the need for the `epsilon` hyperparameter,\n* yields true scale invariance of the update.\n\n### 2) Spectral weight decay (2D)\n\nFor matrix parameters $W \\in \\mathbb{R}^{M \\times N}$, spectral norm better reflects the scale relevant to activations than the Frobenius norm. Adams therefore applies **decoupled spectral weight decay** (akin to AdamW\u2019s decoupling), replacing the usual $\\tfrac{1}{2}\\|W\\|_F^2$ with the spectral norm $\\tfrac{1}{2}\\sigma_1^2$:\n\n* We compute a one-step **power iteration** with persistent state (same idea as PyTorch\u2019s `spectral_norm`) to approximate the top singular triplet $(u_1, \\sigma_1, v_1)$.\n* The decay term is applied as $\\sqrt{M} u_1 \\sigma_1 v^\\top_1$ (the gradient of $\\tfrac{1}{2}\\sigma_1^2$, scaled by $\\sqrt{M}$ to match the RMS of $W$) per update step.\n* This helps control activation scales and mitigates instabilities tied to large spectral norms.\n\n**Efficiency:** the spectral step adds only two GEMV operations per matrix per update, comparable to a handful of extra element-wise ops. In typical FSDP/ZeRO setups the full weight matrix is available during forward/backward, so this integrates cleanly at scale.\n\n## Design Motivation \ud83d\udca1\n\nRecent reports suggest that fully matrix-based optimizers (e.g., Muon) can be hard to implement/parallelize broadly and often show modest end-to-end benefits on large models (~1.1x or less), despite strong stability. Meanwhile, the dominant optimizer Adam is simple and fast but prone to instability and loss spikes.\n\n**Adams** asks: *Can we keep Adam\u2019s speed and simplicity while gaining matrix-level stability?*\n\n## Installation\n\n```bash\npip install adams-torch\n```\n\n## Quick Start \ud83d\udcc8\n\nYou don\u2019t need to manually broadcast parameters or all-reduce gradients\u2014multi-GPU usage matches single-GPU usage. Fully compatible with `torch.compile`.\n\n> FSDP is not supported yet. Contributions welcome.\n\n```python\nimport os\nimport torch\nimport torch.distributed as dist\nfrom adams import Adams_ZeRO  # main optimizer\n\ndef init():\n    # Initialize distributed training if launched via torchrun/torch.distributed\n    if \"LOCAL_RANK\" in os.environ:\n        dist.init_process_group(backend=\"nccl\")\n        torch.cuda.set_device(int(os.environ[\"LOCAL_RANK\"]))\n\n    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n    model = Model().to(device)\n\n    # Spectral decay applies to matrix-shaped params.\n    # scalar_vector_weight_decay applies standard decoupled L2 to 0D/1D params.\n    optimizer = Adams_ZeRO(\n        model.parameters(),\n        lr=3e-4,\n        weight_decay=0.1,                 # spectral decay for matrices\n        scalar_vector_weight_decay=0.1,   # L2 for scalars/vectors\n        betas=(0.9, 0.95)\n    )\n\n    # Sync any internal buffers across ranks if required by your setup.\n    optimizer.broadcast_buffers(model.buffers())\n\n    return model, optimizer\n\n@torch.compile  # Optional: works with torch.compile\ndef train_step(model, optimizer, batch):\n    loss = model(batch)        # forward; compute your loss\n    loss.backward()            # backward\n    optimizer.step()           # no gradient clipping needed\n    optimizer.zero_grad(set_to_none=True)\n    return loss\n```\n\n## Notes \u26a0\ufe0f\n\nCare should be taken as matrix-based optimizers (e.g. Muon).\n\n1. **Non\u2011matrix parameters.** Disable the matrix\u2011based part (spectral decay) for parameters that are scalars, vectors, or collections of vectors (e.g. LayerNorm, Embedding, Output Head, etc.) by setting `param.use_spectral_decay = False`. Adams uses a separate decoupled L2 term, controlled by `scalar_vector_weight_decay` (default `0.1`).\n2. **Batched matrices.** Parameters that are conceptually multiple matrices concatenated along leading dimensions (e.g., attention QKV projections) should be expressed with shape `(B, M, N)`. Adams treats all dimensions except the last two as batch dimensions. \uff08P.S. In our experiments, we treat each Attention head q,k,v as separate projection matrices. E.g. there are 24 (8 head * 3) matrices in QKV proj for 8 MHA heads\uff09\n\n## Practical Tips \u270f\ufe0f\n\n* **Hyperparameters:** start with AdamW-like settings; the bounded update removes `epsilon`. Adams can handle much larger weight decay to improve generalization, e.g. `1.0`.\n* **Stability:** the bounded step and spectral decay together target sources of spikes linked to tiny second moments and large spectral norms.\n* **Generalization & adversarial robustness:** spectral regularization is widely observed to improve both, and Adams adopts a lightweight decoupled form.\n\n## References\n\n1. [Scaling Exponents Across Parameterizations and Optimizers](https://arxiv.org/pdf/2407.05872)\n2. [Adaptive Preconditioners Trigger Loss Spikes in Adam](https://arxiv.org/pdf/2506.04805)\n3. [Muon: An optimizer for the hidden layers of neural networks](https://github.com/KellerJordan/Muon)\n4. [Spectral Norm Regularization for Improving the\nGeneralizability of Deep Learning](https://arxiv.org/pdf/1705.10941)\n5. [Thinking from spectral norm gradient to new weight decay](https://kexue.fm/archives/10648)\n\n## License\n\nApache-2.0\n",
    "bugtrack_url": null,
    "license": null,
    "summary": "Adams optimizer: next-generation optimizer blending element-wise methods with matrix-aware regularization",
    "version": "0.0.1",
    "project_urls": {
        "Homepage": "https://github.com/imoneoi/Adams",
        "Issues": "https://github.com/imoneoi/Adams/issues",
        "Repository": "https://github.com/imoneoi/Adams"
    },
    "split_keywords": [
        "pytorch",
        " optimizer",
        " deep learning",
        " machine learning",
        " spectral regularization"
    ],
    "urls": [
        {
            "comment_text": null,
            "digests": {
                "blake2b_256": "b7f12486f865bcd1abf5a25acf3d066d960692278b700c92c099ec4e4dbdc0da",
                "md5": "cb34663ef77f4283348c5eb8c2088a3f",
                "sha256": "6f0588925f4ef113dd831f83644551025619b26dcd5add08a4f11ac30cc015ab"
            },
            "downloads": -1,
            "filename": "adams_torch-0.0.1-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "cb34663ef77f4283348c5eb8c2088a3f",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.8",
            "size": 12298,
            "upload_time": "2025-09-13T07:51:56",
            "upload_time_iso_8601": "2025-09-13T07:51:56.392946Z",
            "url": "https://files.pythonhosted.org/packages/b7/f1/2486f865bcd1abf5a25acf3d066d960692278b700c92c099ec4e4dbdc0da/adams_torch-0.0.1-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": null,
            "digests": {
                "blake2b_256": "9bcae85c7ef7b0af57ecc3c9322721459242c7b44c6412e74feae1e08ea15779",
                "md5": "1c89361a9796dc350734ad35bd5f5c7b",
                "sha256": "d37a4822c45904f68a04102b0d35fe0474be1e22eccc2c8d27e870f50e1dde11"
            },
            "downloads": -1,
            "filename": "adams_torch-0.0.1.tar.gz",
            "has_sig": false,
            "md5_digest": "1c89361a9796dc350734ad35bd5f5c7b",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.8",
            "size": 316463,
            "upload_time": "2025-09-13T07:51:57",
            "upload_time_iso_8601": "2025-09-13T07:51:57.674518Z",
            "url": "https://files.pythonhosted.org/packages/9b/ca/e85c7ef7b0af57ecc3c9322721459242c7b44c6412e74feae1e08ea15779/adams_torch-0.0.1.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2025-09-13 07:51:57",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "imoneoi",
    "github_project": "Adams",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "lcname": "adams-torch"
}
        
One
Elapsed time: 1.23901s