kernax


Namekernax JSON
Version 0.1.9 PyPI version JSON
download
home_page
SummaryRegularized Stein thinning using JAX
upload_time2023-10-10 15:29:52
maintainer
docs_urlNone
authorBrian Staber
requires_python>=3.9
licenseMIT License
keywords machine learning statistics mcmc thinning stein
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            <h1 align="center">Kernax: regularized Stein thinning</h1>

```python
import jax
import jax.numpy as jnp
rng_key = jax.random.PRNGKey(0)
x = jax.random.normal(rng_key, (1000,2))

from jax.scipy.stats import multivariate_normal
def logprob_fn(x):
    return multivariate_normal.logpdf(x, mean=jnp.zeros(2), cov=jnp.eye(2))
score_fn = jax.grad(logprob_fn)

score_values = jax.vmap(score_fn, 0)(x)

from kernax.utils import median_heuristic
lengthscale = jnp.array([median_heuristic(x)])

from kernax import SteinThinning
stein_fn = SteinThinning(x, score_values, lengthscale)
indices = stein_fn(100)

from kernax import laplace_log_p_softplus
log_p = jax.vmap(score_fn, 0)(x)
laplace_log_p_values = laplace_log_p_softplus(x, score_fn)

from kernax import RegularizedSteinThinning
reg_stein_fn = RegularizedSteinThinning(x, log_p, score_values, laplace_log_p_values, lengthscale)
indices = reg_stein_fn(100)
```

## Documentation

