<div align = "center">
<img width=400 src="assets/kernexlogo.svg" align="center">
<h3 align="center">Differentiable Stencil computations in JAX </h2>
[**Installation**](#Installation)
|[**Description**](#Description)
|[**Quick example**](#QuickExample)
|[**Function mesh**](#FunctionMesh)
|[**More Examples**](#MoreExamples)
|[**Benchmarking**](#Benchmarking)
![Tests](https://github.com/ASEM000/kernex/actions/workflows/tests.yml/badge.svg)
![pyver](https://img.shields.io/badge/python-3.8%203.8%203.9%203.11-red)
![codestyle](https://img.shields.io/badge/codestyle-black-black)
[![Downloads](https://pepy.tech/badge/kernex)](https://pepy.tech/project/kernex)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/14UEqKzIyZsDzQ9IMeanvztXxbbbatTYV?usp=sharing)
[![codecov](https://codecov.io/gh/ASEM000/kernex/branch/main/graph/badge.svg?token=3KLL24Z94I)](https://codecov.io/gh/ASEM000/kernex)
[![DOI](https://zenodo.org/badge/512400616.svg)](https://zenodo.org/badge/latestdoi/512400616)
</div>
## π οΈ Installation<a id="Installation"></a>
```python
pip install kernex
```
## π Description<a id="Description"></a>
Kernex extends `jax.vmap`/`jax.lax.map`/`jax.pmap` with `kmap` and `jax.lax.scan` with `kscan` for general stencil computations.
The prime motivation for this package is to blend the solution process of PDEs into a NN setting.
## β© Quick Example <a id="QuickExample">
<div align="center">
<table>
<tr>
<td width="50%" align="center" > kmap </td> <td align="center" > kscan </td>
</tr>
<tr>
<td>
```python
import kernex as kex
import jax.numpy as jnp
@kex.kmap(kernel_size=(3,))
def sum_all(x):
return jnp.sum(x)
>>> x = jnp.array([1,2,3,4,5])
>>> print(sum_all(x))
[ 6 9 12]
```
</td>
<td>
```python
import kernex as kex
import jax.numpy as jnp
@kex.kscan(kernel_size=(3,))
def sum_all(x):
return jnp.sum(x)
> > > x = jnp.array([1,2,3,4,5])
> > > print(sum_all(x))
> > > [ 6 13 22]
````
</td>
</tr>
</table>
<table>
<tr>
<td width="50%">
`jax.vmap` is used to sum each window content.
<img src="assets/kmap_sum.png" width=400px>
</td>
<td>
`lax.scan` is used to update the array and the window sum is calculated sequentially.
the first three rows represents the three sequential steps used to get the solution in the last row.
<img align="center" src="assets/kscan_sum.png" width=400px>
</td>
</tr>
</table>
</div>
## πΈοΈ Function mesh concept <a id="FunctionMesh">
<details>
The objective is to apply `f(x) = x^2 at index=0 and f(x) = x^3 at index=(1,10)`
To achieve the following operation with `jax.lax.switch` , we need a list of 10 functions correspoing to each cell of the example array.
For this reason , kernex adopts a modified version of `jax.lax.switch` to reduce the number of branches required.
```python
# function applies x^2 at boundaries, and applies x^3 to to the interior
βββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ
f = β x^2 β x^3 β x^3 β x^3 β x^3 β x^3 β x^3 β x^3 β x^3 β x^3 β
βββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ
βββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ
f( β 1 β 2 β 3 β 4 β 5 β 6 β 7 β 8 β 9 β 10 β ) =
βββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ
βββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ
β 1 β 8 β 27 β 64 β 125 β 216 β 343 β 512 β 729 β1000 β
βββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ
# Gradient of this function
βββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ
df/dx = β 2x β3x^2 β3x^2 β3x^2 β3x^2 β3x^2 β3x^2 β3x^2 β3x^2 β3x^2 β
βββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ
βββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ
df/dx( β 1 β 2 β 3 β 4 β 5 β 6 β 7 β 8 β 9 β 10 β ) =
βββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ
βββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ
β 2 β 12 β 27 β 48 β 75 β 108 β 147 β 192 β 243 β 300 β
βββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ
````
<div align ="center">
<table>
<tr>
<td> Function mesh </td> <td> Array equivalent </td>
</tr>
<tr>
<td>
```python
F = kex.kmap(kernel_size=(1,))
F[0] = lambda x:x[0]**2
F[1:] = lambda x:x[0]**3
array = jnp.arange(1,11).astype('float32')
print(F(array))
>>> [1., 8., 27., 64., 125.,
... 216., 343., 512., 729., 1000.]
print(jax.grad(lambda x:jnp.sum(F(x)))(array))
>>> [2.,12.,27.,48.,75.,
... 108.,147.,192.,243.,300.]
```
</td>
<td>
```python
def F(x):
f1 = lambda x:x**2
f2 = lambda x:x**3
x = x.at[0].set(f1(x[0]))
x = x.at[1:].set(f2(x[1:]))
return x
array = jnp.arange(1,11).astype('float32')
print(F(array))
>>> [1., 8., 27., 64., 125.,
... 216., 343., 512., 729., 1000.]
print(jax.grad(lambda x: jnp.sum(F(x)))(array))
>>> [2.,12.,27.,48.,75.,
... 108.,147.,192.,243.,300.]
```
</td>
</tr>
</table>
Additionally , we can combine the function mesh concept with stencil computation for scientific computing.
See Linear convection in **More examples** section
</div>
</details>
## π’ More examples<a id="MoreExamples"></a>
<details>
<summary>1οΈβ£ Convolution operation</summary>
```python
import jax
import jax.numpy as jnp
import kernex as kex
@jax.jit
@kex.kmap(
kernel_size= (3,3,3),
padding = ('valid','same','same'))
def kernex_conv2d(x,w):
# JAX channel first conv2d with 3x3x3 kernel_size
return jnp.sum(x*w)
```
</details>
<details>
<summary>2οΈβ£ Laplacian operation</summary>
```python
# see also
# https://numba.pydata.org/numba-doc/latest/user/stencil.html#basic-usage
import jax
import jax.numpy as jnp
import kernex as kex
@kex.kmap(
kernel_size=(3,3),
padding= 'valid',
relative=True) # `relative`= True enables relative indexing
def laplacian(x):
return ( 0*x[1,-1] + 1*x[1,0] + 0*x[1,1] +
1*x[0,-1] +-4*x[0,0] + 1*x[0,1] +
0*x[-1,-1] + 1*x[-1,0] + 0*x[-1,1] )
# apply laplacian
>>> print(laplacian(jnp.ones([10,10])))
DeviceArray(
[[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)
```
</details>
<details><summary>3οΈβ£ Get Patches of an array</summary>
```python
import jax
import jax.numpy as jnp
import kernex as kex
@kex.kmap(kernel_size=(3,3),relative=True)
def identity(x):
# similar to numba.stencil
# this function returns the top left cell in the padded/unpadded kernel view
# or center cell if `relative`=True
return x[0,0]
# unlike numba.stencil , vector output is allowed in kernex
# this function is similar to
# `jax.lax.conv_general_dilated_patches(x,(3,),(1,),padding='same')`
@jax.jit
@kex.kmap(kernel_size=(3,3),padding='same')
def get_3x3_patches(x):
# returns 5x5x3x3 array
return x
mat = jnp.arange(1,26).reshape(5,5)
>>> print(mat)
[[ 1 2 3 4 5]
[ 6 7 8 9 10]
[11 12 13 14 15]
[16 17 18 19 20]
[21 22 23 24 25]]
# get the view at array index = (0,0)
>>> print(get_3x3_patches(mat)[0,0])
[[0 0 0]
[0 1 2]
[0 6 7]]
```
</details>
<details>
<summary>4οΈβ£ Linear convection </summary>
$\Large {\partial u \over \partial t} + c {\partial u \over \partial x} = 0$ <br> <br>
$\Large u_i^{n} = u_i^{n-1} - c \frac{\Delta t}{\Delta x}(u_i^{n-1}-u_{i-1}^{n-1})$
<div align ="center">
<table>
<tr>
<td> Problem setup </td> <td> Stencil view </td>
</tr>
<tr>
<td>
<img src="assets/linear_convection_init.png" width="500px">
</td>
<td>
<img src="assets/linear_convection_view.png" width="500px">
</td>
</tr>
</table>
</div>
```python
import jax
import jax.numpy as jnp
import kernex as kex
import matplotlib.pyplot as plt
# see https://nbviewer.org/github/barbagroup/CFDPython/blob/master/lessons/01_Step_1.ipynb
tmax,xmax = 0.5,2.0
nt,nx = 151,51
dt,dx = tmax/(nt-1) , xmax/(nx-1)
u = jnp.ones([nt,nx])
c = 0.5
# kscan moves sequentially in row-major order and updates in-place using lax.scan.
F = kernex.kscan(
kernel_size = (3,3),
padding = ((1,1),(1,1)),
named_axis={0:'n',1:'i'}, # n for time axis , i for spatial axis (optional naming)
relative=True
)
# boundary condtion as a function
def bc(u):
return 1
# initial condtion as a function
def ic1(u):
return 1
def ic2(u):
return 2
def linear_convection(u):
return ( u['i','n-1'] - (c*dt/dx) * (u['i','n-1'] - u['i-1','n-1']) )
F[:,0] = F[:,-1] = bc # assign 1 for left and right boundary for all t
# square wave initial condition
F[:,:int((nx-1)/4)+1] = F[:,int((nx-1)/2):] = ic1
F[0:1, int((nx-1)/4)+1 : int((nx-1)/2)] = ic2
# assign linear convection function for
# interior spatial location [1:-1]
# and start from t>0 [1:]
F[1:,1:-1] = linear_convection
kx_solution = F(jnp.array(u))
plt.figure(figsize=(20,7))
for line in kx_solution[::20]:
plt.plot(jnp.linspace(0,xmax,nx),line)
```
![image](assets/linear_convection.svg)
</details>
<details><summary>5οΈβ£ Gaussian blur</summary>
```python
import jax
import jax.numpy as jnp
import kernex as kex
def gaussian_blur(image, sigma, kernel_size):
x = jnp.linspace(-(kernel_size - 1) / 2.0, (kernel_size- 1) / 2.0, kernel_size)
w = jnp.exp(-0.5 * jnp.square(x) * jax.lax.rsqrt(sigma))
w = jnp.outer(w, w)
w = w / w.sum()
@kex.kmap(kernel_size=(kernel_size, kernel_size), padding="same")
def conv(x):
return jnp.sum(x * w)
return conv(image)
```
</details>
<details > <summary>6οΈβ£ Depthwise convolution </summary>
```python
import jax
import jax.numpy as jnp
import kernex as kex
@jax.jit
@jax.vmap
@kex.kmap(
kernel_size= (3,3),
padding = ('same','same'))
def kernex_depthwise_conv2d(x,w): # Channel-first depthwise convolution # jax.debug.print("x=\n{a}\nw=\n{b} \n\n",a=x, b=w)
return jnp.sum(x\*w)
h,w,c = 5,5,2
k=3
x = jnp.arange(1,h*w*c+1).reshape(c,h,w)
w = jnp.arange(1,k*k*c+1).reshape(c,k,k)
print(kernex_depthwise_conv2d(x,w))</summary>
````
</details>
<details> <summary>7οΈβ£ Maxpooling2D and Averagepooling2D </summary>
```python
@jax.vmap # vectorize over the channel dimension
@kex.kmap(kernel_size=(3,3), strides=(2,2))
def maxpool_2d(x):
# define the kernel for the Maxpool operation over the spatial dimensions
return jnp.max(x)
@jax.vmap # vectorize over the channel dimension
@kex.kmap(kernel_size=(3,3), strides=(2,2))
def avgpool_2d(x):
# define the kernel for the Average pool operation over the spatial dimensions
return jnp.mean(x)
````
</details>
<details><summary>8οΈβ£ Runge-Kutta integration</summary>
```python
# lets solve dydt = y, where y0 = 1 and y(t)=e^t
# using Runge-Kutta 4th order method
# f(t,y) = y
import jax.numpy as jnp
import matplotlib.pyplot as plt
import kernex as kex
t = jnp.linspace(0, 1, 5)
y = jnp.zeros(5)
x = jnp.stack([y, t], axis=0)
dt = t[1] - t[0] # 0.1
f = lambda tn, yn: yn
def ic(x):
""" initial condition y0 = 1 """
return 1.
def rk4(x):
""" runge kutta 4th order integration step """
# ββββββ¬βββββ¬βββββ ββββββββ¬βββββββ¬βββββββ
# β y0 β*y1*β y2 β β[0,-1]β[0, 0]β[0, 1]β
# ββββββΌβββββΌβββββ€ ==> ββββββββΌβββββββΌβββββββ€
# β t0 β t1 β t2 β β[1,-1]β[1, 0]β[1, 1]β
# ββββββ΄βββββ΄βββββ ββββββββ΄βββββββ΄βββββββ
t0 = x[1, -1]
y0 = x[0, -1]
k1 = dt * f(t0, y0)
k2 = dt * f(t0 + dt / 2, y0 + 1 / 2 * k1)
k3 = dt * f(t0 + dt / 2, y0 + 1 / 2 * k2)
k4 = dt * f(t0 + dt, y0 + k3)
yn_1 = y0 + 1 / 6 * (k1 + 2 * k2 + 2 * k3 + k4)
return yn_1
F = kex.kscan(kernel_size=(2, 3), relative=True, padding=((0, 1))) # kernel size = 3
F[0:1, 1:] = rk4
F[0, 0] = ic
# compile the solver
solver = jax.jit(F.__call__)
y = solver(x)[0, :]
plt.plot(t, y, '-o', label='rk4')
plt.plot(t, jnp.exp(t), '-o', label='analytical')
plt.legend()
```
![img](assets/rk4.svg)
</details>
Raw data
{
"_id": null,
"home_page": "https://github.com/ASEM000/kernex",
"name": "kernex",
"maintainer": "",
"docs_url": null,
"requires_python": ">=3.8",
"maintainer_email": "",
"keywords": "python machine-learning pytorch jax",
"author": "Mahmoud Asem",
"author_email": "asem00@kaist.ac.kr",
"download_url": "https://files.pythonhosted.org/packages/b6/30/b8c0fbf0e47bf728249beae14cd2fe82dd11955fcf390cf8b90e2247a76e/kernex-0.2.0.tar.gz",
"platform": null,
"description": "<div align = \"center\">\n<img width=400 src=\"assets/kernexlogo.svg\" align=\"center\">\n\n<h3 align=\"center\">Differentiable Stencil computations in JAX </h2>\n\n[**Installation**](#Installation)\n|[**Description**](#Description)\n|[**Quick example**](#QuickExample)\n|[**Function mesh**](#FunctionMesh)\n|[**More Examples**](#MoreExamples)\n|[**Benchmarking**](#Benchmarking)\n\n![Tests](https://github.com/ASEM000/kernex/actions/workflows/tests.yml/badge.svg)\n![pyver](https://img.shields.io/badge/python-3.8%203.8%203.9%203.11-red)\n![codestyle](https://img.shields.io/badge/codestyle-black-black)\n[![Downloads](https://pepy.tech/badge/kernex)](https://pepy.tech/project/kernex)\n[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/14UEqKzIyZsDzQ9IMeanvztXxbbbatTYV?usp=sharing)\n[![codecov](https://codecov.io/gh/ASEM000/kernex/branch/main/graph/badge.svg?token=3KLL24Z94I)](https://codecov.io/gh/ASEM000/kernex)\n[![DOI](https://zenodo.org/badge/512400616.svg)](https://zenodo.org/badge/latestdoi/512400616)\n\n</div>\n\n## \ud83d\udee0\ufe0f Installation<a id=\"Installation\"></a>\n\n```python\npip install kernex\n```\n\n## \ud83d\udcd6 Description<a id=\"Description\"></a>\n\nKernex extends `jax.vmap`/`jax.lax.map`/`jax.pmap` with `kmap` and `jax.lax.scan` with `kscan` for general stencil computations.\n\nThe prime motivation for this package is to blend the solution process of PDEs into a NN setting.\n\n## \u23e9 Quick Example <a id=\"QuickExample\">\n\n<div align=\"center\">\n<table>\n<tr>\n<td width=\"50%\" align=\"center\" > kmap </td> <td align=\"center\" > kscan </td>\n</tr>\n<tr>\n<td>\n\n```python\n\nimport kernex as kex\nimport jax.numpy as jnp\n\n@kex.kmap(kernel_size=(3,))\ndef sum_all(x):\n return jnp.sum(x)\n\n>>> x = jnp.array([1,2,3,4,5])\n>>> print(sum_all(x))\n[ 6 9 12]\n```\n\n</td>\n<td>\n \n```python\nimport kernex as kex \nimport jax.numpy as jnp\n\n@kex.kscan(kernel_size=(3,))\ndef sum_all(x):\nreturn jnp.sum(x)\n\n> > > x = jnp.array([1,2,3,4,5])\n> > > print(sum_all(x))\n> > > [ 6 13 22]\n\n````\n</td>\n</tr>\n</table>\n\n<table>\n<tr>\n<td width=\"50%\">\n`jax.vmap` is used to sum each window content.\n<img src=\"assets/kmap_sum.png\" width=400px>\n</td>\n<td>\n`lax.scan` is used to update the array and the window sum is calculated sequentially.\nthe first three rows represents the three sequential steps used to get the solution in the last row.\n\n<img align=\"center\" src=\"assets/kscan_sum.png\" width=400px>\n</td>\n</tr>\n</table>\n</div>\n\n\n## \ud83d\udd78\ufe0f Function mesh concept <a id=\"FunctionMesh\">\n<details>\n\nThe objective is to apply `f(x) = x^2 at index=0 and f(x) = x^3 at index=(1,10)`\n\nTo achieve the following operation with `jax.lax.switch` , we need a list of 10 functions correspoing to each cell of the example array.\nFor this reason , kernex adopts a modified version of `jax.lax.switch` to reduce the number of branches required.\n\n```python\n\n# function applies x^2 at boundaries, and applies x^3 to to the interior\n\n \u250c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u2510\n f = \u2502 x^2 \u2502 x^3 \u2502 x^3 \u2502 x^3 \u2502 x^3 \u2502 x^3 \u2502 x^3 \u2502 x^3 \u2502 x^3 \u2502 x^3 \u2502\n \u2514\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2518\n\n \u250c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u2510\n f( \u2502 1 \u2502 2 \u2502 3 \u2502 4 \u2502 5 \u2502 6 \u2502 7 \u2502 8 \u2502 9 \u2502 10 \u2502 ) =\n \u2514\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2518\n \u250c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u2510\n \u2502 1 \u2502 8 \u2502 27 \u2502 64 \u2502 125 \u2502 216 \u2502 343 \u2502 512 \u2502 729 \u25021000 \u2502\n \u2514\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2518\n\n# Gradient of this function\n \u250c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u2510\ndf/dx = \u2502 2x \u25023x^2 \u25023x^2 \u25023x^2 \u25023x^2 \u25023x^2 \u25023x^2 \u25023x^2 \u25023x^2 \u25023x^2 \u2502\n \u2514\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2518\n\n\n \u250c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u2510\n df/dx( \u2502 1 \u2502 2 \u2502 3 \u2502 4 \u2502 5 \u2502 6 \u2502 7 \u2502 8 \u2502 9 \u2502 10 \u2502 ) =\n \u2514\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2518\n \u250c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u2510\n \u2502 2 \u2502 12 \u2502 27 \u2502 48 \u2502 75 \u2502 108 \u2502 147 \u2502 192 \u2502 243 \u2502 300 \u2502\n \u2514\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2518\n````\n\n<div align =\"center\">\n<table>\n<tr>\n<td> Function mesh </td> <td> Array equivalent </td>\n</tr>\n<tr>\n<td>\n\n```python\nF = kex.kmap(kernel_size=(1,))\nF[0] = lambda x:x[0]**2\nF[1:] = lambda x:x[0]**3\n\n\n\n\n\narray = jnp.arange(1,11).astype('float32')\nprint(F(array))\n>>> [1., 8., 27., 64., 125.,\n... 216., 343., 512., 729., 1000.]\n\nprint(jax.grad(lambda x:jnp.sum(F(x)))(array))\n>>> [2.,12.,27.,48.,75.,\n... 108.,147.,192.,243.,300.]\n\n```\n\n</td>\n<td>\n\n```python\n\ndef F(x):\n f1 = lambda x:x**2\n f2 = lambda x:x**3\n x = x.at[0].set(f1(x[0]))\n x = x.at[1:].set(f2(x[1:]))\n return x\n\narray = jnp.arange(1,11).astype('float32')\nprint(F(array))\n>>> [1., 8., 27., 64., 125.,\n... 216., 343., 512., 729., 1000.]\n\nprint(jax.grad(lambda x: jnp.sum(F(x)))(array))\n>>> [2.,12.,27.,48.,75.,\n... 108.,147.,192.,243.,300.]\n```\n\n</td>\n</tr>\n</table>\n\nAdditionally , we can combine the function mesh concept with stencil computation for scientific computing.\nSee Linear convection in **More examples** section\n\n</div>\n\n</details>\n\n## \ud83d\udd22 More examples<a id=\"MoreExamples\"></a>\n\n<details>\n<summary>1\ufe0f\u20e3 Convolution operation</summary>\n\n```python\n\nimport jax\nimport jax.numpy as jnp\nimport kernex as kex\n\n@jax.jit\n@kex.kmap(\n kernel_size= (3,3,3),\n padding = ('valid','same','same'))\ndef kernex_conv2d(x,w):\n # JAX channel first conv2d with 3x3x3 kernel_size\n return jnp.sum(x*w)\n```\n\n</details>\n\n<details>\n<summary>2\ufe0f\u20e3 Laplacian operation</summary>\n\n```python\n# see also\n# https://numba.pydata.org/numba-doc/latest/user/stencil.html#basic-usage\nimport jax\nimport jax.numpy as jnp\nimport kernex as kex\n\n@kex.kmap(\n kernel_size=(3,3),\n padding= 'valid',\n relative=True) # `relative`= True enables relative indexing\ndef laplacian(x):\n return ( 0*x[1,-1] + 1*x[1,0] + 0*x[1,1] +\n 1*x[0,-1] +-4*x[0,0] + 1*x[0,1] +\n 0*x[-1,-1] + 1*x[-1,0] + 0*x[-1,1] )\n\n# apply laplacian\n>>> print(laplacian(jnp.ones([10,10])))\nDeviceArray(\n [[0., 0., 0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)\n\n```\n\n</details>\n\n<details><summary>3\ufe0f\u20e3 Get Patches of an array</summary>\n\n```python\n\nimport jax\nimport jax.numpy as jnp\nimport kernex as kex\n\n@kex.kmap(kernel_size=(3,3),relative=True)\ndef identity(x):\n # similar to numba.stencil\n # this function returns the top left cell in the padded/unpadded kernel view\n # or center cell if `relative`=True\n return x[0,0]\n\n# unlike numba.stencil , vector output is allowed in kernex\n# this function is similar to\n# `jax.lax.conv_general_dilated_patches(x,(3,),(1,),padding='same')`\n@jax.jit\n@kex.kmap(kernel_size=(3,3),padding='same')\ndef get_3x3_patches(x):\n # returns 5x5x3x3 array\n return x\n\nmat = jnp.arange(1,26).reshape(5,5)\n>>> print(mat)\n[[ 1 2 3 4 5]\n [ 6 7 8 9 10]\n [11 12 13 14 15]\n [16 17 18 19 20]\n [21 22 23 24 25]]\n\n\n# get the view at array index = (0,0)\n>>> print(get_3x3_patches(mat)[0,0])\n[[0 0 0]\n [0 1 2]\n [0 6 7]]\n```\n\n</details>\n\n<details>\n<summary>4\ufe0f\u20e3 Linear convection </summary>\n\n$\\Large {\\partial u \\over \\partial t} + c {\\partial u \\over \\partial x} = 0$ <br> <br>\n$\\Large u_i^{n} = u_i^{n-1} - c \\frac{\\Delta t}{\\Delta x}(u_i^{n-1}-u_{i-1}^{n-1})$\n\n<div align =\"center\">\n<table>\n<tr>\n<td> Problem setup </td> <td> Stencil view </td>\n</tr>\n<tr>\n<td>\n\n<img src=\"assets/linear_convection_init.png\" width=\"500px\">\n\n</td>\n<td>\n\n<img src=\"assets/linear_convection_view.png\" width=\"500px\">\n\n</td>\n</tr>\n</table>\n</div>\n\n```python\n\nimport jax\nimport jax.numpy as jnp\nimport kernex as kex\nimport matplotlib.pyplot as plt\n\n# see https://nbviewer.org/github/barbagroup/CFDPython/blob/master/lessons/01_Step_1.ipynb\n\ntmax,xmax = 0.5,2.0\nnt,nx = 151,51\ndt,dx = tmax/(nt-1) , xmax/(nx-1)\nu = jnp.ones([nt,nx])\nc = 0.5\n\n# kscan moves sequentially in row-major order and updates in-place using lax.scan.\n\nF = kernex.kscan(\n kernel_size = (3,3),\n padding = ((1,1),(1,1)),\n named_axis={0:'n',1:'i'}, # n for time axis , i for spatial axis (optional naming)\n relative=True\n )\n\n\n# boundary condtion as a function\ndef bc(u):\n return 1\n\n# initial condtion as a function\ndef ic1(u):\n return 1\n\ndef ic2(u):\n return 2\n\ndef linear_convection(u):\n return ( u['i','n-1'] - (c*dt/dx) * (u['i','n-1'] - u['i-1','n-1']) )\n\n\nF[:,0] = F[:,-1] = bc # assign 1 for left and right boundary for all t\n\n# square wave initial condition\nF[:,:int((nx-1)/4)+1] = F[:,int((nx-1)/2):] = ic1\nF[0:1, int((nx-1)/4)+1 : int((nx-1)/2)] = ic2\n\n# assign linear convection function for\n# interior spatial location [1:-1]\n# and start from t>0 [1:]\nF[1:,1:-1] = linear_convection\n\nkx_solution = F(jnp.array(u))\n\nplt.figure(figsize=(20,7))\nfor line in kx_solution[::20]:\n plt.plot(jnp.linspace(0,xmax,nx),line)\n```\n\n![image](assets/linear_convection.svg)\n\n</details>\n\n<details><summary>5\ufe0f\u20e3 Gaussian blur</summary>\n\n```python\n\nimport jax\nimport jax.numpy as jnp\nimport kernex as kex\n\ndef gaussian_blur(image, sigma, kernel_size):\n x = jnp.linspace(-(kernel_size - 1) / 2.0, (kernel_size- 1) / 2.0, kernel_size)\n w = jnp.exp(-0.5 * jnp.square(x) * jax.lax.rsqrt(sigma))\n w = jnp.outer(w, w)\n w = w / w.sum()\n\n @kex.kmap(kernel_size=(kernel_size, kernel_size), padding=\"same\")\n def conv(x):\n return jnp.sum(x * w)\n\n return conv(image)\n\n\n```\n\n</details>\n\n<details > <summary>6\ufe0f\u20e3 Depthwise convolution </summary>\n \n```python\n\nimport jax\nimport jax.numpy as jnp\nimport kernex as kex\n\n@jax.jit\n@jax.vmap\n@kex.kmap(\nkernel_size= (3,3),\npadding = ('same','same'))\ndef kernex_depthwise_conv2d(x,w): # Channel-first depthwise convolution # jax.debug.print(\"x=\\n{a}\\nw=\\n{b} \\n\\n\",a=x, b=w)\nreturn jnp.sum(x\\*w)\n\nh,w,c = 5,5,2\nk=3\n\nx = jnp.arange(1,h*w*c+1).reshape(c,h,w)\nw = jnp.arange(1,k*k*c+1).reshape(c,k,k)\nprint(kernex_depthwise_conv2d(x,w))</summary>\n\n````\n\n</details>\n\n<details> <summary>7\ufe0f\u20e3 Maxpooling2D and Averagepooling2D </summary>\n\n```python\n@jax.vmap # vectorize over the channel dimension\n@kex.kmap(kernel_size=(3,3), strides=(2,2))\ndef maxpool_2d(x):\n # define the kernel for the Maxpool operation over the spatial dimensions\n return jnp.max(x)\n\n@jax.vmap # vectorize over the channel dimension\n@kex.kmap(kernel_size=(3,3), strides=(2,2))\ndef avgpool_2d(x):\n # define the kernel for the Average pool operation over the spatial dimensions\n return jnp.mean(x)\n````\n\n</details>\n\n<details><summary>8\ufe0f\u20e3 Runge-Kutta integration</summary>\n\n```python\n\n# lets solve dydt = y, where y0 = 1 and y(t)=e^t\n# using Runge-Kutta 4th order method\n# f(t,y) = y\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\nimport kernex as kex\n\n\nt = jnp.linspace(0, 1, 5)\ny = jnp.zeros(5)\nx = jnp.stack([y, t], axis=0)\ndt = t[1] - t[0] # 0.1\nf = lambda tn, yn: yn\n\n\ndef ic(x):\n \"\"\" initial condition y0 = 1 \"\"\"\n return 1.\n\n\ndef rk4(x):\n \"\"\" runge kutta 4th order integration step \"\"\"\n # \u250c\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2510 \u250c\u2500\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u2500\u2510\n # \u2502 y0 \u2502*y1*\u2502 y2 \u2502 \u2502[0,-1]\u2502[0, 0]\u2502[0, 1]\u2502\n # \u251c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2524 ==> \u251c\u2500\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n # \u2502 t0 \u2502 t1 \u2502 t2 \u2502 \u2502[1,-1]\u2502[1, 0]\u2502[1, 1]\u2502\n # \u2514\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2518 \u2514\u2500\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2500\u2518\n t0 = x[1, -1]\n y0 = x[0, -1]\n k1 = dt * f(t0, y0)\n k2 = dt * f(t0 + dt / 2, y0 + 1 / 2 * k1)\n k3 = dt * f(t0 + dt / 2, y0 + 1 / 2 * k2)\n k4 = dt * f(t0 + dt, y0 + k3)\n yn_1 = y0 + 1 / 6 * (k1 + 2 * k2 + 2 * k3 + k4)\n return yn_1\n\n\nF = kex.kscan(kernel_size=(2, 3), relative=True, padding=((0, 1))) # kernel size = 3\n\nF[0:1, 1:] = rk4\nF[0, 0] = ic\n# compile the solver\nsolver = jax.jit(F.__call__)\ny = solver(x)[0, :]\n\nplt.plot(t, y, '-o', label='rk4')\nplt.plot(t, jnp.exp(t), '-o', label='analytical')\nplt.legend()\n\n```\n\n![img](assets/rk4.svg)\n\n</details>\n",
"bugtrack_url": null,
"license": "MIT",
"summary": "Stencil computations in JAX.",
"version": "0.2.0",
"project_urls": {
"Homepage": "https://github.com/ASEM000/kernex"
},
"split_keywords": [
"python",
"machine-learning",
"pytorch",
"jax"
],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "20cb54842aafc3082e3deb14c037c1b014c71955588ac2658f531c0b5173d1bc",
"md5": "e4ddb96572cf1941fb15755f72565ab7",
"sha256": "10b93c65fb23a5f46239961ffb1ab5558cbc43978e86758ef8e42936ef700a19"
},
"downloads": -1,
"filename": "kernex-0.2.0-py3-none-any.whl",
"has_sig": false,
"md5_digest": "e4ddb96572cf1941fb15755f72565ab7",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": ">=3.8",
"size": 31583,
"upload_time": "2023-06-09T21:12:02",
"upload_time_iso_8601": "2023-06-09T21:12:02.391459Z",
"url": "https://files.pythonhosted.org/packages/20/cb/54842aafc3082e3deb14c037c1b014c71955588ac2658f531c0b5173d1bc/kernex-0.2.0-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": "",
"digests": {
"blake2b_256": "b630b8c0fbf0e47bf728249beae14cd2fe82dd11955fcf390cf8b90e2247a76e",
"md5": "fc0a8c928b1e3d9c9979114c1cd58ac2",
"sha256": "c8064142ff7772b2c770a4ce46be6336c85c1e15c6589d7b0a9d451ce726c3a0"
},
"downloads": -1,
"filename": "kernex-0.2.0.tar.gz",
"has_sig": false,
"md5_digest": "fc0a8c928b1e3d9c9979114c1cd58ac2",
"packagetype": "sdist",
"python_version": "source",
"requires_python": ">=3.8",
"size": 28903,
"upload_time": "2023-06-09T21:12:04",
"upload_time_iso_8601": "2023-06-09T21:12:04.319847Z",
"url": "https://files.pythonhosted.org/packages/b6/30/b8c0fbf0e47bf728249beae14cd2fe82dd11955fcf390cf8b90e2247a76e/kernex-0.2.0.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2023-06-09 21:12:04",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "ASEM000",
"github_project": "kernex",
"travis_ci": false,
"coveralls": false,
"github_actions": true,
"lcname": "kernex"
}