----
This project has now been incorporated into [GPJax](https://github.com/JaxGaussianProcesses/GPJax).
----
# [JaxUtils](https://github.com/JaxGaussianProcesses/JaxUtils)
[![CircleCI](https://dl.circleci.com/status-badge/img/gh/JaxGaussianProcesses/JaxUtils/tree/master.svg?style=svg)](https://dl.circleci.com/status-badge/redirect/gh/JaxGaussianProcesses/JaxUtils/tree/master)
`JaxUtils` provides utility functions for the [`JaxGaussianProcesses`]() ecosystem.</h2>
# Contents
- [PyTree](#pytree)
- [Dataset](#dataset)
# PyTree
## Overview
`jaxutils.PyTree` is a mixin class for [registering a python class as a JAX PyTree](https://jax.readthedocs.io/en/latest/pytrees.html#extending-pytrees). You would define your Python class as follows.
```python
class MyClass(jaxutils.PyTree):
...
```
## Example
```python
import jaxutils
from jaxtyping import Float, Array
class Line(jaxutils.PyTree):
def __init__(self, gradient: Float[Array, "1"], intercept: Float[Array, "1"]) -> None
self.gradient = gradient
self.intercept = intercept
def y(self, x: Float[Array, "N"]) -> Float[Array, "N"]
return x * self.gradient + self.intercept
```
# Dataset
## Overview
`jaxutils.Dataset` is a datset abstraction. In future, we wish to extend this to a heterotopic and isotopic data abstraction.
## Example
```python
import jaxutils
import jax.numpy as jnp
# Inputs
X = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
# Outputs
y = jnp.array([[7.0], [8.0], [9.0]])
# Datset
D = jaxutils.Dataset(X=X, y=y)
print(f'The number of datapoints is {D.n}')
print(f'The input dimension is {D.in_dim}')
print(f'The output dimension is {D.out_dim}')
print(f'The input data is {D.X}')
print(f'The output data is {D.y}')
print(f'The data is supervised {D.is_supervised()}')
print(f'The data is unsupervised {D.is_unsupervised()}')
```
```
The number of datapoints is 3
The input dimension is 2
The output dimension is 1
The input data is [[1. 2.]
[3. 4.]
[5. 6.]]
The output data is [[7.]
[8.]
[9.]]
The data is supervised True
The data is unsupervised False
```
You can also add dataset together to concatenate them.
```python
# New inputs
X_new = jnp.array([[1.5, 2.5], [3.5, 4.5], [5.5, 6.5]])
# New outputs
y_new = jnp.array([[7.0], [8.0], [9.0]])
# New dataset
D_new = jaxutils.Dataset(X=X_new, y=y_new)
# Concatenate the two datasets
D = D + D_new
print(f'The number of datapoints is {D.n}')
print(f'The input dimension is {D.in_dim}')
print(f'The output dimension is {D.out_dim}')
print(f'The input data is {D.X}')
print(f'The output data is {D.y}')
print(f'The data is supervised {D.is_supervised()}')
print(f'The data is unsupervised {D.is_unsupervised()}')
```
```
The number of datapoints is 6
The input dimension is 2
The output dimension is 1
The input data is [[1. 2. ]
[3. 4. ]
[5. 6. ]
[1.5 2.5]
[3.5 4.5]
[5.5 6.5]]
The output data is [[7.]
[8.]
[9.]
[7.]
[8.]
[9.]]
The data is supervised True
The data is unsupervised False
```
Raw data
{
"_id": null,
"home_page": null,
"name": "jaxutils-nightly",
"maintainer": null,
"docs_url": null,
"requires_python": null,
"maintainer_email": null,
"keywords": "gaussian-processes jax machine-learning bayesian",
"author": "Daniel Dodd and Thomas Pinder",
"author_email": "tompinder@live.co.uk",
"download_url": "https://files.pythonhosted.org/packages/68/a4/f64287bd1df59950afc8aef0e802226837706c0d5bc508309914fd35c13b/jaxutils-nightly-0.0.8.dev20241222.tar.gz",
"platform": null,
"description": "----\nThis project has now been incorporated into [GPJax](https://github.com/JaxGaussianProcesses/GPJax).\n----\n# [JaxUtils](https://github.com/JaxGaussianProcesses/JaxUtils)\n\n[![CircleCI](https://dl.circleci.com/status-badge/img/gh/JaxGaussianProcesses/JaxUtils/tree/master.svg?style=svg)](https://dl.circleci.com/status-badge/redirect/gh/JaxGaussianProcesses/JaxUtils/tree/master)\n\n`JaxUtils` provides utility functions for the [`JaxGaussianProcesses`]() ecosystem.</h2>\n\n# Contents\n\n- [PyTree](#pytree)\n- [Dataset](#dataset)\n\n# PyTree\n\n## Overview\n\n`jaxutils.PyTree` is a mixin class for [registering a python class as a JAX PyTree](https://jax.readthedocs.io/en/latest/pytrees.html#extending-pytrees). You would define your Python class as follows.\n\n```python\nclass MyClass(jaxutils.PyTree):\n ...\n\n```\n\n## Example\n\n```python\nimport jaxutils\n\nfrom jaxtyping import Float, Array\n\nclass Line(jaxutils.PyTree):\n def __init__(self, gradient: Float[Array, \"1\"], intercept: Float[Array, \"1\"]) -> None\n self.gradient = gradient\n self.intercept = intercept\n\n def y(self, x: Float[Array, \"N\"]) -> Float[Array, \"N\"]\n return x * self.gradient + self.intercept\n```\n\n# Dataset\n\n## Overview\n\n`jaxutils.Dataset` is a datset abstraction. In future, we wish to extend this to a heterotopic and isotopic data abstraction.\n\n## Example\n\n```python\nimport jaxutils\nimport jax.numpy as jnp\n\n# Inputs\nX = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])\n\n# Outputs\ny = jnp.array([[7.0], [8.0], [9.0]])\n\n# Datset\nD = jaxutils.Dataset(X=X, y=y)\n\nprint(f'The number of datapoints is {D.n}')\nprint(f'The input dimension is {D.in_dim}')\nprint(f'The output dimension is {D.out_dim}')\nprint(f'The input data is {D.X}')\nprint(f'The output data is {D.y}')\nprint(f'The data is supervised {D.is_supervised()}')\nprint(f'The data is unsupervised {D.is_unsupervised()}')\n```\n\n```\nThe number of datapoints is 3\nThe input dimension is 2\nThe output dimension is 1\nThe input data is [[1. 2.]\n [3. 4.]\n [5. 6.]]\nThe output data is [[7.]\n [8.]\n [9.]]\nThe data is supervised True\nThe data is unsupervised False\n```\n\nYou can also add dataset together to concatenate them.\n\n```python\n# New inputs\nX_new = jnp.array([[1.5, 2.5], [3.5, 4.5], [5.5, 6.5]])\n\n# New outputs\ny_new = jnp.array([[7.0], [8.0], [9.0]])\n\n# New dataset\nD_new = jaxutils.Dataset(X=X_new, y=y_new)\n\n# Concatenate the two datasets\nD = D + D_new\n\nprint(f'The number of datapoints is {D.n}')\nprint(f'The input dimension is {D.in_dim}')\nprint(f'The output dimension is {D.out_dim}')\nprint(f'The input data is {D.X}')\nprint(f'The output data is {D.y}')\nprint(f'The data is supervised {D.is_supervised()}')\nprint(f'The data is unsupervised {D.is_unsupervised()}')\n```\n\n```\nThe number of datapoints is 6\nThe input dimension is 2\nThe output dimension is 1\nThe input data is [[1. 2. ]\n [3. 4. ]\n [5. 6. ]\n [1.5 2.5]\n [3.5 4.5]\n [5.5 6.5]]\nThe output data is [[7.]\n [8.]\n [9.]\n [7.]\n [8.]\n [9.]]\nThe data is supervised True\nThe data is unsupervised False\n```\n\n\n",
"bugtrack_url": null,
"license": "LICENSE",
"summary": "Utility functions for JaxGaussianProcesses",
"version": "0.0.8.dev20241222",
"project_urls": {
"Documentation": "https://JaxUitls.readthedocs.io/en/latest/",
"Source": "https://github.com/JaxGaussianProcesses/JaxUitls"
},
"split_keywords": [
"gaussian-processes",
"jax",
"machine-learning",
"bayesian"
],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "2074e7bf1f5637c3f826c0e03bfb726c6092a30d45ea630e0b1bbd3a645a7fdf",
"md5": "823b8e19b4b909267d55873080704560",
"sha256": "785239bc3811fe3245aba9e022aa6a32a8733106c7c5a96d566b0edf81acf27f"
},
"downloads": -1,
"filename": "jaxutils_nightly-0.0.8.dev20241222-py3-none-any.whl",
"has_sig": false,
"md5_digest": "823b8e19b4b909267d55873080704560",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": null,
"size": 17978,
"upload_time": "2024-12-22T00:07:38",
"upload_time_iso_8601": "2024-12-22T00:07:38.151925Z",
"url": "https://files.pythonhosted.org/packages/20/74/e7bf1f5637c3f826c0e03bfb726c6092a30d45ea630e0b1bbd3a645a7fdf/jaxutils_nightly-0.0.8.dev20241222-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": "",
"digests": {
"blake2b_256": "68a4f64287bd1df59950afc8aef0e802226837706c0d5bc508309914fd35c13b",
"md5": "0d490c816c79cddc3b1eefb9e7789ff8",
"sha256": "289c43b7d2aee5051976f25cd53ee0cba9bf87f6ed65af118e49ddd1b94ec867"
},
"downloads": -1,
"filename": "jaxutils-nightly-0.0.8.dev20241222.tar.gz",
"has_sig": false,
"md5_digest": "0d490c816c79cddc3b1eefb9e7789ff8",
"packagetype": "sdist",
"python_version": "source",
"requires_python": null,
"size": 30164,
"upload_time": "2024-12-22T00:07:40",
"upload_time_iso_8601": "2024-12-22T00:07:40.755939Z",
"url": "https://files.pythonhosted.org/packages/68/a4/f64287bd1df59950afc8aef0e802226837706c0d5bc508309914fd35c13b/jaxutils-nightly-0.0.8.dev20241222.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2024-12-22 00:07:40",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "JaxGaussianProcesses",
"github_project": "JaxUitls",
"github_not_found": true,
"lcname": "jaxutils-nightly"
}