surjectors


Namesurjectors JSON
Version 0.3.0 PyPI version JSON
download
home_page
SummarySurjection layers for density estimation with normalizing flows
upload_time2024-02-01 09:45:49
maintainer
docs_urlNone
author
requires_python>=3.8
license
keywords density estimation normalizing flows surjections
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # surjectors

[![status](http://www.repostatus.org/badges/latest/concept.svg)](http://www.repostatus.org/#concept)
[![ci](https://github.com/dirmeier/surjectors/actions/workflows/ci.yaml/badge.svg)](https://github.com/dirmeier/surjectors/actions/workflows/ci.yaml)
[![version](https://img.shields.io/pypi/v/surjectors.svg?colorB=black&style=flat)](https://pypi.org/project/surjectors/)

> Surjection layers for density estimation with normalizing flows

## About

Surjectors is a light-weight library for density estimation using
inference and generative surjective normalizing flows, i.e., flows can that reduce or increase dimensionality.
Surjectors builds on Distrax and Haiku and is fully compatible with both of them.

Surjectors makes use of

- Haiku`s module system for neural networks,
- Distrax for probability distributions and some base bijectors,
- Optax for gradient-based optimization,
- JAX for autodiff and XLA computation.

## Examples

You can, for instance, construct a simple normalizing flow like this:

```python
import distrax
from jax import numpy as jnp
from surjectors import Slice, LULinear, Chain
from surjectors import TransformedDistribution
from surjectors.nn import make_mlp

def decoder_fn(n_dim):
    def _fn(z):
        params = make_mlp([32, 32, n_dim * 2])(z)
        means, log_scales = jnp.split(params, 2, -1)
        return distrax.Independent(distrax.Normal(means, jnp.exp(log_scales)))
    return _fn

base_distribution = distrax.Normal(jnp.zeros(5), jnp.ones(5))
transform = Chain([Slice(10, decoder_fn(10)), LULinear(5)])
pushforward = TransformedDistribution(base_distribution, transform)
```

More self-contained examples can be found in [examples](https://github.com/dirmeier/surjectors/tree/main/examples).

## Documentation

Documentation can be found [here](https://surjectors.readthedocs.io/en/latest/).

## Installation

Make sure to have a working `JAX` installation. Depending whether you want to use CPU/GPU/TPU,
please follow [these instructions](https://github.com/google/jax#installation).

To install the package from PyPI, call:

```bash
pip install surjectors
```

To install the latest GitHub <RELEASE>, just call the following on the command line:

```bash
pip install git+https://github.com/dirmeier/surjectors@<RELEASE>
```

## Contributing

Contributions in the form of pull requests are more than welcome. A good way to start is to check out issues labelled
[good first issue](https://github.com/dirmeier/surjectors/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22).

In order to contribute:

1) Clone `Surjectors` and install `hatch` via `pip install hatch`,
2) create a new branch locally `git checkout -b feature/my-new-feature` or `git checkout -b issue/fixes-bug`,
3) implement your contribution and ideally a test case,
4) test it by calling `hatch run test` on the (Unix) command line,
5) submit a PR 🙂

## Author

Simon Dirmeier <a href="mailto:sfyrbnd @ pm me">sfyrbnd @ pm me</a>

            

Raw data

            {
    "_id": null,
    "home_page": "",
    "name": "surjectors",
    "maintainer": "",
    "docs_url": null,
    "requires_python": ">=3.8",
    "maintainer_email": "",
    "keywords": "density estimation,normalizing flows,surjections",
    "author": "",
    "author_email": "Simon Dirmeier <sfyrbnd@pm.me>",
    "download_url": "https://files.pythonhosted.org/packages/e9/63/5ab798b462c2c20e4b698b76263c5cb2d7df9ab34b1735a95ce94cb3c5b3/surjectors-0.3.0.tar.gz",
    "platform": null,
    "description": "# surjectors\n\n[![status](http://www.repostatus.org/badges/latest/concept.svg)](http://www.repostatus.org/#concept)\n[![ci](https://github.com/dirmeier/surjectors/actions/workflows/ci.yaml/badge.svg)](https://github.com/dirmeier/surjectors/actions/workflows/ci.yaml)\n[![version](https://img.shields.io/pypi/v/surjectors.svg?colorB=black&style=flat)](https://pypi.org/project/surjectors/)\n\n> Surjection layers for density estimation with normalizing flows\n\n## About\n\nSurjectors is a light-weight library for density estimation using\ninference and generative surjective normalizing flows, i.e., flows can that reduce or increase dimensionality.\nSurjectors builds on Distrax and Haiku and is fully compatible with both of them.\n\nSurjectors makes use of\n\n- Haiku`s module system for neural networks,\n- Distrax for probability distributions and some base bijectors,\n- Optax for gradient-based optimization,\n- JAX for autodiff and XLA computation.\n\n## Examples\n\nYou can, for instance, construct a simple normalizing flow like this:\n\n```python\nimport distrax\nfrom jax import numpy as jnp\nfrom surjectors import Slice, LULinear, Chain\nfrom surjectors import TransformedDistribution\nfrom surjectors.nn import make_mlp\n\ndef decoder_fn(n_dim):\n    def _fn(z):\n        params = make_mlp([32, 32, n_dim * 2])(z)\n        means, log_scales = jnp.split(params, 2, -1)\n        return distrax.Independent(distrax.Normal(means, jnp.exp(log_scales)))\n    return _fn\n\nbase_distribution = distrax.Normal(jnp.zeros(5), jnp.ones(5))\ntransform = Chain([Slice(10, decoder_fn(10)), LULinear(5)])\npushforward = TransformedDistribution(base_distribution, transform)\n```\n\nMore self-contained examples can be found in [examples](https://github.com/dirmeier/surjectors/tree/main/examples).\n\n## Documentation\n\nDocumentation can be found [here](https://surjectors.readthedocs.io/en/latest/).\n\n## Installation\n\nMake sure to have a working `JAX` installation. Depending whether you want to use CPU/GPU/TPU,\nplease follow [these instructions](https://github.com/google/jax#installation).\n\nTo install the package from PyPI, call:\n\n```bash\npip install surjectors\n```\n\nTo install the latest GitHub <RELEASE>, just call the following on the command line:\n\n```bash\npip install git+https://github.com/dirmeier/surjectors@<RELEASE>\n```\n\n## Contributing\n\nContributions in the form of pull requests are more than welcome. A good way to start is to check out issues labelled\n[good first issue](https://github.com/dirmeier/surjectors/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22).\n\nIn order to contribute:\n\n1) Clone `Surjectors` and install `hatch` via `pip install hatch`,\n2) create a new branch locally `git checkout -b feature/my-new-feature` or `git checkout -b issue/fixes-bug`,\n3) implement your contribution and ideally a test case,\n4) test it by calling `hatch run test` on the (Unix) command line,\n5) submit a PR \ud83d\ude42\n\n## Author\n\nSimon Dirmeier <a href=\"mailto:sfyrbnd @ pm me\">sfyrbnd @ pm me</a>\n",
    "bugtrack_url": null,
    "license": "",
    "summary": "Surjection layers for density estimation with normalizing flows",
    "version": "0.3.0",
    "project_urls": {
        "homepage": "https://github.com/dirmeier/surjectors"
    },
    "split_keywords": [
        "density estimation",
        "normalizing flows",
        "surjections"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "91cc7fb0bec8d92f6b07f3ec4d92e87897386ee53fe28c142688630ff54de883",
                "md5": "a39e1c83e01c932e98ebb18ae49c760c",
                "sha256": "a6940c1974e3116fb189fe2c5292f2f009c36b16a2d61670c033e32de3c871fe"
            },
            "downloads": -1,
            "filename": "surjectors-0.3.0-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "a39e1c83e01c932e98ebb18ae49c760c",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.8",
            "size": 43445,
            "upload_time": "2024-02-01T09:45:46",
            "upload_time_iso_8601": "2024-02-01T09:45:46.281842Z",
            "url": "https://files.pythonhosted.org/packages/91/cc/7fb0bec8d92f6b07f3ec4d92e87897386ee53fe28c142688630ff54de883/surjectors-0.3.0-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "e9635ab798b462c2c20e4b698b76263c5cb2d7df9ab34b1735a95ce94cb3c5b3",
                "md5": "cb269083fea6a70dcba0603c871a3203",
                "sha256": "4a25163e5f0b09200187144d188704d7bc4658bdd7e730744fcfb9c74f9ad2cc"
            },
            "downloads": -1,
            "filename": "surjectors-0.3.0.tar.gz",
            "has_sig": false,
            "md5_digest": "cb269083fea6a70dcba0603c871a3203",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.8",
            "size": 194320,
            "upload_time": "2024-02-01T09:45:49",
            "upload_time_iso_8601": "2024-02-01T09:45:49.407462Z",
            "url": "https://files.pythonhosted.org/packages/e9/63/5ab798b462c2c20e4b698b76263c5cb2d7df9ab34b1735a95ce94cb3c5b3/surjectors-0.3.0.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2024-02-01 09:45:49",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "dirmeier",
    "github_project": "surjectors",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "lcname": "surjectors"
}
        
Elapsed time: 0.24593s