numba4jax


Namenumba4jax JSON
Version 0.0.12 PyPI version JSON
download
home_page
SummaryUsa numba in jax-compiled kernels.
upload_time2024-02-07 23:23:43
maintainer
docs_urlNone
author
requires_python>=3.9
licenseMIT
keywords jax numba compile jit kernel
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage
            # numba4jax



A small experimental python package allowing you to use numba-jitted functions from within jax 
with no overhead.

This package uses the CFFI of Numba to expose the C Function pointer of your compiled
function to XLA. It works both for CPU and GPU functions.

This package exports a single decorator `@njit4jax`, which takes an argument, a function
or Tuple describing the output shape of the function itself.
See the brief example below.

```python

import jax
import jax.numpy as jnp

from numba4jax import ShapedArray, njit4jax


def compute_type(*x):
    return x[0]


@njit4jax(compute_type)
def test(args):
    y, x, x2 = args
    y[:] = x[:] + 1


z = jnp.ones((1, 2), dtype=float)

jax.make_jaxpr(test)(z, z)

print("output: ", test(z, z))
print("output: ", jax.jit(test)(z, z))

z = jnp.ones((2, 3), dtype=float)
print("output: ", jax.jit(test)(z, z))

z = jnp.ones((1, 3, 1), dtype=float)
print("output: ", jax.jit(test)(z, z))

```

## Backend support

This package supports both the CPU and GPU backends of jax.
The GPU backend is only supported on linux, and is highly experimental.
It requires CUDA to be installed in a standard path.
CUDA is found through `numba.cuda`, so you should first check that `numba.cuda`
works.

            

Raw data

            {
    "_id": null,
    "home_page": "",
    "name": "numba4jax",
    "maintainer": "",
    "docs_url": null,
    "requires_python": ">=3.9",
    "maintainer_email": "",
    "keywords": "Jax,Numba,compile,jit,kernel",
    "author": "",
    "author_email": "Filippo Vicentini <filippovicentini@gmail.com>",
    "download_url": "https://files.pythonhosted.org/packages/5c/b1/41177ba8bc47db2997f7627d57c9059b526255351ec3d5c6ba22a260f2b5/numba4jax-0.0.12.tar.gz",
    "platform": null,
    "description": "# numba4jax\n\n\n\nA small experimental python package allowing you to use numba-jitted functions from within jax \nwith no overhead.\n\nThis package uses the CFFI of Numba to expose the C Function pointer of your compiled\nfunction to XLA. It works both for CPU and GPU functions.\n\nThis package exports a single decorator `@njit4jax`, which takes an argument, a function\nor Tuple describing the output shape of the function itself.\nSee the brief example below.\n\n```python\n\nimport jax\nimport jax.numpy as jnp\n\nfrom numba4jax import ShapedArray, njit4jax\n\n\ndef compute_type(*x):\n    return x[0]\n\n\n@njit4jax(compute_type)\ndef test(args):\n    y, x, x2 = args\n    y[:] = x[:] + 1\n\n\nz = jnp.ones((1, 2), dtype=float)\n\njax.make_jaxpr(test)(z, z)\n\nprint(\"output: \", test(z, z))\nprint(\"output: \", jax.jit(test)(z, z))\n\nz = jnp.ones((2, 3), dtype=float)\nprint(\"output: \", jax.jit(test)(z, z))\n\nz = jnp.ones((1, 3, 1), dtype=float)\nprint(\"output: \", jax.jit(test)(z, z))\n\n```\n\n## Backend support\n\nThis package supports both the CPU and GPU backends of jax.\nThe GPU backend is only supported on linux, and is highly experimental.\nIt requires CUDA to be installed in a standard path.\nCUDA is found through `numba.cuda`, so you should first check that `numba.cuda`\nworks.\n",
    "bugtrack_url": null,
    "license": "MIT",
    "summary": "Usa numba in jax-compiled kernels.",
    "version": "0.0.12",
    "project_urls": {
        "homepage": "http://github.com/PhilipVinc/numba4jax",
        "repository": "http://github.com/PhilipVinc/numba4jax"
    },
    "split_keywords": [
        "jax",
        "numba",
        "compile",
        "jit",
        "kernel"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "673a7f334a13e7533fd6d2939b58e7ed5687a88d03b29c79c9f6d7578479b4db",
                "md5": "d8434e348894b624f266c55ad3fb315d",
                "sha256": "84852e16cc51c16be6fae32c92d0541838d16318cd3d3a85894f96f59b18473e"
            },
            "downloads": -1,
            "filename": "numba4jax-0.0.12-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "d8434e348894b624f266c55ad3fb315d",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.9",
            "size": 16224,
            "upload_time": "2024-02-07T23:23:42",
            "upload_time_iso_8601": "2024-02-07T23:23:42.285425Z",
            "url": "https://files.pythonhosted.org/packages/67/3a/7f334a13e7533fd6d2939b58e7ed5687a88d03b29c79c9f6d7578479b4db/numba4jax-0.0.12-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "5cb141177ba8bc47db2997f7627d57c9059b526255351ec3d5c6ba22a260f2b5",
                "md5": "e248ce7019c6c5f4c83825fb267f7ed1",
                "sha256": "e1faf6a0566f4fb941abf8821b9c854b7398eb08a0c8157927f8b4717a393446"
            },
            "downloads": -1,
            "filename": "numba4jax-0.0.12.tar.gz",
            "has_sig": false,
            "md5_digest": "e248ce7019c6c5f4c83825fb267f7ed1",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.9",
            "size": 12033,
            "upload_time": "2024-02-07T23:23:43",
            "upload_time_iso_8601": "2024-02-07T23:23:43.968511Z",
            "url": "https://files.pythonhosted.org/packages/5c/b1/41177ba8bc47db2997f7627d57c9059b526255351ec3d5c6ba22a260f2b5/numba4jax-0.0.12.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2024-02-07 23:23:43",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "PhilipVinc",
    "github_project": "numba4jax",
    "travis_ci": false,
    "coveralls": true,
    "github_actions": true,
    "lcname": "numba4jax"
}
        
Elapsed time: 0.18339s