sbijax


Namesbijax JSON
Version 0.3.4 PyPI version JSON
download
home_pageNone
SummarySimulation-based inference in JAX
upload_time2025-02-02 22:28:52
maintainerNone
docs_urlNone
authorNone
requires_python>=3.9
licenseNone
keywords abc approximate bayesian computation sbi simulation-based inference
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # sbijax <img src="https://raw.githubusercontent.com/dirmeier/sbijax/main/docs/_static/sticker.png" align="right" width="160px"/>

[![active](https://www.repostatus.org/badges/latest/active.svg)](https://www.repostatus.org/#active)
[![ci](https://github.com/dirmeier/sbijax/actions/workflows/ci.yaml/badge.svg)](https://github.com/dirmeier/sbijax/actions/workflows/ci.yaml)
[![codecov](https://codecov.io/gh/dirmeier/sbijax/branch/main/graph/badge.svg?token=dn1xNBSalZ)](https://codecov.io/gh/dirmeier/sbijax)
[![documentation](https://readthedocs.org/projects/sbijax/badge/?version=latest)](https://sbijax.readthedocs.io/en/latest/?badge=latest)
[![version](https://img.shields.io/pypi/v/sbijax.svg?colorB=black&style=flat)](https://pypi.org/project/sbijax/)

> Simulation-based inference in JAX

## About

``Sbijax`` is a Python library for neural simulation-based inference and
approximate Bayesian computation using [JAX](https://github.com/google/jax).
It implements recent methods, such as *Simulated-annealing ABC*,
*Surjective Neural Likelihood Estimation*, *Neural Approximate Sufficient Statistics*
or *Consistency model posterior estimation*, as well as methods to compute model
diagnostics and for visualizing posterior distributions.

> [!CAUTION]
> ⚠️ As per the LICENSE file, there is no warranty whatsoever for this free software tool. If you discover bugs, please report them.

## Examples

`Sbijax` implements a slim object-oriented API with functional elements stemming from
JAX. All a user needs to define is a prior model, a simulator function and an inferential algorithm.
For example, you can define a neural likelihood estimation method and generate posterior samples like this:

```python
from jax import numpy as jnp, random as jr
from sbijax import NLE
from sbijax.nn import make_maf
from tensorflow_probability.substrates.jax import distributions as tfd

def prior_fn():
    prior = tfd.JointDistributionNamed(dict(
        theta=tfd.Normal(jnp.zeros(2), jnp.ones(2))
    ), batch_ndims=0)
    return prior

def simulator_fn(seed, theta):
    p = tfd.Normal(jnp.zeros_like(theta["theta"]), 0.1)
    y = theta["theta"] + p.sample(seed=seed)
    return y


fns = prior_fn, simulator_fn
model = NLE(fns, make_maf(2))

y_observed = jnp.array([-1.0, 1.0])
data, _ = model.simulate_data(jr.PRNGKey(1))
params, _ = model.fit(jr.PRNGKey(2), data=data)
posterior, _ = model.sample_posterior(jr.PRNGKey(3), params, y_observed)
```

More self-contained examples can be found in [examples](https://github.com/dirmeier/sbijax/tree/main/examples).

## Documentation

Documentation can be found [here](https://sbijax.readthedocs.io/en/latest/).

## Installation

Make sure to have a working `JAX` installation. Depending whether you want to use CPU/GPU/TPU,
please follow [these instructions](https://github.com/google/jax#installation).

To install from PyPI, just call the following on the command line:

```bash
pip install sbijax
```

To install the latest GitHub <RELEASE>, use:

```bash
pip install git+https://github.com/dirmeier/sbijax@<RELEASE>
```

## Contributing

Contributions in the form of pull requests are more than welcome. A good way to start is to check out issues labelled
[good first issue](https://github.com/dirmeier/sbijax/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22).

In order to contribute:

1) Clone `sbijax` and install `hatch` via `pip install hatch`,
2) create a new branch locally `git checkout -b feature/my-new-feature` or `git checkout -b issue/fixes-bug`,
3) implement your contribution and ideally a test case,
4) test it by calling `make tests`, `make lints` and `make format` on the (Unix) command line,
5) submit a PR 🙂

## Citing sbijax

If you find our work relevant to your research, please consider citing:

```
@article{dirmeier2024simulation,
  title={Simulation-based inference with the Python Package sbijax},
  author={Dirmeier, Simon and Ulzega, Simone and Mira, Antonietta and Albert, Carlo},
  journal={arXiv preprint arXiv:2409.19435},
  year={2024}
}
```

## Acknowledgements

> [!NOTE]
> 📝 The API of the package is heavily inspired by the excellent Pytorch-based [`sbi`](https://github.com/sbi-dev/sbi) package.

## Author

Simon Dirmeier <a href="mailto:sfyrbnd @ pm me">sfyrbnd @ pm me</a>

            

Raw data

            {
    "_id": null,
    "home_page": null,
    "name": "sbijax",
    "maintainer": null,
    "docs_url": null,
    "requires_python": ">=3.9",
    "maintainer_email": null,
    "keywords": "abc, approximate Bayesian computation, sbi, simulation-based inference",
    "author": null,
    "author_email": "Simon Dirmeier <sfyrbnd@pm.me>",
    "download_url": "https://files.pythonhosted.org/packages/2b/53/30e7f61c1151f3dea721a9b96ceb54bb1fcefb4902cbb0629f0ebe2d42ae/sbijax-0.3.4.tar.gz",
    "platform": null,
    "description": "# sbijax <img src=\"https://raw.githubusercontent.com/dirmeier/sbijax/main/docs/_static/sticker.png\" align=\"right\" width=\"160px\"/>\n\n[![active](https://www.repostatus.org/badges/latest/active.svg)](https://www.repostatus.org/#active)\n[![ci](https://github.com/dirmeier/sbijax/actions/workflows/ci.yaml/badge.svg)](https://github.com/dirmeier/sbijax/actions/workflows/ci.yaml)\n[![codecov](https://codecov.io/gh/dirmeier/sbijax/branch/main/graph/badge.svg?token=dn1xNBSalZ)](https://codecov.io/gh/dirmeier/sbijax)\n[![documentation](https://readthedocs.org/projects/sbijax/badge/?version=latest)](https://sbijax.readthedocs.io/en/latest/?badge=latest)\n[![version](https://img.shields.io/pypi/v/sbijax.svg?colorB=black&style=flat)](https://pypi.org/project/sbijax/)\n\n> Simulation-based inference in JAX\n\n## About\n\n``Sbijax`` is a Python library for neural simulation-based inference and\napproximate Bayesian computation using [JAX](https://github.com/google/jax).\nIt implements recent methods, such as *Simulated-annealing ABC*,\n*Surjective Neural Likelihood Estimation*, *Neural Approximate Sufficient Statistics*\nor *Consistency model posterior estimation*, as well as methods to compute model\ndiagnostics and for visualizing posterior distributions.\n\n> [!CAUTION]\n> \u26a0\ufe0f As per the LICENSE file, there is no warranty whatsoever for this free software tool. If you discover bugs, please report them.\n\n## Examples\n\n`Sbijax` implements a slim object-oriented API with functional elements stemming from\nJAX. All a user needs to define is a prior model, a simulator function and an inferential algorithm.\nFor example, you can define a neural likelihood estimation method and generate posterior samples like this:\n\n```python\nfrom jax import numpy as jnp, random as jr\nfrom sbijax import NLE\nfrom sbijax.nn import make_maf\nfrom tensorflow_probability.substrates.jax import distributions as tfd\n\ndef prior_fn():\n    prior = tfd.JointDistributionNamed(dict(\n        theta=tfd.Normal(jnp.zeros(2), jnp.ones(2))\n    ), batch_ndims=0)\n    return prior\n\ndef simulator_fn(seed, theta):\n    p = tfd.Normal(jnp.zeros_like(theta[\"theta\"]), 0.1)\n    y = theta[\"theta\"] + p.sample(seed=seed)\n    return y\n\n\nfns = prior_fn, simulator_fn\nmodel = NLE(fns, make_maf(2))\n\ny_observed = jnp.array([-1.0, 1.0])\ndata, _ = model.simulate_data(jr.PRNGKey(1))\nparams, _ = model.fit(jr.PRNGKey(2), data=data)\nposterior, _ = model.sample_posterior(jr.PRNGKey(3), params, y_observed)\n```\n\nMore self-contained examples can be found in [examples](https://github.com/dirmeier/sbijax/tree/main/examples).\n\n## Documentation\n\nDocumentation can be found [here](https://sbijax.readthedocs.io/en/latest/).\n\n## Installation\n\nMake sure to have a working `JAX` installation. Depending whether you want to use CPU/GPU/TPU,\nplease follow [these instructions](https://github.com/google/jax#installation).\n\nTo install from PyPI, just call the following on the command line:\n\n```bash\npip install sbijax\n```\n\nTo install the latest GitHub <RELEASE>, use:\n\n```bash\npip install git+https://github.com/dirmeier/sbijax@<RELEASE>\n```\n\n## Contributing\n\nContributions in the form of pull requests are more than welcome. A good way to start is to check out issues labelled\n[good first issue](https://github.com/dirmeier/sbijax/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22).\n\nIn order to contribute:\n\n1) Clone `sbijax` and install `hatch` via `pip install hatch`,\n2) create a new branch locally `git checkout -b feature/my-new-feature` or `git checkout -b issue/fixes-bug`,\n3) implement your contribution and ideally a test case,\n4) test it by calling `make tests`, `make lints` and `make format` on the (Unix) command line,\n5) submit a PR \ud83d\ude42\n\n## Citing sbijax\n\nIf you find our work relevant to your research, please consider citing:\n\n```\n@article{dirmeier2024simulation,\n  title={Simulation-based inference with the Python Package sbijax},\n  author={Dirmeier, Simon and Ulzega, Simone and Mira, Antonietta and Albert, Carlo},\n  journal={arXiv preprint arXiv:2409.19435},\n  year={2024}\n}\n```\n\n## Acknowledgements\n\n> [!NOTE]\n> \ud83d\udcdd The API of the package is heavily inspired by the excellent Pytorch-based [`sbi`](https://github.com/sbi-dev/sbi) package.\n\n## Author\n\nSimon Dirmeier <a href=\"mailto:sfyrbnd @ pm me\">sfyrbnd @ pm me</a>\n",
    "bugtrack_url": null,
    "license": null,
    "summary": "Simulation-based inference in JAX",
    "version": "0.3.4",
    "project_urls": {
        "Documentation": "https://sbijax.readthedocs.io",
        "Homepage": "https://github.com/dirmeier/sbijax"
    },
    "split_keywords": [
        "abc",
        " approximate bayesian computation",
        " sbi",
        " simulation-based inference"
    ],
    "urls": [
        {
            "comment_text": null,
            "digests": {
                "blake2b_256": "8662cb4212f728859528b9913a1b177d2fc0bc4dfb7f01f6e8a44b019b857d70",
                "md5": "ebb3c8a8f941380bce2c593cf2216ae5",
                "sha256": "58751377fd9ee9f55928e0a6569375bf329adf3a244fc184ef521eeb6dfb9981"
            },
            "downloads": -1,
            "filename": "sbijax-0.3.4-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "ebb3c8a8f941380bce2c593cf2216ae5",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.9",
            "size": 81994,
            "upload_time": "2025-02-02T22:28:51",
            "upload_time_iso_8601": "2025-02-02T22:28:51.162874Z",
            "url": "https://files.pythonhosted.org/packages/86/62/cb4212f728859528b9913a1b177d2fc0bc4dfb7f01f6e8a44b019b857d70/sbijax-0.3.4-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": null,
            "digests": {
                "blake2b_256": "2b5330e7f61c1151f3dea721a9b96ceb54bb1fcefb4902cbb0629f0ebe2d42ae",
                "md5": "3af9d19735abed80d3b5e4364e83d570",
                "sha256": "fb2ac0d25ade8783fed1ccb6916ab336ecca9b159056b5ef7dcfa7c9b865765f"
            },
            "downloads": -1,
            "filename": "sbijax-0.3.4.tar.gz",
            "has_sig": false,
            "md5_digest": "3af9d19735abed80d3b5e4364e83d570",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.9",
            "size": 47551,
            "upload_time": "2025-02-02T22:28:52",
            "upload_time_iso_8601": "2025-02-02T22:28:52.744545Z",
            "url": "https://files.pythonhosted.org/packages/2b/53/30e7f61c1151f3dea721a9b96ceb54bb1fcefb4902cbb0629f0ebe2d42ae/sbijax-0.3.4.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2025-02-02 22:28:52",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "dirmeier",
    "github_project": "sbijax",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "lcname": "sbijax"
}
        
Elapsed time: 0.38164s