            <div align = "center">
<img  width=400 src="assets/kernexlogo.svg" align="center">

<h3 align="center">Differentiable Stencil computations in JAX </h2>

|[**Quick example**](#QuickExample)
|[**Function mesh**](#FunctionMesh)
|[**More Examples**](#MoreExamples)

[![Open In Colab](](


## πŸ› οΈ Installation<a id="Installation"></a>

pip install kernex

## πŸ“– Description<a id="Description"></a>

Kernex extends `jax.vmap`/``/`jax.pmap` with `kmap` and `jax.lax.scan` with `kscan` for general stencil computations.

The prime motivation for this package is to blend the solution process of PDEs into a NN setting.

## ⏩ Quick Example <a id="QuickExample">

<div align="center">
<td width="50%" align="center" > kmap </td> <td align="center" > kscan </td>


import kernex as kex
import jax.numpy as jnp

def sum_all(x):
    return jnp.sum(x)

>>> x = jnp.array([1,2,3,4,5])
>>> print(sum_all(x))
[ 6  9 12]

import kernex as kex 
import jax.numpy as jnp

def sum_all(x):
return jnp.sum(x)

> > > x = jnp.array([1,2,3,4,5])
> > > print(sum_all(x))
> > > [ 6 13 22]


<td width="50%">
`jax.vmap` is used to sum each window content.
<img src="assets/kmap_sum.png" width=400px>
`lax.scan` is used to update the array and the window sum is calculated sequentially.
the first three rows represents the three sequential steps used to get the solution in the last row.

<img align="center" src="assets/kscan_sum.png" width=400px>

## πŸ•ΈοΈ Function mesh concept <a id="FunctionMesh">

The objective is to apply `f(x) = x^2  at index=0  and f(x) = x^3 at  index=(1,10)`

To achieve the following operation with `jax.lax.switch` , we need a list of 10 functions correspoing to each cell of the example array.
For this reason , kernex adopts a modified version of `jax.lax.switch` to reduce the number of branches required.


# function applies x^2 at boundaries, and applies x^3 to to the interior

  f =   β”‚ x^2 β”‚ x^3 β”‚ x^3 β”‚ x^3 β”‚ x^3 β”‚ x^3 β”‚ x^3 β”‚ x^3 β”‚ x^3 β”‚ x^3 β”‚

 f(     β”‚  1  β”‚  2  β”‚  3  β”‚  4  β”‚  5  β”‚  6  β”‚  7  β”‚  8  β”‚  9  β”‚ 10  β”‚ ) =
        β”‚  1  β”‚  8  β”‚  27 β”‚  64 β”‚ 125 β”‚ 216 β”‚ 343 β”‚ 512 β”‚ 729 β”‚1000 β”‚

# Gradient of this function
df/dx = β”‚ 2x  β”‚3x^2 β”‚3x^2 β”‚3x^2 β”‚3x^2 β”‚3x^2 β”‚3x^2 β”‚3x^2 β”‚3x^2 β”‚3x^2 β”‚

 df/dx( β”‚  1  β”‚  2  β”‚  3  β”‚  4  β”‚  5  β”‚  6  β”‚  7  β”‚  8  β”‚  9  β”‚ 10  β”‚ ) =
        β”‚  2  β”‚  12 β”‚ 27  β”‚  48 β”‚ 75  β”‚ 108 β”‚ 147 β”‚ 192 β”‚ 243 β”‚ 300 β”‚

<div align ="center">
<td> Function mesh </td> <td> Array equivalent </td>

F = kex.kmap(kernel_size=(1,))
F[0] = lambda x:x[0]**2
F[1:] = lambda x:x[0]**3

array = jnp.arange(1,11).astype('float32')
>>> [1., 8., 27., 64., 125.,
... 216., 343., 512., 729., 1000.]

print(jax.grad(lambda x:jnp.sum(F(x)))(array))
>>> [2.,12.,27.,48.,75.,
... 108.,147.,192.,243.,300.]




def F(x):
    f1 = lambda x:x**2
    f2 = lambda x:x**3
    x =[0].set(f1(x[0]))
    x =[1:].set(f2(x[1:]))
    return x

array = jnp.arange(1,11).astype('float32')
>>> [1., 8., 27., 64., 125.,
... 216., 343., 512., 729., 1000.]

print(jax.grad(lambda x: jnp.sum(F(x)))(array))
>>> [2.,12.,27.,48.,75.,
... 108.,147.,192.,243.,300.]


Additionally , we can combine the function mesh concept with stencil computation for scientific computing.
See Linear convection in **More examples** section



## πŸ”’ More examples<a id="MoreExamples"></a>

<summary>1️⃣ Convolution operation</summary>


import jax
import jax.numpy as jnp
import kernex as kex

    kernel_size= (3,3,3),
    padding = ('valid','same','same'))
def kernex_conv2d(x,w):
    # JAX channel first conv2d with 3x3x3 kernel_size
    return jnp.sum(x*w)


<summary>2️⃣ Laplacian operation</summary>

# see also
import jax
import jax.numpy as jnp
import kernex as kex

    padding= 'valid',
    relative=True) # `relative`= True enables relative indexing
def laplacian(x):
    return ( 0*x[1,-1]  + 1*x[1,0]   + 0*x[1,1] +
             1*x[0,-1]  +-4*x[0,0]   + 1*x[0,1] +
             0*x[-1,-1] + 1*x[-1,0]  + 0*x[-1,1] )

# apply laplacian
>>> print(laplacian(jnp.ones([10,10])))
    [[0., 0., 0., 0., 0., 0., 0., 0.],
    [0., 0., 0., 0., 0., 0., 0., 0.],
    [0., 0., 0., 0., 0., 0., 0., 0.],
    [0., 0., 0., 0., 0., 0., 0., 0.],
    [0., 0., 0., 0., 0., 0., 0., 0.],
    [0., 0., 0., 0., 0., 0., 0., 0.],
    [0., 0., 0., 0., 0., 0., 0., 0.],
    [0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)



<details><summary>3️⃣ Get Patches of an array</summary>


import jax
import jax.numpy as jnp
import kernex as kex

def identity(x):
    # similar to numba.stencil
    # this function returns the top left cell in the padded/unpadded kernel view
    # or center cell if `relative`=True
    return x[0,0]

# unlike numba.stencil , vector output is allowed in kernex
# this function is similar to
# `jax.lax.conv_general_dilated_patches(x,(3,),(1,),padding='same')`
def get_3x3_patches(x):
    # returns 5x5x3x3 array
    return x

mat = jnp.arange(1,26).reshape(5,5)
>>> print(mat)
[[ 1  2  3  4  5]
 [ 6  7  8  9 10]
 [11 12 13 14 15]
 [16 17 18 19 20]
 [21 22 23 24 25]]

# get the view at array index = (0,0)
>>> print(get_3x3_patches(mat)[0,0])
[[0 0 0]
 [0 1 2]
 [0 6 7]]


<summary>4️⃣ Linear convection </summary>

$\Large {\partial u \over \partial t} + c {\partial u \over \partial x} = 0$ <br> <br>
$\Large u_i^{n} = u_i^{n-1} - c \frac{\Delta t}{\Delta x}(u_i^{n-1}-u_{i-1}^{n-1})$

<div align ="center">
<td> Problem setup </td> <td> Stencil view  </td>

<img src="assets/linear_convection_init.png" width="500px">


<img src="assets/linear_convection_view.png" width="500px">



import jax
import jax.numpy as jnp
import kernex as kex
import matplotlib.pyplot as plt

# see

tmax,xmax = 0.5,2.0
nt,nx = 151,51
dt,dx = tmax/(nt-1) , xmax/(nx-1)
u = jnp.ones([nt,nx])
c = 0.5

# kscan moves sequentially in row-major order and updates in-place using lax.scan.

F = kernex.kscan(
        kernel_size = (3,3),
        padding = ((1,1),(1,1)),
        named_axis={0:'n',1:'i'},  # n for time axis , i for spatial axis (optional naming)

# boundary condtion as a function
def bc(u):
    return 1

# initial condtion as a function
def ic1(u):
    return 1

def ic2(u):
    return 2

def linear_convection(u):
    return ( u['i','n-1'] - (c*dt/dx) * (u['i','n-1'] - u['i-1','n-1']) )

F[:,0]  = F[:,-1] = bc # assign 1 for left and right boundary for all t

# square wave initial condition
F[:,:int((nx-1)/4)+1] = F[:,int((nx-1)/2):] = ic1
F[0:1, int((nx-1)/4)+1 : int((nx-1)/2)] = ic2

# assign linear convection function for
# interior spatial location [1:-1]
# and start from t>0  [1:]
F[1:,1:-1] = linear_convection

kx_solution = F(jnp.array(u))

for line in kx_solution[::20]:



<details><summary>5️⃣ Gaussian blur</summary>


import jax
import jax.numpy as jnp
import kernex as kex

def gaussian_blur(image, sigma, kernel_size):
    x = jnp.linspace(-(kernel_size - 1) / 2.0, (kernel_size- 1) / 2.0, kernel_size)
    w = jnp.exp(-0.5 * jnp.square(x) * jax.lax.rsqrt(sigma))
    w = jnp.outer(w, w)
    w = w / w.sum()

    @kex.kmap(kernel_size=(kernel_size, kernel_size), padding="same")
    def conv(x):
        return jnp.sum(x * w)

    return conv(image)



<details > <summary>6️⃣ Depthwise convolution </summary>

import jax
import jax.numpy as jnp
import kernex as kex

kernel_size= (3,3),
padding = ('same','same'))
def kernex_depthwise_conv2d(x,w): # Channel-first depthwise convolution # jax.debug.print("x=\n{a}\nw=\n{b} \n\n",a=x, b=w)
return jnp.sum(x\*w)

h,w,c = 5,5,2

x = jnp.arange(1,h*w*c+1).reshape(c,h,w)
w = jnp.arange(1,k*k*c+1).reshape(c,k,k)



<details> <summary>7️⃣ Maxpooling2D and Averagepooling2D </summary>

@jax.vmap # vectorize over the channel dimension
@kex.kmap(kernel_size=(3,3), strides=(2,2))
def maxpool_2d(x):
    # define the kernel for the Maxpool operation over the spatial dimensions
    return jnp.max(x)

@jax.vmap # vectorize over the channel dimension
@kex.kmap(kernel_size=(3,3), strides=(2,2))
def avgpool_2d(x):
    # define the kernel for the Average pool operation over the spatial dimensions
    return jnp.mean(x)


<details><summary>8️⃣ Runge-Kutta integration</summary>


# lets solve dydt = y, where y0 = 1 and y(t)=e^t
# using Runge-Kutta 4th order method
# f(t,y) = y
import jax.numpy as jnp
import matplotlib.pyplot as plt
import kernex as kex

t = jnp.linspace(0, 1, 5)
y = jnp.zeros(5)
x = jnp.stack([y, t], axis=0)
dt = t[1] - t[0]  # 0.1
f = lambda tn, yn: yn

def ic(x):
    """ initial condition y0 = 1 """
    return 1.

def rk4(x):
    """ runge kutta 4th order integration step """
    # β”Œβ”€β”€β”€β”€β”¬β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”      β”Œβ”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”
    # β”‚ y0 β”‚*y1*β”‚ y2 β”‚      β”‚[0,-1]β”‚[0, 0]β”‚[0, 1]β”‚
    # β”œβ”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€ ==>  β”œβ”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€
    # β”‚ t0 β”‚ t1 β”‚ t2 β”‚      β”‚[1,-1]β”‚[1, 0]β”‚[1, 1]β”‚
    # β””β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”˜      β””β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”˜
    t0 = x[1, -1]
    y0 = x[0, -1]
    k1 = dt * f(t0, y0)
    k2 = dt * f(t0 + dt / 2, y0 + 1 / 2 * k1)
    k3 = dt * f(t0 + dt / 2, y0 + 1 / 2 * k2)
    k4 = dt * f(t0 + dt, y0 + k3)
    yn_1 = y0 + 1 / 6 * (k1 + 2 * k2 + 2 * k3 + k4)
    return yn_1

F = kex.kscan(kernel_size=(2, 3), relative=True, padding=((0, 1)))  # kernel size = 3

F[0:1, 1:] = rk4
F[0, 0] = ic
# compile the solver
solver = jax.jit(F.__call__)
y = solver(x)[0, :]

plt.plot(t, y, '-o', label='rk4')
plt.plot(t, jnp.exp(t), '-o', label='analytical')





