torchrbf


Nametorchrbf JSON
Version 0.0.1 PyPI version JSON
download
home_page
SummaryRadial Basis Function Interpolation in PyTorch
upload_time2023-03-22 20:43:17
maintainer
docs_urlNone
authorArman Maesumi
requires_python
license
keywords python pytorch rbf interpolation radial basis function
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            
# torchrbf: Radial Basis Function Interpolation in PyTorch

This is a PyTorch module for [Radial Basis Function (RBF) Interpolation](https://en.wikipedia.org/wiki/Radial_basis_function_interpolation), which is translated from [SciPy's implemenation](https://docs.scipy.org/doc/scipy/reference/generated/scipy.interpolate.RBFInterpolator.html). This implementation benefits from GPU acceleration, making it significantly faster and more suitable for larger interpolation problems.

## Installation
```
pip install torchrbf
```

The only dependencies are PyTorch and NumPy. If you want to run the tests and benchmarks, you also need SciPy installed.

## A note on numerical precision
If you are using TF32, you may experience numerical precision issues. TF32 is enabled by default in PyTorch versions 1.7 to 1.11 (see [here](https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)). To disable it, you can use 
```python
torch.backends.cuda.matmul.allow_tf32 = False
``` 

torchrbf will issue a warning if TF32 is enabled.

## Usage

Here is a simple example for interpolating 3D data in a 2D domain:

```python
import torch
import matplotlib.pyplot as plt
from torchrbf import RBFInterpolator

y = torch.rand(100, 2) # Data coordinates
d = torch.rand(100, 3) # Data vectors at each point

interpolator = RBFInterpolator(y, d, smoothing=1.0, kernel='thin_plate_spline')

# Query coordinates (100x100 grid of points)
x = torch.linspace(0, 1, 100)
y = torch.linspace(0, 1, 100)
grid_points = torch.meshgrid(x, y, indexing='ij')
grid_points = torch.stack(grid_points, dim=-1).reshape(-1, 2)

# Query RBF on grid points
interp_vals = interpolator(grid_points)

# Plot the interpolated values in 2D
plt.scatter(grid_points[:, 0], grid_points[:, 1], c=interp_vals[:, 0])
plt.title('Interpolated values in 2D')
plt.show()
```
<div style="width: 60%; height: 60%; display: block; margin-left:auto; margin-right:auto">

  ![](imgs/example2d.png)

</div>

## Performance versus SciPy

Since the module is implemented in PyTorch, it benefits from GPU acceleration. For larger interpolation problems, torchrbf is significantly faster than SciPy's implementation (+100x faster on a RTX 3090):


<div style="width: 60%; height: 60%; display: block; margin-left:auto; margin-right:auto">

  ![](imgs/forwards_per_second.png)

</div>

            

Raw data

            {
    "_id": null,
    "home_page": "",
    "name": "torchrbf",
    "maintainer": "",
    "docs_url": null,
    "requires_python": "",
    "maintainer_email": "",
    "keywords": "python,pytorch,rbf,interpolation,radial basis function",
    "author": "Arman Maesumi",
    "author_email": "arman.maesumi@gmail.com",
    "download_url": "https://files.pythonhosted.org/packages/98/05/223f2037d4fa84f9471343334586912b0f92f43ef835dd3376fc6b682ec7/torchrbf-0.0.1.tar.gz",
    "platform": null,
    "description": "\n# torchrbf: Radial Basis Function Interpolation in PyTorch\n\nThis is a PyTorch module for [Radial Basis Function (RBF) Interpolation](https://en.wikipedia.org/wiki/Radial_basis_function_interpolation), which is translated from [SciPy's implemenation](https://docs.scipy.org/doc/scipy/reference/generated/scipy.interpolate.RBFInterpolator.html). This implementation benefits from GPU acceleration, making it significantly faster and more suitable for larger interpolation problems.\n\n## Installation\n```\npip install torchrbf\n```\n\nThe only dependencies are PyTorch and NumPy. If you want to run the tests and benchmarks, you also need SciPy installed.\n\n## A note on numerical precision\nIf you are using TF32, you may experience numerical precision issues. TF32 is enabled by default in PyTorch versions 1.7 to 1.11 (see [here](https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)). To disable it, you can use \n```python\ntorch.backends.cuda.matmul.allow_tf32 = False\n``` \n\ntorchrbf will issue a warning if TF32 is enabled.\n\n## Usage\n\nHere is a simple example for interpolating 3D data in a 2D domain:\n\n```python\nimport torch\nimport matplotlib.pyplot as plt\nfrom torchrbf import RBFInterpolator\n\ny = torch.rand(100, 2) # Data coordinates\nd = torch.rand(100, 3) # Data vectors at each point\n\ninterpolator = RBFInterpolator(y, d, smoothing=1.0, kernel='thin_plate_spline')\n\n# Query coordinates (100x100 grid of points)\nx = torch.linspace(0, 1, 100)\ny = torch.linspace(0, 1, 100)\ngrid_points = torch.meshgrid(x, y, indexing='ij')\ngrid_points = torch.stack(grid_points, dim=-1).reshape(-1, 2)\n\n# Query RBF on grid points\ninterp_vals = interpolator(grid_points)\n\n# Plot the interpolated values in 2D\nplt.scatter(grid_points[:, 0], grid_points[:, 1], c=interp_vals[:, 0])\nplt.title('Interpolated values in 2D')\nplt.show()\n```\n<div style=\"width: 60%; height: 60%; display: block; margin-left:auto; margin-right:auto\">\n\n  ![](imgs/example2d.png)\n\n</div>\n\n## Performance versus SciPy\n\nSince the module is implemented in PyTorch, it benefits from GPU acceleration. For larger interpolation problems, torchrbf is significantly faster than SciPy's implementation (+100x faster on a RTX 3090):\n\n\n<div style=\"width: 60%; height: 60%; display: block; margin-left:auto; margin-right:auto\">\n\n  ![](imgs/forwards_per_second.png)\n\n</div>\n",
    "bugtrack_url": null,
    "license": "",
    "summary": "Radial Basis Function Interpolation in PyTorch",
    "version": "0.0.1",
    "split_keywords": [
        "python",
        "pytorch",
        "rbf",
        "interpolation",
        "radial basis function"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "971c1d34324b9b7d88971c5e2c20960d14f06e56889fa4258502b259f1e209b4",
                "md5": "a7aa99e1fcf3691dff1cfaa184c592c3",
                "sha256": "03dba76c92df073d6bc785fb8353602fbd4a7aad64cc27197f89485e970b26a8"
            },
            "downloads": -1,
            "filename": "torchrbf-0.0.1-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "a7aa99e1fcf3691dff1cfaa184c592c3",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": null,
            "size": 12909,
            "upload_time": "2023-03-22T20:43:14",
            "upload_time_iso_8601": "2023-03-22T20:43:14.161158Z",
            "url": "https://files.pythonhosted.org/packages/97/1c/1d34324b9b7d88971c5e2c20960d14f06e56889fa4258502b259f1e209b4/torchrbf-0.0.1-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "9805223f2037d4fa84f9471343334586912b0f92f43ef835dd3376fc6b682ec7",
                "md5": "768f06ce0e70e1f8dad6b04556ef26d7",
                "sha256": "591a9de6f6beb0e14024328c034417739a7fee6056e648ce374b910f50c47db3"
            },
            "downloads": -1,
            "filename": "torchrbf-0.0.1.tar.gz",
            "has_sig": false,
            "md5_digest": "768f06ce0e70e1f8dad6b04556ef26d7",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": null,
            "size": 8968,
            "upload_time": "2023-03-22T20:43:17",
            "upload_time_iso_8601": "2023-03-22T20:43:17.016897Z",
            "url": "https://files.pythonhosted.org/packages/98/05/223f2037d4fa84f9471343334586912b0f92f43ef835dd3376fc6b682ec7/torchrbf-0.0.1.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2023-03-22 20:43:17",
    "github": false,
    "gitlab": false,
    "bitbucket": false,
    "lcname": "torchrbf"
}
        
Elapsed time: 0.08418s