FiniteDiffX


NameFiniteDiffX JSON
Version 0.0.5 PyPI version JSON
download
home_pagehttps://github.com/ASEM000/FiniteDiffX
SummaryFinite difference tools in JAX.
upload_time2023-06-08 12:45:53
maintainer
docs_urlNone
authorMahmoud Asem
requires_python>=3.8
licenseMIT
keywords python jax
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            <h5 align="center">
<img width="200px" src="assets/finitediffx_logo.svg"> <br>

<br>

[**Installation**](#installation)
|[**Examples**](#examples)

![Tests](https://github.com/ASEM000/pytreeclass/actions/workflows/tests.yml/badge.svg)
![pyver](https://img.shields.io/badge/python-3.8%203.9%203.10%203.11_-red)
![pyver](https://img.shields.io/badge/jax-0.4+-red)
![codestyle](https://img.shields.io/badge/codestyle-black-black)
[![Downloads](https://pepy.tech/badge/FiniteDiffX)](https://pepy.tech/project/FiniteDiffX)
[![codecov](https://codecov.io/github/ASEM000/FiniteDiffX/branch/main/graph/badge.svg?token=VD45Y4HLWV)](https://codecov.io/github/ASEM000/FiniteDiffX)  
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ASEM000/FiniteDiffX/blob/main/FiniteDiffX%20Examples.ipynb)
![GitHub commit activity](https://img.shields.io/github/commit-activity/m/ASEM000/FiniteDiffX)
![PyPI](https://img.shields.io/pypi/v/FiniteDiffX)

</h5>

Differentiable finite difference tools in `jax`
Implements :

**`Array` accepting functions:**

- `difference(array, axis, accuracy, step_size, method, derivative)`
- `gradient(array, accuracy, method, step_size)`
- `jacobian(array, accuracy, method, step_size)`
- `divergence(array, accuracy, step_size, method, keepdims)`
- `hessian(array, accuracy, method, step_size)`
- `laplacian(array, accuracy, method, step_size)`
- `curl(array, step_size, method, keep_dims)`

**Function transformation:**

- `fgrad`, and `value_and_fgrad` : similar to `jax.grad` and `jax.value_and_grad` but with finite difference approximation.
- `define_fdjvp`: define `custom_jvp` rules using finite difference approximation (see example below).

## 🛠️ Installation<a id="installation"></a>

```python
pip install FiniteDiffX
```

**Install development version**

```python
pip install git+https://github.com/ASEM000/FiniteDiffX
```

**If you find it useful to you, consider giving it a star! 🌟**

<br>

## ⏩ Examples<a id="examples"></a>

### **`Array` accepting functions:**

```python

import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import numpy.testing as npt

import finitediffx as fdx

# lets first define a vector valued function F: R^3 -> R^3
# F = F1, F2
# F1 = x^2 + y^3
# F2 = x^4 + y^3
# F3 = 0
# F = [x**2 + y**3, x**4 + y**3, 0]

x, y, z = [jnp.linspace(0, 1, 100)] * 3
dx, dy, dz = x[1] - x[0], y[1] - y[0], z[1] - z[0]
X, Y, Z = jnp.meshgrid(x, y, z, indexing="ij")
F1 = X**2 + Y**3
F2 = X**4 + Y**3
F3 = jnp.zeros_like(F1)
F = jnp.stack([F1, F2, F3], axis=0)

```

<details>

<summary>Difference</summary>

```python

# ∂F1/∂x : differentiate F1 with respect to x (i.e axis=0)
dF1dx = fdx.difference(F1, axis=0, step_size=dx, accuracy=6, method="central")
dF1dx_exact = 2 * X
npt.assert_allclose(dF1dx, dF1dx_exact, atol=1e-7)

# ∂F2/∂y : differentiate F2 with respect to y (i.e axis=1)
dF2dy = fdx.difference(F2, axis=1, step_size=dy, accuracy=6)
dF2dy_exact = 3 * Y**2
npt.assert_allclose(dF2dy, dF2dy_exact, atol=1e-7)

```

</details>

<details>

<summary>Divergence</summary>

```python

# ∇.F : the divergence of F
divF = fdx.divergence(F, step_size=(dx, dy, dz), keepdims=False, accuracy=6, method="central")
divF_exact = 2 * X + 3 * Y**2
npt.assert_allclose(divF, divF_exact, atol=1e-7)

```

</details>

<details>

<summary>Gradient</summary>

```python

# ∇F1 : the gradient of F1
gradF1 = fdx.gradient(F1, step_size=(dx, dy, dz), accuracy=6, method="central")
gradF1_exact = jnp.stack([2 * X, 3 * Y**2, 0 * X], axis=0)
npt.assert_allclose(gradF1, gradF1_exact, atol=1e-7)

```

</details>

<details>

<summary>Laplacian</summary>

```python

# ΔF1 : laplacian of F1
lapF1 = fdx.laplacian(F1, step_size=(dx, dy, dz), accuracy=6, method="central")
lapF1_exact = 2 + 6 * Y
npt.assert_allclose(lapF1, lapF1_exact, atol=1e-7)

```

</details>

<details>

<summary>Curl</summary>

```python

# ∇xF : the curl of F
curlF = fdx.curl(F, step_size=(dx, dy, dz), accuracy=6, method="central")
curlF_exact = jnp.stack([F1 * 0, F1 * 0, 4 * X**3 - 3 * Y**2], axis=0)
npt.assert_allclose(curlF, curlF_exact, atol=1e-7)

```

</details>

<details>

<summary>Jacobian</summary>

```python

# Jacobian of F
JF = fdx.jacobian(F, accuracy=4, step_size=(dx, dy, dz), method="central")
JF_exact = jnp.array(
    [
        [2 * X, 3 * Y**2, jnp.zeros_like(X)],
        [4 * X**3, 3 * Y**2, jnp.zeros_like(X)],
        [jnp.zeros_like(X), jnp.zeros_like(X), jnp.zeros_like(X)],
    ]
)
npt.assert_allclose(JF, JF_exact, atol=1e-7)

```

</details>

<details>

<summary>Hessian</summary>

```python

# Hessian of F1
HF1 = fdx.hessian(F1, accuracy=4, step_size=(dx, dy, dz), method="central")
HF1_exact = jnp.array(
    [
        [
            2 * jnp.ones_like(X),  # ∂2F1/∂x2
            0 * jnp.ones_like(X),  # ∂2F1/∂xy
            0 * jnp.ones_like(X),  # ∂2F1/∂xz
        ],
        [
            0 * jnp.ones_like(X),  # ∂2F1/∂yx
            6 * Y**2,              # ∂2F1/∂y2
            0 * jnp.ones_like(X),  # ∂2F1/∂yz
        ],
        [
            0 * jnp.ones_like(X),  # ∂2F1/∂zx
            0 * jnp.ones_like(X),  # ∂2F1/∂zy
            0 * jnp.ones_like(X),  # ∂2F1/∂z2
        ],
    ]
)
npt.assert_allclose(JF, JF_exact, atol=1e-7)

```

</details>

<br><br>

### **Function transformation:**

**`fgrad`**:

`fgrad` can be used in a similar way to `jax.grad`, however the `fgrad` differentiates a function based on the finite difference rules.

<details> <summary> Example </summary>

```python

import jax
from jax import numpy as jnp
import numpy as onp  # Not jax-traceable
import finitediffx as fdx
import functools as ft
from jax.experimental import enable_x64

with enable_x64():

    @fdx.fgrad
    @fdx.fgrad
    def np_rosenbach2_fdx_style_1(x, y):
        """Compute the Rosenbach function for two variables in numpy."""
        return onp.power(1-x, 2) + 100*onp.power(y-onp.power(x, 2), 2)

    @ft.partial(fdx.fgrad, derivative=2)
    def np2_rosenbach2_fdx_style2(x, y):
        """Compute the Rosenbach function for two variables."""
        return onp.power(1-x, 2) + 100*onp.power(y-onp.power(x, 2), 2)

    @jax.grad
    @jax.grad
    def jnp_rosenbach2(x, y):
        """Compute the Rosenbach function for two variables."""
        return jnp.power(1-x, 2) + 100*jnp.power(y-jnp.power(x, 2), 2)

    print(np_rosenbach2_fdx_style_1(1.,2.))
    print(np2_rosenbach2_fdx_style2(1.,2.))
    print(jnp_rosenbach2(1., 2.))
# 402.0000951997936
# 402.0000000002219
# 402.0
```

Also works on pytrees

```python

import finitediffx as fdx

params = {"a":1., "b":2., "c":3.}

@fdx.fgrad
def func(params):
    return params["a"]**2+params["b"]

func(params)
# {'a': Array(1.9995117, dtype=float32),
#  'b': Array(0.9995117, dtype=float32),
#  'c': Array(0., dtype=float32)}

```

</details>

<br>

**`define_fdjvp`**

`define_fdjvp` combines `custom_jvp` and `fgrad` to define custom finite difference rules,when used with `pure_callback` it can to make non-tracable code works within `jax` machinary.

<details> <summary> Example </summary>

_This example is based on the comment from `jax` proposed [`JEP`](https://github.com/google/jax/issues/15425)_

For example this code will fail to work with `jax` transformations, becasue it uses `numpy` functions.

```python
import numpy as onp
import jax


def numpy_func(x: onp.ndarray) -> onp.ndarray:
    return onp.power(x, 2)


try:
    jax.grad(numpy_func)(2.0)
except jax.errors.TracerArrayConversionError as e:
    print(e)

# The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>with<JVPTrace(level=2/0)> with
#   primal = 2.0
#   tangent = Traced<ShapedArray(float32[], weak_type=True)>with<JaxprTrace(level=1/0)> with
#     pval = (ShapedArray(float32[], weak_type=True), None)
#     recipe = LambdaBinding()
# See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
```

We can use `define_fdjvp` to make this work with non-`jax` code

```python

import functools as ft

import jax
from typing import Callable, Any, Union
import jax.numpy as jnp
import numpy as onp
import finitediffx as fdx
import functools as ft


def wrap_pure_callback(func):
    @ft.wraps(func)
    def wrapper(*args, **kwargs):
        args = [jnp.asarray(arg) for arg in args]
        func_ = lambda *a, **k: func(*a, **k).astype(a[0].dtype)
        dtype_ = jax.ShapeDtypeStruct(
            jnp.broadcast_shapes(*[ai.shape for ai in args]),
            args[0].dtype,
        )
        return jax.pure_callback(func_, dtype_, *args, **kwargs, vectorized=True)

    return wrapper


@jax.jit  # -> can compile
@jax.grad  # -> can take gradient
@ft.partial(
    fdx.define_fdjvp,
    # automatically generate offsets
    offsets=fdx.Offset(accuracy=4),
    # manually set step size
    step_size=1e-3,
)
@wrap_pure_callback
def numpy_func(x: onp.ndarray) -> onp.ndarray:
    return onp.power(x, 2)


print(numpy_func(1.0))
# 1.9999794

@jax.jit  # -> can compile
@jax.grad  # -> can take gradient
@ft.partial(
    fdx.define_fdjvp,
    # provide the desired evaluation points for the finite difference stencil
    # in this case its centered finite difference (f(x-1) - f(x+1))/(2*step_size)
    offsets=jnp.array([1, -1]),
    # manually set step size
    step_size=1e-3,
)
@wrap_pure_callback
def numpy_func(x: onp.ndarray) -> onp.ndarray:
    return onp.power(x, 2)

print(numpy_func(1.0))
# 2.0000048
```

</details>

            

Raw data

            {
    "_id": null,
    "home_page": "https://github.com/ASEM000/FiniteDiffX",
    "name": "FiniteDiffX",
    "maintainer": "",
    "docs_url": null,
    "requires_python": ">=3.8",
    "maintainer_email": "",
    "keywords": "python jax",
    "author": "Mahmoud Asem",
    "author_email": "asem00@kaist.ac.kr",
    "download_url": "https://files.pythonhosted.org/packages/4d/ba/3a6f8ea591edc5e80fa97b9b7f1db37ff5c0b4ca201c62eccb4d332226be/FiniteDiffX-0.0.5.tar.gz",
    "platform": null,
    "description": "<h5 align=\"center\">\n<img width=\"200px\" src=\"assets/finitediffx_logo.svg\"> <br>\n\n<br>\n\n[**Installation**](#installation)\n|[**Examples**](#examples)\n\n![Tests](https://github.com/ASEM000/pytreeclass/actions/workflows/tests.yml/badge.svg)\n![pyver](https://img.shields.io/badge/python-3.8%203.9%203.10%203.11_-red)\n![pyver](https://img.shields.io/badge/jax-0.4+-red)\n![codestyle](https://img.shields.io/badge/codestyle-black-black)\n[![Downloads](https://pepy.tech/badge/FiniteDiffX)](https://pepy.tech/project/FiniteDiffX)\n[![codecov](https://codecov.io/github/ASEM000/FiniteDiffX/branch/main/graph/badge.svg?token=VD45Y4HLWV)](https://codecov.io/github/ASEM000/FiniteDiffX)  \n[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ASEM000/FiniteDiffX/blob/main/FiniteDiffX%20Examples.ipynb)\n![GitHub commit activity](https://img.shields.io/github/commit-activity/m/ASEM000/FiniteDiffX)\n![PyPI](https://img.shields.io/pypi/v/FiniteDiffX)\n\n</h5>\n\nDifferentiable finite difference tools in `jax`\nImplements :\n\n**`Array` accepting functions:**\n\n- `difference(array, axis, accuracy, step_size, method, derivative)`\n- `gradient(array, accuracy, method, step_size)`\n- `jacobian(array, accuracy, method, step_size)`\n- `divergence(array, accuracy, step_size, method, keepdims)`\n- `hessian(array, accuracy, method, step_size)`\n- `laplacian(array, accuracy, method, step_size)`\n- `curl(array, step_size, method, keep_dims)`\n\n**Function transformation:**\n\n- `fgrad`, and `value_and_fgrad` : similar to `jax.grad` and `jax.value_and_grad` but with finite difference approximation.\n- `define_fdjvp`: define `custom_jvp` rules using finite difference approximation (see example below).\n\n## \ud83d\udee0\ufe0f Installation<a id=\"installation\"></a>\n\n```python\npip install FiniteDiffX\n```\n\n**Install development version**\n\n```python\npip install git+https://github.com/ASEM000/FiniteDiffX\n```\n\n**If you find it useful to you, consider giving it a star! \ud83c\udf1f**\n\n<br>\n\n## \u23e9 Examples<a id=\"examples\"></a>\n\n### **`Array` accepting functions:**\n\n```python\n\nimport jax\njax.config.update(\"jax_enable_x64\", True)\nimport jax.numpy as jnp\nimport numpy.testing as npt\n\nimport finitediffx as fdx\n\n# lets first define a vector valued function F: R^3 -> R^3\n# F = F1, F2\n# F1 = x^2 + y^3\n# F2 = x^4 + y^3\n# F3 = 0\n# F = [x**2 + y**3, x**4 + y**3, 0]\n\nx, y, z = [jnp.linspace(0, 1, 100)] * 3\ndx, dy, dz = x[1] - x[0], y[1] - y[0], z[1] - z[0]\nX, Y, Z = jnp.meshgrid(x, y, z, indexing=\"ij\")\nF1 = X**2 + Y**3\nF2 = X**4 + Y**3\nF3 = jnp.zeros_like(F1)\nF = jnp.stack([F1, F2, F3], axis=0)\n\n```\n\n<details>\n\n<summary>Difference</summary>\n\n```python\n\n# \u2202F1/\u2202x : differentiate F1 with respect to x (i.e axis=0)\ndF1dx = fdx.difference(F1, axis=0, step_size=dx, accuracy=6, method=\"central\")\ndF1dx_exact = 2 * X\nnpt.assert_allclose(dF1dx, dF1dx_exact, atol=1e-7)\n\n# \u2202F2/\u2202y : differentiate F2 with respect to y (i.e axis=1)\ndF2dy = fdx.difference(F2, axis=1, step_size=dy, accuracy=6)\ndF2dy_exact = 3 * Y**2\nnpt.assert_allclose(dF2dy, dF2dy_exact, atol=1e-7)\n\n```\n\n</details>\n\n<details>\n\n<summary>Divergence</summary>\n\n```python\n\n# \u2207.F : the divergence of F\ndivF = fdx.divergence(F, step_size=(dx, dy, dz), keepdims=False, accuracy=6, method=\"central\")\ndivF_exact = 2 * X + 3 * Y**2\nnpt.assert_allclose(divF, divF_exact, atol=1e-7)\n\n```\n\n</details>\n\n<details>\n\n<summary>Gradient</summary>\n\n```python\n\n# \u2207F1 : the gradient of F1\ngradF1 = fdx.gradient(F1, step_size=(dx, dy, dz), accuracy=6, method=\"central\")\ngradF1_exact = jnp.stack([2 * X, 3 * Y**2, 0 * X], axis=0)\nnpt.assert_allclose(gradF1, gradF1_exact, atol=1e-7)\n\n```\n\n</details>\n\n<details>\n\n<summary>Laplacian</summary>\n\n```python\n\n# \u0394F1 : laplacian of F1\nlapF1 = fdx.laplacian(F1, step_size=(dx, dy, dz), accuracy=6, method=\"central\")\nlapF1_exact = 2 + 6 * Y\nnpt.assert_allclose(lapF1, lapF1_exact, atol=1e-7)\n\n```\n\n</details>\n\n<details>\n\n<summary>Curl</summary>\n\n```python\n\n# \u2207xF : the curl of F\ncurlF = fdx.curl(F, step_size=(dx, dy, dz), accuracy=6, method=\"central\")\ncurlF_exact = jnp.stack([F1 * 0, F1 * 0, 4 * X**3 - 3 * Y**2], axis=0)\nnpt.assert_allclose(curlF, curlF_exact, atol=1e-7)\n\n```\n\n</details>\n\n<details>\n\n<summary>Jacobian</summary>\n\n```python\n\n# Jacobian of F\nJF = fdx.jacobian(F, accuracy=4, step_size=(dx, dy, dz), method=\"central\")\nJF_exact = jnp.array(\n    [\n        [2 * X, 3 * Y**2, jnp.zeros_like(X)],\n        [4 * X**3, 3 * Y**2, jnp.zeros_like(X)],\n        [jnp.zeros_like(X), jnp.zeros_like(X), jnp.zeros_like(X)],\n    ]\n)\nnpt.assert_allclose(JF, JF_exact, atol=1e-7)\n\n```\n\n</details>\n\n<details>\n\n<summary>Hessian</summary>\n\n```python\n\n# Hessian of F1\nHF1 = fdx.hessian(F1, accuracy=4, step_size=(dx, dy, dz), method=\"central\")\nHF1_exact = jnp.array(\n    [\n        [\n            2 * jnp.ones_like(X),  # \u22022F1/\u2202x2\n            0 * jnp.ones_like(X),  # \u22022F1/\u2202xy\n            0 * jnp.ones_like(X),  # \u22022F1/\u2202xz\n        ],\n        [\n            0 * jnp.ones_like(X),  # \u22022F1/\u2202yx\n            6 * Y**2,              # \u22022F1/\u2202y2\n            0 * jnp.ones_like(X),  # \u22022F1/\u2202yz\n        ],\n        [\n            0 * jnp.ones_like(X),  # \u22022F1/\u2202zx\n            0 * jnp.ones_like(X),  # \u22022F1/\u2202zy\n            0 * jnp.ones_like(X),  # \u22022F1/\u2202z2\n        ],\n    ]\n)\nnpt.assert_allclose(JF, JF_exact, atol=1e-7)\n\n```\n\n</details>\n\n<br><br>\n\n### **Function transformation:**\n\n**`fgrad`**:\n\n`fgrad` can be used in a similar way to `jax.grad`, however the `fgrad` differentiates a function based on the finite difference rules.\n\n<details> <summary> Example </summary>\n\n```python\n\nimport jax\nfrom jax import numpy as jnp\nimport numpy as onp  # Not jax-traceable\nimport finitediffx as fdx\nimport functools as ft\nfrom jax.experimental import enable_x64\n\nwith enable_x64():\n\n    @fdx.fgrad\n    @fdx.fgrad\n    def np_rosenbach2_fdx_style_1(x, y):\n        \"\"\"Compute the Rosenbach function for two variables in numpy.\"\"\"\n        return onp.power(1-x, 2) + 100*onp.power(y-onp.power(x, 2), 2)\n\n    @ft.partial(fdx.fgrad, derivative=2)\n    def np2_rosenbach2_fdx_style2(x, y):\n        \"\"\"Compute the Rosenbach function for two variables.\"\"\"\n        return onp.power(1-x, 2) + 100*onp.power(y-onp.power(x, 2), 2)\n\n    @jax.grad\n    @jax.grad\n    def jnp_rosenbach2(x, y):\n        \"\"\"Compute the Rosenbach function for two variables.\"\"\"\n        return jnp.power(1-x, 2) + 100*jnp.power(y-jnp.power(x, 2), 2)\n\n    print(np_rosenbach2_fdx_style_1(1.,2.))\n    print(np2_rosenbach2_fdx_style2(1.,2.))\n    print(jnp_rosenbach2(1., 2.))\n# 402.0000951997936\n# 402.0000000002219\n# 402.0\n```\n\nAlso works on pytrees\n\n```python\n\nimport finitediffx as fdx\n\nparams = {\"a\":1., \"b\":2., \"c\":3.}\n\n@fdx.fgrad\ndef func(params):\n    return params[\"a\"]**2+params[\"b\"]\n\nfunc(params)\n# {'a': Array(1.9995117, dtype=float32),\n#  'b': Array(0.9995117, dtype=float32),\n#  'c': Array(0., dtype=float32)}\n\n```\n\n</details>\n\n<br>\n\n**`define_fdjvp`**\n\n`define_fdjvp` combines `custom_jvp` and `fgrad` to define custom finite difference rules,when used with `pure_callback` it can to make non-tracable code works within `jax` machinary.\n\n<details> <summary> Example </summary>\n\n_This example is based on the comment from `jax` proposed [`JEP`](https://github.com/google/jax/issues/15425)_\n\nFor example this code will fail to work with `jax` transformations, becasue it uses `numpy` functions.\n\n```python\nimport numpy as onp\nimport jax\n\n\ndef numpy_func(x: onp.ndarray) -> onp.ndarray:\n    return onp.power(x, 2)\n\n\ntry:\n    jax.grad(numpy_func)(2.0)\nexcept jax.errors.TracerArrayConversionError as e:\n    print(e)\n\n# The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>with<JVPTrace(level=2/0)> with\n#   primal = 2.0\n#   tangent = Traced<ShapedArray(float32[], weak_type=True)>with<JaxprTrace(level=1/0)> with\n#     pval = (ShapedArray(float32[], weak_type=True), None)\n#     recipe = LambdaBinding()\n# See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError\n```\n\nWe can use `define_fdjvp` to make this work with non-`jax` code\n\n```python\n\nimport functools as ft\n\nimport jax\nfrom typing import Callable, Any, Union\nimport jax.numpy as jnp\nimport numpy as onp\nimport finitediffx as fdx\nimport functools as ft\n\n\ndef wrap_pure_callback(func):\n    @ft.wraps(func)\n    def wrapper(*args, **kwargs):\n        args = [jnp.asarray(arg) for arg in args]\n        func_ = lambda *a, **k: func(*a, **k).astype(a[0].dtype)\n        dtype_ = jax.ShapeDtypeStruct(\n            jnp.broadcast_shapes(*[ai.shape for ai in args]),\n            args[0].dtype,\n        )\n        return jax.pure_callback(func_, dtype_, *args, **kwargs, vectorized=True)\n\n    return wrapper\n\n\n@jax.jit  # -> can compile\n@jax.grad  # -> can take gradient\n@ft.partial(\n    fdx.define_fdjvp,\n    # automatically generate offsets\n    offsets=fdx.Offset(accuracy=4),\n    # manually set step size\n    step_size=1e-3,\n)\n@wrap_pure_callback\ndef numpy_func(x: onp.ndarray) -> onp.ndarray:\n    return onp.power(x, 2)\n\n\nprint(numpy_func(1.0))\n# 1.9999794\n\n@jax.jit  # -> can compile\n@jax.grad  # -> can take gradient\n@ft.partial(\n    fdx.define_fdjvp,\n    # provide the desired evaluation points for the finite difference stencil\n    # in this case its centered finite difference (f(x-1) - f(x+1))/(2*step_size)\n    offsets=jnp.array([1, -1]),\n    # manually set step size\n    step_size=1e-3,\n)\n@wrap_pure_callback\ndef numpy_func(x: onp.ndarray) -> onp.ndarray:\n    return onp.power(x, 2)\n\nprint(numpy_func(1.0))\n# 2.0000048\n```\n\n</details>\n",
    "bugtrack_url": null,
    "license": "MIT",
    "summary": "Finite difference tools in JAX.",
    "version": "0.0.5",
    "project_urls": {
        "Homepage": "https://github.com/ASEM000/FiniteDiffX"
    },
    "split_keywords": [
        "python",
        "jax"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "1b26cf9d5ecdd40639148213883ab7c86092b09ac4303f3b31aa8403bf5c8a8a",
                "md5": "7f90b4bef4b23f881ed2c5295f7ef506",
                "sha256": "ce16fadc41a358e1946c111ed1fcd75382edcec1097d2d026af7751fa4776222"
            },
            "downloads": -1,
            "filename": "FiniteDiffX-0.0.5-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "7f90b4bef4b23f881ed2c5295f7ef506",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.8",
            "size": 25033,
            "upload_time": "2023-06-08T12:45:48",
            "upload_time_iso_8601": "2023-06-08T12:45:48.931407Z",
            "url": "https://files.pythonhosted.org/packages/1b/26/cf9d5ecdd40639148213883ab7c86092b09ac4303f3b31aa8403bf5c8a8a/FiniteDiffX-0.0.5-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "4dba3a6f8ea591edc5e80fa97b9b7f1db37ff5c0b4ca201c62eccb4d332226be",
                "md5": "148b5300883ec316c7f00924fa733a6b",
                "sha256": "bc7a01b0cbb50141cc83e30421c2c4294bb439137abf6864452f207c4b23fd09"
            },
            "downloads": -1,
            "filename": "FiniteDiffX-0.0.5.tar.gz",
            "has_sig": false,
            "md5_digest": "148b5300883ec316c7f00924fa733a6b",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.8",
            "size": 24621,
            "upload_time": "2023-06-08T12:45:53",
            "upload_time_iso_8601": "2023-06-08T12:45:53.446794Z",
            "url": "https://files.pythonhosted.org/packages/4d/ba/3a6f8ea591edc5e80fa97b9b7f1db37ff5c0b4ca201c62eccb4d332226be/FiniteDiffX-0.0.5.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2023-06-08 12:45:53",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "ASEM000",
    "github_project": "FiniteDiffX",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "lcname": "finitediffx"
}
        
Elapsed time: 0.09244s