# surjectors
[](https://www.repostatus.org/#active)
[](https://github.com/dirmeier/surjectors/actions/workflows/ci.yaml)
[](https://pypi.org/project/surjectors/)
[](https://doi.org/10.21105/joss.06188)
> Surjection layers for density estimation with normalizing flows
## About
Surjectors is a light-weight library for density estimation using
inference and generative surjective normalizing flows, i.e., flows can that reduce or increase dimensionality.
Surjectors makes use of
- [Haiku](https://github.com/deepmind/dm-haiku)`s module system for neural networks,
- [Distrax](https://github.com/deepmind/distrax) for probability distributions and some base bijectors,
- [Optax](https://github.com/deepmind/optax) for gradient-based optimization,
- [JAX](https://github.com/google/jax) for autodiff and XLA computation.
## Examples
You can, for instance, construct a simple normalizing flow like this:
```python
import distrax
import haiku as hk
from jax import numpy as jnp, random as jr
from surjectors import Slice, LULinear, Chain
from surjectors import TransformedDistribution
from surjectors.nn import make_mlp
def decoder_fn(n_dim):
def _fn(z):
params = make_mlp([32, 32, n_dim * 2])(z)
means, log_scales = jnp.split(params, 2, -1)
return distrax.Independent(distrax.Normal(means, jnp.exp(log_scales)))
return _fn
@hk.without_apply_rng
@hk.transform
def flow(x):
base_distribution = distrax.Independent(
distrax.Normal(jnp.zeros(5), jnp.ones(5)), 1
)
transform = Chain([Slice(5, decoder_fn(5)), LULinear(5)])
pushforward = TransformedDistribution(base_distribution, transform)
return pushforward.log_prob(x)
x = jr.normal(jr.PRNGKey(1), (1, 10))
params = flow.init(jr.PRNGKey(2), x)
lp = flow.apply(params, x)
```
More self-contained examples can be found in [examples](https://github.com/dirmeier/surjectors/tree/main/examples).
## Documentation
Documentation can be found [here](https://surjectors.readthedocs.io/en/latest/).
## Installation
Make sure to have a working `JAX` installation. Depending whether you want to use CPU/GPU/TPU,
please follow [these instructions](https://github.com/google/jax#installation).
To install the package from PyPI, call:
```bash
pip install surjectors
```
To install the latest GitHub <RELEASE>, just call the following on the command line:
```bash
pip install git+https://github.com/dirmeier/surjectors@<RELEASE>
```
## Contributing
Contributions in the form of pull requests are more than welcome. A good way to start is to check out issues labelled
[good first issue](https://github.com/dirmeier/surjectors/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22).
In order to contribute:
1) Clone `Surjectors` and install `hatch` via `pip install hatch`,
2) create a new branch locally `git checkout -b feature/my-new-feature` or `git checkout -b issue/fixes-bug`,
3) implement your contribution and ideally a test case,
4) test it by calling `hatch run test` on the (Unix) command line,
5) submit a PR 🙂
## Citing Surjectors
If you find our work relevant to your research, please consider citing:
```
@article{dirmeier2024surjectors,
author = {Simon Dirmeier},
title = {Surjectors: surjection layers for density estimation with normalizing flows},
year = {2024},
journal = {Journal of Open Source Software},
publisher = {The Open Journal},
volume = {9},
number = {94},
pages = {6188},
doi = {10.21105/joss.06188}
}
```
## Author
Simon Dirmeier <a href="mailto:sfyrbnd @ pm me">sfyrbnd @ pm me</a>
Raw data
{
"_id": null,
"home_page": null,
"name": "surjectors",
"maintainer": null,
"docs_url": null,
"requires_python": ">=3.9",
"maintainer_email": null,
"keywords": "density estimation, normalizing flows, surjections",
"author": null,
"author_email": "Simon Dirmeier <sfyrbnd@pm.me>",
"download_url": "https://files.pythonhosted.org/packages/77/af/5a8196bb718fc244c95551a39be0764c2e655f6721e58c64bd8d171fade8/surjectors-0.3.3.tar.gz",
"platform": null,
"description": "# surjectors\n\n[](https://www.repostatus.org/#active)\n[](https://github.com/dirmeier/surjectors/actions/workflows/ci.yaml)\n[](https://pypi.org/project/surjectors/)\n[](https://doi.org/10.21105/joss.06188)\n\n> Surjection layers for density estimation with normalizing flows\n\n## About\n\nSurjectors is a light-weight library for density estimation using\ninference and generative surjective normalizing flows, i.e., flows can that reduce or increase dimensionality.\nSurjectors makes use of\n\n- [Haiku](https://github.com/deepmind/dm-haiku)`s module system for neural networks,\n- [Distrax](https://github.com/deepmind/distrax) for probability distributions and some base bijectors,\n- [Optax](https://github.com/deepmind/optax) for gradient-based optimization,\n- [JAX](https://github.com/google/jax) for autodiff and XLA computation.\n\n## Examples\n\nYou can, for instance, construct a simple normalizing flow like this:\n\n```python\nimport distrax\nimport haiku as hk\nfrom jax import numpy as jnp, random as jr\nfrom surjectors import Slice, LULinear, Chain\nfrom surjectors import TransformedDistribution\nfrom surjectors.nn import make_mlp\n\ndef decoder_fn(n_dim):\n def _fn(z):\n params = make_mlp([32, 32, n_dim * 2])(z)\n means, log_scales = jnp.split(params, 2, -1)\n return distrax.Independent(distrax.Normal(means, jnp.exp(log_scales)))\n return _fn\n\n@hk.without_apply_rng\n@hk.transform\ndef flow(x):\n base_distribution = distrax.Independent(\n distrax.Normal(jnp.zeros(5), jnp.ones(5)), 1\n )\n transform = Chain([Slice(5, decoder_fn(5)), LULinear(5)])\n pushforward = TransformedDistribution(base_distribution, transform)\n return pushforward.log_prob(x)\n\nx = jr.normal(jr.PRNGKey(1), (1, 10))\nparams = flow.init(jr.PRNGKey(2), x)\nlp = flow.apply(params, x)\n```\n\nMore self-contained examples can be found in [examples](https://github.com/dirmeier/surjectors/tree/main/examples).\n\n## Documentation\n\nDocumentation can be found [here](https://surjectors.readthedocs.io/en/latest/).\n\n## Installation\n\nMake sure to have a working `JAX` installation. Depending whether you want to use CPU/GPU/TPU,\nplease follow [these instructions](https://github.com/google/jax#installation).\n\nTo install the package from PyPI, call:\n\n```bash\npip install surjectors\n```\n\nTo install the latest GitHub <RELEASE>, just call the following on the command line:\n\n```bash\npip install git+https://github.com/dirmeier/surjectors@<RELEASE>\n```\n\n## Contributing\n\nContributions in the form of pull requests are more than welcome. A good way to start is to check out issues labelled\n[good first issue](https://github.com/dirmeier/surjectors/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22).\n\nIn order to contribute:\n\n1) Clone `Surjectors` and install `hatch` via `pip install hatch`,\n2) create a new branch locally `git checkout -b feature/my-new-feature` or `git checkout -b issue/fixes-bug`,\n3) implement your contribution and ideally a test case,\n4) test it by calling `hatch run test` on the (Unix) command line,\n5) submit a PR \ud83d\ude42\n\n\n## Citing Surjectors\n\nIf you find our work relevant to your research, please consider citing:\n\n```\n@article{dirmeier2024surjectors,\n author = {Simon Dirmeier},\n title = {Surjectors: surjection layers for density estimation with normalizing flows},\n year = {2024},\n journal = {Journal of Open Source Software},\n publisher = {The Open Journal},\n volume = {9},\n number = {94},\n pages = {6188},\n doi = {10.21105/joss.06188}\n}\n```\n\n## Author\n\nSimon Dirmeier <a href=\"mailto:sfyrbnd @ pm me\">sfyrbnd @ pm me</a>\n",
"bugtrack_url": null,
"license": null,
"summary": "Surjection layers for density estimation with normalizing flows",
"version": "0.3.3",
"project_urls": {
"homepage": "https://github.com/dirmeier/surjectors"
},
"split_keywords": [
"density estimation",
" normalizing flows",
" surjections"
],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "8296b50bc6e5b0fcd611d768c553437edb43965c293ad2ee75478d75ac40681e",
"md5": "2c53fdfe573cc3cd3586833832b6562c",
"sha256": "844d60b46a82e23b410cd84c66928f0a59916ba9cd6f93e3effbd1d6444dfd52"
},
"downloads": -1,
"filename": "surjectors-0.3.3-py3-none-any.whl",
"has_sig": false,
"md5_digest": "2c53fdfe573cc3cd3586833832b6562c",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": ">=3.9",
"size": 49338,
"upload_time": "2024-08-17T21:35:49",
"upload_time_iso_8601": "2024-08-17T21:35:49.139504Z",
"url": "https://files.pythonhosted.org/packages/82/96/b50bc6e5b0fcd611d768c553437edb43965c293ad2ee75478d75ac40681e/surjectors-0.3.3-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": "",
"digests": {
"blake2b_256": "77af5a8196bb718fc244c95551a39be0764c2e655f6721e58c64bd8d171fade8",
"md5": "72ae99ec96d7678c34e3ef79024cceee",
"sha256": "21e7251328baf5be5c1f7eb5c5aefb0aa965877235c6bc4a76570d77a9bf3f1f"
},
"downloads": -1,
"filename": "surjectors-0.3.3.tar.gz",
"has_sig": false,
"md5_digest": "72ae99ec96d7678c34e3ef79024cceee",
"packagetype": "sdist",
"python_version": "source",
"requires_python": ">=3.9",
"size": 188977,
"upload_time": "2024-08-17T21:35:50",
"upload_time_iso_8601": "2024-08-17T21:35:50.970781Z",
"url": "https://files.pythonhosted.org/packages/77/af/5a8196bb718fc244c95551a39be0764c2e655f6721e58c64bd8d171fade8/surjectors-0.3.3.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2024-08-17 21:35:50",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "dirmeier",
"github_project": "surjectors",
"travis_ci": false,
"coveralls": false,
"github_actions": true,
"lcname": "surjectors"
}