<!-- <h1 align="center" style="font-family:Monospace" >Pyπ²Class</h1> -->
<h5 align="center">
<img width="250px" src="assets/pytc%20logo.svg"> <br>
<br>
[**Installation**](#installation)
|[**Description**](#description)
|[**Quick Example**](#quick_example)
|[**StatefulComputation**](#stateful_computation)
|[**More**](#more)
|[**Acknowledgements**](#acknowledgements)




[](https://colab.research.google.com/github/ASEM000/PyTreeClass/blob/main/assets/intro.ipynb)
[](https://pepy.tech/project/pytreeclass)
[](https://codecov.io/gh/ASEM000/pytreeclass)
[](https://pytreeclass.readthedocs.io/en/latest/?badge=latest)

[](https://zenodo.org/badge/latestdoi/512717921)

</h5>
**For previous `PyTreeClass` use v0.1 branch**
## π οΈ Installation<a id="installation"></a>
```python
pip install pytreeclass
```
**Install development version**
```python
pip install git+https://github.com/ASEM000/PyTreeClass
```
## π Description<a id="description"></a>
`PyTreeClass` is a JAX-compatible `dataclass`-like decorator to create and operate on stateful JAX PyTrees.
The package aims to achieve _two goals_:
1. π To maintain safe and correct behaviour by using _immutable_ modules with _functional_ API.
2. To achieve the **most intuitive** user experience in the `JAX` ecosystem by :
- ποΈ Defining layers similar to `PyTorch` or `TensorFlow` subclassing style.
- βοΈ Filtering\Indexing layer values similar to `jax.numpy.at[].{get,set,apply,...}`
- π¨ Visualize defined layers in plethora of ways.
## β© Quick Example <a id="quick_example">
### ποΈ Simple Tree example
<div align="center">
<table>
<tr><td align="center">Code</td> <td align="center">PyTree representation</td></tr>
<tr>
<td>
```python
import jax
import jax.numpy as jnp
import pytreeclass as pytc
@pytc.treeclass
class Tree:
a:int = 1
b:tuple[float] = (2.,3.)
c:jax.Array = jnp.array([4.,5.,6.])
def __call__(self, x):
return self.a + self.b[0] + self.c + x
tree = Tree()
```
</td>
<td>
```python
# leaves are parameters
Tree
βββ a=1
βββ b:tuple
β βββ [0]=2.0
β βββ [1]=3.0
βββ c=f32[3](ΞΌ=5.00, Ο=0.82, β[4.00,6.00])
```
</td>
</tr>
</table>
</div>
### π¨ Visualize<a id="Viz">
<details> <summary>Visualize PyTrees in five different ways</summary>
<div align="center">
<table>
<tr>
<td align = "center"> tree_summary</td>
<td align = "center">tree_diagram</td>
<td align = "center">[tree_mermaid](https://mermaid.js.org)(Native support in Github/Notion)</td>
<td align= "center"> tree_repr </td>
<td align="center" > tree_str </td>
</tr>
<tr>
<td>
```python
print(pytc.tree_summary(tree, depth=1))
ββββββ¬βββββββ¬ββββββ¬βββββββ
βNameβType βCountβSize β
ββββββΌβββββββΌββββββΌβββββββ€
βa βint β1 β28.00Bβ
ββββββΌβββββββΌββββββΌβββββββ€
βb βtuple β2 β48.00Bβ
ββββββΌβββββββΌββββββΌβββββββ€
βc βf32[3]β3 β12.00Bβ
ββββββΌβββββββΌββββββΌβββββββ€
βΞ£ βTree β6 β88.00Bβ
ββββββ΄βββββββ΄ββββββ΄βββββββ
```
</td>
<td>
```python
print(pytc.tree_diagram(tree, depth=1))
Tree
βββ a=1
βββ b=(..., ...)
βββ c=f32[3](ΞΌ=5.00, Ο=0.82, β[4.00,6.00])
```
</td>
<td>
```python
print(pytc.tree_mermaid(tree, depth=1))
```
```mermaid
flowchart LR
id15696277213149321320(<b>Tree</b>)
id15696277213149321320--->|"1 leaf<br>28.00B"|id4205845433746830897("<b>a</b>:int=1")
id15696277213149321320--->|"2 leaf<br>48.00B"|id4682191244783855647("<b>b</b>:tuple=(..., ...)")
id15696277213149321320--->|"3 leaf<br>12.00B"|id14652085615030570957("<b>c</b>:ArrayImpl=f32[3](ΞΌ=5.00, Ο=0.82, β[4.00,6.00])")
```
</td>
<td>
```python
print(pytc.tree_repr(tree, depth=1))
Tree(a=1, b=(..., ...), c=f32[3](ΞΌ=5.00, Ο=0.82, β[4.00,6.00]))
```
</td>
<td>
```python
print(pytc.tree_str(tree, depth=1))
Tree(a=1, b=(..., ...), c=[4. 5. 6.])
```
</td>
</tr>
<tr>
<td>
```python
print(pytc.tree_summary(tree, depth=2))
ββββββ¬βββββββ¬ββββββ¬βββββββ
βNameβType βCountβSize β
ββββββΌβββββββΌββββββΌβββββββ€
βa βint β1 β28.00Bβ
ββββββΌβββββββΌββββββΌβββββββ€
βb[0]βfloat β1 β24.00Bβ
ββββββΌβββββββΌββββββΌβββββββ€
βb[1]βfloat β1 β24.00Bβ
ββββββΌβββββββΌββββββΌβββββββ€
βc βf32[3]β3 β12.00Bβ
ββββββΌβββββββΌββββββΌβββββββ€
βΞ£ βTree β6 β88.00Bβ
ββββββ΄βββββββ΄ββββββ΄βββββββ
```
</td>
<td>
```python
print(pytc.tree_diagram(tree, depth=2))
Tree
βββ a=1
βββ b:tuple
β βββ [0]=2.0
β βββ [1]=3.0
βββ c=f32[3](ΞΌ=5.00, Ο=0.82, β[4.00,6.00])
```
</td>
<td>
```python
print(pytc.tree_mermaid(tree, depth=2))
```
```mermaid
flowchart LR
id15696277213149321320(<b>Tree</b>)
id15696277213149321320--->id4205845433746830897("<b>a</b>:int=1")
id15696277213149321320--->|"1 leaf<br>24.00B"|id8168961130706115346("<b>b</b>:tuple")
id8168961130706115346--->|"1 leaf<br>24.00B"|id2766159651176208202("<b>[0]</b>:float=2.0")
id15696277213149321320--->|"1 leaf<br>24.00B"|id12408280303145007954("<b>b</b>:tuple")
id12408280303145007954--->|"1 leaf<br>24.00B"|id7897116322308127883("<b>[1]</b>:float=3.0")
id15696277213149321320--->id14652085615030570957("<b>c</b>:ArrayImpl=f32[3](ΞΌ=5.00, Ο=0.82, β[4.00,6.00])")
```
</td>
<td>
```python
print(pytc.tree_repr(tree, depth=2))
Tree(a=1, b=(2.0, 3.0), c=f32[3](ΞΌ=5.00, Ο=0.82, β[4.00,6.00]))
```
</td>
<td>
```python
print(pytc.tree_str(tree, depth=2))
Tree(a=1, b=(2.0, 3.0), c=[4. 5. 6.])
```
</td>
</tr>
</table>
</div>
</details>
### π Working with `jax` transformation
<details> <summary>Make arbitrary PyTrees work with jax transformations</summary>
Parameters are defined in `Tree` at the top of class definition similar to defining
`dataclasses.dataclass` field.
Lets optimize our parameters
```python
@jax.grad
def loss_func(tree:Tree, x:jax.Array):
preds = jax.vmap(tree)(x) # <--- vectorize the tree call over the leading axis
return jnp.mean(preds**2) # <--- return the mean squared error
@jax.jit
def train_step(tree:Tree, x:jax.Array):
grads = loss_func(tree, x)
# apply a small gradient step
return jax.tree_util.tree_map(lambda x, g: x - 1e-3*g, tree, grads)
# lets freeze the non-differentiable parts of the tree
# in essence any non inexact type should be frozen to
# make the tree differentiable and work with jax transformations
jaxable_tree = jax.tree_util.tree_map(lambda x: pytc.freeze(x) if pytc.is_nondiff(x) else x, tree)
for epoch in range(1_000):
jaxable_tree = train_step(jaxable_tree, jnp.ones([10,1]))
print(jaxable_tree)
# **the `frozen` params have "#" prefix**
#Tree(a=#1, b=(-4.2826524, 3.0), c=[2.3924797 2.905778 3.4190805])
# unfreeze the tree
tree = jax.tree_util.tree_map(pytc.unfreeze, jaxable_tree, is_leaf=pytc.is_frozen)
print(tree)
# Tree(a=1, b=(-4.2826524, 3.0), c=[2.3924797 2.905778 3.4190805])
```
</details>
#### βοΈ Advanced Indexing with `.at[]` <a id="Indexing">
<details> <summary>Out-of-place updates using mask, attribute name or index</summary>
`PyTreeClass` offers 3 means of indexing through `.at[]`
1. Indexing by boolean mask.
2. Indexing by attribute name.
3. Indexing by Leaf index.
**Since `treeclass` wrapped class are immutable, `.at[]` operations returns new instance of the tree**
#### Index update by boolean mask
```python
tree = Tree()
# Tree(a=1, b=(2, 3), c=i32[3](ΞΌ=5.00, Ο=0.82, β[4,6]))
# lets create a mask for values > 4
mask = jax.tree_util.tree_map(lambda x: x>4, tree)
print(mask)
# Tree(a=False, b=(False, False), c=[False True True])
print(tree.at[mask].get())
# Tree(a=None, b=(None, None), c=[5 6])
print(tree.at[mask].set(10))
# Tree(a=1, b=(2, 3), c=[ 4 10 10])
print(tree.at[mask].apply(lambda x: 10))
# Tree(a=1, b=(2, 3), c=[ 4 10 10])
```
#### Index update by attribute name
```python
tree = Tree()
# Tree(a=1, b=(2, 3), c=i32[3](ΞΌ=5.00, Ο=0.82, β[4,6]))
print(tree.at["a"].get())
# Tree(a=1, b=(None, None), c=None)
print(tree.at["a"].set(10))
# Tree(a=10, b=(2, 3), c=[4 5 6])
print(tree.at["a"].apply(lambda x: 10))
# Tree(a=10, b=(2, 3), c=[4 5 6])
```
#### Index update by integer index
```python
tree = Tree()
# Tree(a=1, b=(2, 3), c=i32[3](ΞΌ=5.00, Ο=0.82, β[4,6]))
print(tree.at[1].at[0].get())
# Tree(a=None, b=(2.0, None), c=None)
print(tree.at[1].at[0].set(10))
# Tree(a=1, b=(10, 3.0), c=[4. 5. 6.])
print(tree.at[1].at[0].apply(lambda x: 10))
# Tree(a=1, b=(10, 3.0), c=[4. 5. 6.])
```
</details>
<details>
<summary>
## π Stateful computations<a id="stateful_computation"></a> </summary>
First, [Under jax.jit jax requires states to be explicit](https://jax.readthedocs.io/en/latest/jax-101/07-state.html?highlight=state), this means that for any class instance; variables needs to be separated from the class and be passed explictly. However when using @pytc.treeclass no need to separate the instance variables ; instead the whole instance is passed as a state.
Using the following pattern,Updating state **functionally** can be achieved under `jax.jit`
```python
import jax
import pytreeclass as pytc
@pytc.treeclass
class Counter:
calls : int = 0
def increment(self):
self.calls += 1
counter = Counter() # Counter(calls=0)
```
Here, we define the update function. Since the increment method mutate the internal state, thus we need to use the functional approach to update the state by using `.at`. To achieve this we can use `.at[method_name].__call__(*args,**kwargs)`, this functional call will return the value of this call and a _new_ model instance with the update state.
```python
@jax.jit
def update(counter):
value, new_counter = counter.at["increment"]()
return new_counter
for i in range(10):
counter = update(counter)
print(counter.calls) # 10
```
</details>
## β More<a id="more"></a>
<details><summary>[Advanced] Register custom user-defined classes to work with visualization and indexing tools. </summary>
Similar to [`jax.tree_util.register_pytree_node`](https://jax.readthedocs.io/en/latest/pytrees.html#extending-pytrees), `PyTreeClass` register common data structures and `treeclass` wrapped classes to figure out how to define the names, types, index, and metadatas of certain leaf along its path.
Here is an example of registering
```python
class Tree:
def __init__(self, a, b):
self.a = a
self.b = b
def __repr__(self) -> str:
return f"{self.__class__.__name__}(a={self.a}, b={self.b})"
# jax flatten rule
def tree_flatten(tree):
return (tree.a, tree.b), None
# jax unflatten rule
def tree_unflatten(_, children):
return Tree(*children)
# PyTreeClass flatten rule
def pytc_tree_flatten(tree):
names = ("a", "b")
types = (type(tree.a), type(tree.b))
indices = (0,1)
metadatas = (None, None)
return [*zip(names, types, indices, metadatas)]
# Register with `jax`
jax.tree_util.register_pytree_node(Tree, tree_flatten, tree_unflatten)
# Register the `Tree` class trace function to support indexing
pytc.register_pytree_node_trace(Tree, pytc_tree_flatten)
tree = Tree(1, 2)
# works with jax
jax.tree_util.tree_leaves(tree) # [1, 2]
# works with PyTreeClass viz tools
print(pytc.tree_summary(tree))
# ββββββ¬βββββ¬ββββββ¬βββββββ
# βNameβTypeβCountβSize β
# ββββββΌβββββΌββββββΌβββββββ€
# βa βint β1 β28.00Bβ
# ββββββΌβββββΌββββββΌβββββββ€
# βb βint β1 β28.00Bβ
# ββββββΌβββββΌββββββΌβββββββ€
# βΞ£ βTreeβ2 β56.00Bβ
# ββββββ΄βββββ΄ββββββ΄βββββββ
```
After registeration, you can use internal tools like
- `pytc.tree_map_with_trace`
- `pytc.tree_leaves_with_trace`
- `pytc.tree_flatten_with_trace`
More details on that soon.
</details>
<details> <summary>Validate or convert inputs using callbacks</summary>
`PyTreeClass` includes `callbacks` in the `field` to apply a sequence of functions on input at setting the attribute stage. The callback is quite useful in several cases, for instance, to ensure a certain input type within a valid range. See example:
```python
import jax
import pytreeclass as pytc
def positive_int_callback(value):
if not isinstance(value, int):
raise TypeError("Value must be an integer")
if value <= 0:
raise ValueError("Value must be positive")
return value
@pytc.treeclass
class Tree:
in_features:int = pytc.field(callbacks=[positive_int_callback])
tree = Tree(1)
# no error
tree = Tree(0)
# ValueError: Error for field=`in_features`:
# Value must be positive
tree = Tree(1.0)
# TypeError: Error for field=`in_features`:
# Value must be an integer
```
</details>
<details> <summary> Add leafwise math operations to PyTreeClass wrapped class</summary>
```python
import functools as ft
import pytreeclass as pytc
@ft.partial(pytc.treeclass, leafwise=True)
class Tree:
a:int = 1
b:tuple[float] = (2.,3.)
c:jax.Array = jnp.array([4.,5.,6.])
def __call__(self, x):
return self.a + self.b[0] + self.c + x
tree = Tree()
tree + 100
# Tree(a=101, b=(102.0, 103.0), c=f32[3](ΞΌ=105.00, Ο=0.82, β[104.00,106.00]))
@jax.grad
def loss_func(tree:Tree, x:jax.Array):
preds = jax.vmap(tree)(x) # <--- vectorize the tree call over the leading axis
return jnp.mean(preds**2) # <--- return the mean squared error
@jax.jit
def train_step(tree:Tree, x:jax.Array):
grads = loss_func(tree, x)
return tree - grads*1e-3 # <--- eliminate `tree_map`
# lets freeze the non-differentiable parts of the tree
# in essence any non inexact type should be frozen to
# make the tree differentiable and work with jax transformations
jaxable_tree = jax.tree_util.tree_map(lambda x: pytc.freeze(x) if pytc.is_nondiff(x) else x, tree)
for epoch in range(1_000):
jaxable_tree = train_step(jaxable_tree, jnp.ones([10,1]))
print(jaxable_tree)
# **the `frozen` params have "#" prefix**
# Tree(a=#1, b=(-4.7176366, 3.0), c=[2.4973059 2.760783 3.024264 ])
# unfreeze the tree
tree = jax.tree_util.tree_map(pytc.unfreeze, jaxable_tree, is_leaf=pytc.is_frozen)
print(tree)
# Tree(a=1, b=(-4.7176366, 3.0), c=[2.4973059 2.760783 3.024264 ])
```
</details>
<details> <summary>Eliminate tree_map using bcmap + treeclass(..., leafwise=True) </summary>
TDLR
```python
import functools as ft
import pytreeclass as pytc
import jax.numpy as jnp
@ft.partial(pytc.treeclass, leafwise=True)
class Tree:
a:int = 1
b:tuple[float] = (2.,3.)
c:jax.Array = jnp.array([4.,5.,6.])
tree = Tree()
print(pytc.bcmap(jnp.where)(tree>2, tree+100, 0))
# Tree(a=0, b=(0.0, 103.0), c=[104. 105. 106.])
```
`bcmap(func, is_leaf)` maps a function over [PyTrees](https://jax.readthedocs.io/en/latest/pytrees.html) leaves with automatic broadcasting for scalar arguments.
`bcmap` is function transformation that broadcast a scalar to match the first argument of the function this enables us to convert a function like `jnp.where` to work with arbitrary tree structures without the need to write a specific function for each broadcasting case
For example, lets say we want to use `jnp.where` to zeros out all values in an arbitrary tree structure that are less than 0
tree = ([1], {"a":1, "b":2}, (1,), -1,)
we can use `jax.tree_util.tree_map` to apply `jnp.where` to the tree but we need to write a specific function for broadcasting the scalar to the tree
```python
def map_func(leaf):
# here we encoded the scalar `0` inside the function
return jnp.where(leaf>0, leaf, 0)
jtu.tree_map(map_func, tree)
# ([Array(1, dtype=int32, weak_type=True)],
# {'a': Array(1, dtype=int32, weak_type=True),
# 'b': Array(2, dtype=int32, weak_type=True)},
# (Array(1, dtype=int32, weak_type=True),),
# Array(0, dtype=int32, weak_type=True))
```
However, lets say we want to use `jnp.where` to set a value to a leaf value from another tree that looks like this
```python
def map_func2(lhs_leaf, rhs_leaf):
# here we encoded the scalar `0` inside the function
return jnp.where(lhs_leaf>0, lhs_leaf, rhs_leaf)
tree2 = jtu.tree_map(lambda x: 1000, tree)
jtu.tree_map(map_func2, tree, tree2)
# ([Array(1, dtype=int32, weak_type=True)],
# {'a': Array(1, dtype=int32, weak_type=True),
# 'b': Array(2, dtype=int32, weak_type=True)},
# (Array(1, dtype=int32, weak_type=True),),
# Array(1000, dtype=int32, weak_type=True))
```
Now, `bcmap` makes this easier by figuring out the broadcasting case.
```python
broadcastable_where = pytc.bcmap(jnp.where)
mask = jtu.tree_map(lambda x: x>0, tree)
```
case 1
```python
broadcastable_where(mask, tree, 0)
# ([Array(1, dtype=int32, weak_type=True)],
# {'a': Array(1, dtype=int32, weak_type=True),
# 'b': Array(2, dtype=int32, weak_type=True)},
# (Array(1, dtype=int32, weak_type=True),),
# Array(0, dtype=int32, weak_type=True))
```
case 2
```python
broadcastable_where(mask, tree, tree2)
# ([Array(1, dtype=int32, weak_type=True)],
# {'a': Array(1, dtype=int32, weak_type=True),
# 'b': Array(2, dtype=int32, weak_type=True)},
# (Array(1, dtype=int32, weak_type=True),),
# Array(1000, dtype=int32, weak_type=True))
```
lets then take this a step further to eliminate `mask` from the equation
by using `pytreeclass` with `leafwise=True `
```python
@ft.partial(pytc.treeclass, leafwise=True)
class Tree:
tree : tuple = ([1], {"a":1, "b":2}, (1,), -1,)
tree = Tree()
# Tree(tree=([1], {a:1, b:2}, (1), -1))
```
case 1: broadcast scalar to tree
````python
print(broadcastable_where(tree>0, tree, 0))
# Tree(tree=([1], {a:1, b:2}, (1), 0))
case 2: broadcast tree to tree
```python
print(broadcastable_where(tree>0, tree, tree+100))
# Tree(tree=([1], {a:1, b:2}, (1), 99))
````
`bcmap` also works with all kind of arguments in the wrapped function
```python
print(broadcastable_where(tree>0, x=tree, y=tree+100))
# Tree(tree=([1], {a:1, b:2}, (1), 99))
```
in concolusion, `bcmap` is a function transformation that can be used to
to make functions work with arbitrary tree structures without the need to write
a specific function for each broadcasting case
Moreover, `bcmap` can be more powerful when used with `pytreeclass` to
facilitate operation of arbitrary functions on `PyTree` objects
without the need to use `tree_map`
</details>
<details><summary>Use PyTreeClass vizualization tools with arbitrary PyTrees </summary>
```python
import jax
import pytreeclass as pytc
tree = [1, [2,3], 4]
print(pytc.tree_diagram(tree,depth=1))
# list
# βββ [0]=1
# βββ [1]=[..., ...]
# βββ [2]=4
print(pytc.tree_diagram(tree,depth=2))
# list
# βββ [0]=1
# βββ [1]:list
# β βββ [0]=2
# β βββ [1]=3
# βββ [2]=4
print(pytc.tree_summary(tree,depth=1))
# ββββββ¬βββββ¬ββββββ¬ββββββββ
# βNameβTypeβCountβSize β
# ββββββΌβββββΌββββββΌββββββββ€
# β[0] βint β1 β28.00B β
# ββββββΌβββββΌββββββΌββββββββ€
# β[1] βlistβ2 β56.00B β
# ββββββΌβββββΌββββββΌββββββββ€
# β[2] βint β1 β28.00B β
# ββββββΌβββββΌββββββΌββββββββ€
# βΞ£ βlistβ4 β112.00Bβ
# ββββββ΄βββββ΄ββββββ΄ββββββββ
print(pytc.tree_summary(tree,depth=2))
# ββββββββ¬βββββ¬ββββββ¬ββββββββ
# βName βTypeβCountβSize β
# ββββββββΌβββββΌββββββΌββββββββ€
# β[0] βint β1 β28.00B β
# ββββββββΌβββββΌββββββΌββββββββ€
# β[1][0]βint β1 β28.00B β
# ββββββββΌβββββΌββββββΌββββββββ€
# β[1][1]βint β1 β28.00B β
# ββββββββΌβββββΌββββββΌββββββββ€
# β[2] βint β1 β28.00B β
# ββββββββΌβββββΌββββββΌββββββββ€
# βΞ£ βlistβ4 β112.00Bβ
# ββββββββ΄βββββ΄ββββββ΄ββββββββ
```
</details>
<details><summary>Use PyTreeClass components with other libraries</summary>
```python
import jax
import pytreeclass as pytc
from flax import struct
@struct.dataclass
class FlaxTree:
a:int = 1
b:tuple[float] = (2.,3.)
c:jax.Array = jax.numpy.array([4.,5.,6.])
def __repr__(self) -> str:
return pytc.tree_repr(self)
def __str__(self) -> str:
return pytc.tree_str(self)
@property
def at(self):
return pytc.tree_indexer(self)
def pytc_flatten_rule(tree):
names =("a","b","c")
types = map(type, (tree.a, tree.b, tree.c))
indices = range(3)
metadatas= (None, None, None)
return [*zip(names, types, indices, metadatas)]
pytc.register_pytree_node_trace(FlaxTree, pytc_flatten_rule)
flax_tree = FlaxTree()
print(f"{flax_tree!r}")
# FlaxTree(a=1, b=(2.0, 3.0), c=f32[3](ΞΌ=5.00, Ο=0.82, β[4.00,6.00]))
print(f"{flax_tree!s}")
# FlaxTree(a=1, b=(2.0, 3.0), c=[4. 5. 6.])
print(pytc.tree_diagram(flax_tree))
# FlaxTree
# βββ a=1
# βββ b:tuple
# β βββ [0]=2.0
# β βββ [1]=3.0
# βββ c=f32[3](ΞΌ=5.00, Ο=0.82, β[4.00,6.00])
print(pytc.tree_summary(flax_tree))
# ββββββ¬βββββββββ¬ββββββ¬βββββββ
# βNameβType βCountβSize β
# ββββββΌβββββββββΌββββββΌβββββββ€
# βa βint β1 β28.00Bβ
# ββββββΌβββββββββΌββββββΌβββββββ€
# βb[0]βfloat β1 β24.00Bβ
# ββββββΌβββββββββΌββββββΌβββββββ€
# βb[1]βfloat β1 β24.00Bβ
# ββββββΌβββββββββΌββββββΌβββββββ€
# βc βf32[3] β3 β12.00Bβ
# ββββββΌβββββββββΌββββββΌβββββββ€
# βΞ£ βFlaxTreeβ6 β88.00Bβ
# ββββββ΄βββββββββ΄ββββββ΄βββββββ
flax_tree.at[0].get()
# FlaxTree(a=1, b=(None, None), c=None)
flax_tree.at["a"].set(10)
# FlaxTree(a=10, b=(2.0, 3.0), c=f32[3](ΞΌ=5.00, Ο=0.82, β[4.00,6.00]))
```
</details>
<details>
<summary>Benchmark flatten/unflatten compared to Flax and Equinox </summary>
<a href="https://colab.research.google.com/github/ASEM000/PyTreeClass/blob/main/assets/benchmark_flatten_unflatten.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
<table>
<tr><td align="center">CPU</td><td align="center">GPU</td></tr>
<tr>
<td><img src='assets/benchmark_cpu.png'></td>
<td><img src='assets/benchmark_gpu.png'></td>
</tr>
</table>
</details>
## π Acknowledgements<a id="acknowledgements"></a>
- [Farid Talibli (for visualization link generation backend)](https://www.linkedin.com/in/frdt98)
- [Treex](https://github.com/cgarciae/treex), [Equinox](https://github.com/patrick-kidger/equinox), [tree-math](https://github.com/google/tree-math), [Flax](https://github.com/google/flax), [TensorFlow](https://www.tensorflow.org), [PyTorch](https://pytorch.org)
- [Lovely JAX](https://github.com/xl0/lovely-jax)
Raw data
{
"_id": null,
"home_page": "https://github.com/ASEM000/pytreeclass",
"name": "pytreeclass",
"maintainer": "",
"docs_url": null,
"requires_python": ">=3.8",
"maintainer_email": "",
"keywords": "python machine-learning pytorch jax",
"author": "Mahmoud Asem",
"author_email": "asem00@kaist.ac.kr",
"download_url": "https://files.pythonhosted.org/packages/0f/f7/e7c371fed186f64e1462054d60244c45435afd4cacb0662b1dfe8f4e5908/pytreeclass-0.2.1.tar.gz",
"platform": null,
"description": "<!-- <h1 align=\"center\" style=\"font-family:Monospace\" >Py\ud83c\udf32Class</h1> -->\n<h5 align=\"center\">\n<img width=\"250px\" src=\"assets/pytc%20logo.svg\"> <br>\n\n<br>\n\n[**Installation**](#installation)\n|[**Description**](#description)\n|[**Quick Example**](#quick_example)\n|[**StatefulComputation**](#stateful_computation)\n|[**More**](#more)\n|[**Acknowledgements**](#acknowledgements)\n\n\n\n\n\n[](https://colab.research.google.com/github/ASEM000/PyTreeClass/blob/main/assets/intro.ipynb)\n[](https://pepy.tech/project/pytreeclass)\n[](https://codecov.io/gh/ASEM000/pytreeclass)\n[](https://pytreeclass.readthedocs.io/en/latest/?badge=latest)\n\n[](https://zenodo.org/badge/latestdoi/512717921)\n\n\n</h5>\n\n**For previous `PyTreeClass` use v0.1 branch**\n\n## \ud83d\udee0\ufe0f Installation<a id=\"installation\"></a>\n\n```python\npip install pytreeclass\n```\n\n**Install development version**\n\n```python\npip install git+https://github.com/ASEM000/PyTreeClass\n```\n\n## \ud83d\udcd6 Description<a id=\"description\"></a>\n\n`PyTreeClass` is a JAX-compatible `dataclass`-like decorator to create and operate on stateful JAX PyTrees.\n\nThe package aims to achieve _two goals_:\n\n1. \ud83d\udd12 To maintain safe and correct behaviour by using _immutable_ modules with _functional_ API.\n2. To achieve the **most intuitive** user experience in the `JAX` ecosystem by :\n - \ud83c\udfd7\ufe0f Defining layers similar to `PyTorch` or `TensorFlow` subclassing style.\n - \u261d\ufe0f Filtering\\Indexing layer values similar to `jax.numpy.at[].{get,set,apply,...}`\n - \ud83c\udfa8 Visualize defined layers in plethora of ways.\n\n## \u23e9 Quick Example <a id=\"quick_example\">\n\n### \ud83c\udfd7\ufe0f Simple Tree example\n\n<div align=\"center\">\n<table>\n<tr><td align=\"center\">Code</td> <td align=\"center\">PyTree representation</td></tr>\n<tr>\n<td>\n\n```python\nimport jax\nimport jax.numpy as jnp\nimport pytreeclass as pytc\n\n@pytc.treeclass\nclass Tree:\n a:int = 1\n b:tuple[float] = (2.,3.)\n c:jax.Array = jnp.array([4.,5.,6.])\n\n def __call__(self, x):\n return self.a + self.b[0] + self.c + x\n\ntree = Tree()\n```\n\n</td>\n\n<td>\n\n```python\n# leaves are parameters\n\nTree\n \u251c\u2500\u2500 a=1\n \u251c\u2500\u2500 b:tuple\n \u2502 \u251c\u2500\u2500 [0]=2.0\n \u2502 \u2514\u2500\u2500 [1]=3.0\n \u2514\u2500\u2500 c=f32[3](\u03bc=5.00, \u03c3=0.82, \u2208[4.00,6.00])\n```\n\n</td>\n\n</tr>\n</table>\n</div>\n\n### \ud83c\udfa8 Visualize<a id=\"Viz\">\n\n<details> <summary>Visualize PyTrees in five different ways</summary>\n\n<div align=\"center\">\n<table>\n<tr>\n <td align = \"center\"> tree_summary</td> \n <td align = \"center\">tree_diagram</td>\n <td align = \"center\">[tree_mermaid](https://mermaid.js.org)(Native support in Github/Notion)</td>\n <td align= \"center\"> tree_repr </td>\n <td align=\"center\" > tree_str </td>\n\n</tr>\n\n<tr>\n<td>\n\n```python\nprint(pytc.tree_summary(tree, depth=1))\n\u250c\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u2500\u2510\n\u2502Name\u2502Type \u2502Count\u2502Size \u2502\n\u251c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n\u2502a \u2502int \u25021 \u250228.00B\u2502\n\u251c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n\u2502b \u2502tuple \u25022 \u250248.00B\u2502\n\u251c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n\u2502c \u2502f32[3]\u25023 \u250212.00B\u2502\n\u251c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n\u2502\u03a3 \u2502Tree \u25026 \u250288.00B\u2502\n\u2514\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2500\u2518\n```\n\n</td>\n\n<td>\n\n```python\n\nprint(pytc.tree_diagram(tree, depth=1))\nTree\n \u251c\u2500\u2500 a=1\n \u251c\u2500\u2500 b=(..., ...)\n \u2514\u2500\u2500 c=f32[3](\u03bc=5.00, \u03c3=0.82, \u2208[4.00,6.00])\n```\n\n </td>\n\n<td>\n\n```python\nprint(pytc.tree_mermaid(tree, depth=1))\n```\n\n```mermaid\n\nflowchart LR\n id15696277213149321320(<b>Tree</b>)\n id15696277213149321320--->|\"1 leaf<br>28.00B\"|id4205845433746830897(\"<b>a</b>:int=1\")\n id15696277213149321320--->|\"2 leaf<br>48.00B\"|id4682191244783855647(\"<b>b</b>:tuple=(..., ...)\")\n id15696277213149321320--->|\"3 leaf<br>12.00B\"|id14652085615030570957(\"<b>c</b>:ArrayImpl=f32[3](\u03bc=5.00, \u03c3=0.82, \u2208[4.00,6.00])\")\n```\n\n</td>\n\n<td>\n\n```python\nprint(pytc.tree_repr(tree, depth=1))\nTree(a=1, b=(..., ...), c=f32[3](\u03bc=5.00, \u03c3=0.82, \u2208[4.00,6.00]))\n```\n\n</td>\n\n<td>\n\n```python\nprint(pytc.tree_str(tree, depth=1))\nTree(a=1, b=(..., ...), c=[4. 5. 6.])\n```\n\n</td>\n\n</tr>\n\n<tr>\n\n<td>\n\n```python\nprint(pytc.tree_summary(tree, depth=2))\n\u250c\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u2500\u2510\n\u2502Name\u2502Type \u2502Count\u2502Size \u2502\n\u251c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n\u2502a \u2502int \u25021 \u250228.00B\u2502\n\u251c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n\u2502b[0]\u2502float \u25021 \u250224.00B\u2502\n\u251c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n\u2502b[1]\u2502float \u25021 \u250224.00B\u2502\n\u251c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n\u2502c \u2502f32[3]\u25023 \u250212.00B\u2502\n\u251c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n\u2502\u03a3 \u2502Tree \u25026 \u250288.00B\u2502\n\u2514\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2500\u2518\n```\n\n</td>\n\n<td>\n\n```python\nprint(pytc.tree_diagram(tree, depth=2))\nTree\n \u251c\u2500\u2500 a=1\n \u251c\u2500\u2500 b:tuple\n \u2502 \u251c\u2500\u2500 [0]=2.0\n \u2502 \u2514\u2500\u2500 [1]=3.0\n \u2514\u2500\u2500 c=f32[3](\u03bc=5.00, \u03c3=0.82, \u2208[4.00,6.00])\n```\n\n</td>\n\n<td>\n\n```python\nprint(pytc.tree_mermaid(tree, depth=2))\n```\n\n```mermaid\nflowchart LR\n id15696277213149321320(<b>Tree</b>)\n id15696277213149321320--->id4205845433746830897(\"<b>a</b>:int=1\")\n id15696277213149321320--->|\"1 leaf<br>24.00B\"|id8168961130706115346(\"<b>b</b>:tuple\")\n id8168961130706115346--->|\"1 leaf<br>24.00B\"|id2766159651176208202(\"<b>[0]</b>:float=2.0\")\n id15696277213149321320--->|\"1 leaf<br>24.00B\"|id12408280303145007954(\"<b>b</b>:tuple\")\n id12408280303145007954--->|\"1 leaf<br>24.00B\"|id7897116322308127883(\"<b>[1]</b>:float=3.0\")\n id15696277213149321320--->id14652085615030570957(\"<b>c</b>:ArrayImpl=f32[3](\u03bc=5.00, \u03c3=0.82, \u2208[4.00,6.00])\")\n```\n\n</td>\n\n<td>\n\n```python\nprint(pytc.tree_repr(tree, depth=2))\nTree(a=1, b=(2.0, 3.0), c=f32[3](\u03bc=5.00, \u03c3=0.82, \u2208[4.00,6.00]))\n```\n\n</td>\n\n<td>\n\n```python\nprint(pytc.tree_str(tree, depth=2))\nTree(a=1, b=(2.0, 3.0), c=[4. 5. 6.])\n```\n\n</td>\n\n</tr>\n\n </table>\n\n </div>\n\n</details>\n\n### \ud83c\udfc3 Working with `jax` transformation\n\n<details> <summary>Make arbitrary PyTrees work with jax transformations</summary>\n\nParameters are defined in `Tree` at the top of class definition similar to defining\n`dataclasses.dataclass` field.\nLets optimize our parameters\n\n```python\n\n@jax.grad\ndef loss_func(tree:Tree, x:jax.Array):\n preds = jax.vmap(tree)(x) # <--- vectorize the tree call over the leading axis\n return jnp.mean(preds**2) # <--- return the mean squared error\n\n@jax.jit\ndef train_step(tree:Tree, x:jax.Array):\n grads = loss_func(tree, x)\n # apply a small gradient step\n return jax.tree_util.tree_map(lambda x, g: x - 1e-3*g, tree, grads)\n\n# lets freeze the non-differentiable parts of the tree\n# in essence any non inexact type should be frozen to\n# make the tree differentiable and work with jax transformations\njaxable_tree = jax.tree_util.tree_map(lambda x: pytc.freeze(x) if pytc.is_nondiff(x) else x, tree)\n\nfor epoch in range(1_000):\n jaxable_tree = train_step(jaxable_tree, jnp.ones([10,1]))\n\nprint(jaxable_tree)\n# **the `frozen` params have \"#\" prefix**\n#Tree(a=#1, b=(-4.2826524, 3.0), c=[2.3924797 2.905778 3.4190805])\n\n\n# unfreeze the tree\ntree = jax.tree_util.tree_map(pytc.unfreeze, jaxable_tree, is_leaf=pytc.is_frozen)\nprint(tree)\n# Tree(a=1, b=(-4.2826524, 3.0), c=[2.3924797 2.905778 3.4190805])\n```\n\n</details>\n\n#### \u261d\ufe0f Advanced Indexing with `.at[]` <a id=\"Indexing\">\n\n<details> <summary>Out-of-place updates using mask, attribute name or index</summary>\n\n`PyTreeClass` offers 3 means of indexing through `.at[]`\n\n1. Indexing by boolean mask.\n2. Indexing by attribute name.\n3. Indexing by Leaf index.\n\n**Since `treeclass` wrapped class are immutable, `.at[]` operations returns new instance of the tree**\n\n#### Index update by boolean mask\n\n```python\ntree = Tree()\n# Tree(a=1, b=(2, 3), c=i32[3](\u03bc=5.00, \u03c3=0.82, \u2208[4,6]))\n\n# lets create a mask for values > 4\nmask = jax.tree_util.tree_map(lambda x: x>4, tree)\n\nprint(mask)\n# Tree(a=False, b=(False, False), c=[False True True])\n\nprint(tree.at[mask].get())\n# Tree(a=None, b=(None, None), c=[5 6])\n\nprint(tree.at[mask].set(10))\n# Tree(a=1, b=(2, 3), c=[ 4 10 10])\n\nprint(tree.at[mask].apply(lambda x: 10))\n# Tree(a=1, b=(2, 3), c=[ 4 10 10])\n```\n\n#### Index update by attribute name\n\n```python\ntree = Tree()\n# Tree(a=1, b=(2, 3), c=i32[3](\u03bc=5.00, \u03c3=0.82, \u2208[4,6]))\n\nprint(tree.at[\"a\"].get())\n# Tree(a=1, b=(None, None), c=None)\n\nprint(tree.at[\"a\"].set(10))\n# Tree(a=10, b=(2, 3), c=[4 5 6])\n\nprint(tree.at[\"a\"].apply(lambda x: 10))\n# Tree(a=10, b=(2, 3), c=[4 5 6])\n```\n\n#### Index update by integer index\n\n```python\ntree = Tree()\n# Tree(a=1, b=(2, 3), c=i32[3](\u03bc=5.00, \u03c3=0.82, \u2208[4,6]))\n\nprint(tree.at[1].at[0].get())\n# Tree(a=None, b=(2.0, None), c=None)\n\nprint(tree.at[1].at[0].set(10))\n# Tree(a=1, b=(10, 3.0), c=[4. 5. 6.])\n\nprint(tree.at[1].at[0].apply(lambda x: 10))\n# Tree(a=1, b=(10, 3.0), c=[4. 5. 6.])\n```\n\n</details>\n\n<details>\n\n<summary>\n\n## \ud83d\udcdc Stateful computations<a id=\"stateful_computation\"></a> </summary>\n\nFirst, [Under jax.jit jax requires states to be explicit](https://jax.readthedocs.io/en/latest/jax-101/07-state.html?highlight=state), this means that for any class instance; variables needs to be separated from the class and be passed explictly. However when using @pytc.treeclass no need to separate the instance variables ; instead the whole instance is passed as a state.\n\nUsing the following pattern,Updating state **functionally** can be achieved under `jax.jit`\n\n```python\nimport jax\nimport pytreeclass as pytc\n\n@pytc.treeclass\nclass Counter:\n calls : int = 0\n\n def increment(self):\n self.calls += 1\ncounter = Counter() # Counter(calls=0)\n```\n\nHere, we define the update function. Since the increment method mutate the internal state, thus we need to use the functional approach to update the state by using `.at`. To achieve this we can use `.at[method_name].__call__(*args,**kwargs)`, this functional call will return the value of this call and a _new_ model instance with the update state.\n\n```python\n@jax.jit\ndef update(counter):\n value, new_counter = counter.at[\"increment\"]()\n return new_counter\n\nfor i in range(10):\n counter = update(counter)\n\nprint(counter.calls) # 10\n```\n\n</details>\n\n## \u2795 More<a id=\"more\"></a>\n\n<details><summary>[Advanced] Register custom user-defined classes to work with visualization and indexing tools. </summary>\n\nSimilar to [`jax.tree_util.register_pytree_node`](https://jax.readthedocs.io/en/latest/pytrees.html#extending-pytrees), `PyTreeClass` register common data structures and `treeclass` wrapped classes to figure out how to define the names, types, index, and metadatas of certain leaf along its path.\n\nHere is an example of registering\n\n```python\n\nclass Tree:\n def __init__(self, a, b):\n self.a = a\n self.b = b\n\n def __repr__(self) -> str:\n return f\"{self.__class__.__name__}(a={self.a}, b={self.b})\"\n\n\n# jax flatten rule\ndef tree_flatten(tree):\n return (tree.a, tree.b), None\n\n# jax unflatten rule\ndef tree_unflatten(_, children):\n return Tree(*children)\n\n# PyTreeClass flatten rule\ndef pytc_tree_flatten(tree):\n names = (\"a\", \"b\")\n types = (type(tree.a), type(tree.b))\n indices = (0,1)\n metadatas = (None, None)\n return [*zip(names, types, indices, metadatas)]\n\n\n# Register with `jax`\njax.tree_util.register_pytree_node(Tree, tree_flatten, tree_unflatten)\n\n# Register the `Tree` class trace function to support indexing\npytc.register_pytree_node_trace(Tree, pytc_tree_flatten)\n\ntree = Tree(1, 2)\n\n# works with jax\njax.tree_util.tree_leaves(tree) # [1, 2]\n\n# works with PyTreeClass viz tools\nprint(pytc.tree_summary(tree))\n\n# \u250c\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u2500\u2510\n# \u2502Name\u2502Type\u2502Count\u2502Size \u2502\n# \u251c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n# \u2502a \u2502int \u25021 \u250228.00B\u2502\n# \u251c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n# \u2502b \u2502int \u25021 \u250228.00B\u2502\n# \u251c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n# \u2502\u03a3 \u2502Tree\u25022 \u250256.00B\u2502\n# \u2514\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2500\u2518\n\n```\n\nAfter registeration, you can use internal tools like\n\n- `pytc.tree_map_with_trace`\n- `pytc.tree_leaves_with_trace`\n- `pytc.tree_flatten_with_trace`\n\nMore details on that soon.\n\n</details>\n\n<details> <summary>Validate or convert inputs using callbacks</summary>\n\n`PyTreeClass` includes `callbacks` in the `field` to apply a sequence of functions on input at setting the attribute stage. The callback is quite useful in several cases, for instance, to ensure a certain input type within a valid range. See example:\n\n```python\nimport jax\nimport pytreeclass as pytc\n\ndef positive_int_callback(value):\n if not isinstance(value, int):\n raise TypeError(\"Value must be an integer\")\n if value <= 0:\n raise ValueError(\"Value must be positive\")\n return value\n\n\n@pytc.treeclass\nclass Tree:\n in_features:int = pytc.field(callbacks=[positive_int_callback])\n\n\ntree = Tree(1)\n# no error\n\ntree = Tree(0)\n# ValueError: Error for field=`in_features`:\n# Value must be positive\n\ntree = Tree(1.0)\n# TypeError: Error for field=`in_features`:\n# Value must be an integer\n```\n\n</details>\n\n<details> <summary> Add leafwise math operations to PyTreeClass wrapped class</summary>\n\n```python\nimport functools as ft\nimport pytreeclass as pytc\n\n@ft.partial(pytc.treeclass, leafwise=True)\nclass Tree:\n a:int = 1\n b:tuple[float] = (2.,3.)\n c:jax.Array = jnp.array([4.,5.,6.])\n\n def __call__(self, x):\n return self.a + self.b[0] + self.c + x\n\ntree = Tree()\n\ntree + 100\n# Tree(a=101, b=(102.0, 103.0), c=f32[3](\u03bc=105.00, \u03c3=0.82, \u2208[104.00,106.00]))\n\n@jax.grad\ndef loss_func(tree:Tree, x:jax.Array):\n preds = jax.vmap(tree)(x) # <--- vectorize the tree call over the leading axis\n return jnp.mean(preds**2) # <--- return the mean squared error\n\n@jax.jit\ndef train_step(tree:Tree, x:jax.Array):\n grads = loss_func(tree, x)\n return tree - grads*1e-3 # <--- eliminate `tree_map`\n\n# lets freeze the non-differentiable parts of the tree\n# in essence any non inexact type should be frozen to\n# make the tree differentiable and work with jax transformations\njaxable_tree = jax.tree_util.tree_map(lambda x: pytc.freeze(x) if pytc.is_nondiff(x) else x, tree)\n\nfor epoch in range(1_000):\n jaxable_tree = train_step(jaxable_tree, jnp.ones([10,1]))\n\nprint(jaxable_tree)\n# **the `frozen` params have \"#\" prefix**\n# Tree(a=#1, b=(-4.7176366, 3.0), c=[2.4973059 2.760783 3.024264 ])\n\n\n# unfreeze the tree\ntree = jax.tree_util.tree_map(pytc.unfreeze, jaxable_tree, is_leaf=pytc.is_frozen)\nprint(tree)\n# Tree(a=1, b=(-4.7176366, 3.0), c=[2.4973059 2.760783 3.024264 ])\n```\n\n</details>\n\n<details> <summary>Eliminate tree_map using bcmap + treeclass(..., leafwise=True) </summary>\n\nTDLR\n\n```python\nimport functools as ft\nimport pytreeclass as pytc\nimport jax.numpy as jnp\n\n@ft.partial(pytc.treeclass, leafwise=True)\nclass Tree:\n a:int = 1\n b:tuple[float] = (2.,3.)\n c:jax.Array = jnp.array([4.,5.,6.])\n\ntree = Tree()\n\nprint(pytc.bcmap(jnp.where)(tree>2, tree+100, 0))\n# Tree(a=0, b=(0.0, 103.0), c=[104. 105. 106.])\n\n```\n\n`bcmap(func, is_leaf)` maps a function over [PyTrees](https://jax.readthedocs.io/en/latest/pytrees.html) leaves with automatic broadcasting for scalar arguments.\n\n`bcmap` is function transformation that broadcast a scalar to match the first argument of the function this enables us to convert a function like `jnp.where` to work with arbitrary tree structures without the need to write a specific function for each broadcasting case\n\nFor example, lets say we want to use `jnp.where` to zeros out all values in an arbitrary tree structure that are less than 0\n\ntree = ([1], {\"a\":1, \"b\":2}, (1,), -1,)\n\nwe can use `jax.tree_util.tree_map` to apply `jnp.where` to the tree but we need to write a specific function for broadcasting the scalar to the tree\n\n```python\ndef map_func(leaf):\n # here we encoded the scalar `0` inside the function\n return jnp.where(leaf>0, leaf, 0)\n\njtu.tree_map(map_func, tree)\n# ([Array(1, dtype=int32, weak_type=True)],\n# {'a': Array(1, dtype=int32, weak_type=True),\n# 'b': Array(2, dtype=int32, weak_type=True)},\n# (Array(1, dtype=int32, weak_type=True),),\n# Array(0, dtype=int32, weak_type=True))\n```\n\nHowever, lets say we want to use `jnp.where` to set a value to a leaf value from another tree that looks like this\n\n```python\ndef map_func2(lhs_leaf, rhs_leaf):\n # here we encoded the scalar `0` inside the function\n return jnp.where(lhs_leaf>0, lhs_leaf, rhs_leaf)\n\ntree2 = jtu.tree_map(lambda x: 1000, tree)\n\njtu.tree_map(map_func2, tree, tree2)\n# ([Array(1, dtype=int32, weak_type=True)],\n# {'a': Array(1, dtype=int32, weak_type=True),\n# 'b': Array(2, dtype=int32, weak_type=True)},\n# (Array(1, dtype=int32, weak_type=True),),\n# Array(1000, dtype=int32, weak_type=True))\n```\n\nNow, `bcmap` makes this easier by figuring out the broadcasting case.\n\n```python\nbroadcastable_where = pytc.bcmap(jnp.where)\nmask = jtu.tree_map(lambda x: x>0, tree)\n```\n\ncase 1\n\n```python\nbroadcastable_where(mask, tree, 0)\n# ([Array(1, dtype=int32, weak_type=True)],\n# {'a': Array(1, dtype=int32, weak_type=True),\n# 'b': Array(2, dtype=int32, weak_type=True)},\n# (Array(1, dtype=int32, weak_type=True),),\n# Array(0, dtype=int32, weak_type=True))\n```\n\ncase 2\n\n```python\nbroadcastable_where(mask, tree, tree2)\n# ([Array(1, dtype=int32, weak_type=True)],\n# {'a': Array(1, dtype=int32, weak_type=True),\n# 'b': Array(2, dtype=int32, weak_type=True)},\n# (Array(1, dtype=int32, weak_type=True),),\n# Array(1000, dtype=int32, weak_type=True))\n```\n\nlets then take this a step further to eliminate `mask` from the equation\nby using `pytreeclass` with `leafwise=True `\n\n```python\n@ft.partial(pytc.treeclass, leafwise=True)\nclass Tree:\n tree : tuple = ([1], {\"a\":1, \"b\":2}, (1,), -1,)\n\ntree = Tree()\n# Tree(tree=([1], {a:1, b:2}, (1), -1))\n```\n\ncase 1: broadcast scalar to tree\n\n````python\nprint(broadcastable_where(tree>0, tree, 0))\n# Tree(tree=([1], {a:1, b:2}, (1), 0))\n\ncase 2: broadcast tree to tree\n```python\nprint(broadcastable_where(tree>0, tree, tree+100))\n# Tree(tree=([1], {a:1, b:2}, (1), 99))\n````\n\n`bcmap` also works with all kind of arguments in the wrapped function\n\n```python\nprint(broadcastable_where(tree>0, x=tree, y=tree+100))\n# Tree(tree=([1], {a:1, b:2}, (1), 99))\n```\n\nin concolusion, `bcmap` is a function transformation that can be used to\nto make functions work with arbitrary tree structures without the need to write\na specific function for each broadcasting case\n\nMoreover, `bcmap` can be more powerful when used with `pytreeclass` to\nfacilitate operation of arbitrary functions on `PyTree` objects\nwithout the need to use `tree_map`\n\n</details>\n\n<details><summary>Use PyTreeClass vizualization tools with arbitrary PyTrees </summary>\n\n```python\nimport jax\nimport pytreeclass as pytc\n\ntree = [1, [2,3], 4]\n\nprint(pytc.tree_diagram(tree,depth=1))\n# list\n# \u251c\u2500\u2500 [0]=1\n# \u251c\u2500\u2500 [1]=[..., ...]\n# \u2514\u2500\u2500 [2]=4\n\nprint(pytc.tree_diagram(tree,depth=2))\n# list\n# \u251c\u2500\u2500 [0]=1\n# \u251c\u2500\u2500 [1]:list\n# \u2502 \u251c\u2500\u2500 [0]=2\n# \u2502 \u2514\u2500\u2500 [1]=3\n# \u2514\u2500\u2500 [2]=4\n\n\nprint(pytc.tree_summary(tree,depth=1))\n# \u250c\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2510\n# \u2502Name\u2502Type\u2502Count\u2502Size \u2502\n# \u251c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n# \u2502[0] \u2502int \u25021 \u250228.00B \u2502\n# \u251c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n# \u2502[1] \u2502list\u25022 \u250256.00B \u2502\n# \u251c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n# \u2502[2] \u2502int \u25021 \u250228.00B \u2502\n# \u251c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n# \u2502\u03a3 \u2502list\u25024 \u2502112.00B\u2502\n# \u2514\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2518\n\nprint(pytc.tree_summary(tree,depth=2))\n# \u250c\u2500\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2510\n# \u2502Name \u2502Type\u2502Count\u2502Size \u2502\n# \u251c\u2500\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n# \u2502[0] \u2502int \u25021 \u250228.00B \u2502\n# \u251c\u2500\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n# \u2502[1][0]\u2502int \u25021 \u250228.00B \u2502\n# \u251c\u2500\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n# \u2502[1][1]\u2502int \u25021 \u250228.00B \u2502\n# \u251c\u2500\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n# \u2502[2] \u2502int \u25021 \u250228.00B \u2502\n# \u251c\u2500\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n# \u2502\u03a3 \u2502list\u25024 \u2502112.00B\u2502\n# \u2514\u2500\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2518\n```\n\n</details>\n\n<details><summary>Use PyTreeClass components with other libraries</summary>\n\n```python\nimport jax\nimport pytreeclass as pytc\nfrom flax import struct\n\n@struct.dataclass\nclass FlaxTree:\n a:int = 1\n b:tuple[float] = (2.,3.)\n c:jax.Array = jax.numpy.array([4.,5.,6.])\n\n def __repr__(self) -> str:\n return pytc.tree_repr(self)\n def __str__(self) -> str:\n return pytc.tree_str(self)\n @property\n def at(self):\n return pytc.tree_indexer(self)\n\ndef pytc_flatten_rule(tree):\n names =(\"a\",\"b\",\"c\")\n types = map(type, (tree.a, tree.b, tree.c))\n indices = range(3)\n metadatas= (None, None, None)\n return [*zip(names, types, indices, metadatas)]\n\npytc.register_pytree_node_trace(FlaxTree, pytc_flatten_rule)\n\nflax_tree = FlaxTree()\n\nprint(f\"{flax_tree!r}\")\n# FlaxTree(a=1, b=(2.0, 3.0), c=f32[3](\u03bc=5.00, \u03c3=0.82, \u2208[4.00,6.00]))\n\nprint(f\"{flax_tree!s}\")\n# FlaxTree(a=1, b=(2.0, 3.0), c=[4. 5. 6.])\n\nprint(pytc.tree_diagram(flax_tree))\n# FlaxTree\n# \u251c\u2500\u2500 a=1\n# \u251c\u2500\u2500 b:tuple\n# \u2502 \u251c\u2500\u2500 [0]=2.0\n# \u2502 \u2514\u2500\u2500 [1]=3.0\n# \u2514\u2500\u2500 c=f32[3](\u03bc=5.00, \u03c3=0.82, \u2208[4.00,6.00])\n\nprint(pytc.tree_summary(flax_tree))\n# \u250c\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u2500\u2510\n# \u2502Name\u2502Type \u2502Count\u2502Size \u2502\n# \u251c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n# \u2502a \u2502int \u25021 \u250228.00B\u2502\n# \u251c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n# \u2502b[0]\u2502float \u25021 \u250224.00B\u2502\n# \u251c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n# \u2502b[1]\u2502float \u25021 \u250224.00B\u2502\n# \u251c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n# \u2502c \u2502f32[3] \u25023 \u250212.00B\u2502\n# \u251c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n# \u2502\u03a3 \u2502FlaxTree\u25026 \u250288.00B\u2502\n# \u2514\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2500\u2518\n\nflax_tree.at[0].get()\n# FlaxTree(a=1, b=(None, None), c=None)\n\nflax_tree.at[\"a\"].set(10)\n# FlaxTree(a=10, b=(2.0, 3.0), c=f32[3](\u03bc=5.00, \u03c3=0.82, \u2208[4.00,6.00]))\n```\n\n</details>\n\n<details>\n<summary>Benchmark flatten/unflatten compared to Flax and Equinox </summary>\n\n<a href=\"https://colab.research.google.com/github/ASEM000/PyTreeClass/blob/main/assets/benchmark_flatten_unflatten.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n\n<table>\n\n<tr><td align=\"center\">CPU</td><td align=\"center\">GPU</td></tr>\n\n<tr>\n\n<td><img src='assets/benchmark_cpu.png'></td>\n<td><img src='assets/benchmark_gpu.png'></td>\n\n</tr>\n\n</table>\n\n</details>\n\n## \ud83d\udcd9 Acknowledgements<a id=\"acknowledgements\"></a>\n\n- [Farid Talibli (for visualization link generation backend)](https://www.linkedin.com/in/frdt98)\n- [Treex](https://github.com/cgarciae/treex), [Equinox](https://github.com/patrick-kidger/equinox), [tree-math](https://github.com/google/tree-math), [Flax](https://github.com/google/flax), [TensorFlow](https://www.tensorflow.org), [PyTorch](https://pytorch.org)\n- [Lovely JAX](https://github.com/xl0/lovely-jax)\n",
"bugtrack_url": null,
"license": "Apache-2.0",
"summary": "JAX compatible dataclass.",
"version": "0.2.1",
"split_keywords": [
"python",
"machine-learning",
"pytorch",
"jax"
],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "651df5b93409dde5fa7391ea39626fcab4bf25a19f7d50d22cde00e7403a99b9",
"md5": "72a8c8e285cc4201ea6e3dd31d29486e",
"sha256": "816cfee6cd856580c094565700d1744e8a7e67d8fbb9b9f3c0a2f71ce48ac4f8"
},
"downloads": -1,
"filename": "pytreeclass-0.2.1-py3-none-any.whl",
"has_sig": false,
"md5_digest": "72a8c8e285cc4201ea6e3dd31d29486e",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": ">=3.8",
"size": 58384,
"upload_time": "2023-03-19T03:37:07",
"upload_time_iso_8601": "2023-03-19T03:37:07.636070Z",
"url": "https://files.pythonhosted.org/packages/65/1d/f5b93409dde5fa7391ea39626fcab4bf25a19f7d50d22cde00e7403a99b9/pytreeclass-0.2.1-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": "",
"digests": {
"blake2b_256": "0ff7e7c371fed186f64e1462054d60244c45435afd4cacb0662b1dfe8f4e5908",
"md5": "dda54f48459237246a6852b650d99802",
"sha256": "3de22bd6eb3931f2d30a749fba2f656e997058ffb6aba8b30a918d44d47e2db6"
},
"downloads": -1,
"filename": "pytreeclass-0.2.1.tar.gz",
"has_sig": false,
"md5_digest": "dda54f48459237246a6852b650d99802",
"packagetype": "sdist",
"python_version": "source",
"requires_python": ">=3.8",
"size": 58332,
"upload_time": "2023-03-19T03:37:10",
"upload_time_iso_8601": "2023-03-19T03:37:10.141936Z",
"url": "https://files.pythonhosted.org/packages/0f/f7/e7c371fed186f64e1462054d60244c45435afd4cacb0662b1dfe8f4e5908/pytreeclass-0.2.1.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2023-03-19 03:37:10",
"github": true,
"gitlab": false,
"bitbucket": false,
"github_user": "ASEM000",
"github_project": "pytreeclass",
"travis_ci": false,
"coveralls": false,
"github_actions": true,
"lcname": "pytreeclass"
}