sbijax


Namesbijax JSON
Version 0.3.0 PyPI version JSON
download
home_pageNone
SummarySimulation-based inference in JAX
upload_time2024-08-18 12:47:36
maintainerNone
docs_urlNone
authorNone
requires_python>=3.9
licenseNone
keywords abc approximate bayesian computation sbi simulation-based inference
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # sbijax <img src="https://raw.githubusercontent.com/dirmeier/sbijax/main/docs/_static/sticker.png" align="right" width="160px"/>

[![active](https://www.repostatus.org/badges/latest/active.svg)](https://www.repostatus.org/#active)
[![ci](https://github.com/dirmeier/sbijax/actions/workflows/ci.yaml/badge.svg)](https://github.com/dirmeier/sbijax/actions/workflows/ci.yaml)
[![codecov](https://codecov.io/gh/dirmeier/sbijax/branch/main/graph/badge.svg?token=dn1xNBSalZ)](https://codecov.io/gh/dirmeier/sbijax)
[![documentation](https://readthedocs.org/projects/sbijax/badge/?version=latest)](https://sbijax.readthedocs.io/en/latest/?badge=latest)
[![version](https://img.shields.io/pypi/v/sbijax.svg?colorB=black&style=flat)](https://pypi.org/project/sbijax/)

> Simulation-based inference in JAX

## About

``Sbijax`` is a Python library for neural simulation-based inference and
approximate Bayesian computation using [JAX](https://github.com/google/jax).
It implements recent methods, such as *Simulated-annealing ABC*,
*Surjective Neural Likelihood Estimation*, *Neural Approximate Sufficient Statistics*
or *Consistency model posterior estimation*, as well as methods to compute model
diagnostics and for visualizing posterior distributions.

> [!CAUTION]
> ⚠️ As per the LICENSE file, there is no warranty whatsoever for this free software tool. If you discover bugs, please report them.

## Examples

`Sbijax` implements a slim object-oriented API with functional elements stemming from
JAX. All a user needs to define is a prior model, a simulator function and an inferential algorithm.
For example, you can define a neural likelihood estimation method and generate posterior samples like this:

```python
from jax import numpy as jnp, random as jr
from sbijax import NLE
from sbijax.nn import make_maf
from tensorflow_probability.substrates.jax import distributions as tfd

def prior_fn():
    prior = tfd.JointDistributionNamed(dict(
        theta=tfd.Normal(jnp.zeros(2), jnp.ones(2))
    ), batch_ndims=0)
    return prior

def simulator_fn(seed, theta):
    p = tfd.Normal(jnp.zeros_like(theta["theta"]), 0.1)
    y = theta["theta"] + p.sample(seed=seed)
    return y


fns = prior_fn, simulator_fn
model = NLE(fns, make_maf(2))

y_observed = jnp.array([-1.0, 1.0])
data, _ = model.simulate_data(jr.PRNGKey(1))
params, _ = model.fit(jr.PRNGKey(2), data=data)
posterior, _ = model.sample_posterior(jr.PRNGKey(3), params, y_observed)
```

More self-contained examples can be found in [examples](https://github.com/dirmeier/sbijax/tree/main/examples).

## Documentation

Documentation can be found [here](https://sbijax.readthedocs.io/en/latest/).

## Installation

Make sure to have a working `JAX` installation. Depending whether you want to use CPU/GPU/TPU,
please follow [these instructions](https://github.com/google/jax#installation).

To install from PyPI, just call the following on the command line:

```bash
pip install sbijax
```

To install the latest GitHub <RELEASE>, use:

```bash
pip install git+https://github.com/dirmeier/sbijax@<RELEASE>
```

## Contributing

Contributions in the form of pull requests are more than welcome. A good way to start is to check out issues labelled
[good first issue](https://github.com/dirmeier/sbijax/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22).

In order to contribute:

1) Clone `sbijax` and install `hatch` via `pip install hatch`,
2) create a new branch locally `git checkout -b feature/my-new-feature` or `git checkout -b issue/fixes-bug`,
3) implement your contribution and ideally a test case,
4) test it by calling `make tests`, `make lints` and `make format` on the (Unix) command line,
5) submit a PR 🙂

## Acknowledgements

> [!NOTE]
> 📝 The API of the package is heavily inspired by the excellent Pytorch-based [`sbi`](https://github.com/sbi-dev/sbi) package.

## Author

Simon Dirmeier <a href="mailto:sfyrbnd @ pm me">sfyrbnd @ pm me</a>

            

Raw data

            {
    "_id": null,
    "home_page": null,
    "name": "sbijax",
    "maintainer": null,
    "docs_url": null,
    "requires_python": ">=3.9",
    "maintainer_email": null,
    "keywords": "abc, approximate Bayesian computation, sbi, simulation-based inference",
    "author": null,
    "author_email": "Simon Dirmeier <sfyrbnd@pm.me>",
    "download_url": "https://files.pythonhosted.org/packages/4c/e5/b2212322c1d82bd8f31102d1462a4d0e5155a073e202d0b9060d94193cd6/sbijax-0.3.0.tar.gz",
    "platform": null,
    "description": "# sbijax <img src=\"https://raw.githubusercontent.com/dirmeier/sbijax/main/docs/_static/sticker.png\" align=\"right\" width=\"160px\"/>\n\n[![active](https://www.repostatus.org/badges/latest/active.svg)](https://www.repostatus.org/#active)\n[![ci](https://github.com/dirmeier/sbijax/actions/workflows/ci.yaml/badge.svg)](https://github.com/dirmeier/sbijax/actions/workflows/ci.yaml)\n[![codecov](https://codecov.io/gh/dirmeier/sbijax/branch/main/graph/badge.svg?token=dn1xNBSalZ)](https://codecov.io/gh/dirmeier/sbijax)\n[![documentation](https://readthedocs.org/projects/sbijax/badge/?version=latest)](https://sbijax.readthedocs.io/en/latest/?badge=latest)\n[![version](https://img.shields.io/pypi/v/sbijax.svg?colorB=black&style=flat)](https://pypi.org/project/sbijax/)\n\n> Simulation-based inference in JAX\n\n## About\n\n``Sbijax`` is a Python library for neural simulation-based inference and\napproximate Bayesian computation using [JAX](https://github.com/google/jax).\nIt implements recent methods, such as *Simulated-annealing ABC*,\n*Surjective Neural Likelihood Estimation*, *Neural Approximate Sufficient Statistics*\nor *Consistency model posterior estimation*, as well as methods to compute model\ndiagnostics and for visualizing posterior distributions.\n\n> [!CAUTION]\n> \u26a0\ufe0f As per the LICENSE file, there is no warranty whatsoever for this free software tool. If you discover bugs, please report them.\n\n## Examples\n\n`Sbijax` implements a slim object-oriented API with functional elements stemming from\nJAX. All a user needs to define is a prior model, a simulator function and an inferential algorithm.\nFor example, you can define a neural likelihood estimation method and generate posterior samples like this:\n\n```python\nfrom jax import numpy as jnp, random as jr\nfrom sbijax import NLE\nfrom sbijax.nn import make_maf\nfrom tensorflow_probability.substrates.jax import distributions as tfd\n\ndef prior_fn():\n    prior = tfd.JointDistributionNamed(dict(\n        theta=tfd.Normal(jnp.zeros(2), jnp.ones(2))\n    ), batch_ndims=0)\n    return prior\n\ndef simulator_fn(seed, theta):\n    p = tfd.Normal(jnp.zeros_like(theta[\"theta\"]), 0.1)\n    y = theta[\"theta\"] + p.sample(seed=seed)\n    return y\n\n\nfns = prior_fn, simulator_fn\nmodel = NLE(fns, make_maf(2))\n\ny_observed = jnp.array([-1.0, 1.0])\ndata, _ = model.simulate_data(jr.PRNGKey(1))\nparams, _ = model.fit(jr.PRNGKey(2), data=data)\nposterior, _ = model.sample_posterior(jr.PRNGKey(3), params, y_observed)\n```\n\nMore self-contained examples can be found in [examples](https://github.com/dirmeier/sbijax/tree/main/examples).\n\n## Documentation\n\nDocumentation can be found [here](https://sbijax.readthedocs.io/en/latest/).\n\n## Installation\n\nMake sure to have a working `JAX` installation. Depending whether you want to use CPU/GPU/TPU,\nplease follow [these instructions](https://github.com/google/jax#installation).\n\nTo install from PyPI, just call the following on the command line:\n\n```bash\npip install sbijax\n```\n\nTo install the latest GitHub <RELEASE>, use:\n\n```bash\npip install git+https://github.com/dirmeier/sbijax@<RELEASE>\n```\n\n## Contributing\n\nContributions in the form of pull requests are more than welcome. A good way to start is to check out issues labelled\n[good first issue](https://github.com/dirmeier/sbijax/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22).\n\nIn order to contribute:\n\n1) Clone `sbijax` and install `hatch` via `pip install hatch`,\n2) create a new branch locally `git checkout -b feature/my-new-feature` or `git checkout -b issue/fixes-bug`,\n3) implement your contribution and ideally a test case,\n4) test it by calling `make tests`, `make lints` and `make format` on the (Unix) command line,\n5) submit a PR \ud83d\ude42\n\n## Acknowledgements\n\n> [!NOTE]\n> \ud83d\udcdd The API of the package is heavily inspired by the excellent Pytorch-based [`sbi`](https://github.com/sbi-dev/sbi) package.\n\n## Author\n\nSimon Dirmeier <a href=\"mailto:sfyrbnd @ pm me\">sfyrbnd @ pm me</a>\n",
    "bugtrack_url": null,
    "license": null,
    "summary": "Simulation-based inference in JAX",
    "version": "0.3.0",
    "project_urls": {
        "Documentation": "https://sbijax.readthedocs.io",
        "Homepage": "https://github.com/dirmeier/sbijax"
    },
    "split_keywords": [
        "abc",
        " approximate bayesian computation",
        " sbi",
        " simulation-based inference"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "c10cf0cb57100bb3511cc83abd88f42566270720bda7bfd2f7a981257edaa68a",
                "md5": "041ac3ec0894459abd2d13dc90d067c6",
                "sha256": "6561755757f1ab67c318f161f063ae45e704de258b75967c104378baba4a98c1"
            },
            "downloads": -1,
            "filename": "sbijax-0.3.0-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "041ac3ec0894459abd2d13dc90d067c6",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.9",
            "size": 64254,
            "upload_time": "2024-08-18T12:47:34",
            "upload_time_iso_8601": "2024-08-18T12:47:34.299930Z",
            "url": "https://files.pythonhosted.org/packages/c1/0c/f0cb57100bb3511cc83abd88f42566270720bda7bfd2f7a981257edaa68a/sbijax-0.3.0-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "4ce5b2212322c1d82bd8f31102d1462a4d0e5155a073e202d0b9060d94193cd6",
                "md5": "71b58d31661b61ca0a3eac7d5dd48444",
                "sha256": "3bc95817cf1966c5e78b813d75e61c3ebea4872f6056b9cef9f0405233d37724"
            },
            "downloads": -1,
            "filename": "sbijax-0.3.0.tar.gz",
            "has_sig": false,
            "md5_digest": "71b58d31661b61ca0a3eac7d5dd48444",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.9",
            "size": 20162446,
            "upload_time": "2024-08-18T12:47:36",
            "upload_time_iso_8601": "2024-08-18T12:47:36.887940Z",
            "url": "https://files.pythonhosted.org/packages/4c/e5/b2212322c1d82bd8f31102d1462a4d0e5155a073e202d0b9060d94193cd6/sbijax-0.3.0.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2024-08-18 12:47:36",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "dirmeier",
    "github_project": "sbijax",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "lcname": "sbijax"
}
        
Elapsed time: 0.45136s