| Name | nitrous-ema JSON |
| Version |
0.0.1
JSON |
| download |
| home_page | None |
| Summary | Fast and simple post-hoc EMA (Karras et al., 2023) with minimal `.item()` calls. |
| upload_time | 2024-09-05 22:47:47 |
| maintainer | None |
| docs_url | None |
| author | None |
| requires_python | >=3.9 |
| license | None |
| keywords |
|
| VCS |
 |
| bugtrack_url |
|
| requirements |
No requirements were recorded.
|
| Travis-CI |
No Travis.
|
| coveralls test coverage |
No coveralls.
|
# nitrous-ema
Fast and simple post-hoc EMA (Karras et al., 2023) with minimal `.item()` calls.
A fork of https://github.com/lucidrains/ema-pytorch
Features added:
- No more `.item()` calls during update which would force a device synchronization and slows things down. `initted` and `step` are now stored as Python types on CPUs. They are still put into the state dict via `set_extra_state` and `get_extra_state`.
- Added a `step_size_correction` parameter to scale the weighting term (with geometric mean) when `update_every` is larger than 1. Otherwise the effective update rate would be too slow
Starter script:
```python
import torch
import torch.nn as nn
import torch.optim as optim
from nitrous_ema import PostHocEMA
# simple EMA application
data = torch.randn(512, 128)
target = torch.randn(512, 1)
net = nn.Linear(128, 1)
optimizer = optim.SGD(net.parameters(), lr=0.01)
ema = PostHocEMA(net,
sigma_rels=[0.05, 0.1],
checkpoint_every_num_steps=100,
update_every=10,
step_size_correction=True)
for _ in range(1000):
optimizer.zero_grad()
sample_idx = torch.randint(0, 512, (32, ))
loss = (net(data[sample_idx]) - target[sample_idx]).pow(2).mean()
loss.backward()
optimizer.step()
ema.update()
# Evaluate the model on the test data
with torch.no_grad():
loss = (net(data) - target).pow(2).mean()
print("Loss: ", loss.item())
# Evaluate the EMA model on the test data
with torch.no_grad():
ema_model = ema.synthesize_ema_model(sigma_rel=0.08, device='cpu')
loss = (ema_model(data) - target).pow(2).mean()
print("EMA Loss: ", loss.item())
```
Raw data
{
"_id": null,
"home_page": null,
"name": "nitrous-ema",
"maintainer": null,
"docs_url": null,
"requires_python": ">=3.9",
"maintainer_email": null,
"keywords": null,
"author": null,
"author_email": "Rex Cheng <hkchengrex@gmail.com>",
"download_url": "https://files.pythonhosted.org/packages/04/8e/5aa3664e21b6379015abaac2333dea00aee4887e3f12846d4bff7fd270c8/nitrous_ema-0.0.1.tar.gz",
"platform": null,
"description": "# nitrous-ema\nFast and simple post-hoc EMA (Karras et al., 2023) with minimal `.item()` calls.\n\nA fork of https://github.com/lucidrains/ema-pytorch\n\nFeatures added:\n- No more `.item()` calls during update which would force a device synchronization and slows things down. `initted` and `step` are now stored as Python types on CPUs. They are still put into the state dict via `set_extra_state` and `get_extra_state`. \n- Added a `step_size_correction` parameter to scale the weighting term (with geometric mean) when `update_every` is larger than 1. Otherwise the effective update rate would be too slow\n\nStarter script:\n```python\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom nitrous_ema import PostHocEMA\n\n# simple EMA application\ndata = torch.randn(512, 128)\ntarget = torch.randn(512, 1)\nnet = nn.Linear(128, 1)\noptimizer = optim.SGD(net.parameters(), lr=0.01)\nema = PostHocEMA(net,\n sigma_rels=[0.05, 0.1],\n checkpoint_every_num_steps=100,\n update_every=10,\n step_size_correction=True)\n\nfor _ in range(1000):\n optimizer.zero_grad()\n sample_idx = torch.randint(0, 512, (32, ))\n loss = (net(data[sample_idx]) - target[sample_idx]).pow(2).mean()\n loss.backward()\n optimizer.step()\n ema.update()\n\n# Evaluate the model on the test data\nwith torch.no_grad():\n loss = (net(data) - target).pow(2).mean()\n print(\"Loss: \", loss.item())\n\n# Evaluate the EMA model on the test data\nwith torch.no_grad():\n ema_model = ema.synthesize_ema_model(sigma_rel=0.08, device='cpu')\n loss = (ema_model(data) - target).pow(2).mean()\n print(\"EMA Loss: \", loss.item())\n\n```",
"bugtrack_url": null,
"license": null,
"summary": "Fast and simple post-hoc EMA (Karras et al., 2023) with minimal `.item()` calls.",
"version": "0.0.1",
"project_urls": {
"Homepage": "https://github.com/hkchengrex/nitrous-ema",
"Issues": "https://github.com/hkchengrex/nitrous-ema/issues"
},
"split_keywords": [],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "dd9cae7670eee17bc0841c66288eb04248519cff0d9984483db0747dd8d00d6b",
"md5": "790cdc16786b87e70ad26e28d2335eab",
"sha256": "74c9a8a7309e308ce27b984ae06da3d68c81e289b3682f794d8d7aa1e8dff90e"
},
"downloads": -1,
"filename": "nitrous_ema-0.0.1-py3-none-any.whl",
"has_sig": false,
"md5_digest": "790cdc16786b87e70ad26e28d2335eab",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": ">=3.9",
"size": 6704,
"upload_time": "2024-09-05T22:47:46",
"upload_time_iso_8601": "2024-09-05T22:47:46.062097Z",
"url": "https://files.pythonhosted.org/packages/dd/9c/ae7670eee17bc0841c66288eb04248519cff0d9984483db0747dd8d00d6b/nitrous_ema-0.0.1-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": "",
"digests": {
"blake2b_256": "048e5aa3664e21b6379015abaac2333dea00aee4887e3f12846d4bff7fd270c8",
"md5": "4faec861807065c5b949ff57b6e6fbea",
"sha256": "f13812c09c3e9499581d1adcdbf62e51f60814800a24cc930e752f04a4d7ef75"
},
"downloads": -1,
"filename": "nitrous_ema-0.0.1.tar.gz",
"has_sig": false,
"md5_digest": "4faec861807065c5b949ff57b6e6fbea",
"packagetype": "sdist",
"python_version": "source",
"requires_python": ">=3.9",
"size": 7913,
"upload_time": "2024-09-05T22:47:47",
"upload_time_iso_8601": "2024-09-05T22:47:47.052375Z",
"url": "https://files.pythonhosted.org/packages/04/8e/5aa3664e21b6379015abaac2333dea00aee4887e3f12846d4bff7fd270c8/nitrous_ema-0.0.1.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2024-09-05 22:47:47",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "hkchengrex",
"github_project": "nitrous-ema",
"travis_ci": false,
"coveralls": false,
"github_actions": false,
"lcname": "nitrous-ema"
}