surjectors


Namesurjectors JSON
Version 0.3.3 PyPI version JSON
download
home_pageNone
SummarySurjection layers for density estimation with normalizing flows
upload_time2024-08-17 21:35:50
maintainerNone
docs_urlNone
authorNone
requires_python>=3.9
licenseNone
keywords density estimation normalizing flows surjections
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # surjectors

[![active](https://www.repostatus.org/badges/latest/active.svg)](https://www.repostatus.org/#active)
[![ci](https://github.com/dirmeier/surjectors/actions/workflows/ci.yaml/badge.svg)](https://github.com/dirmeier/surjectors/actions/workflows/ci.yaml)
[![version](https://img.shields.io/pypi/v/surjectors.svg?colorB=black&style=flat)](https://pypi.org/project/surjectors/)
[![doi](https://joss.theoj.org/papers/10.21105/joss.06188/status.svg)](https://doi.org/10.21105/joss.06188)

> Surjection layers for density estimation with normalizing flows

## About

Surjectors is a light-weight library for density estimation using
inference and generative surjective normalizing flows, i.e., flows can that reduce or increase dimensionality.
Surjectors makes use of

- [Haiku](https://github.com/deepmind/dm-haiku)`s module system for neural networks,
- [Distrax](https://github.com/deepmind/distrax) for probability distributions and some base bijectors,
- [Optax](https://github.com/deepmind/optax) for gradient-based optimization,
- [JAX](https://github.com/google/jax) for autodiff and XLA computation.

## Examples

You can, for instance, construct a simple normalizing flow like this:

```python
import distrax
import haiku as hk
from jax import numpy as jnp, random as jr
from surjectors import Slice, LULinear, Chain
from surjectors import TransformedDistribution
from surjectors.nn import make_mlp

def decoder_fn(n_dim):
    def _fn(z):
        params = make_mlp([32, 32, n_dim * 2])(z)
        means, log_scales = jnp.split(params, 2, -1)
        return distrax.Independent(distrax.Normal(means, jnp.exp(log_scales)))
    return _fn

@hk.without_apply_rng
@hk.transform
def flow(x):
    base_distribution = distrax.Independent(
        distrax.Normal(jnp.zeros(5), jnp.ones(5)), 1
    )
    transform = Chain([Slice(5, decoder_fn(5)), LULinear(5)])
    pushforward = TransformedDistribution(base_distribution, transform)
    return pushforward.log_prob(x)

x = jr.normal(jr.PRNGKey(1), (1, 10))
params = flow.init(jr.PRNGKey(2), x)
lp = flow.apply(params, x)
```

More self-contained examples can be found in [examples](https://github.com/dirmeier/surjectors/tree/main/examples).

## Documentation

Documentation can be found [here](https://surjectors.readthedocs.io/en/latest/).

## Installation

Make sure to have a working `JAX` installation. Depending whether you want to use CPU/GPU/TPU,
please follow [these instructions](https://github.com/google/jax#installation).

To install the package from PyPI, call:

```bash
pip install surjectors
```

To install the latest GitHub <RELEASE>, just call the following on the command line:

```bash
pip install git+https://github.com/dirmeier/surjectors@<RELEASE>
```

## Contributing

Contributions in the form of pull requests are more than welcome. A good way to start is to check out issues labelled
[good first issue](https://github.com/dirmeier/surjectors/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22).

In order to contribute:

1) Clone `Surjectors` and install `hatch` via `pip install hatch`,
2) create a new branch locally `git checkout -b feature/my-new-feature` or `git checkout -b issue/fixes-bug`,
3) implement your contribution and ideally a test case,
4) test it by calling `hatch run test` on the (Unix) command line,
5) submit a PR 🙂


## Citing Surjectors

If you find our work relevant to your research, please consider citing:

```
@article{dirmeier2024surjectors,
    author = {Simon Dirmeier},
    title = {Surjectors: surjection layers for density estimation with normalizing flows},
    year = {2024},
    journal = {Journal of Open Source Software},
    publisher = {The Open Journal},
    volume = {9},
    number = {94},
    pages = {6188},
    doi = {10.21105/joss.06188}
}
```

## Author

Simon Dirmeier <a href="mailto:sfyrbnd @ pm me">sfyrbnd @ pm me</a>

            

Raw data

            {
    "_id": null,
    "home_page": null,
    "name": "surjectors",
    "maintainer": null,
    "docs_url": null,
    "requires_python": ">=3.9",
    "maintainer_email": null,
    "keywords": "density estimation, normalizing flows, surjections",
    "author": null,
    "author_email": "Simon Dirmeier <sfyrbnd@pm.me>",
    "download_url": "https://files.pythonhosted.org/packages/77/af/5a8196bb718fc244c95551a39be0764c2e655f6721e58c64bd8d171fade8/surjectors-0.3.3.tar.gz",
    "platform": null,
    "description": "# surjectors\n\n[![active](https://www.repostatus.org/badges/latest/active.svg)](https://www.repostatus.org/#active)\n[![ci](https://github.com/dirmeier/surjectors/actions/workflows/ci.yaml/badge.svg)](https://github.com/dirmeier/surjectors/actions/workflows/ci.yaml)\n[![version](https://img.shields.io/pypi/v/surjectors.svg?colorB=black&style=flat)](https://pypi.org/project/surjectors/)\n[![doi](https://joss.theoj.org/papers/10.21105/joss.06188/status.svg)](https://doi.org/10.21105/joss.06188)\n\n> Surjection layers for density estimation with normalizing flows\n\n## About\n\nSurjectors is a light-weight library for density estimation using\ninference and generative surjective normalizing flows, i.e., flows can that reduce or increase dimensionality.\nSurjectors makes use of\n\n- [Haiku](https://github.com/deepmind/dm-haiku)`s module system for neural networks,\n- [Distrax](https://github.com/deepmind/distrax) for probability distributions and some base bijectors,\n- [Optax](https://github.com/deepmind/optax) for gradient-based optimization,\n- [JAX](https://github.com/google/jax) for autodiff and XLA computation.\n\n## Examples\n\nYou can, for instance, construct a simple normalizing flow like this:\n\n```python\nimport distrax\nimport haiku as hk\nfrom jax import numpy as jnp, random as jr\nfrom surjectors import Slice, LULinear, Chain\nfrom surjectors import TransformedDistribution\nfrom surjectors.nn import make_mlp\n\ndef decoder_fn(n_dim):\n    def _fn(z):\n        params = make_mlp([32, 32, n_dim * 2])(z)\n        means, log_scales = jnp.split(params, 2, -1)\n        return distrax.Independent(distrax.Normal(means, jnp.exp(log_scales)))\n    return _fn\n\n@hk.without_apply_rng\n@hk.transform\ndef flow(x):\n    base_distribution = distrax.Independent(\n        distrax.Normal(jnp.zeros(5), jnp.ones(5)), 1\n    )\n    transform = Chain([Slice(5, decoder_fn(5)), LULinear(5)])\n    pushforward = TransformedDistribution(base_distribution, transform)\n    return pushforward.log_prob(x)\n\nx = jr.normal(jr.PRNGKey(1), (1, 10))\nparams = flow.init(jr.PRNGKey(2), x)\nlp = flow.apply(params, x)\n```\n\nMore self-contained examples can be found in [examples](https://github.com/dirmeier/surjectors/tree/main/examples).\n\n## Documentation\n\nDocumentation can be found [here](https://surjectors.readthedocs.io/en/latest/).\n\n## Installation\n\nMake sure to have a working `JAX` installation. Depending whether you want to use CPU/GPU/TPU,\nplease follow [these instructions](https://github.com/google/jax#installation).\n\nTo install the package from PyPI, call:\n\n```bash\npip install surjectors\n```\n\nTo install the latest GitHub <RELEASE>, just call the following on the command line:\n\n```bash\npip install git+https://github.com/dirmeier/surjectors@<RELEASE>\n```\n\n## Contributing\n\nContributions in the form of pull requests are more than welcome. A good way to start is to check out issues labelled\n[good first issue](https://github.com/dirmeier/surjectors/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22).\n\nIn order to contribute:\n\n1) Clone `Surjectors` and install `hatch` via `pip install hatch`,\n2) create a new branch locally `git checkout -b feature/my-new-feature` or `git checkout -b issue/fixes-bug`,\n3) implement your contribution and ideally a test case,\n4) test it by calling `hatch run test` on the (Unix) command line,\n5) submit a PR \ud83d\ude42\n\n\n## Citing Surjectors\n\nIf you find our work relevant to your research, please consider citing:\n\n```\n@article{dirmeier2024surjectors,\n    author = {Simon Dirmeier},\n    title = {Surjectors: surjection layers for density estimation with normalizing flows},\n    year = {2024},\n    journal = {Journal of Open Source Software},\n    publisher = {The Open Journal},\n    volume = {9},\n    number = {94},\n    pages = {6188},\n    doi = {10.21105/joss.06188}\n}\n```\n\n## Author\n\nSimon Dirmeier <a href=\"mailto:sfyrbnd @ pm me\">sfyrbnd @ pm me</a>\n",
    "bugtrack_url": null,
    "license": null,
    "summary": "Surjection layers for density estimation with normalizing flows",
    "version": "0.3.3",
    "project_urls": {
        "homepage": "https://github.com/dirmeier/surjectors"
    },
    "split_keywords": [
        "density estimation",
        " normalizing flows",
        " surjections"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "8296b50bc6e5b0fcd611d768c553437edb43965c293ad2ee75478d75ac40681e",
                "md5": "2c53fdfe573cc3cd3586833832b6562c",
                "sha256": "844d60b46a82e23b410cd84c66928f0a59916ba9cd6f93e3effbd1d6444dfd52"
            },
            "downloads": -1,
            "filename": "surjectors-0.3.3-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "2c53fdfe573cc3cd3586833832b6562c",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.9",
            "size": 49338,
            "upload_time": "2024-08-17T21:35:49",
            "upload_time_iso_8601": "2024-08-17T21:35:49.139504Z",
            "url": "https://files.pythonhosted.org/packages/82/96/b50bc6e5b0fcd611d768c553437edb43965c293ad2ee75478d75ac40681e/surjectors-0.3.3-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "77af5a8196bb718fc244c95551a39be0764c2e655f6721e58c64bd8d171fade8",
                "md5": "72ae99ec96d7678c34e3ef79024cceee",
                "sha256": "21e7251328baf5be5c1f7eb5c5aefb0aa965877235c6bc4a76570d77a9bf3f1f"
            },
            "downloads": -1,
            "filename": "surjectors-0.3.3.tar.gz",
            "has_sig": false,
            "md5_digest": "72ae99ec96d7678c34e3ef79024cceee",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.9",
            "size": 188977,
            "upload_time": "2024-08-17T21:35:50",
            "upload_time_iso_8601": "2024-08-17T21:35:50.970781Z",
            "url": "https://files.pythonhosted.org/packages/77/af/5a8196bb718fc244c95551a39be0764c2e655f6721e58c64bd8d171fade8/surjectors-0.3.3.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2024-08-17 21:35:50",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "dirmeier",
    "github_project": "surjectors",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "lcname": "surjectors"
}
        
Elapsed time: 0.65152s