torch-kde


Nametorch-kde JSON
Version 0.1.4 PyPI version JSON
download
home_pageNone
SummaryA differentiable implementation of kernel density estimation in PyTorch
upload_time2025-03-09 18:54:01
maintainerNone
docs_urlNone
authorNone
requires_pythonNone
licenseMIT
keywords density estimation kde pytorch differentiable
VCS
bugtrack_url
requirements asttokens comm contourpy cycler debugpy decorator executing filelock fonttools fsspec ipykernel ipython jedi Jinja2 jupyter_client jupyter_core kiwisolver MarkupSafe matplotlib matplotlib-inline mpmath nest-asyncio networkx numpy packaging parso pexpect pillow platformdirs prompt_toolkit psutil ptyprocess pure_eval Pygments pyparsing python-dateutil pyzmq SciencePlots scipy scikit-learn six stack-data sympy torch tornado traitlets triton typing_extensions wcwidth
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # TorchKDE

![Python Version](https://img.shields.io/badge/python-3.11.11%2B-blue.svg)
![PyTorch Version](https://img.shields.io/badge/pytorch-2.5.1-brightgreen.svg)
![Tests](https://github.com/rudolfwilliam/torch-kde/actions/workflows/ci.yml/badge.svg)
[![DOI](https://zenodo.org/badge/901331908.svg)](https://doi.org/10.5281/zenodo.14674657)

A differentiable implementation of [kernel density estimation](https://en.wikipedia.org/wiki/Kernel_density_estimation) in PyTorch by Klaus-Rudolf Kladny.

$$\hat{f}(x) = \frac{1}{|H|^{\frac{1}{2}} n} \sum_{i=1}^n K \left( H^{-\frac{1}{2}} \left( x - x_i \right) \right)$$

## Installation Instructions

The torch-kde package can be installed via `pip`. Run

```bash
pip install torch-kde
```

Now you are ready to go! If you would also like to run the code from the Jupyter notebooks or contribute to this package, please also install the packages in the `requirements.txt`:

```bash
pip install -r requirements.txt
```

## What's included?

### Kernel Density Estimation

The `KernelDensity` class supports the same operations as the [KernelDensity class in scikit-learn](https://scikit-learn.org/dev/modules/generated/sklearn.neighbors.KernelDensity.html), but implemented in PyTorch and differentiable with respect to input data. Here is a little taste:

```python
from torchkde import KernelDensity
import torch

multivariate_normal = torch.distributions.MultivariateNormal(torch.ones(2), torch.eye(2))
X = multivariate_normal.sample((1000,)) # create data
X.requires_grad = True # enable differentiation
kde = KernelDensity(bandwidth=1.0, kernel='gaussian') # create kde object with isotropic bandwidth matrix
_ = kde.fit(X) # fit kde to data

X_new = multivariate_normal.sample((100,)) # create new data 
logprob = kde.score_samples(X_new)

logprob.grad_fn # is not None
```

You may also check out `demo_kde.ipynb` for a simple demo on the [Bart Simpson distribution](https://www.stat.cmu.edu/~larry/=sml/densityestimation.pdf).

### Tophat Kernel Approximation

The Tophat kernel is not differentiable at two points and has zero derivative everywhere else. Thus, we provide a differentiable approximation via a generalized Gaussian (see e.g. [Pascal et al.](https://arxiv.org/pdf/1302.6498) for reference):

$$K^{\text{tophat}}(x; \beta) = \frac{\beta \Gamma \left( \frac{p}{2} \right) }{\pi^{\frac{p}{2}} \Gamma \left( \frac{p}{2\beta} \right) 2^{\frac{p}{2\beta}}} \text{exp} \left( - \frac{\| x \|_2^{2\beta}}{2} \right),$$

where $p$ is the dimensionality of $x$. Based on this kernel, we can approximate the Tophat kernel for large values of $\beta$.

We note that for $\beta = 1$, this approximation corresponds to a Gaussian kernel. Also, while the approximation becomes better for large values of $\beta$, its gradients with respect to the input also become larger. This is a tradeoff that must be balanced when using this kernel.

## Supported Settings

The current implementation provides the following functionality:

<div align="center">

| Feature                  | Supported Values            |
|--------------------------|-----------------------------|
| Kernels                  | Gaussian, Epanechnikov, Exponential, Tophat Approximation      |
| Tree Algorithms          | Standard                    |
| Bandwidths               | Float (Isotropic bandwidth matrix), Scott, Silverman |

</div>

## Got an Extension? Create a Pull Request!

In case you do not know how to do that, here are the necessary steps:

1. Fork the repo
2. Create your feature branch (`git checkout -b cool_tree_algorithm`)
3. Run the unit tests (`python -m tests.test_kde`) and only proceed if the script outputs "OK".
4. Commit your changes (`git commit -am 'Add cool tree algorithm'`)
5. Push to the branch (`git push origin cool_tree_algorithm`)
6. Open a Pull Request

## Issues?

If you discover a bug or do not understand something, please create an issue or let me know directly at *kkladny [at] tuebingen [dot] mpg [dot] de*! I am also happy to take requests for implementing specific functionalities.


<div align="center">

> "In God we trust. All others must bring data."
> 
> — W. Edwards Deming
> 
</div>

            

Raw data

            {
    "_id": null,
    "home_page": null,
    "name": "torch-kde",
    "maintainer": null,
    "docs_url": null,
    "requires_python": null,
    "maintainer_email": null,
    "keywords": "density estimation, kde, pytorch, differentiable",
    "author": null,
    "author_email": "Klaus-Rudolf Kladny <kkladny@tuebingen.mpg.de>",
    "download_url": "https://files.pythonhosted.org/packages/88/9c/8be3779a973b993c8e4cd43bb109bcfb77c65b1e0aa370845d093041ea27/torch_kde-0.1.4.tar.gz",
    "platform": null,
    "description": "# TorchKDE\n\n![Python Version](https://img.shields.io/badge/python-3.11.11%2B-blue.svg)\n![PyTorch Version](https://img.shields.io/badge/pytorch-2.5.1-brightgreen.svg)\n![Tests](https://github.com/rudolfwilliam/torch-kde/actions/workflows/ci.yml/badge.svg)\n[![DOI](https://zenodo.org/badge/901331908.svg)](https://doi.org/10.5281/zenodo.14674657)\n\nA differentiable implementation of [kernel density estimation](https://en.wikipedia.org/wiki/Kernel_density_estimation) in PyTorch by Klaus-Rudolf Kladny.\n\n$$\\hat{f}(x) = \\frac{1}{|H|^{\\frac{1}{2}} n} \\sum_{i=1}^n K \\left( H^{-\\frac{1}{2}} \\left( x - x_i \\right) \\right)$$\n\n## Installation Instructions\n\nThe torch-kde package can be installed via `pip`. Run\n\n```bash\npip install torch-kde\n```\n\nNow you are ready to go! If you would also like to run the code from the Jupyter notebooks or contribute to this package, please also install the packages in the `requirements.txt`:\n\n```bash\npip install -r requirements.txt\n```\n\n## What's included?\n\n### Kernel Density Estimation\n\nThe `KernelDensity` class supports the same operations as the [KernelDensity class in scikit-learn](https://scikit-learn.org/dev/modules/generated/sklearn.neighbors.KernelDensity.html), but implemented in PyTorch and differentiable with respect to input data. Here is a little taste:\n\n```python\nfrom torchkde import KernelDensity\nimport torch\n\nmultivariate_normal = torch.distributions.MultivariateNormal(torch.ones(2), torch.eye(2))\nX = multivariate_normal.sample((1000,)) # create data\nX.requires_grad = True # enable differentiation\nkde = KernelDensity(bandwidth=1.0, kernel='gaussian') # create kde object with isotropic bandwidth matrix\n_ = kde.fit(X) # fit kde to data\n\nX_new = multivariate_normal.sample((100,)) # create new data \nlogprob = kde.score_samples(X_new)\n\nlogprob.grad_fn # is not None\n```\n\nYou may also check out `demo_kde.ipynb` for a simple demo on the [Bart Simpson distribution](https://www.stat.cmu.edu/~larry/=sml/densityestimation.pdf).\n\n### Tophat Kernel Approximation\n\nThe Tophat kernel is not differentiable at two points and has zero derivative everywhere else. Thus, we provide a differentiable approximation via a generalized Gaussian (see e.g. [Pascal et al.](https://arxiv.org/pdf/1302.6498) for reference):\n\n$$K^{\\text{tophat}}(x; \\beta) = \\frac{\\beta \\Gamma \\left( \\frac{p}{2} \\right) }{\\pi^{\\frac{p}{2}} \\Gamma \\left( \\frac{p}{2\\beta} \\right) 2^{\\frac{p}{2\\beta}}} \\text{exp} \\left( - \\frac{\\| x \\|_2^{2\\beta}}{2} \\right),$$\n\nwhere $p$ is the dimensionality of $x$. Based on this kernel, we can approximate the Tophat kernel for large values of $\\beta$.\n\nWe note that for $\\beta = 1$, this approximation corresponds to a Gaussian kernel. Also, while the approximation becomes better for large values of $\\beta$, its gradients with respect to the input also become larger. This is a tradeoff that must be balanced when using this kernel.\n\n## Supported Settings\n\nThe current implementation provides the following functionality:\n\n<div align=\"center\">\n\n| Feature                  | Supported Values            |\n|--------------------------|-----------------------------|\n| Kernels                  | Gaussian, Epanechnikov, Exponential, Tophat Approximation      |\n| Tree Algorithms          | Standard                    |\n| Bandwidths               | Float (Isotropic bandwidth matrix), Scott, Silverman |\n\n</div>\n\n## Got an Extension? Create a Pull Request!\n\nIn case you do not know how to do that, here are the necessary steps:\n\n1. Fork the repo\n2. Create your feature branch (`git checkout -b cool_tree_algorithm`)\n3. Run the unit tests (`python -m tests.test_kde`) and only proceed if the script outputs \"OK\".\n4. Commit your changes (`git commit -am 'Add cool tree algorithm'`)\n5. Push to the branch (`git push origin cool_tree_algorithm`)\n6. Open a Pull Request\n\n## Issues?\n\nIf you discover a bug or do not understand something, please create an issue or let me know directly at *kkladny [at] tuebingen [dot] mpg [dot] de*! I am also happy to take requests for implementing specific functionalities.\n\n\n<div align=\"center\">\n\n> \"In God we trust. All others must bring data.\"\n> \n> \u2014 W. Edwards Deming\n> \n</div>\n",
    "bugtrack_url": null,
    "license": "MIT",
    "summary": "A differentiable implementation of kernel density estimation in PyTorch",
    "version": "0.1.4",
    "project_urls": {
        "Homepage": "https://github.com/rudolfwilliam/torch-kde"
    },
    "split_keywords": [
        "density estimation",
        " kde",
        " pytorch",
        " differentiable"
    ],
    "urls": [
        {
            "comment_text": null,
            "digests": {
                "blake2b_256": "3a74f5a1b89cad5483fba0776e76db4f951c04e8c125f248503dcd886371619c",
                "md5": "ab7d17e853d9d76901cf4a4f1d5f5ca4",
                "sha256": "ebf640375dced69a7a80b36abfad8452c583d35cd24a070ce2b53f5f22326b7c"
            },
            "downloads": -1,
            "filename": "torch_kde-0.1.4-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "ab7d17e853d9d76901cf4a4f1d5f5ca4",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": null,
            "size": 9605,
            "upload_time": "2025-03-09T18:54:00",
            "upload_time_iso_8601": "2025-03-09T18:54:00.903333Z",
            "url": "https://files.pythonhosted.org/packages/3a/74/f5a1b89cad5483fba0776e76db4f951c04e8c125f248503dcd886371619c/torch_kde-0.1.4-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": null,
            "digests": {
                "blake2b_256": "889c8be3779a973b993c8e4cd43bb109bcfb77c65b1e0aa370845d093041ea27",
                "md5": "c494a40c0c4e1492c7cfeb975facd6d4",
                "sha256": "4ec3a27c932d3ec84890e9f27dfeca64e2a5cf02f162ebc7e99554fcf51813b0"
            },
            "downloads": -1,
            "filename": "torch_kde-0.1.4.tar.gz",
            "has_sig": false,
            "md5_digest": "c494a40c0c4e1492c7cfeb975facd6d4",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": null,
            "size": 9887,
            "upload_time": "2025-03-09T18:54:01",
            "upload_time_iso_8601": "2025-03-09T18:54:01.667510Z",
            "url": "https://files.pythonhosted.org/packages/88/9c/8be3779a973b993c8e4cd43bb109bcfb77c65b1e0aa370845d093041ea27/torch_kde-0.1.4.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2025-03-09 18:54:01",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "rudolfwilliam",
    "github_project": "torch-kde",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "requirements": [
        {
            "name": "asttokens",
            "specs": [
                [
                    "==",
                    "3.0.0"
                ]
            ]
        },
        {
            "name": "comm",
            "specs": [
                [
                    "==",
                    "0.2.2"
                ]
            ]
        },
        {
            "name": "contourpy",
            "specs": [
                [
                    "==",
                    "1.3.1"
                ]
            ]
        },
        {
            "name": "cycler",
            "specs": [
                [
                    "==",
                    "0.12.1"
                ]
            ]
        },
        {
            "name": "debugpy",
            "specs": [
                [
                    "==",
                    "1.8.9"
                ]
            ]
        },
        {
            "name": "decorator",
            "specs": [
                [
                    "==",
                    "5.1.1"
                ]
            ]
        },
        {
            "name": "executing",
            "specs": [
                [
                    "==",
                    "2.1.0"
                ]
            ]
        },
        {
            "name": "filelock",
            "specs": [
                [
                    "==",
                    "3.16.1"
                ]
            ]
        },
        {
            "name": "fonttools",
            "specs": [
                [
                    "==",
                    "4.55.2"
                ]
            ]
        },
        {
            "name": "fsspec",
            "specs": [
                [
                    "==",
                    "2024.10.0"
                ]
            ]
        },
        {
            "name": "ipykernel",
            "specs": [
                [
                    "==",
                    "6.29.5"
                ]
            ]
        },
        {
            "name": "ipython",
            "specs": [
                [
                    "==",
                    "8.30.0"
                ]
            ]
        },
        {
            "name": "jedi",
            "specs": [
                [
                    "==",
                    "0.19.2"
                ]
            ]
        },
        {
            "name": "Jinja2",
            "specs": [
                [
                    "==",
                    "3.1.4"
                ]
            ]
        },
        {
            "name": "jupyter_client",
            "specs": [
                [
                    "==",
                    "8.6.3"
                ]
            ]
        },
        {
            "name": "jupyter_core",
            "specs": [
                [
                    "==",
                    "5.7.2"
                ]
            ]
        },
        {
            "name": "kiwisolver",
            "specs": [
                [
                    "==",
                    "1.4.7"
                ]
            ]
        },
        {
            "name": "MarkupSafe",
            "specs": [
                [
                    "==",
                    "3.0.2"
                ]
            ]
        },
        {
            "name": "matplotlib",
            "specs": [
                [
                    "==",
                    "3.9.3"
                ]
            ]
        },
        {
            "name": "matplotlib-inline",
            "specs": [
                [
                    "==",
                    "0.1.7"
                ]
            ]
        },
        {
            "name": "mpmath",
            "specs": [
                [
                    "==",
                    "1.3.0"
                ]
            ]
        },
        {
            "name": "nest-asyncio",
            "specs": [
                [
                    "==",
                    "1.6.0"
                ]
            ]
        },
        {
            "name": "networkx",
            "specs": [
                [
                    "==",
                    "3.4.2"
                ]
            ]
        },
        {
            "name": "numpy",
            "specs": [
                [
                    "==",
                    "2.2.0"
                ]
            ]
        },
        {
            "name": "packaging",
            "specs": [
                [
                    "==",
                    "24.2"
                ]
            ]
        },
        {
            "name": "parso",
            "specs": [
                [
                    "==",
                    "0.8.4"
                ]
            ]
        },
        {
            "name": "pexpect",
            "specs": [
                [
                    "==",
                    "4.9.0"
                ]
            ]
        },
        {
            "name": "pillow",
            "specs": [
                [
                    "==",
                    "11.0.0"
                ]
            ]
        },
        {
            "name": "platformdirs",
            "specs": [
                [
                    "==",
                    "4.3.6"
                ]
            ]
        },
        {
            "name": "prompt_toolkit",
            "specs": [
                [
                    "==",
                    "3.0.48"
                ]
            ]
        },
        {
            "name": "psutil",
            "specs": [
                [
                    "==",
                    "6.1.0"
                ]
            ]
        },
        {
            "name": "ptyprocess",
            "specs": [
                [
                    "==",
                    "0.7.0"
                ]
            ]
        },
        {
            "name": "pure_eval",
            "specs": [
                [
                    "==",
                    "0.2.3"
                ]
            ]
        },
        {
            "name": "Pygments",
            "specs": [
                [
                    "==",
                    "2.18.0"
                ]
            ]
        },
        {
            "name": "pyparsing",
            "specs": [
                [
                    "==",
                    "3.2.0"
                ]
            ]
        },
        {
            "name": "python-dateutil",
            "specs": [
                [
                    "==",
                    "2.9.0.post0"
                ]
            ]
        },
        {
            "name": "pyzmq",
            "specs": [
                [
                    "==",
                    "26.2.0"
                ]
            ]
        },
        {
            "name": "SciencePlots",
            "specs": [
                [
                    "==",
                    "2.1.1"
                ]
            ]
        },
        {
            "name": "scipy",
            "specs": [
                [
                    "==",
                    "1.14.1"
                ]
            ]
        },
        {
            "name": "scikit-learn",
            "specs": [
                [
                    "==",
                    "1.6.1"
                ]
            ]
        },
        {
            "name": "six",
            "specs": [
                [
                    "==",
                    "1.17.0"
                ]
            ]
        },
        {
            "name": "stack-data",
            "specs": [
                [
                    "==",
                    "0.6.3"
                ]
            ]
        },
        {
            "name": "sympy",
            "specs": [
                [
                    "==",
                    "1.13.1"
                ]
            ]
        },
        {
            "name": "torch",
            "specs": [
                [
                    "==",
                    "2.5.1"
                ]
            ]
        },
        {
            "name": "tornado",
            "specs": [
                [
                    "==",
                    "6.4.2"
                ]
            ]
        },
        {
            "name": "traitlets",
            "specs": [
                [
                    "==",
                    "5.14.3"
                ]
            ]
        },
        {
            "name": "triton",
            "specs": [
                [
                    "==",
                    "3.1.0"
                ]
            ]
        },
        {
            "name": "typing_extensions",
            "specs": [
                [
                    "==",
                    "4.12.2"
                ]
            ]
        },
        {
            "name": "wcwidth",
            "specs": [
                [
                    "==",
                    "0.2.13"
                ]
            ]
        }
    ],
    "lcname": "torch-kde"
}
        
Elapsed time: 1.07537s