----
This project has now been incorporated into [GPJax](https://github.com/JaxGaussianProcesses/GPJax).
----
<p align="center">
<img width="700" height="300" src="https://raw.githubusercontent.com/JaxGaussianProcesses/JaxKern/main/docs/_static/logo/logo.png" alt="JaxKern's logo">
</p>
<h2 align='center'>Kernels in Jax.</h2>
[![codecov](https://codecov.io/gh/JaxGaussianProcesses/JaxKern/branch/main/graph/badge.svg?token=8WD7YYMPFS)](https://codecov.io/gh/JaxGaussianProcesses/JaxKern)
[![CircleCI](https://dl.circleci.com/status-badge/img/gh/JaxGaussianProcesses/JaxKern/tree/main.svg?style=svg)](https://dl.circleci.com/status-badge/redirect/gh/JaxGaussianProcesses/JaxKern/tree/main)
[![Documentation Status](https://readthedocs.org/projects/gpjax/badge/?version=latest)](https://gpjax.readthedocs.io/en/latest/?badge=latest)
[![PyPI version](https://badge.fury.io/py/jaxkern.svg)](https://badge.fury.io/py/jaxkern)
[![Downloads](https://pepy.tech/badge/jaxkern)](https://pepy.tech/project/jaxkern)
[![Slack Invite](https://img.shields.io/badge/Slack_Invite--blue?style=social&logo=slack)](https://join.slack.com/t/gpjax/shared_invite/zt-1da57pmjn-rdBCVg9kApirEEn2E5Q2Zw)
## Introduction
JaxKern is Python library for working with kernel functions in JAX. We currently support the following kernels:
* Stationary
* Radial basis function (Squared exponential)
* Matérn
* Powered exponential
* Rational quadratic
* White noise
* Periodic
* Non-stationary
* Linear
* Polynomial
* Non-Euclidean
* Graph kernels
In addition to this, we implement kernel approximations using the [Random Fourier feature](https://people.eecs.berkeley.edu/~brecht/papers/07.rah.rec.nips.pdf) approach.
## Example
The following code snippet demonstrates how the first order Matérn kernel can be computed and, subsequently, approximated using random Fourier features.
```python
import jaxkern as jk
import jax.numpy as jnp
import jax.random as jr
key = jr.PRNGKey(123)
# Define the points on which we'll evaluate the kernel
X = jr.uniform(key, shape = (10, 1), minval=-3., maxval=3.)
Y = jr.uniform(key, shape = (20, 1), minval=-3., maxval=3.)
# Instantiate the kernel and its parameters
kernel = jk.Matern32()
params = kernel.init_params(key)
# Compute the 10x10 Gram matrix
Kxx = kernel.gram(params, X)
# Compute the 10x20 cross-covariance matrix
Kxy = kernel.cross_covariance(params, X, Y)
# Build a RFF approximation
approx = RFF(kernel, num_basis_fns = 5)
rff_params = approx.init_params(key)
# Build an approximation to the Gram matrix
Qff = approx.gram(rff_params, X)
```
## Code Structure
All kernels are supplied with a `gram` and `cross_covariance` method. When computing a Gram matrix, there is often some structure in the data (e.g., Markov) that can be exploited to yield a sparse matrix. To instruct JAX how to operate on this, the return type of `gram` is a Linear Operator from [JaxLinOp](https://github.com/JaxGaussianProcesses/JaxLinOp).
Within [GPJax](https://github.com/JaxGaussianProcesses/GPJax), all kernel computations are handled using JaxKern.
## Documentation
A full set of documentation is a work in progress. However, many of the details in JaxKern can be found in the [GPJax documentation](https://gpjax.readthedocs.io/en/latest/).
Raw data
{
"_id": null,
"home_page": null,
"name": "jaxkern-nightly",
"maintainer": null,
"docs_url": null,
"requires_python": null,
"maintainer_email": null,
"keywords": "gaussian-processes jax machine-learning bayesian",
"author": "Daniel Dodd and Thomas Pinder",
"author_email": "tompinder@live.co.uk",
"download_url": "https://files.pythonhosted.org/packages/34/51/613fe33444acf9e6ae3d42abef730406be9dcf2275fd448926ba1e53c01f/jaxkern-nightly-0.0.5.dev20241221.tar.gz",
"platform": null,
"description": "----\nThis project has now been incorporated into [GPJax](https://github.com/JaxGaussianProcesses/GPJax).\n----\n\n<p align=\"center\">\n<img width=\"700\" height=\"300\" src=\"https://raw.githubusercontent.com/JaxGaussianProcesses/JaxKern/main/docs/_static/logo/logo.png\" alt=\"JaxKern's logo\">\n</p>\n<h2 align='center'>Kernels in Jax.</h2>\n\n[![codecov](https://codecov.io/gh/JaxGaussianProcesses/JaxKern/branch/main/graph/badge.svg?token=8WD7YYMPFS)](https://codecov.io/gh/JaxGaussianProcesses/JaxKern)\n[![CircleCI](https://dl.circleci.com/status-badge/img/gh/JaxGaussianProcesses/JaxKern/tree/main.svg?style=svg)](https://dl.circleci.com/status-badge/redirect/gh/JaxGaussianProcesses/JaxKern/tree/main)\n[![Documentation Status](https://readthedocs.org/projects/gpjax/badge/?version=latest)](https://gpjax.readthedocs.io/en/latest/?badge=latest)\n[![PyPI version](https://badge.fury.io/py/jaxkern.svg)](https://badge.fury.io/py/jaxkern)\n[![Downloads](https://pepy.tech/badge/jaxkern)](https://pepy.tech/project/jaxkern)\n[![Slack Invite](https://img.shields.io/badge/Slack_Invite--blue?style=social&logo=slack)](https://join.slack.com/t/gpjax/shared_invite/zt-1da57pmjn-rdBCVg9kApirEEn2E5Q2Zw)\n\n## Introduction\n\nJaxKern is Python library for working with kernel functions in JAX. We currently support the following kernels:\n* Stationary\n * Radial basis function (Squared exponential)\n * Mat\u00e9rn\n * Powered exponential\n * Rational quadratic\n * White noise\n * Periodic\n* Non-stationary\n * Linear \n * Polynomial\n* Non-Euclidean\n * Graph kernels\n\nIn addition to this, we implement kernel approximations using the [Random Fourier feature](https://people.eecs.berkeley.edu/~brecht/papers/07.rah.rec.nips.pdf) approach.\n\n## Example\n\nThe following code snippet demonstrates how the first order Mat\u00e9rn kernel can be computed and, subsequently, approximated using random Fourier features.\n```python\nimport jaxkern as jk\nimport jax.numpy as jnp\nimport jax.random as jr\nkey = jr.PRNGKey(123)\n\n# Define the points on which we'll evaluate the kernel\nX = jr.uniform(key, shape = (10, 1), minval=-3., maxval=3.)\nY = jr.uniform(key, shape = (20, 1), minval=-3., maxval=3.)\n\n# Instantiate the kernel and its parameters\nkernel = jk.Matern32()\nparams = kernel.init_params(key)\n\n# Compute the 10x10 Gram matrix\nKxx = kernel.gram(params, X)\n\n# Compute the 10x20 cross-covariance matrix\nKxy = kernel.cross_covariance(params, X, Y)\n\n# Build a RFF approximation\napprox = RFF(kernel, num_basis_fns = 5)\nrff_params = approx.init_params(key)\n\n# Build an approximation to the Gram matrix\nQff = approx.gram(rff_params, X)\n```\n\n## Code Structure\n\nAll kernels are supplied with a `gram` and `cross_covariance` method. When computing a Gram matrix, there is often some structure in the data (e.g., Markov) that can be exploited to yield a sparse matrix. To instruct JAX how to operate on this, the return type of `gram` is a Linear Operator from [JaxLinOp](https://github.com/JaxGaussianProcesses/JaxLinOp). \n\nWithin [GPJax](https://github.com/JaxGaussianProcesses/GPJax), all kernel computations are handled using JaxKern.\n\n## Documentation\n\nA full set of documentation is a work in progress. However, many of the details in JaxKern can be found in the [GPJax documentation](https://gpjax.readthedocs.io/en/latest/).\n\n\n",
"bugtrack_url": null,
"license": "LICENSE",
"summary": "Kernels in Jax.",
"version": "0.0.5.dev20241221",
"project_urls": {
"Documentation": "https://JaxKern.readthedocs.io/en/latest/",
"Source": "https://github.com/JaxGaussianProcesses/JaxKern"
},
"split_keywords": [
"gaussian-processes",
"jax",
"machine-learning",
"bayesian"
],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "45861ae8f750cd7bb76d1fdf970b109000ab34be2274195a10b560e61bfb4dfd",
"md5": "c5b03ed96e9cdedd6affa8ea5f2b2fb9",
"sha256": "2a8a470d8abf667390843b2e2f53b1ca783aa12a0e9c033a645d8185482078ed"
},
"downloads": -1,
"filename": "jaxkern_nightly-0.0.5.dev20241221-py3-none-any.whl",
"has_sig": false,
"md5_digest": "c5b03ed96e9cdedd6affa8ea5f2b2fb9",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": null,
"size": 35844,
"upload_time": "2024-12-21T00:05:15",
"upload_time_iso_8601": "2024-12-21T00:05:15.345611Z",
"url": "https://files.pythonhosted.org/packages/45/86/1ae8f750cd7bb76d1fdf970b109000ab34be2274195a10b560e61bfb4dfd/jaxkern_nightly-0.0.5.dev20241221-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": "",
"digests": {
"blake2b_256": "3451613fe33444acf9e6ae3d42abef730406be9dcf2275fd448926ba1e53c01f",
"md5": "52a8a84e405329a0dd2abaa5bba16391",
"sha256": "88c4b77270ee796e88414bd8bd480bbd9f5d3bbab52fe0cfcaf5c2ffea1f81f6"
},
"downloads": -1,
"filename": "jaxkern-nightly-0.0.5.dev20241221.tar.gz",
"has_sig": false,
"md5_digest": "52a8a84e405329a0dd2abaa5bba16391",
"packagetype": "sdist",
"python_version": "source",
"requires_python": null,
"size": 33785,
"upload_time": "2024-12-21T00:05:18",
"upload_time_iso_8601": "2024-12-21T00:05:18.155191Z",
"url": "https://files.pythonhosted.org/packages/34/51/613fe33444acf9e6ae3d42abef730406be9dcf2275fd448926ba1e53c01f/jaxkern-nightly-0.0.5.dev20241221.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2024-12-21 00:05:18",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "JaxGaussianProcesses",
"github_project": "JaxKern",
"travis_ci": false,
"coveralls": false,
"github_actions": true,
"circle": true,
"lcname": "jaxkern-nightly"
}