SGMCMCJax


NameSGMCMCJax JSON
Version 0.2.13 PyPI version JSON
download
home_pagehttps://github.com/jeremiecoullon/SGMCMCJax
SummarySGMCMC samplers in JAX
upload_time2023-08-07 18:50:21
maintainer
docs_urlNone
authorJeremie Coullon
requires_python
licenseLICENSE.txt
keywords
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # SGMCMCJax

[**Quickstart**](#example-usage) | [**Samplers**](#samplers) | [**Documentation**](https://sgmcmcjax.readthedocs.io/en/latest/index.html)

SGMCMCJax is a lightweight library of stochastic gradient Markov chain Monte Carlo (SGMCMC) algorithms. The aim is to include both standard samplers (SGLD, SGHMC) as well as state of the art samplers while requiring only JAX to run.

The target audience for this library is researchers and practitioners: simply plug in your JAX model and easily obtain samples.

[![DOI](https://joss.theoj.org/papers/10.21105/joss.04113/status.svg)](https://doi.org/10.21105/joss.04113)

## Example usage

We show the basic usage with the following example of estimating the mean of a D-dimensional Gaussian from data using a Gaussian prior.

```python
import jax.numpy as jnp
from jax import random
from sgmcmcjax.samplers import build_sgld_sampler


# define model in JAX
def loglikelihood(theta, x):
    return -0.5*jnp.dot(x-theta, x-theta)

def logprior(theta):
    return -0.5*jnp.dot(theta, theta)*0.01

# generate dataset
N, D = 10_000, 100
key = random.PRNGKey(0)
X_data = random.normal(key, shape=(N, D))

# build sampler
batch_size = int(0.1*N)
dt = 1e-5
my_sampler = build_sgld_sampler(dt, loglikelihood, logprior, (X_data,), batch_size)

# run sampler
Nsamples = 10_000
samples = my_sampler(key, Nsamples, jnp.zeros(D))
```

## Ask a question or open an issue

Please open issues on [Github Issue Tracker](https://github.com/jeremiecoullon/SGMCMCJax/issues), or ask a question in the [Discussion section](https://github.com/jeremiecoullon/SGMCMCJax/discussions) on Github.


## Samplers

The library includes several SGMCMC algorithms with their pros and cons briefly discussed in the [documentation](https://sgmcmcjax.readthedocs.io/en/latest/all_samplers.html).

The current list of samplers is:

- SGLD
- SGLD-CV
- SVRG-Langevin
- SGHMC
- SGHMC-CV
- SVRG-SGHMC
- pSGLD
- SGLDAdam
- BAOAB
- SGNHT
- SGNHT-CV
- BADODAB
- BADODAB-CV


## Installation

Create a virtual environment and either install a stable version using pip or install the development version.

### Stable version

To install the latest stable version run:

```
pip install sgmcmcjax
```

### Development version

To install the development version run:

```
git clone https://github.com/jeremiecoullon/SGMCMCJax.git
cd SGMCMCJax
python -m pip install -e .
```
Then run the tests with `pip install -r requirements-dev.txt; make`

To run code style checks: `make lint`

## Citing SGMCMCJax

Please use the following bibtex reference to cite this repository:

```
@article{Coullon2022,
  doi = {10.21105/joss.04113},
  url = {https://doi.org/10.21105/joss.04113},
  year = {2022},
  publisher = {The Open Journal},
  volume = {7},
  number = {72},
  pages = {4113},
  author = {Jeremie Coullon and Christopher Nemeth},
  title = {SGMCMCJax: a lightweight JAX library for stochastic gradient Markov chain Monte Carlo algorithms},
  journal = {Journal of Open Source Software}
}
```




            

Raw data

            {
    "_id": null,
    "home_page": "https://github.com/jeremiecoullon/SGMCMCJax",
    "name": "SGMCMCJax",
    "maintainer": "",
    "docs_url": null,
    "requires_python": "",
    "maintainer_email": "",
    "keywords": "",
    "author": "Jeremie Coullon",
    "author_email": "jeremie.coullon@gmail.com",
    "download_url": "https://files.pythonhosted.org/packages/be/f5/6d4c545db7d2672354245ce03d5d9c518f2262bb00edc1e61cd4e9154a99/SGMCMCJax-0.2.13.tar.gz",
    "platform": null,
    "description": "# SGMCMCJax\n\n[**Quickstart**](#example-usage) | [**Samplers**](#samplers) | [**Documentation**](https://sgmcmcjax.readthedocs.io/en/latest/index.html)\n\nSGMCMCJax is a lightweight library of stochastic gradient Markov chain Monte Carlo (SGMCMC) algorithms. The aim is to include both standard samplers (SGLD, SGHMC) as well as state of the art samplers while requiring only JAX to run.\n\nThe target audience for this library is researchers and practitioners: simply plug in your JAX model and easily obtain samples.\n\n[![DOI](https://joss.theoj.org/papers/10.21105/joss.04113/status.svg)](https://doi.org/10.21105/joss.04113)\n\n## Example usage\n\nWe show the basic usage with the following example of estimating the mean of a D-dimensional Gaussian from data using a Gaussian prior.\n\n```python\nimport jax.numpy as jnp\nfrom jax import random\nfrom sgmcmcjax.samplers import build_sgld_sampler\n\n\n# define model in JAX\ndef loglikelihood(theta, x):\n    return -0.5*jnp.dot(x-theta, x-theta)\n\ndef logprior(theta):\n    return -0.5*jnp.dot(theta, theta)*0.01\n\n# generate dataset\nN, D = 10_000, 100\nkey = random.PRNGKey(0)\nX_data = random.normal(key, shape=(N, D))\n\n# build sampler\nbatch_size = int(0.1*N)\ndt = 1e-5\nmy_sampler = build_sgld_sampler(dt, loglikelihood, logprior, (X_data,), batch_size)\n\n# run sampler\nNsamples = 10_000\nsamples = my_sampler(key, Nsamples, jnp.zeros(D))\n```\n\n## Ask a question or open an issue\n\nPlease open issues on [Github Issue Tracker](https://github.com/jeremiecoullon/SGMCMCJax/issues), or ask a question in the [Discussion section](https://github.com/jeremiecoullon/SGMCMCJax/discussions) on Github.\n\n\n## Samplers\n\nThe library includes several SGMCMC algorithms with their pros and cons briefly discussed in the [documentation](https://sgmcmcjax.readthedocs.io/en/latest/all_samplers.html).\n\nThe current list of samplers is:\n\n- SGLD\n- SGLD-CV\n- SVRG-Langevin\n- SGHMC\n- SGHMC-CV\n- SVRG-SGHMC\n- pSGLD\n- SGLDAdam\n- BAOAB\n- SGNHT\n- SGNHT-CV\n- BADODAB\n- BADODAB-CV\n\n\n## Installation\n\nCreate a virtual environment and either install a stable version using pip or install the development version.\n\n### Stable version\n\nTo install the latest stable version run:\n\n```\npip install sgmcmcjax\n```\n\n### Development version\n\nTo install the development version run:\n\n```\ngit clone https://github.com/jeremiecoullon/SGMCMCJax.git\ncd SGMCMCJax\npython -m pip install -e .\n```\nThen run the tests with `pip install -r requirements-dev.txt; make`\n\nTo run code style checks: `make lint`\n\n## Citing SGMCMCJax\n\nPlease use the following bibtex reference to cite this repository:\n\n```\n@article{Coullon2022,\n  doi = {10.21105/joss.04113},\n  url = {https://doi.org/10.21105/joss.04113},\n  year = {2022},\n  publisher = {The Open Journal},\n  volume = {7},\n  number = {72},\n  pages = {4113},\n  author = {Jeremie Coullon and Christopher Nemeth},\n  title = {SGMCMCJax: a lightweight JAX library for stochastic gradient Markov chain Monte Carlo algorithms},\n  journal = {Journal of Open Source Software}\n}\n```\n\n\n\n",
    "bugtrack_url": null,
    "license": "LICENSE.txt",
    "summary": "SGMCMC samplers in JAX",
    "version": "0.2.13",
    "project_urls": {
        "Homepage": "https://github.com/jeremiecoullon/SGMCMCJax"
    },
    "split_keywords": [],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "b1aaefa8777af80fe753b540705863934459b56d0bc434e048c4ce3a74fd1a46",
                "md5": "6a7535e2d278564a086be5de2a8ae9d3",
                "sha256": "5729a36cab4388ae955eef7e780759845f7bafd75367cdb31694ec51d03a283a"
            },
            "downloads": -1,
            "filename": "SGMCMCJax-0.2.13-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "6a7535e2d278564a086be5de2a8ae9d3",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": null,
            "size": 27535,
            "upload_time": "2023-08-07T18:50:19",
            "upload_time_iso_8601": "2023-08-07T18:50:19.462379Z",
            "url": "https://files.pythonhosted.org/packages/b1/aa/efa8777af80fe753b540705863934459b56d0bc434e048c4ce3a74fd1a46/SGMCMCJax-0.2.13-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "bef56d4c545db7d2672354245ce03d5d9c518f2262bb00edc1e61cd4e9154a99",
                "md5": "3f67967fa44e01ee61524faed426ea12",
                "sha256": "633eba94d160014055557bd7cbc53c2f454644e6d123eb64a16e2fbc7c353920"
            },
            "downloads": -1,
            "filename": "SGMCMCJax-0.2.13.tar.gz",
            "has_sig": false,
            "md5_digest": "3f67967fa44e01ee61524faed426ea12",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": null,
            "size": 21611,
            "upload_time": "2023-08-07T18:50:21",
            "upload_time_iso_8601": "2023-08-07T18:50:21.328471Z",
            "url": "https://files.pythonhosted.org/packages/be/f5/6d4c545db7d2672354245ce03d5d9c518f2262bb00edc1e61cd4e9154a99/SGMCMCJax-0.2.13.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2023-08-07 18:50:21",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "jeremiecoullon",
    "github_project": "SGMCMCJax",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "requirements": [],
    "lcname": "sgmcmcjax"
}
        
Elapsed time: 0.10360s