<h1 align="center">Kernax: regularized Stein thinning</h1>
```python
import jax
import jax.numpy as jnp
rng_key = jax.random.PRNGKey(0)
x = jax.random.normal(rng_key, (1000,2))
from jax.scipy.stats import multivariate_normal
def logprob_fn(x):
return multivariate_normal.logpdf(x, mean=jnp.zeros(2), cov=jnp.eye(2))
score_fn = jax.grad(logprob_fn)
score_values = jax.vmap(score_fn, 0)(x)
from kernax.utils import median_heuristic
lengthscale = jnp.array([median_heuristic(x)])
from kernax import SteinThinning
stein_fn = SteinThinning(x, score_values, lengthscale)
indices = stein_fn(100)
from kernax import laplace_log_p_softplus
log_p = jax.vmap(score_fn, 0)(x)
laplace_log_p_values = laplace_log_p_softplus(x, score_fn)
from kernax import RegularizedSteinThinning
reg_stein_fn = RegularizedSteinThinning(x, log_p, score_values, laplace_log_p_values, lengthscale)
indices = reg_stein_fn(100)
```
## Documentation
Documentation is available at [readthedocs](https://kernax.readthedocs.io/en/latest/?kernax=latest).
## Contributing
This code is not meant to be an evolving library. However, feel free to create issues and merge requests.
## Install guide
### PyPI
```console
pip install kernax
```
### Conda
A conda package will soon be available on the conda-forge channel.
### From source
To install from source, clone this repository, then add the package to your `PYTHONPATH` or simply do
```console
pip install -e .
```
All the requirements are listed in the file `env.yml`. It can be used to create a conda environement as follows.
```console
cd kernax-main
conda env create -n kernax -f env.yml
```
Activate the new environment:
```console
conda activate kernax
```
And test if it is working properly:
```
python -c "import kernax; print(dir(kernax))"
```
## Reproductibility
This code implements the regularized Stein thinning algorithm introduced in the paper [Kernel Stein Discrepancy thinning: a theoretical perspective of pathologies and a practical fix with regularization](https://arxiv.org/pdf/2301.13528.pdf).
Please consider citing the paper when using this library:
```bibtex
@article{benard2023kernel,
title={Kernel Stein Discrepancy thinning: a theoretical perspective of pathologies and a practical fix with regularization},
author={B{\'e}nard, Cl{\'e}ment and Staber, Brian and Da Veiga, S{\'e}bastien},
journal={arXiv preprint arXiv:2301.13528},
year={2023}
}
```
All the numerical experiments presented in the [paper](https://arxiv.org/pdf/2301.13528.pdf) can be reproduced with the scripts made available in the example folder.
In particular:
* Figures 1, 2 & 3 can be reproduced with the script example/mog_randn.py
* Each experiment in Section 4 and Appendix 1 can be reproduced with the scripts gathered in the following folders:
* Gaussian mixture: example/mog4_mcmc and example/mog4_mcmc_dim
* Mixture of banana-shaped distributions: example/mobt2_mcmc and example/mobt2_mcmc_dim
* Bayesian logistic regression: example/logistic_regression.py
* Two additional scripts are also available to reproduce figures shown in the supplementary material:
* Figure 2: example/mog_weight_weights.py
* Figure 6: example/mog4_mcmc_lambda
Raw data
{
"_id": null,
"home_page": "",
"name": "kernax",
"maintainer": "",
"docs_url": null,
"requires_python": ">=3.9",
"maintainer_email": "",
"keywords": "machine learning,statistics,mcmc,thinning,Stein",
"author": "Brian Staber",
"author_email": "Brian Staber <brian.staber@safrangroup.com>",
"download_url": "https://files.pythonhosted.org/packages/45/b1/9fef571957344c930363aa1190cb707de9670629547925844f843330b786/kernax-0.1.9.tar.gz",
"platform": "Linux",
"description": "<h1 align=\"center\">Kernax: regularized Stein thinning</h1>\r\n\r\n```python\r\nimport jax\r\nimport jax.numpy as jnp\r\nrng_key = jax.random.PRNGKey(0)\r\nx = jax.random.normal(rng_key, (1000,2))\r\n\r\nfrom jax.scipy.stats import multivariate_normal\r\ndef logprob_fn(x):\r\n return multivariate_normal.logpdf(x, mean=jnp.zeros(2), cov=jnp.eye(2))\r\nscore_fn = jax.grad(logprob_fn)\r\n\r\nscore_values = jax.vmap(score_fn, 0)(x)\r\n\r\nfrom kernax.utils import median_heuristic\r\nlengthscale = jnp.array([median_heuristic(x)])\r\n\r\nfrom kernax import SteinThinning\r\nstein_fn = SteinThinning(x, score_values, lengthscale)\r\nindices = stein_fn(100)\r\n\r\nfrom kernax import laplace_log_p_softplus\r\nlog_p = jax.vmap(score_fn, 0)(x)\r\nlaplace_log_p_values = laplace_log_p_softplus(x, score_fn)\r\n\r\nfrom kernax import RegularizedSteinThinning\r\nreg_stein_fn = RegularizedSteinThinning(x, log_p, score_values, laplace_log_p_values, lengthscale)\r\nindices = reg_stein_fn(100)\r\n```\r\n\r\n## Documentation\r\n\r\nDocumentation is available at [readthedocs](https://kernax.readthedocs.io/en/latest/?kernax=latest).\r\n\r\n## Contributing\r\n\r\nThis code is not meant to be an evolving library. However, feel free to create issues and merge requests.\r\n\r\n## Install guide\r\n\r\n### PyPI\r\n\r\n```console\r\npip install kernax\r\n```\r\n\r\n### Conda\r\n\r\nA conda package will soon be available on the conda-forge channel.\r\n\r\n### From source\r\n\r\nTo install from source, clone this repository, then add the package to your `PYTHONPATH` or simply do\r\n```console\r\npip install -e .\r\n```\r\nAll the requirements are listed in the file `env.yml`. It can be used to create a conda environement as follows.\r\n```console\r\ncd kernax-main\r\nconda env create -n kernax -f env.yml\r\n```\r\nActivate the new environment:\r\n```console\r\nconda activate kernax\r\n```\r\nAnd test if it is working properly:\r\n```\r\npython -c \"import kernax; print(dir(kernax))\"\r\n```\r\n\r\n## Reproductibility\r\n\r\nThis code implements the regularized Stein thinning algorithm introduced in the paper [Kernel Stein Discrepancy thinning: a theoretical perspective of pathologies and a practical fix with regularization](https://arxiv.org/pdf/2301.13528.pdf).\r\n\r\nPlease consider citing the paper when using this library:\r\n```bibtex\r\n@article{benard2023kernel,\r\n title={Kernel Stein Discrepancy thinning: a theoretical perspective of pathologies and a practical fix with regularization},\r\n author={B{\\'e}nard, Cl{\\'e}ment and Staber, Brian and Da Veiga, S{\\'e}bastien},\r\n journal={arXiv preprint arXiv:2301.13528},\r\n year={2023}\r\n}\r\n```\r\n\r\nAll the numerical experiments presented in the [paper](https://arxiv.org/pdf/2301.13528.pdf) can be reproduced with the scripts made available in the example folder.\r\n\r\nIn particular:\r\n\r\n* Figures 1, 2 & 3 can be reproduced with the script example/mog_randn.py\r\n\r\n* Each experiment in Section 4 and Appendix 1 can be reproduced with the scripts gathered in the following folders:\r\n * Gaussian mixture: example/mog4_mcmc and example/mog4_mcmc_dim\r\n * Mixture of banana-shaped distributions: example/mobt2_mcmc and example/mobt2_mcmc_dim\r\n * Bayesian logistic regression: example/logistic_regression.py\r\n\r\n* Two additional scripts are also available to reproduce figures shown in the supplementary material:\r\n * Figure 2: example/mog_weight_weights.py\r\n * Figure 6: example/mog4_mcmc_lambda\r\n",
"bugtrack_url": null,
"license": "MIT License",
"summary": "Regularized Stein thinning using JAX",
"version": "0.1.9",
"project_urls": {
"documentation": "https://kernax.readthedocs.io/en/latest/",
"homepage": "https://gitlab.com/drti/kernax",
"repository": "https://gitlab.com/drti/kernax"
},
"split_keywords": [
"machine learning",
"statistics",
"mcmc",
"thinning",
"stein"
],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "be287c5d8353b9ecf19ce619ff0298c8fe158a5d8a6e6fdb80812fbf6f5f1f18",
"md5": "40a88aa608a5adaa3cea26e6a65457ef",
"sha256": "4e35d9dc2c2cf9fc5a3223ab4e3cc723954456bd9f17545eb3d568c23bee4ee7"
},
"downloads": -1,
"filename": "kernax-0.1.9-py3-none-any.whl",
"has_sig": false,
"md5_digest": "40a88aa608a5adaa3cea26e6a65457ef",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": ">=3.9",
"size": 35063,
"upload_time": "2023-10-10T15:29:50",
"upload_time_iso_8601": "2023-10-10T15:29:50.566381Z",
"url": "https://files.pythonhosted.org/packages/be/28/7c5d8353b9ecf19ce619ff0298c8fe158a5d8a6e6fdb80812fbf6f5f1f18/kernax-0.1.9-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": "",
"digests": {
"blake2b_256": "45b19fef571957344c930363aa1190cb707de9670629547925844f843330b786",
"md5": "2efc845e17dd4a99c382759a54b283b5",
"sha256": "dd5fdda982ca3ba549d1b080788c79fbc6a00842e9fe302be3114b2fe68aec0f"
},
"downloads": -1,
"filename": "kernax-0.1.9.tar.gz",
"has_sig": false,
"md5_digest": "2efc845e17dd4a99c382759a54b283b5",
"packagetype": "sdist",
"python_version": "source",
"requires_python": ">=3.9",
"size": 28347,
"upload_time": "2023-10-10T15:29:52",
"upload_time_iso_8601": "2023-10-10T15:29:52.102039Z",
"url": "https://files.pythonhosted.org/packages/45/b1/9fef571957344c930363aa1190cb707de9670629547925844f843330b786/kernax-0.1.9.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2023-10-10 15:29:52",
"github": false,
"gitlab": true,
"bitbucket": false,
"codeberg": false,
"gitlab_user": "drti",
"gitlab_project": "kernax",
"lcname": "kernax"
}