jaxutils-nightly


Namejaxutils-nightly JSON
Version 0.0.8.dev20240424 PyPI version JSON
download
home_pageNone
SummaryUtility functions for JaxGaussianProcesses
upload_time2024-04-24 00:06:44
maintainerNone
docs_urlNone
authorDaniel Dodd and Thomas Pinder
requires_pythonNone
licenseLICENSE
keywords gaussian-processes jax machine-learning bayesian
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            ----
This project has now been incorporated into [GPJax](https://github.com/JaxGaussianProcesses/GPJax).
----
# [JaxUtils](https://github.com/JaxGaussianProcesses/JaxUtils)

[![CircleCI](https://dl.circleci.com/status-badge/img/gh/JaxGaussianProcesses/JaxUtils/tree/master.svg?style=svg)](https://dl.circleci.com/status-badge/redirect/gh/JaxGaussianProcesses/JaxUtils/tree/master)

`JaxUtils` provides utility functions for the [`JaxGaussianProcesses`]() ecosystem.</h2>

# Contents

- [PyTree](#pytree)
- [Dataset](#dataset)

# PyTree

## Overview

`jaxutils.PyTree` is a mixin class for [registering a python class as a JAX PyTree](https://jax.readthedocs.io/en/latest/pytrees.html#extending-pytrees). You would define your Python class as follows.

```python
class MyClass(jaxutils.PyTree):
    ...

```

## Example

```python
import jaxutils

from jaxtyping import Float, Array

class Line(jaxutils.PyTree):
    def __init__(self, gradient: Float[Array, "1"], intercept: Float[Array, "1"]) -> None
        self.gradient = gradient
        self.intercept = intercept

    def y(self, x: Float[Array, "N"]) -> Float[Array, "N"]
        return x * self.gradient + self.intercept
```

# Dataset

## Overview

`jaxutils.Dataset` is a datset abstraction. In future, we wish to extend this to a heterotopic and isotopic data abstraction.

## Example

```python
import jaxutils
import jax.numpy as jnp

# Inputs
X = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])

# Outputs
y = jnp.array([[7.0], [8.0], [9.0]])

# Datset
D = jaxutils.Dataset(X=X, y=y)

print(f'The number of datapoints is {D.n}')
print(f'The input dimension is {D.in_dim}')
print(f'The output dimension is {D.out_dim}')
print(f'The input data is {D.X}')
print(f'The output data is {D.y}')
print(f'The data is supervised {D.is_supervised()}')
print(f'The data is unsupervised {D.is_unsupervised()}')
```

```
The number of datapoints is 3
The input dimension is 2
The output dimension is 1
The input data is [[1. 2.]
 [3. 4.]
 [5. 6.]]
The output data is [[7.]
 [8.]
 [9.]]
The data is supervised True
The data is unsupervised False
```

You can also add dataset together to concatenate them.

```python
# New inputs
X_new = jnp.array([[1.5, 2.5], [3.5, 4.5], [5.5, 6.5]])

# New outputs
y_new = jnp.array([[7.0], [8.0], [9.0]])

# New dataset
D_new = jaxutils.Dataset(X=X_new, y=y_new)

# Concatenate the two datasets
D = D + D_new

print(f'The number of datapoints is {D.n}')
print(f'The input dimension is {D.in_dim}')
print(f'The output dimension is {D.out_dim}')
print(f'The input data is {D.X}')
print(f'The output data is {D.y}')
print(f'The data is supervised {D.is_supervised()}')
print(f'The data is unsupervised {D.is_unsupervised()}')
```

```
The number of datapoints is 6
The input dimension is 2
The output dimension is 1
The input data is [[1.  2. ]
 [3.  4. ]
 [5.  6. ]
 [1.5 2.5]
 [3.5 4.5]
 [5.5 6.5]]
The output data is [[7.]
 [8.]
 [9.]
 [7.]
 [8.]
 [9.]]
The data is supervised True
The data is unsupervised False
```



            

Raw data

            {
    "_id": null,
    "home_page": null,
    "name": "jaxutils-nightly",
    "maintainer": null,
    "docs_url": null,
    "requires_python": null,
    "maintainer_email": null,
    "keywords": "gaussian-processes jax machine-learning bayesian",
    "author": "Daniel Dodd and Thomas Pinder",
    "author_email": "tompinder@live.co.uk",
    "download_url": "https://files.pythonhosted.org/packages/7f/9c/f44c3f08e89a0ca7b2795ec199adff1cda449929364f7d25362a48b2cc28/jaxutils-nightly-0.0.8.dev20240424.tar.gz",
    "platform": null,
    "description": "----\nThis project has now been incorporated into [GPJax](https://github.com/JaxGaussianProcesses/GPJax).\n----\n# [JaxUtils](https://github.com/JaxGaussianProcesses/JaxUtils)\n\n[![CircleCI](https://dl.circleci.com/status-badge/img/gh/JaxGaussianProcesses/JaxUtils/tree/master.svg?style=svg)](https://dl.circleci.com/status-badge/redirect/gh/JaxGaussianProcesses/JaxUtils/tree/master)\n\n`JaxUtils` provides utility functions for the [`JaxGaussianProcesses`]() ecosystem.</h2>\n\n# Contents\n\n- [PyTree](#pytree)\n- [Dataset](#dataset)\n\n# PyTree\n\n## Overview\n\n`jaxutils.PyTree` is a mixin class for [registering a python class as a JAX PyTree](https://jax.readthedocs.io/en/latest/pytrees.html#extending-pytrees). You would define your Python class as follows.\n\n```python\nclass MyClass(jaxutils.PyTree):\n    ...\n\n```\n\n## Example\n\n```python\nimport jaxutils\n\nfrom jaxtyping import Float, Array\n\nclass Line(jaxutils.PyTree):\n    def __init__(self, gradient: Float[Array, \"1\"], intercept: Float[Array, \"1\"]) -> None\n        self.gradient = gradient\n        self.intercept = intercept\n\n    def y(self, x: Float[Array, \"N\"]) -> Float[Array, \"N\"]\n        return x * self.gradient + self.intercept\n```\n\n# Dataset\n\n## Overview\n\n`jaxutils.Dataset` is a datset abstraction. In future, we wish to extend this to a heterotopic and isotopic data abstraction.\n\n## Example\n\n```python\nimport jaxutils\nimport jax.numpy as jnp\n\n# Inputs\nX = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])\n\n# Outputs\ny = jnp.array([[7.0], [8.0], [9.0]])\n\n# Datset\nD = jaxutils.Dataset(X=X, y=y)\n\nprint(f'The number of datapoints is {D.n}')\nprint(f'The input dimension is {D.in_dim}')\nprint(f'The output dimension is {D.out_dim}')\nprint(f'The input data is {D.X}')\nprint(f'The output data is {D.y}')\nprint(f'The data is supervised {D.is_supervised()}')\nprint(f'The data is unsupervised {D.is_unsupervised()}')\n```\n\n```\nThe number of datapoints is 3\nThe input dimension is 2\nThe output dimension is 1\nThe input data is [[1. 2.]\n [3. 4.]\n [5. 6.]]\nThe output data is [[7.]\n [8.]\n [9.]]\nThe data is supervised True\nThe data is unsupervised False\n```\n\nYou can also add dataset together to concatenate them.\n\n```python\n# New inputs\nX_new = jnp.array([[1.5, 2.5], [3.5, 4.5], [5.5, 6.5]])\n\n# New outputs\ny_new = jnp.array([[7.0], [8.0], [9.0]])\n\n# New dataset\nD_new = jaxutils.Dataset(X=X_new, y=y_new)\n\n# Concatenate the two datasets\nD = D + D_new\n\nprint(f'The number of datapoints is {D.n}')\nprint(f'The input dimension is {D.in_dim}')\nprint(f'The output dimension is {D.out_dim}')\nprint(f'The input data is {D.X}')\nprint(f'The output data is {D.y}')\nprint(f'The data is supervised {D.is_supervised()}')\nprint(f'The data is unsupervised {D.is_unsupervised()}')\n```\n\n```\nThe number of datapoints is 6\nThe input dimension is 2\nThe output dimension is 1\nThe input data is [[1.  2. ]\n [3.  4. ]\n [5.  6. ]\n [1.5 2.5]\n [3.5 4.5]\n [5.5 6.5]]\nThe output data is [[7.]\n [8.]\n [9.]\n [7.]\n [8.]\n [9.]]\nThe data is supervised True\nThe data is unsupervised False\n```\n\n\n",
    "bugtrack_url": null,
    "license": "LICENSE",
    "summary": "Utility functions for JaxGaussianProcesses",
    "version": "0.0.8.dev20240424",
    "project_urls": {
        "Documentation": "https://JaxUitls.readthedocs.io/en/latest/",
        "Source": "https://github.com/JaxGaussianProcesses/JaxUitls"
    },
    "split_keywords": [
        "gaussian-processes",
        "jax",
        "machine-learning",
        "bayesian"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "c40ab3956ae5ea1844e0044501e2e801fc9b86e333359c601217699d7ec904eb",
                "md5": "2a38c71e5e9781c2e83ca4658980740c",
                "sha256": "694675ab9ad3e199c1d6e75c9a13ee8c6c7d10ada3362445f6a31ffbac3944d1"
            },
            "downloads": -1,
            "filename": "jaxutils_nightly-0.0.8.dev20240424-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "2a38c71e5e9781c2e83ca4658980740c",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": null,
            "size": 17978,
            "upload_time": "2024-04-24T00:06:40",
            "upload_time_iso_8601": "2024-04-24T00:06:40.189878Z",
            "url": "https://files.pythonhosted.org/packages/c4/0a/b3956ae5ea1844e0044501e2e801fc9b86e333359c601217699d7ec904eb/jaxutils_nightly-0.0.8.dev20240424-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "7f9cf44c3f08e89a0ca7b2795ec199adff1cda449929364f7d25362a48b2cc28",
                "md5": "1a9740c7a39a6557bf5718199b54f382",
                "sha256": "5cb9862125e39e8aab6554e63037029f0a8461d60b543be9cf5744810037d975"
            },
            "downloads": -1,
            "filename": "jaxutils-nightly-0.0.8.dev20240424.tar.gz",
            "has_sig": false,
            "md5_digest": "1a9740c7a39a6557bf5718199b54f382",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": null,
            "size": 30167,
            "upload_time": "2024-04-24T00:06:44",
            "upload_time_iso_8601": "2024-04-24T00:06:44.176639Z",
            "url": "https://files.pythonhosted.org/packages/7f/9c/f44c3f08e89a0ca7b2795ec199adff1cda449929364f7d25362a48b2cc28/jaxutils-nightly-0.0.8.dev20240424.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2024-04-24 00:06:44",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "JaxGaussianProcesses",
    "github_project": "JaxUitls",
    "github_not_found": true,
    "lcname": "jaxutils-nightly"
}
        
Elapsed time: 0.27082s