nitrous-ema


Namenitrous-ema JSON
Version 0.0.1 PyPI version JSON
download
home_pageNone
SummaryFast and simple post-hoc EMA (Karras et al., 2023) with minimal `.item()` calls.
upload_time2024-09-05 22:47:47
maintainerNone
docs_urlNone
authorNone
requires_python>=3.9
licenseNone
keywords
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # nitrous-ema
Fast and simple post-hoc EMA (Karras et al., 2023) with minimal `.item()` calls.

A fork of https://github.com/lucidrains/ema-pytorch

Features added:
- No more `.item()` calls during update which would force a device synchronization and slows things down. `initted` and `step` are now stored as Python types on CPUs. They are still put into the state dict via `set_extra_state` and `get_extra_state`. 
- Added a `step_size_correction` parameter to scale the weighting term (with geometric mean) when `update_every` is larger than 1. Otherwise the effective update rate would be too slow

Starter script:
```python
import torch
import torch.nn as nn
import torch.optim as optim
from nitrous_ema import PostHocEMA

# simple EMA application
data = torch.randn(512, 128)
target = torch.randn(512, 1)
net = nn.Linear(128, 1)
optimizer = optim.SGD(net.parameters(), lr=0.01)
ema = PostHocEMA(net,
                    sigma_rels=[0.05, 0.1],
                    checkpoint_every_num_steps=100,
                    update_every=10,
                    step_size_correction=True)

for _ in range(1000):
    optimizer.zero_grad()
    sample_idx = torch.randint(0, 512, (32, ))
    loss = (net(data[sample_idx]) - target[sample_idx]).pow(2).mean()
    loss.backward()
    optimizer.step()
    ema.update()

# Evaluate the model on the test data
with torch.no_grad():
    loss = (net(data) - target).pow(2).mean()
    print("Loss: ", loss.item())

# Evaluate the EMA model on the test data
with torch.no_grad():
    ema_model = ema.synthesize_ema_model(sigma_rel=0.08, device='cpu')
    loss = (ema_model(data) - target).pow(2).mean()
    print("EMA Loss: ", loss.item())

```
            

Raw data

            {
    "_id": null,
    "home_page": null,
    "name": "nitrous-ema",
    "maintainer": null,
    "docs_url": null,
    "requires_python": ">=3.9",
    "maintainer_email": null,
    "keywords": null,
    "author": null,
    "author_email": "Rex Cheng <hkchengrex@gmail.com>",
    "download_url": "https://files.pythonhosted.org/packages/04/8e/5aa3664e21b6379015abaac2333dea00aee4887e3f12846d4bff7fd270c8/nitrous_ema-0.0.1.tar.gz",
    "platform": null,
    "description": "# nitrous-ema\nFast and simple post-hoc EMA (Karras et al., 2023) with minimal `.item()` calls.\n\nA fork of https://github.com/lucidrains/ema-pytorch\n\nFeatures added:\n- No more `.item()` calls during update which would force a device synchronization and slows things down. `initted` and `step` are now stored as Python types on CPUs. They are still put into the state dict via `set_extra_state` and `get_extra_state`. \n- Added a `step_size_correction` parameter to scale the weighting term (with geometric mean) when `update_every` is larger than 1. Otherwise the effective update rate would be too slow\n\nStarter script:\n```python\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom nitrous_ema import PostHocEMA\n\n# simple EMA application\ndata = torch.randn(512, 128)\ntarget = torch.randn(512, 1)\nnet = nn.Linear(128, 1)\noptimizer = optim.SGD(net.parameters(), lr=0.01)\nema = PostHocEMA(net,\n                    sigma_rels=[0.05, 0.1],\n                    checkpoint_every_num_steps=100,\n                    update_every=10,\n                    step_size_correction=True)\n\nfor _ in range(1000):\n    optimizer.zero_grad()\n    sample_idx = torch.randint(0, 512, (32, ))\n    loss = (net(data[sample_idx]) - target[sample_idx]).pow(2).mean()\n    loss.backward()\n    optimizer.step()\n    ema.update()\n\n# Evaluate the model on the test data\nwith torch.no_grad():\n    loss = (net(data) - target).pow(2).mean()\n    print(\"Loss: \", loss.item())\n\n# Evaluate the EMA model on the test data\nwith torch.no_grad():\n    ema_model = ema.synthesize_ema_model(sigma_rel=0.08, device='cpu')\n    loss = (ema_model(data) - target).pow(2).mean()\n    print(\"EMA Loss: \", loss.item())\n\n```",
    "bugtrack_url": null,
    "license": null,
    "summary": "Fast and simple post-hoc EMA (Karras et al., 2023) with minimal `.item()` calls.",
    "version": "0.0.1",
    "project_urls": {
        "Homepage": "https://github.com/hkchengrex/nitrous-ema",
        "Issues": "https://github.com/hkchengrex/nitrous-ema/issues"
    },
    "split_keywords": [],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "dd9cae7670eee17bc0841c66288eb04248519cff0d9984483db0747dd8d00d6b",
                "md5": "790cdc16786b87e70ad26e28d2335eab",
                "sha256": "74c9a8a7309e308ce27b984ae06da3d68c81e289b3682f794d8d7aa1e8dff90e"
            },
            "downloads": -1,
            "filename": "nitrous_ema-0.0.1-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "790cdc16786b87e70ad26e28d2335eab",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.9",
            "size": 6704,
            "upload_time": "2024-09-05T22:47:46",
            "upload_time_iso_8601": "2024-09-05T22:47:46.062097Z",
            "url": "https://files.pythonhosted.org/packages/dd/9c/ae7670eee17bc0841c66288eb04248519cff0d9984483db0747dd8d00d6b/nitrous_ema-0.0.1-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "048e5aa3664e21b6379015abaac2333dea00aee4887e3f12846d4bff7fd270c8",
                "md5": "4faec861807065c5b949ff57b6e6fbea",
                "sha256": "f13812c09c3e9499581d1adcdbf62e51f60814800a24cc930e752f04a4d7ef75"
            },
            "downloads": -1,
            "filename": "nitrous_ema-0.0.1.tar.gz",
            "has_sig": false,
            "md5_digest": "4faec861807065c5b949ff57b6e6fbea",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.9",
            "size": 7913,
            "upload_time": "2024-09-05T22:47:47",
            "upload_time_iso_8601": "2024-09-05T22:47:47.052375Z",
            "url": "https://files.pythonhosted.org/packages/04/8e/5aa3664e21b6379015abaac2333dea00aee4887e3f12846d4bff7fd270c8/nitrous_ema-0.0.1.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2024-09-05 22:47:47",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "hkchengrex",
    "github_project": "nitrous-ema",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": false,
    "lcname": "nitrous-ema"
}
        
Elapsed time: 0.54702s