# sbijax <img src="https://raw.githubusercontent.com/dirmeier/sbijax/main/docs/_static/sticker.png" align="right" width="160px"/>
[![active](https://www.repostatus.org/badges/latest/active.svg)](https://www.repostatus.org/#active)
[![ci](https://github.com/dirmeier/sbijax/actions/workflows/ci.yaml/badge.svg)](https://github.com/dirmeier/sbijax/actions/workflows/ci.yaml)
[![codecov](https://codecov.io/gh/dirmeier/sbijax/branch/main/graph/badge.svg?token=dn1xNBSalZ)](https://codecov.io/gh/dirmeier/sbijax)
[![documentation](https://readthedocs.org/projects/sbijax/badge/?version=latest)](https://sbijax.readthedocs.io/en/latest/?badge=latest)
[![version](https://img.shields.io/pypi/v/sbijax.svg?colorB=black&style=flat)](https://pypi.org/project/sbijax/)
> Simulation-based inference in JAX
## About
``Sbijax`` is a Python library for neural simulation-based inference and
approximate Bayesian computation using [JAX](https://github.com/google/jax).
It implements recent methods, such as *Simulated-annealing ABC*,
*Surjective Neural Likelihood Estimation*, *Neural Approximate Sufficient Statistics*
or *Consistency model posterior estimation*, as well as methods to compute model
diagnostics and for visualizing posterior distributions.
> [!CAUTION]
> ⚠️ As per the LICENSE file, there is no warranty whatsoever for this free software tool. If you discover bugs, please report them.
## Examples
`Sbijax` implements a slim object-oriented API with functional elements stemming from
JAX. All a user needs to define is a prior model, a simulator function and an inferential algorithm.
For example, you can define a neural likelihood estimation method and generate posterior samples like this:
```python
from jax import numpy as jnp, random as jr
from sbijax import NLE
from sbijax.nn import make_maf
from tensorflow_probability.substrates.jax import distributions as tfd
def prior_fn():
prior = tfd.JointDistributionNamed(dict(
theta=tfd.Normal(jnp.zeros(2), jnp.ones(2))
), batch_ndims=0)
return prior
def simulator_fn(seed, theta):
p = tfd.Normal(jnp.zeros_like(theta["theta"]), 0.1)
y = theta["theta"] + p.sample(seed=seed)
return y
fns = prior_fn, simulator_fn
model = NLE(fns, make_maf(2))
y_observed = jnp.array([-1.0, 1.0])
data, _ = model.simulate_data(jr.PRNGKey(1))
params, _ = model.fit(jr.PRNGKey(2), data=data)
posterior, _ = model.sample_posterior(jr.PRNGKey(3), params, y_observed)
```
More self-contained examples can be found in [examples](https://github.com/dirmeier/sbijax/tree/main/examples).
## Documentation
Documentation can be found [here](https://sbijax.readthedocs.io/en/latest/).
## Installation
Make sure to have a working `JAX` installation. Depending whether you want to use CPU/GPU/TPU,
please follow [these instructions](https://github.com/google/jax#installation).
To install from PyPI, just call the following on the command line:
```bash
pip install sbijax
```
To install the latest GitHub <RELEASE>, use:
```bash
pip install git+https://github.com/dirmeier/sbijax@<RELEASE>
```
## Contributing
Contributions in the form of pull requests are more than welcome. A good way to start is to check out issues labelled
[good first issue](https://github.com/dirmeier/sbijax/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22).
In order to contribute:
1) Clone `sbijax` and install `hatch` via `pip install hatch`,
2) create a new branch locally `git checkout -b feature/my-new-feature` or `git checkout -b issue/fixes-bug`,
3) implement your contribution and ideally a test case,
4) test it by calling `make tests`, `make lints` and `make format` on the (Unix) command line,
5) submit a PR 🙂
## Acknowledgements
> [!NOTE]
> 📝 The API of the package is heavily inspired by the excellent Pytorch-based [`sbi`](https://github.com/sbi-dev/sbi) package.
## Author
Simon Dirmeier <a href="mailto:sfyrbnd @ pm me">sfyrbnd @ pm me</a>
Raw data
{
"_id": null,
"home_page": null,
"name": "sbijax",
"maintainer": null,
"docs_url": null,
"requires_python": ">=3.9",
"maintainer_email": null,
"keywords": "abc, approximate Bayesian computation, sbi, simulation-based inference",
"author": null,
"author_email": "Simon Dirmeier <sfyrbnd@pm.me>",
"download_url": "https://files.pythonhosted.org/packages/77/f1/4e2fe6095a7f8026d6587a30279ac9980a0db685192c132c7fa4396d69d8/sbijax-0.3.3.post1.tar.gz",
"platform": null,
"description": "# sbijax <img src=\"https://raw.githubusercontent.com/dirmeier/sbijax/main/docs/_static/sticker.png\" align=\"right\" width=\"160px\"/>\n\n[![active](https://www.repostatus.org/badges/latest/active.svg)](https://www.repostatus.org/#active)\n[![ci](https://github.com/dirmeier/sbijax/actions/workflows/ci.yaml/badge.svg)](https://github.com/dirmeier/sbijax/actions/workflows/ci.yaml)\n[![codecov](https://codecov.io/gh/dirmeier/sbijax/branch/main/graph/badge.svg?token=dn1xNBSalZ)](https://codecov.io/gh/dirmeier/sbijax)\n[![documentation](https://readthedocs.org/projects/sbijax/badge/?version=latest)](https://sbijax.readthedocs.io/en/latest/?badge=latest)\n[![version](https://img.shields.io/pypi/v/sbijax.svg?colorB=black&style=flat)](https://pypi.org/project/sbijax/)\n\n> Simulation-based inference in JAX\n\n## About\n\n``Sbijax`` is a Python library for neural simulation-based inference and\napproximate Bayesian computation using [JAX](https://github.com/google/jax).\nIt implements recent methods, such as *Simulated-annealing ABC*,\n*Surjective Neural Likelihood Estimation*, *Neural Approximate Sufficient Statistics*\nor *Consistency model posterior estimation*, as well as methods to compute model\ndiagnostics and for visualizing posterior distributions.\n\n> [!CAUTION]\n> \u26a0\ufe0f As per the LICENSE file, there is no warranty whatsoever for this free software tool. If you discover bugs, please report them.\n\n## Examples\n\n`Sbijax` implements a slim object-oriented API with functional elements stemming from\nJAX. All a user needs to define is a prior model, a simulator function and an inferential algorithm.\nFor example, you can define a neural likelihood estimation method and generate posterior samples like this:\n\n```python\nfrom jax import numpy as jnp, random as jr\nfrom sbijax import NLE\nfrom sbijax.nn import make_maf\nfrom tensorflow_probability.substrates.jax import distributions as tfd\n\ndef prior_fn():\n prior = tfd.JointDistributionNamed(dict(\n theta=tfd.Normal(jnp.zeros(2), jnp.ones(2))\n ), batch_ndims=0)\n return prior\n\ndef simulator_fn(seed, theta):\n p = tfd.Normal(jnp.zeros_like(theta[\"theta\"]), 0.1)\n y = theta[\"theta\"] + p.sample(seed=seed)\n return y\n\n\nfns = prior_fn, simulator_fn\nmodel = NLE(fns, make_maf(2))\n\ny_observed = jnp.array([-1.0, 1.0])\ndata, _ = model.simulate_data(jr.PRNGKey(1))\nparams, _ = model.fit(jr.PRNGKey(2), data=data)\nposterior, _ = model.sample_posterior(jr.PRNGKey(3), params, y_observed)\n```\n\nMore self-contained examples can be found in [examples](https://github.com/dirmeier/sbijax/tree/main/examples).\n\n## Documentation\n\nDocumentation can be found [here](https://sbijax.readthedocs.io/en/latest/).\n\n## Installation\n\nMake sure to have a working `JAX` installation. Depending whether you want to use CPU/GPU/TPU,\nplease follow [these instructions](https://github.com/google/jax#installation).\n\nTo install from PyPI, just call the following on the command line:\n\n```bash\npip install sbijax\n```\n\nTo install the latest GitHub <RELEASE>, use:\n\n```bash\npip install git+https://github.com/dirmeier/sbijax@<RELEASE>\n```\n\n## Contributing\n\nContributions in the form of pull requests are more than welcome. A good way to start is to check out issues labelled\n[good first issue](https://github.com/dirmeier/sbijax/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22).\n\nIn order to contribute:\n\n1) Clone `sbijax` and install `hatch` via `pip install hatch`,\n2) create a new branch locally `git checkout -b feature/my-new-feature` or `git checkout -b issue/fixes-bug`,\n3) implement your contribution and ideally a test case,\n4) test it by calling `make tests`, `make lints` and `make format` on the (Unix) command line,\n5) submit a PR \ud83d\ude42\n\n## Acknowledgements\n\n> [!NOTE]\n> \ud83d\udcdd The API of the package is heavily inspired by the excellent Pytorch-based [`sbi`](https://github.com/sbi-dev/sbi) package.\n\n## Author\n\nSimon Dirmeier <a href=\"mailto:sfyrbnd @ pm me\">sfyrbnd @ pm me</a>\n",
"bugtrack_url": null,
"license": null,
"summary": "Simulation-based inference in JAX",
"version": "0.3.3.post1",
"project_urls": {
"Documentation": "https://sbijax.readthedocs.io",
"Homepage": "https://github.com/dirmeier/sbijax"
},
"split_keywords": [
"abc",
" approximate bayesian computation",
" sbi",
" simulation-based inference"
],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "3576e9a20d2ae91be78fcdce090169feed8138fe7c7cf81b56f7144c64a7925f",
"md5": "21949a902f5e9df93ba3699a3baf0402",
"sha256": "eaaf8c43b0d4859c045917ae2ca032f85de5f6ea9567fbfb69b0b7f8f42f41b7"
},
"downloads": -1,
"filename": "sbijax-0.3.3.post1-py3-none-any.whl",
"has_sig": false,
"md5_digest": "21949a902f5e9df93ba3699a3baf0402",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": ">=3.9",
"size": 64913,
"upload_time": "2025-01-15T09:34:20",
"upload_time_iso_8601": "2025-01-15T09:34:20.192336Z",
"url": "https://files.pythonhosted.org/packages/35/76/e9a20d2ae91be78fcdce090169feed8138fe7c7cf81b56f7144c64a7925f/sbijax-0.3.3.post1-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": "",
"digests": {
"blake2b_256": "77f14e2fe6095a7f8026d6587a30279ac9980a0db685192c132c7fa4396d69d8",
"md5": "d6f17ac449104845c02f757dca48932e",
"sha256": "4dd6987aa93deb7b7df49c96f3e317f2c30f5c4cb6095a4adc9699fc64fa48d2"
},
"downloads": -1,
"filename": "sbijax-0.3.3.post1.tar.gz",
"has_sig": false,
"md5_digest": "d6f17ac449104845c02f757dca48932e",
"packagetype": "sdist",
"python_version": "source",
"requires_python": ">=3.9",
"size": 39500,
"upload_time": "2025-01-15T09:34:22",
"upload_time_iso_8601": "2025-01-15T09:34:22.137248Z",
"url": "https://files.pythonhosted.org/packages/77/f1/4e2fe6095a7f8026d6587a30279ac9980a0db685192c132c7fa4396d69d8/sbijax-0.3.3.post1.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2025-01-15 09:34:22",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "dirmeier",
"github_project": "sbijax",
"travis_ci": false,
"coveralls": false,
"github_actions": true,
"lcname": "sbijax"
}