sbgm


Namesbgm JSON
Version 0.0.21 PyPI version JSON
download
home_pageNone
SummaryScore-based Diffusion models in JAX.
upload_time2025-01-09 17:59:54
maintainerNone
docs_urlNone
authorNone
requires_python~=3.12
licenseMIT License Copyright (c) [year] [fullname] Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
keywords deep-learning diffusion-models generative-models jax score-based-diffusion
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            <h1 align='center'>sbgm</h1>
<h2 align='center'>Score-Based Diffusion Models in JAX</h2>

Implementation and extension of [Score-Based Generative Modeling through Stochastic Differential Equations (Song++20)](https://arxiv.org/abs/2011.13456) and [Maximum Likelihood Training of Score-Based Diffusion Models (Song++21)](https://arxiv.org/abs/2101.09258) in `jax` and `equinox`. 

This repository provides a lightweight library of models, sampling and likelihood routines. Suitable for likelihood-free or emulation based approaches. Tested and typed code to ensure reliable and benchmarkable training and inference.

> [!WARNING]
> :building_construction: Note this repository is under construction, expect changes. :building_construction:

### Score-based diffusion models

Diffusion models are deep hierarchical models for data that use neural networks to model the reverse of a diffusion process that adds a sequence of noise perturbations to the data. 

Modern cutting-edge diffusion models (see citations) express both the forward and reverse diffusion processes as a Stochastic Differential Equation (SDE).

-----

<p align="center">
  <img src="https://github.com/homerjed/sbgm/blob/main/assets/sde_ode.png" />
</p>

*A diagram showing how to map data to a noise distribution (the prior) with an SDE, and reverse this SDE for generative modeling. One can also reverse the associated probability flow ODE, which yields a deterministic process that samples from the same distribution as the SDE. Both the reverse-time SDE and probability flow ODE can be obtained by estimating the score.* 
<!-- $\nabla_{\boldsymbol{x}} \log p_t(\boldsymbol{x}_t)$ -->

-----

For any SDE of the form 

$$
\text{d}\boldsymbol{x} = f(\boldsymbol{x}, t)\text{d}t + g(t)\text{d}\boldsymbol{w},
$$

the reverse of the SDE from noise to data is given by 

$$
\text{d}\boldsymbol{x} = [f(\boldsymbol{x}, t) - g(t)^2\nabla_{\boldsymbol{x}}\log p_t(\boldsymbol{x})]\text{d}t + g(t)\text{d}\boldsymbol{w}.
$$

For every SDE there exists an associated ordinary differential equation (ODE)

$$
\text{d}\boldsymbol{x} = [f(\boldsymbol{x}, t)\text{d}t - \frac{1}{2}g(t)^2\nabla_{\boldsymbol{x}}\log p_t(\boldsymbol{x})]\text{d}t,
$$

where the trajectories of the SDE and ODE have the same marginal PDFs $p_t(\boldsymbol{x})$.

The Stein score of the marginal probability distributions over $t$ is approximated with a neural network $\nabla_{\boldsymbol{x}}\log p_t(\boldsymbol{x})\approx s_{\theta}(\boldsymbol{x}(t), t)$. The parameters of the neural network are fit by minimising the score-matching loss.

### Computing log-likelihoods with diffusion models

For each SDE there exists a deterministic ODE with marginal likelihoods $p_t(\boldsymbol{x})$ that match the SDE for all time $t$

$$
\text{d}\boldsymbol{x} = [f(\boldsymbol{x}, t)\text{d}t - \frac{1}{2}g(t)^2\nabla_{\boldsymbol{x}}\log p_t(\boldsymbol{x})]\text{d}t = f'(\boldsymbol{x}(t), t)\text{d}t.
$$

The continuous normalizing flow formalism allows the ODE to be expressed as

$$
\frac{\partial}{\partial t} \log p(\boldsymbol{x}(t)) = \nabla_{\boldsymbol{x}} \cdot f'(\boldsymbol{x}(t), t),
$$

which gives the log-likelihood of a datapoint $\boldsymbol{x}$ as 

$$
\log p(\boldsymbol{x}(0)) = \log p(\boldsymbol{x}(T)) + \int_{t=0}^{t=T}\text{d}t \; \nabla_{\boldsymbol{x}}\cdot f'(\boldsymbol{x}, t).
$$

Note that maximum-likelihood training is prohibitively expensive for SDE based diffusion models.

### Usage

Install via

```
pip install sbgm
```

See [examples](https://github.com/homerjed/sbgm/tree/main/examples).

To run on the `cifar10` image dataset, try something like

```python
import sbgm
import data
import configs

datasets_path = "."
root_dir = "."

config = configs.cifar10_config()

key = jr.key(config.seed)
data_key, model_key, train_key = jr.split(key, 3)

dataset = data.cifar10(datasets_path, data_key)

sharding = sbgm.shard.get_sharding()
    
# Diffusion model 
model = sbgm.models.get_model(
    model_key, 
    config.model.model_type, 
    dataset.data_shape, 
    dataset.context_shape, 
    dataset.parameter_dim,
    config
)

# Stochastic differential equation (SDE)
sde = sbgm.sde.get_sde(config.sde)

# Fit model to dataset
model = sbgm.train.train(
    train_key,
    model,
    sde,
    dataset,
    config,
    sharding=sharding,
    save_dir=root_dir
)
```

### Features

* Parallelised exact and approximate log-likelihood calculations,
* UNet and transformer score network implementations,
* VP, SubVP and VE SDEs (neural network $\beta(t)$ and $\sigma(t)$ functions are on the list!),
* Multi-modal conditioning (basically just optional parameter and image conditioning methods),
* Checkpointing optimiser and model,
* Multi-device training and sampling.

### Samples

> [!NOTE]
> I haven't optimised any training/architecture hyperparameters or trained long enough here, you could do a lot better. 

<h4 align='left'>Flowers</h4>

Euler-Marayama sampling
![Flowers Euler-Marayama sampling](assets/flowers_eu.png?raw=true)

ODE sampling
![Flowers ODE sampling](assets/flowers_ode.png?raw=true)

<h4 align='left'>CIFAR10</h4>

Euler-Marayama sampling
![CIFAR10 Euler-marayama sampling](assets/cifar10_eu.png?raw=true)

ODE sampling
![CIFAR10 ODE sampling](assets/cifar10_ode.png?raw=true)

<!-- ![alt text](assets/flowers_ode.png?raw=true) -->

### SDEs 
![alt text](assets/sdes.png?raw=true)

### Citations
```bibtex
@misc{song2021scorebasedgenerativemodelingstochastic,
      title={Score-Based Generative Modeling through Stochastic Differential Equations}, 
      author={Yang Song and Jascha Sohl-Dickstein and Diederik P. Kingma and Abhishek Kumar and Stefano Ermon and Ben Poole},
      year={2021},
      eprint={2011.13456},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2011.13456}, 
}
```

```bibtex
@misc{song2021maximumlikelihoodtrainingscorebased,
      title={Maximum Likelihood Training of Score-Based Diffusion Models}, 
      author={Yang Song and Conor Durkan and Iain Murray and Stefano Ermon},
      year={2021},
      eprint={2101.09258},
      archivePrefix={arXiv},
      primaryClass={stat.ML},
      url={https://arxiv.org/abs/2101.09258}, 
}
```
            

Raw data

            {
    "_id": null,
    "home_page": null,
    "name": "sbgm",
    "maintainer": null,
    "docs_url": null,
    "requires_python": "~=3.12",
    "maintainer_email": null,
    "keywords": "deep-learning, diffusion-models, generative-models, jax, score-based-diffusion",
    "author": null,
    "author_email": "Jed Homer <jedhmr@gmail.com>",
    "download_url": "https://files.pythonhosted.org/packages/e8/c0/6461ee2b29dff384e9d5a483b92692f061934935b4e70af9618dba5402e0/sbgm-0.0.21.tar.gz",
    "platform": null,
    "description": "<h1 align='center'>sbgm</h1>\n<h2 align='center'>Score-Based Diffusion Models in JAX</h2>\n\nImplementation and extension of [Score-Based Generative Modeling through Stochastic Differential Equations (Song++20)](https://arxiv.org/abs/2011.13456) and [Maximum Likelihood Training of Score-Based Diffusion Models (Song++21)](https://arxiv.org/abs/2101.09258) in `jax` and `equinox`. \n\nThis repository provides a lightweight library of models, sampling and likelihood routines. Suitable for likelihood-free or emulation based approaches. Tested and typed code to ensure reliable and benchmarkable training and inference.\n\n> [!WARNING]\n> :building_construction: Note this repository is under construction, expect changes. :building_construction:\n\n### Score-based diffusion models\n\nDiffusion models are deep hierarchical models for data that use neural networks to model the reverse of a diffusion process that adds a sequence of noise perturbations to the data. \n\nModern cutting-edge diffusion models (see citations) express both the forward and reverse diffusion processes as a Stochastic Differential Equation (SDE).\n\n-----\n\n<p align=\"center\">\n  <img src=\"https://github.com/homerjed/sbgm/blob/main/assets/sde_ode.png\" />\n</p>\n\n*A diagram showing how to map data to a noise distribution (the prior) with an SDE, and reverse this SDE for generative modeling. One can also reverse the associated probability flow ODE, which yields a deterministic process that samples from the same distribution as the SDE. Both the reverse-time SDE and probability flow ODE can be obtained by estimating the score.* \n<!-- $\\nabla_{\\boldsymbol{x}} \\log p_t(\\boldsymbol{x}_t)$ -->\n\n-----\n\nFor any SDE of the form \n\n$$\n\\text{d}\\boldsymbol{x} = f(\\boldsymbol{x}, t)\\text{d}t + g(t)\\text{d}\\boldsymbol{w},\n$$\n\nthe reverse of the SDE from noise to data is given by \n\n$$\n\\text{d}\\boldsymbol{x} = [f(\\boldsymbol{x}, t) - g(t)^2\\nabla_{\\boldsymbol{x}}\\log p_t(\\boldsymbol{x})]\\text{d}t + g(t)\\text{d}\\boldsymbol{w}.\n$$\n\nFor every SDE there exists an associated ordinary differential equation (ODE)\n\n$$\n\\text{d}\\boldsymbol{x} = [f(\\boldsymbol{x}, t)\\text{d}t - \\frac{1}{2}g(t)^2\\nabla_{\\boldsymbol{x}}\\log p_t(\\boldsymbol{x})]\\text{d}t,\n$$\n\nwhere the trajectories of the SDE and ODE have the same marginal PDFs $p_t(\\boldsymbol{x})$.\n\nThe Stein score of the marginal probability distributions over $t$ is approximated with a neural network $\\nabla_{\\boldsymbol{x}}\\log p_t(\\boldsymbol{x})\\approx s_{\\theta}(\\boldsymbol{x}(t), t)$. The parameters of the neural network are fit by minimising the score-matching loss.\n\n### Computing log-likelihoods with diffusion models\n\nFor each SDE there exists a deterministic ODE with marginal likelihoods $p_t(\\boldsymbol{x})$ that match the SDE for all time $t$\n\n$$\n\\text{d}\\boldsymbol{x} = [f(\\boldsymbol{x}, t)\\text{d}t - \\frac{1}{2}g(t)^2\\nabla_{\\boldsymbol{x}}\\log p_t(\\boldsymbol{x})]\\text{d}t = f'(\\boldsymbol{x}(t), t)\\text{d}t.\n$$\n\nThe continuous normalizing flow formalism allows the ODE to be expressed as\n\n$$\n\\frac{\\partial}{\\partial t} \\log p(\\boldsymbol{x}(t)) = \\nabla_{\\boldsymbol{x}} \\cdot f'(\\boldsymbol{x}(t), t),\n$$\n\nwhich gives the log-likelihood of a datapoint $\\boldsymbol{x}$ as \n\n$$\n\\log p(\\boldsymbol{x}(0)) = \\log p(\\boldsymbol{x}(T)) + \\int_{t=0}^{t=T}\\text{d}t \\; \\nabla_{\\boldsymbol{x}}\\cdot f'(\\boldsymbol{x}, t).\n$$\n\nNote that maximum-likelihood training is prohibitively expensive for SDE based diffusion models.\n\n### Usage\n\nInstall via\n\n```\npip install sbgm\n```\n\nSee [examples](https://github.com/homerjed/sbgm/tree/main/examples).\n\nTo run on the `cifar10` image dataset, try something like\n\n```python\nimport sbgm\nimport data\nimport configs\n\ndatasets_path = \".\"\nroot_dir = \".\"\n\nconfig = configs.cifar10_config()\n\nkey = jr.key(config.seed)\ndata_key, model_key, train_key = jr.split(key, 3)\n\ndataset = data.cifar10(datasets_path, data_key)\n\nsharding = sbgm.shard.get_sharding()\n    \n# Diffusion model \nmodel = sbgm.models.get_model(\n    model_key, \n    config.model.model_type, \n    dataset.data_shape, \n    dataset.context_shape, \n    dataset.parameter_dim,\n    config\n)\n\n# Stochastic differential equation (SDE)\nsde = sbgm.sde.get_sde(config.sde)\n\n# Fit model to dataset\nmodel = sbgm.train.train(\n    train_key,\n    model,\n    sde,\n    dataset,\n    config,\n    sharding=sharding,\n    save_dir=root_dir\n)\n```\n\n### Features\n\n* Parallelised exact and approximate log-likelihood calculations,\n* UNet and transformer score network implementations,\n* VP, SubVP and VE SDEs (neural network $\\beta(t)$ and $\\sigma(t)$ functions are on the list!),\n* Multi-modal conditioning (basically just optional parameter and image conditioning methods),\n* Checkpointing optimiser and model,\n* Multi-device training and sampling.\n\n### Samples\n\n> [!NOTE]\n> I haven't optimised any training/architecture hyperparameters or trained long enough here, you could do a lot better. \n\n<h4 align='left'>Flowers</h4>\n\nEuler-Marayama sampling\n![Flowers Euler-Marayama sampling](assets/flowers_eu.png?raw=true)\n\nODE sampling\n![Flowers ODE sampling](assets/flowers_ode.png?raw=true)\n\n<h4 align='left'>CIFAR10</h4>\n\nEuler-Marayama sampling\n![CIFAR10 Euler-marayama sampling](assets/cifar10_eu.png?raw=true)\n\nODE sampling\n![CIFAR10 ODE sampling](assets/cifar10_ode.png?raw=true)\n\n<!-- ![alt text](assets/flowers_ode.png?raw=true) -->\n\n### SDEs \n![alt text](assets/sdes.png?raw=true)\n\n### Citations\n```bibtex\n@misc{song2021scorebasedgenerativemodelingstochastic,\n      title={Score-Based Generative Modeling through Stochastic Differential Equations}, \n      author={Yang Song and Jascha Sohl-Dickstein and Diederik P. Kingma and Abhishek Kumar and Stefano Ermon and Ben Poole},\n      year={2021},\n      eprint={2011.13456},\n      archivePrefix={arXiv},\n      primaryClass={cs.LG},\n      url={https://arxiv.org/abs/2011.13456}, \n}\n```\n\n```bibtex\n@misc{song2021maximumlikelihoodtrainingscorebased,\n      title={Maximum Likelihood Training of Score-Based Diffusion Models}, \n      author={Yang Song and Conor Durkan and Iain Murray and Stefano Ermon},\n      year={2021},\n      eprint={2101.09258},\n      archivePrefix={arXiv},\n      primaryClass={stat.ML},\n      url={https://arxiv.org/abs/2101.09258}, \n}\n```",
    "bugtrack_url": null,
    "license": "MIT License  Copyright (c) [year] [fullname]  Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the \"Software\"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:  The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.  THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.",
    "summary": "Score-based Diffusion models in JAX.",
    "version": "0.0.21",
    "project_urls": {
        "repository": "https://github.com/homerjed/sbgm"
    },
    "split_keywords": [
        "deep-learning",
        " diffusion-models",
        " generative-models",
        " jax",
        " score-based-diffusion"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "04d2755b7ae5bab88b608bb60755227ee7268f2e734dd9a3b37c157bfc6dd5b9",
                "md5": "e8f20ecc17821b48b4943dbb7aa93905",
                "sha256": "63256c55bb3e22c7710b97467e7c07151534a5c2daa239851c10f0905f96d013"
            },
            "downloads": -1,
            "filename": "sbgm-0.0.21-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "e8f20ecc17821b48b4943dbb7aa93905",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": "~=3.12",
            "size": 35678,
            "upload_time": "2025-01-09T17:59:53",
            "upload_time_iso_8601": "2025-01-09T17:59:53.297610Z",
            "url": "https://files.pythonhosted.org/packages/04/d2/755b7ae5bab88b608bb60755227ee7268f2e734dd9a3b37c157bfc6dd5b9/sbgm-0.0.21-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "e8c06461ee2b29dff384e9d5a483b92692f061934935b4e70af9618dba5402e0",
                "md5": "36172861c49c1d4a7f46a9185b84af00",
                "sha256": "5045ca3f69b2cca4f91b83600ea18643ae1dd3aba19cde2a6a012966ecad1242"
            },
            "downloads": -1,
            "filename": "sbgm-0.0.21.tar.gz",
            "has_sig": false,
            "md5_digest": "36172861c49c1d4a7f46a9185b84af00",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": "~=3.12",
            "size": 27156,
            "upload_time": "2025-01-09T17:59:54",
            "upload_time_iso_8601": "2025-01-09T17:59:54.552335Z",
            "url": "https://files.pythonhosted.org/packages/e8/c0/6461ee2b29dff384e9d5a483b92692f061934935b4e70af9618dba5402e0/sbgm-0.0.21.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2025-01-09 17:59:54",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "homerjed",
    "github_project": "sbgm",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "lcname": "sbgm"
}
        
Elapsed time: 0.42073s