# BFloat16 Fused Optimizer
A mixed-precision optimizer to solve the [stale weights](https://storage.googleapis.com/deepmind-media/research/language-research/Training%20Gopher.pdf) problem of bfloat16 training.
When training models using `bfloat16` optimizer, updates might often be cancelled if it's small compared to weight in magnitude, leading to the stale weights problem, which significantly hurt performance.
Utilizing the fact that the round-towards-zero (RTZ) result of a `float32` to `bfloat16` is the high 16 bits, this optimizer stores an extra 16-bit weights mantissa, acting as 16+16 optimizer, which is mathematically equivalent to storing an extra 32-bit master weight, solving the stale weights problem while only costs 25% more memory.
## Usage
Drop-in replacement of `torch.optim.AdamW`. All parameters need to be in `bfloat16`.
- Doesn't support `foreach`, `fused` argument, as the optimizer is already fused
- Doesn't support `amsgrad`, `maximize`, `capturable`, `differentiable` argument yet
```bash
pip install bf16_fused_adam
```
```python
from bf16_fused_adam import BF16FusedAdamW
# All supported arguments are listed below
optim = BF16FusedAdamW(model.parameters(),
lr=1e-3,
weight_decay=0.1,
betas=(0.9, 0.95),
eps=1e-5,
)
```
## Details
AdamW Reference States (PyTorch FusedAdamW):
- param (bf16)
- grad (bf16)
- exp_avg (bf16)
- exp_avg_sq (bf16)
16+16 Optimizer States (BF16FusedAdamW):
- param (bf16, high 16 bits of master fp32 weights)
- mantissa (uint16, low 16 bits of master fp32 weights)
- grad (bf16)
- exp_avg (bf16)
- exp_avg_sq (bf16)
```
Master weight: (sign 1) (exponent 8) (mantissa 7) (mantissa 16) = 32bit
[ param 16 ] [mantissa 16] = 32bit
```
## TODO
- [ ] Stochastic rounding (trading precision for memory)
- [ ] 16+8 optimizer (saving more memory)
```
Master weight: (sign 1) (exponent 8) (mantissa 7) (mantissa 8) (mantissa 8) = 32bit
[ param 16 ] [mantissa 8] [dropped 8] = 24bit
```
## Consistency Tests
We tested the consistency against reference AdamW implementation. To run tests, clone this repository, run pytest:
```bash
pip install -e .
pytest
```
### Passed
- [x] H100
- [x] A100
- [ ] RTX 4090 [TBD]
- [ ] RTX 3090 [TBD]
## References
16+16 optimizer:
- https://arxiv.org/pdf/2309.12381.pdf
PyTorch AdamW:
- https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cuda/fused_adam_utils.cuh
- https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cuda/FusedAdamWKernel.cu
Gopher:
- https://storage.googleapis.com/deepmind-media/research/language-research/Training%20Gopher.pdf
Raw data
{
"_id": null,
"home_page": "https://github.com/imoneoi/bf16_fused_adam",
"name": "bf16-fused-adam",
"maintainer": null,
"docs_url": null,
"requires_python": null,
"maintainer_email": null,
"keywords": null,
"author": "One",
"author_email": "imonenext@gmail.com",
"download_url": "https://files.pythonhosted.org/packages/62/62/32b5d462a9af4ad59082160fb6bf815d0cbdaeed0a8546461cbadb30d8b1/bf16_fused_adam-0.1.tar.gz",
"platform": null,
"description": "# BFloat16 Fused Optimizer\n\nA mixed-precision optimizer to solve the [stale weights](https://storage.googleapis.com/deepmind-media/research/language-research/Training%20Gopher.pdf) problem of bfloat16 training.\n\nWhen training models using `bfloat16` optimizer, updates might often be cancelled if it's small compared to weight in magnitude, leading to the stale weights problem, which significantly hurt performance. \n\nUtilizing the fact that the round-towards-zero (RTZ) result of a `float32` to `bfloat16` is the high 16 bits, this optimizer stores an extra 16-bit weights mantissa, acting as 16+16 optimizer, which is mathematically equivalent to storing an extra 32-bit master weight, solving the stale weights problem while only costs 25% more memory.\n\n## Usage\n\nDrop-in replacement of `torch.optim.AdamW`. All parameters need to be in `bfloat16`.\n \n - Doesn't support `foreach`, `fused` argument, as the optimizer is already fused\n - Doesn't support `amsgrad`, `maximize`, `capturable`, `differentiable` argument yet\n\n```bash\npip install bf16_fused_adam\n```\n\n```python\nfrom bf16_fused_adam import BF16FusedAdamW\n\n# All supported arguments are listed below\noptim = BF16FusedAdamW(model.parameters(),\n lr=1e-3,\n weight_decay=0.1,\n betas=(0.9, 0.95),\n eps=1e-5,\n)\n```\n\n## Details\n\nAdamW Reference States (PyTorch FusedAdamW):\n\n - param (bf16)\n - grad (bf16)\n - exp_avg (bf16)\n - exp_avg_sq (bf16)\n\n16+16 Optimizer States (BF16FusedAdamW):\n\n - param (bf16, high 16 bits of master fp32 weights)\n - mantissa (uint16, low 16 bits of master fp32 weights)\n - grad (bf16)\n - exp_avg (bf16)\n - exp_avg_sq (bf16)\n\n```\nMaster weight: (sign 1) (exponent 8) (mantissa 7) (mantissa 16) = 32bit\n [ param 16 ] [mantissa 16] = 32bit\n```\n\n## TODO\n\n - [ ] Stochastic rounding (trading precision for memory)\n - [ ] 16+8 optimizer (saving more memory)\n\n ```\nMaster weight: (sign 1) (exponent 8) (mantissa 7) (mantissa 8) (mantissa 8) = 32bit\n [ param 16 ] [mantissa 8] [dropped 8] = 24bit\n```\n\n## Consistency Tests\n\nWe tested the consistency against reference AdamW implementation. To run tests, clone this repository, run pytest:\n\n```bash\npip install -e .\npytest\n```\n\n### Passed\n\n - [x] H100\n - [x] A100\n - [ ] RTX 4090 [TBD]\n - [ ] RTX 3090 [TBD]\n\n## References\n\n16+16 optimizer:\n\n - https://arxiv.org/pdf/2309.12381.pdf\n\nPyTorch AdamW:\n - https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cuda/fused_adam_utils.cuh\n - https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cuda/FusedAdamWKernel.cu\n\nGopher:\n - https://storage.googleapis.com/deepmind-media/research/language-research/Training%20Gopher.pdf\n",
"bugtrack_url": null,
"license": null,
"summary": "BFloat16 Fused Adam Optimizer",
"version": "0.1",
"project_urls": {
"Homepage": "https://github.com/imoneoi/bf16_fused_adam"
},
"split_keywords": [],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "626232b5d462a9af4ad59082160fb6bf815d0cbdaeed0a8546461cbadb30d8b1",
"md5": "43790e87d6872f785098c3517689910f",
"sha256": "393ef40b422dc0cb8002c57c9af08cdc5ad11607186119be95acd964b892885d"
},
"downloads": -1,
"filename": "bf16_fused_adam-0.1.tar.gz",
"has_sig": false,
"md5_digest": "43790e87d6872f785098c3517689910f",
"packagetype": "sdist",
"python_version": "source",
"requires_python": null,
"size": 13488,
"upload_time": "2024-09-25T04:07:36",
"upload_time_iso_8601": "2024-09-25T04:07:36.867570Z",
"url": "https://files.pythonhosted.org/packages/62/62/32b5d462a9af4ad59082160fb6bf815d0cbdaeed0a8546461cbadb30d8b1/bf16_fused_adam-0.1.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2024-09-25 04:07:36",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "imoneoi",
"github_project": "bf16_fused_adam",
"travis_ci": false,
"coveralls": false,
"github_actions": true,
"lcname": "bf16-fused-adam"
}