# torchrbf: Radial Basis Function Interpolation in PyTorch
This is a PyTorch module for [Radial Basis Function (RBF) Interpolation](https://en.wikipedia.org/wiki/Radial_basis_function_interpolation), which is translated from [SciPy's implemenation](https://docs.scipy.org/doc/scipy/reference/generated/scipy.interpolate.RBFInterpolator.html). This implementation benefits from GPU acceleration, making it significantly faster and more suitable for larger interpolation problems.
## Installation
```
pip install torchrbf
```
The only dependencies are PyTorch and NumPy. If you want to run the tests and benchmarks, you also need SciPy installed.
## A note on numerical precision
If you are using TF32, you may experience numerical precision issues. TF32 is enabled by default in PyTorch versions 1.7 to 1.11 (see [here](https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)). To disable it, you can use
```python
torch.backends.cuda.matmul.allow_tf32 = False
```
torchrbf will issue a warning if TF32 is enabled.
## Usage
Here is a simple example for interpolating 3D data in a 2D domain:
```python
import torch
import matplotlib.pyplot as plt
from torchrbf import RBFInterpolator
y = torch.rand(100, 2) # Data coordinates
d = torch.rand(100, 3) # Data vectors at each point
interpolator = RBFInterpolator(y, d, smoothing=1.0, kernel='thin_plate_spline')
# Query coordinates (100x100 grid of points)
x = torch.linspace(0, 1, 100)
y = torch.linspace(0, 1, 100)
grid_points = torch.meshgrid(x, y, indexing='ij')
grid_points = torch.stack(grid_points, dim=-1).reshape(-1, 2)
# Query RBF on grid points
interp_vals = interpolator(grid_points)
# Plot the interpolated values in 2D
plt.scatter(grid_points[:, 0], grid_points[:, 1], c=interp_vals[:, 0])
plt.title('Interpolated values in 2D')
plt.show()
```
<div style="width: 60%; height: 60%; display: block; margin-left:auto; margin-right:auto">
![](imgs/example2d.png)
</div>
## Performance versus SciPy
Since the module is implemented in PyTorch, it benefits from GPU acceleration. For larger interpolation problems, torchrbf is significantly faster than SciPy's implementation (+100x faster on a RTX 3090):
<div style="width: 60%; height: 60%; display: block; margin-left:auto; margin-right:auto">
![](imgs/forwards_per_second.png)
</div>
Raw data
{
"_id": null,
"home_page": "",
"name": "torchrbf",
"maintainer": "",
"docs_url": null,
"requires_python": "",
"maintainer_email": "",
"keywords": "python,pytorch,rbf,interpolation,radial basis function",
"author": "Arman Maesumi",
"author_email": "arman.maesumi@gmail.com",
"download_url": "https://files.pythonhosted.org/packages/98/05/223f2037d4fa84f9471343334586912b0f92f43ef835dd3376fc6b682ec7/torchrbf-0.0.1.tar.gz",
"platform": null,
"description": "\n# torchrbf: Radial Basis Function Interpolation in PyTorch\n\nThis is a PyTorch module for [Radial Basis Function (RBF) Interpolation](https://en.wikipedia.org/wiki/Radial_basis_function_interpolation), which is translated from [SciPy's implemenation](https://docs.scipy.org/doc/scipy/reference/generated/scipy.interpolate.RBFInterpolator.html). This implementation benefits from GPU acceleration, making it significantly faster and more suitable for larger interpolation problems.\n\n## Installation\n```\npip install torchrbf\n```\n\nThe only dependencies are PyTorch and NumPy. If you want to run the tests and benchmarks, you also need SciPy installed.\n\n## A note on numerical precision\nIf you are using TF32, you may experience numerical precision issues. TF32 is enabled by default in PyTorch versions 1.7 to 1.11 (see [here](https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)). To disable it, you can use \n```python\ntorch.backends.cuda.matmul.allow_tf32 = False\n``` \n\ntorchrbf will issue a warning if TF32 is enabled.\n\n## Usage\n\nHere is a simple example for interpolating 3D data in a 2D domain:\n\n```python\nimport torch\nimport matplotlib.pyplot as plt\nfrom torchrbf import RBFInterpolator\n\ny = torch.rand(100, 2) # Data coordinates\nd = torch.rand(100, 3) # Data vectors at each point\n\ninterpolator = RBFInterpolator(y, d, smoothing=1.0, kernel='thin_plate_spline')\n\n# Query coordinates (100x100 grid of points)\nx = torch.linspace(0, 1, 100)\ny = torch.linspace(0, 1, 100)\ngrid_points = torch.meshgrid(x, y, indexing='ij')\ngrid_points = torch.stack(grid_points, dim=-1).reshape(-1, 2)\n\n# Query RBF on grid points\ninterp_vals = interpolator(grid_points)\n\n# Plot the interpolated values in 2D\nplt.scatter(grid_points[:, 0], grid_points[:, 1], c=interp_vals[:, 0])\nplt.title('Interpolated values in 2D')\nplt.show()\n```\n<div style=\"width: 60%; height: 60%; display: block; margin-left:auto; margin-right:auto\">\n\n ![](imgs/example2d.png)\n\n</div>\n\n## Performance versus SciPy\n\nSince the module is implemented in PyTorch, it benefits from GPU acceleration. For larger interpolation problems, torchrbf is significantly faster than SciPy's implementation (+100x faster on a RTX 3090):\n\n\n<div style=\"width: 60%; height: 60%; display: block; margin-left:auto; margin-right:auto\">\n\n ![](imgs/forwards_per_second.png)\n\n</div>\n",
"bugtrack_url": null,
"license": "",
"summary": "Radial Basis Function Interpolation in PyTorch",
"version": "0.0.1",
"split_keywords": [
"python",
"pytorch",
"rbf",
"interpolation",
"radial basis function"
],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "971c1d34324b9b7d88971c5e2c20960d14f06e56889fa4258502b259f1e209b4",
"md5": "a7aa99e1fcf3691dff1cfaa184c592c3",
"sha256": "03dba76c92df073d6bc785fb8353602fbd4a7aad64cc27197f89485e970b26a8"
},
"downloads": -1,
"filename": "torchrbf-0.0.1-py3-none-any.whl",
"has_sig": false,
"md5_digest": "a7aa99e1fcf3691dff1cfaa184c592c3",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": null,
"size": 12909,
"upload_time": "2023-03-22T20:43:14",
"upload_time_iso_8601": "2023-03-22T20:43:14.161158Z",
"url": "https://files.pythonhosted.org/packages/97/1c/1d34324b9b7d88971c5e2c20960d14f06e56889fa4258502b259f1e209b4/torchrbf-0.0.1-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": "",
"digests": {
"blake2b_256": "9805223f2037d4fa84f9471343334586912b0f92f43ef835dd3376fc6b682ec7",
"md5": "768f06ce0e70e1f8dad6b04556ef26d7",
"sha256": "591a9de6f6beb0e14024328c034417739a7fee6056e648ce374b910f50c47db3"
},
"downloads": -1,
"filename": "torchrbf-0.0.1.tar.gz",
"has_sig": false,
"md5_digest": "768f06ce0e70e1f8dad6b04556ef26d7",
"packagetype": "sdist",
"python_version": "source",
"requires_python": null,
"size": 8968,
"upload_time": "2023-03-22T20:43:17",
"upload_time_iso_8601": "2023-03-22T20:43:17.016897Z",
"url": "https://files.pythonhosted.org/packages/98/05/223f2037d4fa84f9471343334586912b0f92f43ef835dd3376fc6b682ec7/torchrbf-0.0.1.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2023-03-22 20:43:17",
"github": false,
"gitlab": false,
"bitbucket": false,
"lcname": "torchrbf"
}