Documentation is available at [readthedocs](https://kernax.readthedocs.io/en/latest/?kernax=latest).

## Contributing

This code is not meant to be an evolving library. However, feel free to create issues and merge requests.

## Install guide

### PyPI

```console
pip install kernax
```

### Conda

A conda package will soon be available on the conda-forge channel.

### From source

To install from source, clone this repository, then add the package to your `PYTHONPATH` or simply do
```console
pip install -e .
```
All the requirements are listed in the file `env.yml`. It can be used to create a conda environement as follows.
```console
cd kernax-main
conda env create -n kernax -f env.yml
```
Activate the new environment:
```console
conda activate kernax
```
And test if it is working properly:
```
python -c "import kernax; print(dir(kernax))"
```

## Reproductibility

This code implements the regularized Stein thinning algorithm introduced in the paper [Kernel Stein Discrepancy thinning: a theoretical perspective of pathologies and a practical fix with regularization](https://arxiv.org/pdf/2301.13528.pdf).

Please consider citing the paper when using this library:
```bibtex
@article{benard2023kernel,
  title={Kernel Stein Discrepancy thinning: a theoretical perspective of pathologies and a practical fix with regularization},
  author={B{\'e}nard, Cl{\'e}ment and Staber, Brian and Da Veiga, S{\'e}bastien},
  journal={arXiv preprint arXiv:2301.13528},
  year={2023}
}
```

All the numerical experiments presented in the [paper](https://arxiv.org/pdf/2301.13528.pdf) can be reproduced with the scripts made available in the example folder.

In particular:

* Figures 1, 2 & 3 can be reproduced with the script example/mog_randn.py

* Each experiment in Section 4 and Appendix 1 can be reproduced with the scripts gathered in the following folders:
    * Gaussian mixture: example/mog4_mcmc and example/mog4_mcmc_dim
    * Mixture of banana-shaped distributions: example/mobt2_mcmc and example/mobt2_mcmc_dim
    * Bayesian logistic regression: example/logistic_regression.py

* Two additional scripts are also available to reproduce figures shown in the supplementary material:
    * Figure 2: example/mog_weight_weights.py
    * Figure 6: example/mog4_mcmc_lambda

            

Raw data

            {
    "_id": null,
    "home_page": "",
    "name": "kernax",
    "maintainer": "",
    "docs_url": null,
    "requires_python": ">=3.9",
    "maintainer_email": "",
    "keywords": "machine learning,statistics,mcmc,thinning,Stein",
    "author": "Brian Staber",
    "author_email": "Brian Staber <brian.staber@safrangroup.com>",
    "download_url": "https://files.pythonhosted.org/packages/45/b1/9fef571957344c930363aa1190cb707de9670629547925844f843330b786/kernax-0.1.9.tar.gz",
    "platform": "Linux",
    "description": "<h1 align=\"center\">Kernax: regularized Stein thinning</h1>\r\n\r\n```python\r\nimport jax\r\nimport jax.numpy as jnp\r\nrng_key = jax.random.PRNGKey(0)\r\nx = jax.random.normal(rng_key, (1000,2))\r\n\r\nfrom jax.scipy.stats import multivariate_normal\r\ndef logprob_fn(x):\r\n    return multivariate_normal.logpdf(x, mean=jnp.zeros(2), cov=jnp.eye(2))\r\nscore_fn = jax.grad(logprob_fn)\r\n\r\nscore_values = jax.vmap(score_fn, 0)(x)\r\n\r\nfrom kernax.utils import median_heuristic\r\nlengthscale = jnp.array([median_heuristic(x)])\r\n\r\nfrom kernax import SteinThinning\r\nstein_fn = SteinThinning(x, score_values, lengthscale)\r\nindices = stein_fn(100)\r\n\r\nfrom kernax import laplace_log_p_softplus\r\nlog_p = jax.vmap(score_fn, 0)(x)\r\nlaplace_log_p_values = laplace_log_p_softplus(x, score_fn)\r\n\r\nfrom kernax import RegularizedSteinThinning\r\nreg_stein_fn = RegularizedSteinThinning(x, log_p, score_values, laplace_log_p_values, lengthscale)\r\nindices = reg_stein_fn(100)\r\n```\r\n\r\n## Documentation\r\n\r\nDocumentation is available at [readthedocs](https://kernax.readthedocs.io/en/latest/?kernax=latest).\r\n\r\n## Contributing\r\n\r\nThis code is not meant to be an evolving library. However, feel free to create issues and merge requests.\r\n\r\n## Install guide\r\n\r\n### PyPI\r\n\r\n```console\r\npip install kernax\r\n```\r\n\r\n### Conda\r\n\r\nA conda package will soon be available on the conda-forge channel.\r\n\r\n### From source\r\n\r\nTo install from source, clone this repository, then add the package to your `PYTHONPATH` or simply do\r\n```console\r\npip install -e .\r\n```\r\nAll the requirements are listed in the file `env.yml`. It can be used to create a conda environement as follows.\r\n```console\r\ncd kernax-main\r\nconda env create -n kernax -f env.yml\r\n```\r\nActivate the new environment:\r\n```console\r\nconda activate kernax\r\n```\r\nAnd test if it is working properly:\r\n```\r\npython -c \"import kernax; print(dir(kernax))\"\r\n```\r\n\r\n## Reproductibility\r\n\r\nThis code implements the regularized Stein thinning algorithm introduced in the paper [Kernel Stein Discrepancy thinning: a theoretical perspective of pathologies and a practical fix with regularization](https://arxiv.org/pdf/2301.13528.pdf).\r\n\r\nPlease consider citing the paper when using this library:\r\n```bibtex\r\n@article{benard2023kernel,\r\n  title={Kernel Stein Discrepancy thinning: a theoretical perspective of pathologies and a practical fix with regularization},\r\n  author={B{\\'e}nard, Cl{\\'e}ment and Staber, Brian and Da Veiga, S{\\'e}bastien},\r\n  journal={arXiv preprint arXiv:2301.13528},\r\n  year={2023}\r\n}\r\n```\r\n\r\nAll the numerical experiments presented in the [paper](https://arxiv.org/pdf/2301.13528.pdf) can be reproduced with the scripts made available in the example folder.\r\n\r\nIn particular:\r\n\r\n* Figures 1, 2 & 3 can be reproduced with the script example/mog_randn.py\r\n\r\n* Each experiment in Section 4 and Appendix 1 can be reproduced with the scripts gathered in the following folders:\r\n    * Gaussian mixture: example/mog4_mcmc and example/mog4_mcmc_dim\r\n    * Mixture of banana-shaped distributions: example/mobt2_mcmc and example/mobt2_mcmc_dim\r\n    * Bayesian logistic regression: example/logistic_regression.py\r\n\r\n* Two additional scripts are also available to reproduce figures shown in the supplementary material:\r\n    * Figure 2: example/mog_weight_weights.py\r\n    * Figure 6: example/mog4_mcmc_lambda\r\n",
    "bugtrack_url": null,
    "license": "MIT License",
    "summary": "Regularized Stein thinning using JAX",
    "version": "0.1.9",
    "project_urls": {
        "documentation": "https://kernax.readthedocs.io/en/latest/",
        "homepage": "https://gitlab.com/drti/kernax",
        "repository": "https://gitlab.com/drti/kernax"
    },
    "split_keywords": [
        "machine learning",
        "statistics",
        "mcmc",
        "thinning",
        "stein"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "be287c5d8353b9ecf19ce619ff0298c8fe158a5d8a6e6fdb80812fbf6f5f1f18",
                "md5": "40a88aa608a5adaa3cea26e6a65457ef",
                "sha256": "4e35d9dc2c2cf9fc5a3223ab4e3cc723954456bd9f17545eb3d568c23bee4ee7"
            },
            "downloads": -1,
            "filename": "kernax-0.1.9-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "40a88aa608a5adaa3cea26e6a65457ef",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.9",
            "size": 35063,
            "upload_time": "2023-10-10T15:29:50",
            "upload_time_iso_8601": "2023-10-10T15:29:50.566381Z",
            "url": "https://files.pythonhosted.org/packages/be/28/7c5d8353b9ecf19ce619ff0298c8fe158a5d8a6e6fdb80812fbf6f5f1f18/kernax-0.1.9-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "45b19fef571957344c930363aa1190cb707de9670629547925844f843330b786",
                "md5": "2efc845e17dd4a99c382759a54b283b5",
                "sha256": "dd5fdda982ca3ba549d1b080788c79fbc6a00842e9fe302be3114b2fe68aec0f"
            },
            "downloads": -1,
            "filename": "kernax-0.1.9.tar.gz",
            "has_sig": false,
            "md5_digest": "2efc845e17dd4a99c382759a54b283b5",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.9",
            "size": 28347,
            "upload_time": "2023-10-10T15:29:52",
            "upload_time_iso_8601": "2023-10-10T15:29:52.102039Z",
            "url": "https://files.pythonhosted.org/packages/45/b1/9fef571957344c930363aa1190cb707de9670629547925844f843330b786/kernax-0.1.9.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2023-10-10 15:29:52",
    "github": false,
    "gitlab": true,
    "bitbucket": false,
    "codeberg": false,
    "gitlab_user": "drti",
    "gitlab_project": "kernax",
    "lcname": "kernax"
}
        
Elapsed time: 0.12096s