bf16-fused-adam


Namebf16-fused-adam JSON
Version 0.1 PyPI version JSON
download
home_pagehttps://github.com/imoneoi/bf16_fused_adam
SummaryBFloat16 Fused Adam Optimizer
upload_time2024-09-25 04:07:36
maintainerNone
docs_urlNone
authorOne
requires_pythonNone
licenseNone
keywords
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # BFloat16 Fused Optimizer

A mixed-precision optimizer to solve the [stale weights](https://storage.googleapis.com/deepmind-media/research/language-research/Training%20Gopher.pdf) problem of bfloat16 training.

When training models using `bfloat16` optimizer, updates might often be cancelled if it's small compared to weight in magnitude, leading to the stale weights problem, which significantly hurt performance. 

Utilizing the fact that the round-towards-zero (RTZ) result of a `float32` to `bfloat16` is the high 16 bits, this optimizer stores an extra 16-bit weights mantissa, acting as 16+16 optimizer, which is mathematically equivalent to storing an extra 32-bit master weight, solving the stale weights problem while only costs 25% more memory.

## Usage

Drop-in replacement of `torch.optim.AdamW`. All parameters need to be in `bfloat16`.
 
 - Doesn't support `foreach`, `fused` argument, as the optimizer is already fused
 - Doesn't support `amsgrad`, `maximize`, `capturable`, `differentiable` argument yet

```bash
pip install bf16_fused_adam
```

```python
from bf16_fused_adam import BF16FusedAdamW

# All supported arguments are listed below
optim = BF16FusedAdamW(model.parameters(),
    lr=1e-3,
    weight_decay=0.1,
    betas=(0.9, 0.95),
    eps=1e-5,
)
```

## Details

AdamW Reference States (PyTorch FusedAdamW):

 - param (bf16)
 - grad (bf16)
 - exp_avg (bf16)
 - exp_avg_sq (bf16)

16+16 Optimizer States (BF16FusedAdamW):

 - param (bf16, high 16 bits of master fp32 weights)
 - mantissa (uint16, low 16 bits of master fp32 weights)
 - grad (bf16)
 - exp_avg (bf16)
 - exp_avg_sq (bf16)

```
Master weight: (sign 1) (exponent 8) (mantissa 7) (mantissa 16)   = 32bit
               [             param 16           ] [mantissa 16]   = 32bit
```

## TODO

 - [ ] Stochastic rounding (trading precision for memory)
 - [ ] 16+8 optimizer (saving more memory)

 ```
Master weight: (sign 1) (exponent 8) (mantissa 7) (mantissa 8) (mantissa 8)   = 32bit
               [             param 16           ] [mantissa 8] [dropped 8]    = 24bit
```

## Consistency Tests

We tested the consistency against reference AdamW implementation. To run tests, clone this repository, run pytest:

```bash
pip install -e .
pytest
```

### Passed

 - [x] H100
 - [x] A100
 - [ ] RTX 4090 [TBD]
 - [ ] RTX 3090 [TBD]

## References

16+16 optimizer:

 - https://arxiv.org/pdf/2309.12381.pdf

PyTorch AdamW:
 - https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cuda/fused_adam_utils.cuh
 - https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cuda/FusedAdamWKernel.cu

Gopher:
 - https://storage.googleapis.com/deepmind-media/research/language-research/Training%20Gopher.pdf

            

Raw data

            {
    "_id": null,
    "home_page": "https://github.com/imoneoi/bf16_fused_adam",
    "name": "bf16-fused-adam",
    "maintainer": null,
    "docs_url": null,
    "requires_python": null,
    "maintainer_email": null,
    "keywords": null,
    "author": "One",
    "author_email": "imonenext@gmail.com",
    "download_url": "https://files.pythonhosted.org/packages/62/62/32b5d462a9af4ad59082160fb6bf815d0cbdaeed0a8546461cbadb30d8b1/bf16_fused_adam-0.1.tar.gz",
    "platform": null,
    "description": "# BFloat16 Fused Optimizer\n\nA mixed-precision optimizer to solve the [stale weights](https://storage.googleapis.com/deepmind-media/research/language-research/Training%20Gopher.pdf) problem of bfloat16 training.\n\nWhen training models using `bfloat16` optimizer, updates might often be cancelled if it's small compared to weight in magnitude, leading to the stale weights problem, which significantly hurt performance. \n\nUtilizing the fact that the round-towards-zero (RTZ) result of a `float32` to `bfloat16` is the high 16 bits, this optimizer stores an extra 16-bit weights mantissa, acting as 16+16 optimizer, which is mathematically equivalent to storing an extra 32-bit master weight, solving the stale weights problem while only costs 25% more memory.\n\n## Usage\n\nDrop-in replacement of `torch.optim.AdamW`. All parameters need to be in `bfloat16`.\n \n - Doesn't support `foreach`, `fused` argument, as the optimizer is already fused\n - Doesn't support `amsgrad`, `maximize`, `capturable`, `differentiable` argument yet\n\n```bash\npip install bf16_fused_adam\n```\n\n```python\nfrom bf16_fused_adam import BF16FusedAdamW\n\n# All supported arguments are listed below\noptim = BF16FusedAdamW(model.parameters(),\n    lr=1e-3,\n    weight_decay=0.1,\n    betas=(0.9, 0.95),\n    eps=1e-5,\n)\n```\n\n## Details\n\nAdamW Reference States (PyTorch FusedAdamW):\n\n - param (bf16)\n - grad (bf16)\n - exp_avg (bf16)\n - exp_avg_sq (bf16)\n\n16+16 Optimizer States (BF16FusedAdamW):\n\n - param (bf16, high 16 bits of master fp32 weights)\n - mantissa (uint16, low 16 bits of master fp32 weights)\n - grad (bf16)\n - exp_avg (bf16)\n - exp_avg_sq (bf16)\n\n```\nMaster weight: (sign 1) (exponent 8) (mantissa 7) (mantissa 16)   = 32bit\n               [             param 16           ] [mantissa 16]   = 32bit\n```\n\n## TODO\n\n - [ ] Stochastic rounding (trading precision for memory)\n - [ ] 16+8 optimizer (saving more memory)\n\n ```\nMaster weight: (sign 1) (exponent 8) (mantissa 7) (mantissa 8) (mantissa 8)   = 32bit\n               [             param 16           ] [mantissa 8] [dropped 8]    = 24bit\n```\n\n## Consistency Tests\n\nWe tested the consistency against reference AdamW implementation. To run tests, clone this repository, run pytest:\n\n```bash\npip install -e .\npytest\n```\n\n### Passed\n\n - [x] H100\n - [x] A100\n - [ ] RTX 4090 [TBD]\n - [ ] RTX 3090 [TBD]\n\n## References\n\n16+16 optimizer:\n\n - https://arxiv.org/pdf/2309.12381.pdf\n\nPyTorch AdamW:\n - https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cuda/fused_adam_utils.cuh\n - https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cuda/FusedAdamWKernel.cu\n\nGopher:\n - https://storage.googleapis.com/deepmind-media/research/language-research/Training%20Gopher.pdf\n",
    "bugtrack_url": null,
    "license": null,
    "summary": "BFloat16 Fused Adam Optimizer",
    "version": "0.1",
    "project_urls": {
        "Homepage": "https://github.com/imoneoi/bf16_fused_adam"
    },
    "split_keywords": [],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "626232b5d462a9af4ad59082160fb6bf815d0cbdaeed0a8546461cbadb30d8b1",
                "md5": "43790e87d6872f785098c3517689910f",
                "sha256": "393ef40b422dc0cb8002c57c9af08cdc5ad11607186119be95acd964b892885d"
            },
            "downloads": -1,
            "filename": "bf16_fused_adam-0.1.tar.gz",
            "has_sig": false,
            "md5_digest": "43790e87d6872f785098c3517689910f",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": null,
            "size": 13488,
            "upload_time": "2024-09-25T04:07:36",
            "upload_time_iso_8601": "2024-09-25T04:07:36.867570Z",
            "url": "https://files.pythonhosted.org/packages/62/62/32b5d462a9af4ad59082160fb6bf815d0cbdaeed0a8546461cbadb30d8b1/bf16_fused_adam-0.1.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2024-09-25 04:07:36",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "imoneoi",
    "github_project": "bf16_fused_adam",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "lcname": "bf16-fused-adam"
}
        
One
Elapsed time: 2.50579s