Name | numba4jax JSON |
Version |
0.0.14
JSON |
| download |
home_page | None |
Summary | Usa numba in jax-compiled kernels. |
upload_time | 2024-08-05 11:25:36 |
maintainer | None |
docs_url | None |
author | None |
requires_python | >=3.10 |
license | MIT |
keywords |
jax
numba
compile
jit
kernel
|
VCS |
|
bugtrack_url |
|
requirements |
No requirements were recorded.
|
Travis-CI |
No Travis.
|
coveralls test coverage |
|
# numba4jax
A small experimental python package allowing you to use numba-jitted functions from within jax
with no overhead.
This package uses the CFFI of Numba to expose the C Function pointer of your compiled
function to XLA. It works both for CPU and GPU functions.
This package exports a single decorator `@njit4jax`, which takes an argument, a function
or Tuple describing the output shape of the function itself.
See the brief example below.
```python
import jax
import jax.numpy as jnp
from numba4jax import ShapedArray, njit4jax
def compute_type(*x):
return x[0]
@njit4jax(compute_type)
def test(args):
y, x, x2 = args
y[:] = x[:] + 1
z = jnp.ones((1, 2), dtype=float)
jax.make_jaxpr(test)(z, z)
print("output: ", test(z, z))
print("output: ", jax.jit(test)(z, z))
z = jnp.ones((2, 3), dtype=float)
print("output: ", jax.jit(test)(z, z))
z = jnp.ones((1, 3, 1), dtype=float)
print("output: ", jax.jit(test)(z, z))
```
## Backend support
This package supports both the CPU and GPU backends of jax.
The GPU backend is only supported on linux, and is highly experimental.
It requires CUDA to be installed in a standard path.
CUDA is found through `numba.cuda`, so you should first check that `numba.cuda`
works.
Raw data
{
"_id": null,
"home_page": null,
"name": "numba4jax",
"maintainer": null,
"docs_url": null,
"requires_python": ">=3.10",
"maintainer_email": null,
"keywords": "Jax, Numba, compile, jit, kernel",
"author": null,
"author_email": "Filippo Vicentini <filippovicentini@gmail.com>",
"download_url": "https://files.pythonhosted.org/packages/e2/a4/f97a263f88bcd6aed229214ee6508d44ae5dbb22bcd748d071fda9c3a54c/numba4jax-0.0.14.tar.gz",
"platform": null,
"description": "# numba4jax\n\n\n\nA small experimental python package allowing you to use numba-jitted functions from within jax \nwith no overhead.\n\nThis package uses the CFFI of Numba to expose the C Function pointer of your compiled\nfunction to XLA. It works both for CPU and GPU functions.\n\nThis package exports a single decorator `@njit4jax`, which takes an argument, a function\nor Tuple describing the output shape of the function itself.\nSee the brief example below.\n\n```python\n\nimport jax\nimport jax.numpy as jnp\n\nfrom numba4jax import ShapedArray, njit4jax\n\n\ndef compute_type(*x):\n return x[0]\n\n\n@njit4jax(compute_type)\ndef test(args):\n y, x, x2 = args\n y[:] = x[:] + 1\n\n\nz = jnp.ones((1, 2), dtype=float)\n\njax.make_jaxpr(test)(z, z)\n\nprint(\"output: \", test(z, z))\nprint(\"output: \", jax.jit(test)(z, z))\n\nz = jnp.ones((2, 3), dtype=float)\nprint(\"output: \", jax.jit(test)(z, z))\n\nz = jnp.ones((1, 3, 1), dtype=float)\nprint(\"output: \", jax.jit(test)(z, z))\n\n```\n\n## Backend support\n\nThis package supports both the CPU and GPU backends of jax.\nThe GPU backend is only supported on linux, and is highly experimental.\nIt requires CUDA to be installed in a standard path.\nCUDA is found through `numba.cuda`, so you should first check that `numba.cuda`\nworks.\n",
"bugtrack_url": null,
"license": "MIT",
"summary": "Usa numba in jax-compiled kernels.",
"version": "0.0.14",
"project_urls": {
"homepage": "http://github.com/PhilipVinc/numba4jax",
"repository": "http://github.com/PhilipVinc/numba4jax"
},
"split_keywords": [
"jax",
" numba",
" compile",
" jit",
" kernel"
],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "49ded1a5b8df5efaeed0cafd79a9b32f893291274034761ef523de3452e2b123",
"md5": "06a20dab06cffc8d8b6d64857e2338ac",
"sha256": "cd4a23b5e25a3a4fc5e9adb21ca06cb1cccaf07a31e0dfa979619bf0447d33c2"
},
"downloads": -1,
"filename": "numba4jax-0.0.14-py3-none-any.whl",
"has_sig": false,
"md5_digest": "06a20dab06cffc8d8b6d64857e2338ac",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": ">=3.10",
"size": 16218,
"upload_time": "2024-08-05T11:25:35",
"upload_time_iso_8601": "2024-08-05T11:25:35.409952Z",
"url": "https://files.pythonhosted.org/packages/49/de/d1a5b8df5efaeed0cafd79a9b32f893291274034761ef523de3452e2b123/numba4jax-0.0.14-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": "",
"digests": {
"blake2b_256": "e2a4f97a263f88bcd6aed229214ee6508d44ae5dbb22bcd748d071fda9c3a54c",
"md5": "9de84a067a64a4447dfd89c7d2a67d47",
"sha256": "a16911c3d3d1ac72cd6d9fdd003c285b4b86fe365ca072b8187c228c5011630f"
},
"downloads": -1,
"filename": "numba4jax-0.0.14.tar.gz",
"has_sig": false,
"md5_digest": "9de84a067a64a4447dfd89c7d2a67d47",
"packagetype": "sdist",
"python_version": "source",
"requires_python": ">=3.10",
"size": 12030,
"upload_time": "2024-08-05T11:25:36",
"upload_time_iso_8601": "2024-08-05T11:25:36.857935Z",
"url": "https://files.pythonhosted.org/packages/e2/a4/f97a263f88bcd6aed229214ee6508d44ae5dbb22bcd748d071fda9c3a54c/numba4jax-0.0.14.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2024-08-05 11:25:36",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "PhilipVinc",
"github_project": "numba4jax",
"travis_ci": false,
"coveralls": true,
"github_actions": true,
"lcname": "numba4jax"
}