pytreeclass


Namepytreeclass JSON
Version 0.2.1 PyPI version JSON
download
home_pagehttps://github.com/ASEM000/pytreeclass
SummaryJAX compatible dataclass.
upload_time2023-03-19 03:37:10
maintainer
docs_urlNone
authorMahmoud Asem
requires_python>=3.8
licenseApache-2.0
keywords python machine-learning pytorch jax
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            <!-- <h1 align="center" style="font-family:Monospace" >Py🌲Class</h1> -->
<h5 align="center">
<img width="250px" src="assets/pytc%20logo.svg"> <br>

<br>

[**Installation**](#installation)
|[**Description**](#description)
|[**Quick Example**](#quick_example)
|[**StatefulComputation**](#stateful_computation)
|[**More**](#more)
|[**Acknowledgements**](#acknowledgements)

![Tests](https://github.com/ASEM000/pytreeclass/actions/workflows/tests.yml/badge.svg)
![pyver](https://img.shields.io/badge/python-3.8%203.9%203.10%203.11_-red)
![pyver](https://img.shields.io/badge/jax-0.4+-red)
![codestyle](https://img.shields.io/badge/codestyle-black-black)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ASEM000/PyTreeClass/blob/main/assets/intro.ipynb)
[![Downloads](https://pepy.tech/badge/pytreeclass)](https://pepy.tech/project/pytreeclass)
[![codecov](https://codecov.io/gh/ASEM000/pytreeclass/branch/main/graph/badge.svg?token=TZBRMO0UQH)](https://codecov.io/gh/ASEM000/pytreeclass)
[![Documentation Status](https://readthedocs.org/projects/pytreeclass/badge/?version=latest)](https://pytreeclass.readthedocs.io/en/latest/?badge=latest)
![GitHub commit activity](https://img.shields.io/github/commit-activity/m/ASEM000/pytreeclass)
[![DOI](https://zenodo.org/badge/512717921.svg)](https://zenodo.org/badge/latestdoi/512717921)
![PyPI](https://img.shields.io/pypi/v/pytreeclass)

</h5>

**For previous `PyTreeClass` use v0.1 branch**

## πŸ› οΈ Installation<a id="installation"></a>

```python
pip install pytreeclass
```

**Install development version**

```python
pip install git+https://github.com/ASEM000/PyTreeClass
```

## πŸ“– Description<a id="description"></a>

`PyTreeClass` is a JAX-compatible `dataclass`-like decorator to create and operate on stateful JAX PyTrees.

The package aims to achieve _two goals_:

1. πŸ”’ To maintain safe and correct behaviour by using _immutable_ modules with _functional_ API.
2. To achieve the **most intuitive** user experience in the `JAX` ecosystem by :
   - πŸ—οΈ Defining layers similar to `PyTorch` or `TensorFlow` subclassing style.
   - ☝️ Filtering\Indexing layer values similar to `jax.numpy.at[].{get,set,apply,...}`
   - 🎨 Visualize defined layers in plethora of ways.

## ⏩ Quick Example <a id="quick_example">

### πŸ—οΈ Simple Tree example

<div align="center">
<table>
<tr><td align="center">Code</td> <td align="center">PyTree representation</td></tr>
<tr>
<td>

```python
import jax
import jax.numpy as jnp
import pytreeclass as pytc

@pytc.treeclass
class Tree:
    a:int = 1
    b:tuple[float] = (2.,3.)
    c:jax.Array = jnp.array([4.,5.,6.])

    def __call__(self, x):
        return self.a + self.b[0] + self.c + x

tree = Tree()
```

</td>

<td>

```python
# leaves are parameters

Tree
    β”œβ”€β”€ a=1
    β”œβ”€β”€ b:tuple
    β”‚   β”œβ”€β”€ [0]=2.0
    β”‚   └── [1]=3.0
    └── c=f32[3](ΞΌ=5.00, Οƒ=0.82, ∈[4.00,6.00])
```

</td>

</tr>
</table>
</div>

### 🎨 Visualize<a id="Viz">

<details> <summary>Visualize PyTrees in five different ways</summary>

<div align="center">
<table>
<tr>
 <td align = "center"> tree_summary</td> 
 <td align = "center">tree_diagram</td>
 <td align = "center">[tree_mermaid](https://mermaid.js.org)(Native support in Github/Notion)</td>
 <td align= "center"> tree_repr </td>
 <td align="center" > tree_str </td>

</tr>

<tr>
<td>

```python
print(pytc.tree_summary(tree, depth=1))
β”Œβ”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”
β”‚Nameβ”‚Type  β”‚Countβ”‚Size  β”‚
β”œβ”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€
β”‚a   β”‚int   β”‚1    β”‚28.00Bβ”‚
β”œβ”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€
β”‚b   β”‚tuple β”‚2    β”‚48.00Bβ”‚
β”œβ”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€
β”‚c   β”‚f32[3]β”‚3    β”‚12.00Bβ”‚
β”œβ”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€
β”‚Ξ£   β”‚Tree  β”‚6    β”‚88.00Bβ”‚
β””β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”˜
```

</td>

<td>

```python

print(pytc.tree_diagram(tree, depth=1))
Tree
    β”œβ”€β”€ a=1
    β”œβ”€β”€ b=(..., ...)
    └── c=f32[3](ΞΌ=5.00, Οƒ=0.82, ∈[4.00,6.00])
```

 </td>

<td>

```python
print(pytc.tree_mermaid(tree, depth=1))
```

```mermaid

flowchart LR
    id15696277213149321320(<b>Tree</b>)
    id15696277213149321320--->|"1 leaf<br>28.00B"|id4205845433746830897("<b>a</b>:int=1")
    id15696277213149321320--->|"2 leaf<br>48.00B"|id4682191244783855647("<b>b</b>:tuple=(..., ...)")
    id15696277213149321320--->|"3 leaf<br>12.00B"|id14652085615030570957("<b>c</b>:ArrayImpl=f32[3](ΞΌ=5.00, Οƒ=0.82, ∈[4.00,6.00])")
```

</td>

<td>

```python
print(pytc.tree_repr(tree, depth=1))
Tree(a=1, b=(..., ...), c=f32[3](ΞΌ=5.00, Οƒ=0.82, ∈[4.00,6.00]))
```

</td>

<td>

```python
print(pytc.tree_str(tree, depth=1))
Tree(a=1, b=(..., ...), c=[4. 5. 6.])
```

</td>

</tr>

<tr>

<td>

```python
print(pytc.tree_summary(tree, depth=2))
β”Œβ”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”
β”‚Nameβ”‚Type  β”‚Countβ”‚Size  β”‚
β”œβ”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€
β”‚a   β”‚int   β”‚1    β”‚28.00Bβ”‚
β”œβ”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€
β”‚b[0]β”‚float β”‚1    β”‚24.00Bβ”‚
β”œβ”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€
β”‚b[1]β”‚float β”‚1    β”‚24.00Bβ”‚
β”œβ”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€
β”‚c   β”‚f32[3]β”‚3    β”‚12.00Bβ”‚
β”œβ”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€
β”‚Ξ£   β”‚Tree  β”‚6    β”‚88.00Bβ”‚
β””β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”˜
```

</td>

<td>

```python
print(pytc.tree_diagram(tree, depth=2))
Tree
    β”œβ”€β”€ a=1
    β”œβ”€β”€ b:tuple
    β”‚   β”œβ”€β”€ [0]=2.0
    β”‚   └── [1]=3.0
    └── c=f32[3](ΞΌ=5.00, Οƒ=0.82, ∈[4.00,6.00])
```

</td>

<td>

```python
print(pytc.tree_mermaid(tree, depth=2))
```

```mermaid
flowchart LR
    id15696277213149321320(<b>Tree</b>)
    id15696277213149321320--->id4205845433746830897("<b>a</b>:int=1")
    id15696277213149321320--->|"1 leaf<br>24.00B"|id8168961130706115346("<b>b</b>:tuple")
    id8168961130706115346--->|"1 leaf<br>24.00B"|id2766159651176208202("<b>[0]</b>:float=2.0")
    id15696277213149321320--->|"1 leaf<br>24.00B"|id12408280303145007954("<b>b</b>:tuple")
    id12408280303145007954--->|"1 leaf<br>24.00B"|id7897116322308127883("<b>[1]</b>:float=3.0")
    id15696277213149321320--->id14652085615030570957("<b>c</b>:ArrayImpl=f32[3](ΞΌ=5.00, Οƒ=0.82, ∈[4.00,6.00])")
```

</td>

<td>

```python
print(pytc.tree_repr(tree, depth=2))
Tree(a=1, b=(2.0, 3.0), c=f32[3](ΞΌ=5.00, Οƒ=0.82, ∈[4.00,6.00]))
```

</td>

<td>

```python
print(pytc.tree_str(tree, depth=2))
Tree(a=1, b=(2.0, 3.0), c=[4. 5. 6.])
```

</td>

</tr>

 </table>

 </div>

</details>

### πŸƒ Working with `jax` transformation

<details> <summary>Make arbitrary PyTrees work with jax transformations</summary>

Parameters are defined in `Tree` at the top of class definition similar to defining
`dataclasses.dataclass` field.
Lets optimize our parameters

```python

@jax.grad
def loss_func(tree:Tree, x:jax.Array):
    preds = jax.vmap(tree)(x)  # <--- vectorize the tree call over the leading axis
    return jnp.mean(preds**2)  # <--- return the mean squared error

@jax.jit
def train_step(tree:Tree, x:jax.Array):
    grads = loss_func(tree, x)
    # apply a small gradient step
    return jax.tree_util.tree_map(lambda x, g: x - 1e-3*g, tree, grads)

# lets freeze the non-differentiable parts of the tree
# in essence any non inexact type should be frozen to
# make the tree differentiable and work with jax transformations
jaxable_tree = jax.tree_util.tree_map(lambda x: pytc.freeze(x) if pytc.is_nondiff(x) else x, tree)

for epoch in range(1_000):
    jaxable_tree = train_step(jaxable_tree, jnp.ones([10,1]))

print(jaxable_tree)
# **the `frozen` params have "#" prefix**
#Tree(a=#1, b=(-4.2826524, 3.0), c=[2.3924797 2.905778  3.4190805])


# unfreeze the tree
tree = jax.tree_util.tree_map(pytc.unfreeze, jaxable_tree, is_leaf=pytc.is_frozen)
print(tree)
# Tree(a=1, b=(-4.2826524, 3.0), c=[2.3924797 2.905778  3.4190805])
```

</details>

#### ☝️ Advanced Indexing with `.at[]` <a id="Indexing">

<details> <summary>Out-of-place updates using mask, attribute name or index</summary>

`PyTreeClass` offers 3 means of indexing through `.at[]`

1. Indexing by boolean mask.
2. Indexing by attribute name.
3. Indexing by Leaf index.

**Since `treeclass` wrapped class are immutable, `.at[]` operations returns new instance of the tree**

#### Index update by boolean mask

```python
tree = Tree()
# Tree(a=1, b=(2, 3), c=i32[3](ΞΌ=5.00, Οƒ=0.82, ∈[4,6]))

# lets create a mask for values > 4
mask = jax.tree_util.tree_map(lambda x: x>4, tree)

print(mask)
# Tree(a=False, b=(False, False), c=[False  True  True])

print(tree.at[mask].get())
# Tree(a=None, b=(None, None), c=[5 6])

print(tree.at[mask].set(10))
# Tree(a=1, b=(2, 3), c=[ 4 10 10])

print(tree.at[mask].apply(lambda x: 10))
# Tree(a=1, b=(2, 3), c=[ 4 10 10])
```

#### Index update by attribute name

```python
tree = Tree()
# Tree(a=1, b=(2, 3), c=i32[3](ΞΌ=5.00, Οƒ=0.82, ∈[4,6]))

print(tree.at["a"].get())
# Tree(a=1, b=(None, None), c=None)

print(tree.at["a"].set(10))
# Tree(a=10, b=(2, 3), c=[4 5 6])

print(tree.at["a"].apply(lambda x: 10))
# Tree(a=10, b=(2, 3), c=[4 5 6])
```

#### Index update by integer index

```python
tree = Tree()
# Tree(a=1, b=(2, 3), c=i32[3](ΞΌ=5.00, Οƒ=0.82, ∈[4,6]))

print(tree.at[1].at[0].get())
# Tree(a=None, b=(2.0, None), c=None)

print(tree.at[1].at[0].set(10))
# Tree(a=1, b=(10, 3.0), c=[4. 5. 6.])

print(tree.at[1].at[0].apply(lambda x: 10))
# Tree(a=1, b=(10, 3.0), c=[4. 5. 6.])
```

</details>

<details>

<summary>

## πŸ“œ Stateful computations<a id="stateful_computation"></a> </summary>

First, [Under jax.jit jax requires states to be explicit](https://jax.readthedocs.io/en/latest/jax-101/07-state.html?highlight=state), this means that for any class instance; variables needs to be separated from the class and be passed explictly. However when using @pytc.treeclass no need to separate the instance variables ; instead the whole instance is passed as a state.

Using the following pattern,Updating state **functionally** can be achieved under `jax.jit`

```python
import jax
import pytreeclass as pytc

@pytc.treeclass
class Counter:
    calls : int = 0

    def increment(self):
        self.calls += 1
counter = Counter() # Counter(calls=0)
```

Here, we define the update function. Since the increment method mutate the internal state, thus we need to use the functional approach to update the state by using `.at`. To achieve this we can use `.at[method_name].__call__(*args,**kwargs)`, this functional call will return the value of this call and a _new_ model instance with the update state.

```python
@jax.jit
def update(counter):
    value, new_counter = counter.at["increment"]()
    return new_counter

for i in range(10):
    counter = update(counter)

print(counter.calls) # 10
```

</details>

## βž• More<a id="more"></a>

<details><summary>[Advanced] Register custom user-defined classes to work with visualization and indexing tools. </summary>

Similar to [`jax.tree_util.register_pytree_node`](https://jax.readthedocs.io/en/latest/pytrees.html#extending-pytrees), `PyTreeClass` register common data structures and `treeclass` wrapped classes to figure out how to define the names, types, index, and metadatas of certain leaf along its path.

Here is an example of registering

```python

class Tree:
    def __init__(self, a, b):
        self.a = a
        self.b = b

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(a={self.a}, b={self.b})"


# jax flatten rule
def tree_flatten(tree):
    return (tree.a, tree.b), None

# jax unflatten rule
def tree_unflatten(_, children):
    return Tree(*children)

# PyTreeClass flatten rule
def pytc_tree_flatten(tree):
    names = ("a", "b")
    types = (type(tree.a), type(tree.b))
    indices = (0,1)
    metadatas = (None, None)
    return [*zip(names, types, indices, metadatas)]


# Register with `jax`
jax.tree_util.register_pytree_node(Tree, tree_flatten, tree_unflatten)

# Register the `Tree` class trace function to support indexing
pytc.register_pytree_node_trace(Tree, pytc_tree_flatten)

tree = Tree(1, 2)

# works with jax
jax.tree_util.tree_leaves(tree)  # [1, 2]

# works with PyTreeClass viz tools
print(pytc.tree_summary(tree))

# β”Œβ”€β”€β”€β”€β”¬β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”
# β”‚Nameβ”‚Typeβ”‚Countβ”‚Size  β”‚
# β”œβ”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€
# β”‚a   β”‚int β”‚1    β”‚28.00Bβ”‚
# β”œβ”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€
# β”‚b   β”‚int β”‚1    β”‚28.00Bβ”‚
# β”œβ”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€
# β”‚Ξ£   β”‚Treeβ”‚2    β”‚56.00Bβ”‚
# β””β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”˜

```

After registeration, you can use internal tools like

- `pytc.tree_map_with_trace`
- `pytc.tree_leaves_with_trace`
- `pytc.tree_flatten_with_trace`

More details on that soon.

</details>

<details> <summary>Validate or convert inputs using callbacks</summary>

`PyTreeClass` includes `callbacks` in the `field` to apply a sequence of functions on input at setting the attribute stage. The callback is quite useful in several cases, for instance, to ensure a certain input type within a valid range. See example:

```python
import jax
import pytreeclass as pytc

def positive_int_callback(value):
    if not isinstance(value, int):
        raise TypeError("Value must be an integer")
    if value <= 0:
        raise ValueError("Value must be positive")
    return value


@pytc.treeclass
class Tree:
    in_features:int = pytc.field(callbacks=[positive_int_callback])


tree = Tree(1)
# no error

tree = Tree(0)
# ValueError: Error for field=`in_features`:
# Value must be positive

tree = Tree(1.0)
# TypeError: Error for field=`in_features`:
# Value must be an integer
```

</details>

<details>  <summary> Add leafwise math operations to PyTreeClass wrapped class</summary>

```python
import functools as ft
import pytreeclass as pytc

@ft.partial(pytc.treeclass, leafwise=True)
class Tree:
    a:int = 1
    b:tuple[float] = (2.,3.)
    c:jax.Array = jnp.array([4.,5.,6.])

    def __call__(self, x):
        return self.a + self.b[0] + self.c + x

tree = Tree()

tree + 100
# Tree(a=101, b=(102.0, 103.0), c=f32[3](ΞΌ=105.00, Οƒ=0.82, ∈[104.00,106.00]))

@jax.grad
def loss_func(tree:Tree, x:jax.Array):
    preds = jax.vmap(tree)(x)  # <--- vectorize the tree call over the leading axis
    return jnp.mean(preds**2)  # <--- return the mean squared error

@jax.jit
def train_step(tree:Tree, x:jax.Array):
    grads = loss_func(tree, x)
    return tree - grads*1e-3  # <--- eliminate `tree_map`

# lets freeze the non-differentiable parts of the tree
# in essence any non inexact type should be frozen to
# make the tree differentiable and work with jax transformations
jaxable_tree = jax.tree_util.tree_map(lambda x: pytc.freeze(x) if pytc.is_nondiff(x) else x, tree)

for epoch in range(1_000):
    jaxable_tree = train_step(jaxable_tree, jnp.ones([10,1]))

print(jaxable_tree)
# **the `frozen` params have "#" prefix**
# Tree(a=#1, b=(-4.7176366, 3.0), c=[2.4973059 2.760783  3.024264 ])


# unfreeze the tree
tree = jax.tree_util.tree_map(pytc.unfreeze, jaxable_tree, is_leaf=pytc.is_frozen)
print(tree)
# Tree(a=1, b=(-4.7176366, 3.0), c=[2.4973059 2.760783  3.024264 ])
```

</details>

<details> <summary>Eliminate tree_map using bcmap + treeclass(..., leafwise=True) </summary>

TDLR

```python
import functools as ft
import pytreeclass as pytc
import jax.numpy as jnp

@ft.partial(pytc.treeclass, leafwise=True)
class Tree:
    a:int = 1
    b:tuple[float] = (2.,3.)
    c:jax.Array = jnp.array([4.,5.,6.])

tree = Tree()

print(pytc.bcmap(jnp.where)(tree>2, tree+100, 0))
# Tree(a=0, b=(0.0, 103.0), c=[104. 105. 106.])

```

`bcmap(func, is_leaf)` maps a function over [PyTrees](https://jax.readthedocs.io/en/latest/pytrees.html) leaves with automatic broadcasting for scalar arguments.

`bcmap` is function transformation that broadcast a scalar to match the first argument of the function this enables us to convert a function like `jnp.where` to work with arbitrary tree structures without the need to write a specific function for each broadcasting case

For example, lets say we want to use `jnp.where` to zeros out all values in an arbitrary tree structure that are less than 0

tree = ([1], {"a":1, "b":2}, (1,), -1,)

we can use `jax.tree_util.tree_map` to apply `jnp.where` to the tree but we need to write a specific function for broadcasting the scalar to the tree

```python
def map_func(leaf):
    # here we encoded the scalar `0` inside the function
    return jnp.where(leaf>0, leaf, 0)

jtu.tree_map(map_func, tree)
# ([Array(1, dtype=int32, weak_type=True)],
#  {'a': Array(1, dtype=int32, weak_type=True),
#   'b': Array(2, dtype=int32, weak_type=True)},
#  (Array(1, dtype=int32, weak_type=True),),
#  Array(0, dtype=int32, weak_type=True))
```

However, lets say we want to use `jnp.where` to set a value to a leaf value from another tree that looks like this

```python
def map_func2(lhs_leaf, rhs_leaf):
    # here we encoded the scalar `0` inside the function
    return jnp.where(lhs_leaf>0, lhs_leaf, rhs_leaf)

tree2 = jtu.tree_map(lambda x: 1000, tree)

jtu.tree_map(map_func2, tree, tree2)
# ([Array(1, dtype=int32, weak_type=True)],
#  {'a': Array(1, dtype=int32, weak_type=True),
#   'b': Array(2, dtype=int32, weak_type=True)},
#  (Array(1, dtype=int32, weak_type=True),),
#  Array(1000, dtype=int32, weak_type=True))
```

Now, `bcmap` makes this easier by figuring out the broadcasting case.

```python
broadcastable_where = pytc.bcmap(jnp.where)
mask = jtu.tree_map(lambda x: x>0, tree)
```

case 1

```python
broadcastable_where(mask, tree, 0)
# ([Array(1, dtype=int32, weak_type=True)],
#  {'a': Array(1, dtype=int32, weak_type=True),
#   'b': Array(2, dtype=int32, weak_type=True)},
#  (Array(1, dtype=int32, weak_type=True),),
#  Array(0, dtype=int32, weak_type=True))
```

case 2

```python
broadcastable_where(mask, tree, tree2)
# ([Array(1, dtype=int32, weak_type=True)],
#  {'a': Array(1, dtype=int32, weak_type=True),
#   'b': Array(2, dtype=int32, weak_type=True)},
#  (Array(1, dtype=int32, weak_type=True),),
#  Array(1000, dtype=int32, weak_type=True))
```

lets then take this a step further to eliminate `mask` from the equation
by using `pytreeclass` with `leafwise=True `

```python
@ft.partial(pytc.treeclass, leafwise=True)
class Tree:
    tree : tuple = ([1], {"a":1, "b":2}, (1,), -1,)

tree = Tree()
# Tree(tree=([1], {a:1, b:2}, (1), -1))
```

case 1: broadcast scalar to tree

````python
print(broadcastable_where(tree>0, tree, 0))
# Tree(tree=([1], {a:1, b:2}, (1), 0))

case 2: broadcast tree to tree
```python
print(broadcastable_where(tree>0, tree, tree+100))
# Tree(tree=([1], {a:1, b:2}, (1), 99))
````

`bcmap` also works with all kind of arguments in the wrapped function

```python
print(broadcastable_where(tree>0, x=tree, y=tree+100))
# Tree(tree=([1], {a:1, b:2}, (1), 99))
```

in concolusion, `bcmap` is a function transformation that can be used to
to make functions work with arbitrary tree structures without the need to write
a specific function for each broadcasting case

Moreover, `bcmap` can be more powerful when used with `pytreeclass` to
facilitate operation of arbitrary functions on `PyTree` objects
without the need to use `tree_map`

</details>

<details><summary>Use PyTreeClass vizualization tools with arbitrary PyTrees </summary>

```python
import jax
import pytreeclass as pytc

tree = [1, [2,3], 4]

print(pytc.tree_diagram(tree,depth=1))
# list
#     β”œβ”€β”€ [0]=1
#     β”œβ”€β”€ [1]=[..., ...]
#     └── [2]=4

print(pytc.tree_diagram(tree,depth=2))
# list
#     β”œβ”€β”€ [0]=1
#     β”œβ”€β”€ [1]:list
#     β”‚   β”œβ”€β”€ [0]=2
#     β”‚   └── [1]=3
#     └── [2]=4


print(pytc.tree_summary(tree,depth=1))
# β”Œβ”€β”€β”€β”€β”¬β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”
# β”‚Nameβ”‚Typeβ”‚Countβ”‚Size   β”‚
# β”œβ”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€
# β”‚[0] β”‚int β”‚1    β”‚28.00B β”‚
# β”œβ”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€
# β”‚[1] β”‚listβ”‚2    β”‚56.00B β”‚
# β”œβ”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€
# β”‚[2] β”‚int β”‚1    β”‚28.00B β”‚
# β”œβ”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€
# β”‚Ξ£   β”‚listβ”‚4    β”‚112.00Bβ”‚
# β””β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”˜

print(pytc.tree_summary(tree,depth=2))
# β”Œβ”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”
# β”‚Name  β”‚Typeβ”‚Countβ”‚Size   β”‚
# β”œβ”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€
# β”‚[0]   β”‚int β”‚1    β”‚28.00B β”‚
# β”œβ”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€
# β”‚[1][0]β”‚int β”‚1    β”‚28.00B β”‚
# β”œβ”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€
# β”‚[1][1]β”‚int β”‚1    β”‚28.00B β”‚
# β”œβ”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€
# β”‚[2]   β”‚int β”‚1    β”‚28.00B β”‚
# β”œβ”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€
# β”‚Ξ£     β”‚listβ”‚4    β”‚112.00Bβ”‚
# β””β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”˜
```

</details>

<details><summary>Use PyTreeClass components with other libraries</summary>

```python
import jax
import pytreeclass as pytc
from flax import struct

@struct.dataclass
class FlaxTree:
    a:int = 1
    b:tuple[float] = (2.,3.)
    c:jax.Array = jax.numpy.array([4.,5.,6.])

    def __repr__(self) -> str:
        return pytc.tree_repr(self)
    def __str__(self) -> str:
        return pytc.tree_str(self)
    @property
    def at(self):
        return pytc.tree_indexer(self)

def pytc_flatten_rule(tree):
    names =("a","b","c")
    types = map(type, (tree.a, tree.b, tree.c))
    indices = range(3)
    metadatas= (None, None, None)
    return [*zip(names, types, indices, metadatas)]

pytc.register_pytree_node_trace(FlaxTree, pytc_flatten_rule)

flax_tree = FlaxTree()

print(f"{flax_tree!r}")
# FlaxTree(a=1, b=(2.0, 3.0), c=f32[3](ΞΌ=5.00, Οƒ=0.82, ∈[4.00,6.00]))

print(f"{flax_tree!s}")
# FlaxTree(a=1, b=(2.0, 3.0), c=[4. 5. 6.])

print(pytc.tree_diagram(flax_tree))
# FlaxTree
#     β”œβ”€β”€ a=1
#     β”œβ”€β”€ b:tuple
#     β”‚   β”œβ”€β”€ [0]=2.0
#     β”‚   └── [1]=3.0
#     └── c=f32[3](ΞΌ=5.00, Οƒ=0.82, ∈[4.00,6.00])

print(pytc.tree_summary(flax_tree))
# β”Œβ”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”
# β”‚Nameβ”‚Type    β”‚Countβ”‚Size  β”‚
# β”œβ”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€
# β”‚a   β”‚int     β”‚1    β”‚28.00Bβ”‚
# β”œβ”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€
# β”‚b[0]β”‚float   β”‚1    β”‚24.00Bβ”‚
# β”œβ”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€
# β”‚b[1]β”‚float   β”‚1    β”‚24.00Bβ”‚
# β”œβ”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€
# β”‚c   β”‚f32[3]  β”‚3    β”‚12.00Bβ”‚
# β”œβ”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€
# β”‚Ξ£   β”‚FlaxTreeβ”‚6    β”‚88.00Bβ”‚
# β””β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”˜

flax_tree.at[0].get()
# FlaxTree(a=1, b=(None, None), c=None)

flax_tree.at["a"].set(10)
# FlaxTree(a=10, b=(2.0, 3.0), c=f32[3](ΞΌ=5.00, Οƒ=0.82, ∈[4.00,6.00]))
```

</details>

<details>
<summary>Benchmark flatten/unflatten compared to Flax and Equinox </summary>

<a href="https://colab.research.google.com/github/ASEM000/PyTreeClass/blob/main/assets/benchmark_flatten_unflatten.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<table>

<tr><td align="center">CPU</td><td align="center">GPU</td></tr>

<tr>

<td><img src='assets/benchmark_cpu.png'></td>
<td><img src='assets/benchmark_gpu.png'></td>

</tr>

</table>

</details>

## πŸ“™ Acknowledgements<a id="acknowledgements"></a>

- [Farid Talibli (for visualization link generation backend)](https://www.linkedin.com/in/frdt98)
- [Treex](https://github.com/cgarciae/treex), [Equinox](https://github.com/patrick-kidger/equinox), [tree-math](https://github.com/google/tree-math), [Flax](https://github.com/google/flax), [TensorFlow](https://www.tensorflow.org), [PyTorch](https://pytorch.org)
- [Lovely JAX](https://github.com/xl0/lovely-jax)

            

Raw data

            {
    "_id": null,
    "home_page": "https://github.com/ASEM000/pytreeclass",
    "name": "pytreeclass",
    "maintainer": "",
    "docs_url": null,
    "requires_python": ">=3.8",
    "maintainer_email": "",
    "keywords": "python machine-learning pytorch jax",
    "author": "Mahmoud Asem",
    "author_email": "asem00@kaist.ac.kr",
    "download_url": "https://files.pythonhosted.org/packages/0f/f7/e7c371fed186f64e1462054d60244c45435afd4cacb0662b1dfe8f4e5908/pytreeclass-0.2.1.tar.gz",
    "platform": null,
    "description": "<!-- <h1 align=\"center\" style=\"font-family:Monospace\" >Py\ud83c\udf32Class</h1> -->\n<h5 align=\"center\">\n<img width=\"250px\" src=\"assets/pytc%20logo.svg\"> <br>\n\n<br>\n\n[**Installation**](#installation)\n|[**Description**](#description)\n|[**Quick Example**](#quick_example)\n|[**StatefulComputation**](#stateful_computation)\n|[**More**](#more)\n|[**Acknowledgements**](#acknowledgements)\n\n![Tests](https://github.com/ASEM000/pytreeclass/actions/workflows/tests.yml/badge.svg)\n![pyver](https://img.shields.io/badge/python-3.8%203.9%203.10%203.11_-red)\n![pyver](https://img.shields.io/badge/jax-0.4+-red)\n![codestyle](https://img.shields.io/badge/codestyle-black-black)\n[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ASEM000/PyTreeClass/blob/main/assets/intro.ipynb)\n[![Downloads](https://pepy.tech/badge/pytreeclass)](https://pepy.tech/project/pytreeclass)\n[![codecov](https://codecov.io/gh/ASEM000/pytreeclass/branch/main/graph/badge.svg?token=TZBRMO0UQH)](https://codecov.io/gh/ASEM000/pytreeclass)\n[![Documentation Status](https://readthedocs.org/projects/pytreeclass/badge/?version=latest)](https://pytreeclass.readthedocs.io/en/latest/?badge=latest)\n![GitHub commit activity](https://img.shields.io/github/commit-activity/m/ASEM000/pytreeclass)\n[![DOI](https://zenodo.org/badge/512717921.svg)](https://zenodo.org/badge/latestdoi/512717921)\n![PyPI](https://img.shields.io/pypi/v/pytreeclass)\n\n</h5>\n\n**For previous `PyTreeClass` use v0.1 branch**\n\n## \ud83d\udee0\ufe0f Installation<a id=\"installation\"></a>\n\n```python\npip install pytreeclass\n```\n\n**Install development version**\n\n```python\npip install git+https://github.com/ASEM000/PyTreeClass\n```\n\n## \ud83d\udcd6 Description<a id=\"description\"></a>\n\n`PyTreeClass` is a JAX-compatible `dataclass`-like decorator to create and operate on stateful JAX PyTrees.\n\nThe package aims to achieve _two goals_:\n\n1. \ud83d\udd12 To maintain safe and correct behaviour by using _immutable_ modules with _functional_ API.\n2. To achieve the **most intuitive** user experience in the `JAX` ecosystem by :\n   - \ud83c\udfd7\ufe0f Defining layers similar to `PyTorch` or `TensorFlow` subclassing style.\n   - \u261d\ufe0f Filtering\\Indexing layer values similar to `jax.numpy.at[].{get,set,apply,...}`\n   - \ud83c\udfa8 Visualize defined layers in plethora of ways.\n\n## \u23e9 Quick Example <a id=\"quick_example\">\n\n### \ud83c\udfd7\ufe0f Simple Tree example\n\n<div align=\"center\">\n<table>\n<tr><td align=\"center\">Code</td> <td align=\"center\">PyTree representation</td></tr>\n<tr>\n<td>\n\n```python\nimport jax\nimport jax.numpy as jnp\nimport pytreeclass as pytc\n\n@pytc.treeclass\nclass Tree:\n    a:int = 1\n    b:tuple[float] = (2.,3.)\n    c:jax.Array = jnp.array([4.,5.,6.])\n\n    def __call__(self, x):\n        return self.a + self.b[0] + self.c + x\n\ntree = Tree()\n```\n\n</td>\n\n<td>\n\n```python\n# leaves are parameters\n\nTree\n    \u251c\u2500\u2500 a=1\n    \u251c\u2500\u2500 b:tuple\n    \u2502   \u251c\u2500\u2500 [0]=2.0\n    \u2502   \u2514\u2500\u2500 [1]=3.0\n    \u2514\u2500\u2500 c=f32[3](\u03bc=5.00, \u03c3=0.82, \u2208[4.00,6.00])\n```\n\n</td>\n\n</tr>\n</table>\n</div>\n\n### \ud83c\udfa8 Visualize<a id=\"Viz\">\n\n<details> <summary>Visualize PyTrees in five different ways</summary>\n\n<div align=\"center\">\n<table>\n<tr>\n <td align = \"center\"> tree_summary</td> \n <td align = \"center\">tree_diagram</td>\n <td align = \"center\">[tree_mermaid](https://mermaid.js.org)(Native support in Github/Notion)</td>\n <td align= \"center\"> tree_repr </td>\n <td align=\"center\" > tree_str </td>\n\n</tr>\n\n<tr>\n<td>\n\n```python\nprint(pytc.tree_summary(tree, depth=1))\n\u250c\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u2500\u2510\n\u2502Name\u2502Type  \u2502Count\u2502Size  \u2502\n\u251c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n\u2502a   \u2502int   \u25021    \u250228.00B\u2502\n\u251c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n\u2502b   \u2502tuple \u25022    \u250248.00B\u2502\n\u251c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n\u2502c   \u2502f32[3]\u25023    \u250212.00B\u2502\n\u251c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n\u2502\u03a3   \u2502Tree  \u25026    \u250288.00B\u2502\n\u2514\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2500\u2518\n```\n\n</td>\n\n<td>\n\n```python\n\nprint(pytc.tree_diagram(tree, depth=1))\nTree\n    \u251c\u2500\u2500 a=1\n    \u251c\u2500\u2500 b=(..., ...)\n    \u2514\u2500\u2500 c=f32[3](\u03bc=5.00, \u03c3=0.82, \u2208[4.00,6.00])\n```\n\n </td>\n\n<td>\n\n```python\nprint(pytc.tree_mermaid(tree, depth=1))\n```\n\n```mermaid\n\nflowchart LR\n    id15696277213149321320(<b>Tree</b>)\n    id15696277213149321320--->|\"1 leaf<br>28.00B\"|id4205845433746830897(\"<b>a</b>:int=1\")\n    id15696277213149321320--->|\"2 leaf<br>48.00B\"|id4682191244783855647(\"<b>b</b>:tuple=(..., ...)\")\n    id15696277213149321320--->|\"3 leaf<br>12.00B\"|id14652085615030570957(\"<b>c</b>:ArrayImpl=f32[3](\u03bc=5.00, \u03c3=0.82, \u2208[4.00,6.00])\")\n```\n\n</td>\n\n<td>\n\n```python\nprint(pytc.tree_repr(tree, depth=1))\nTree(a=1, b=(..., ...), c=f32[3](\u03bc=5.00, \u03c3=0.82, \u2208[4.00,6.00]))\n```\n\n</td>\n\n<td>\n\n```python\nprint(pytc.tree_str(tree, depth=1))\nTree(a=1, b=(..., ...), c=[4. 5. 6.])\n```\n\n</td>\n\n</tr>\n\n<tr>\n\n<td>\n\n```python\nprint(pytc.tree_summary(tree, depth=2))\n\u250c\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u2500\u2510\n\u2502Name\u2502Type  \u2502Count\u2502Size  \u2502\n\u251c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n\u2502a   \u2502int   \u25021    \u250228.00B\u2502\n\u251c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n\u2502b[0]\u2502float \u25021    \u250224.00B\u2502\n\u251c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n\u2502b[1]\u2502float \u25021    \u250224.00B\u2502\n\u251c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n\u2502c   \u2502f32[3]\u25023    \u250212.00B\u2502\n\u251c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n\u2502\u03a3   \u2502Tree  \u25026    \u250288.00B\u2502\n\u2514\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2500\u2518\n```\n\n</td>\n\n<td>\n\n```python\nprint(pytc.tree_diagram(tree, depth=2))\nTree\n    \u251c\u2500\u2500 a=1\n    \u251c\u2500\u2500 b:tuple\n    \u2502   \u251c\u2500\u2500 [0]=2.0\n    \u2502   \u2514\u2500\u2500 [1]=3.0\n    \u2514\u2500\u2500 c=f32[3](\u03bc=5.00, \u03c3=0.82, \u2208[4.00,6.00])\n```\n\n</td>\n\n<td>\n\n```python\nprint(pytc.tree_mermaid(tree, depth=2))\n```\n\n```mermaid\nflowchart LR\n    id15696277213149321320(<b>Tree</b>)\n    id15696277213149321320--->id4205845433746830897(\"<b>a</b>:int=1\")\n    id15696277213149321320--->|\"1 leaf<br>24.00B\"|id8168961130706115346(\"<b>b</b>:tuple\")\n    id8168961130706115346--->|\"1 leaf<br>24.00B\"|id2766159651176208202(\"<b>[0]</b>:float=2.0\")\n    id15696277213149321320--->|\"1 leaf<br>24.00B\"|id12408280303145007954(\"<b>b</b>:tuple\")\n    id12408280303145007954--->|\"1 leaf<br>24.00B\"|id7897116322308127883(\"<b>[1]</b>:float=3.0\")\n    id15696277213149321320--->id14652085615030570957(\"<b>c</b>:ArrayImpl=f32[3](\u03bc=5.00, \u03c3=0.82, \u2208[4.00,6.00])\")\n```\n\n</td>\n\n<td>\n\n```python\nprint(pytc.tree_repr(tree, depth=2))\nTree(a=1, b=(2.0, 3.0), c=f32[3](\u03bc=5.00, \u03c3=0.82, \u2208[4.00,6.00]))\n```\n\n</td>\n\n<td>\n\n```python\nprint(pytc.tree_str(tree, depth=2))\nTree(a=1, b=(2.0, 3.0), c=[4. 5. 6.])\n```\n\n</td>\n\n</tr>\n\n </table>\n\n </div>\n\n</details>\n\n### \ud83c\udfc3 Working with `jax` transformation\n\n<details> <summary>Make arbitrary PyTrees work with jax transformations</summary>\n\nParameters are defined in `Tree` at the top of class definition similar to defining\n`dataclasses.dataclass` field.\nLets optimize our parameters\n\n```python\n\n@jax.grad\ndef loss_func(tree:Tree, x:jax.Array):\n    preds = jax.vmap(tree)(x)  # <--- vectorize the tree call over the leading axis\n    return jnp.mean(preds**2)  # <--- return the mean squared error\n\n@jax.jit\ndef train_step(tree:Tree, x:jax.Array):\n    grads = loss_func(tree, x)\n    # apply a small gradient step\n    return jax.tree_util.tree_map(lambda x, g: x - 1e-3*g, tree, grads)\n\n# lets freeze the non-differentiable parts of the tree\n# in essence any non inexact type should be frozen to\n# make the tree differentiable and work with jax transformations\njaxable_tree = jax.tree_util.tree_map(lambda x: pytc.freeze(x) if pytc.is_nondiff(x) else x, tree)\n\nfor epoch in range(1_000):\n    jaxable_tree = train_step(jaxable_tree, jnp.ones([10,1]))\n\nprint(jaxable_tree)\n# **the `frozen` params have \"#\" prefix**\n#Tree(a=#1, b=(-4.2826524, 3.0), c=[2.3924797 2.905778  3.4190805])\n\n\n# unfreeze the tree\ntree = jax.tree_util.tree_map(pytc.unfreeze, jaxable_tree, is_leaf=pytc.is_frozen)\nprint(tree)\n# Tree(a=1, b=(-4.2826524, 3.0), c=[2.3924797 2.905778  3.4190805])\n```\n\n</details>\n\n#### \u261d\ufe0f Advanced Indexing with `.at[]` <a id=\"Indexing\">\n\n<details> <summary>Out-of-place updates using mask, attribute name or index</summary>\n\n`PyTreeClass` offers 3 means of indexing through `.at[]`\n\n1. Indexing by boolean mask.\n2. Indexing by attribute name.\n3. Indexing by Leaf index.\n\n**Since `treeclass` wrapped class are immutable, `.at[]` operations returns new instance of the tree**\n\n#### Index update by boolean mask\n\n```python\ntree = Tree()\n# Tree(a=1, b=(2, 3), c=i32[3](\u03bc=5.00, \u03c3=0.82, \u2208[4,6]))\n\n# lets create a mask for values > 4\nmask = jax.tree_util.tree_map(lambda x: x>4, tree)\n\nprint(mask)\n# Tree(a=False, b=(False, False), c=[False  True  True])\n\nprint(tree.at[mask].get())\n# Tree(a=None, b=(None, None), c=[5 6])\n\nprint(tree.at[mask].set(10))\n# Tree(a=1, b=(2, 3), c=[ 4 10 10])\n\nprint(tree.at[mask].apply(lambda x: 10))\n# Tree(a=1, b=(2, 3), c=[ 4 10 10])\n```\n\n#### Index update by attribute name\n\n```python\ntree = Tree()\n# Tree(a=1, b=(2, 3), c=i32[3](\u03bc=5.00, \u03c3=0.82, \u2208[4,6]))\n\nprint(tree.at[\"a\"].get())\n# Tree(a=1, b=(None, None), c=None)\n\nprint(tree.at[\"a\"].set(10))\n# Tree(a=10, b=(2, 3), c=[4 5 6])\n\nprint(tree.at[\"a\"].apply(lambda x: 10))\n# Tree(a=10, b=(2, 3), c=[4 5 6])\n```\n\n#### Index update by integer index\n\n```python\ntree = Tree()\n# Tree(a=1, b=(2, 3), c=i32[3](\u03bc=5.00, \u03c3=0.82, \u2208[4,6]))\n\nprint(tree.at[1].at[0].get())\n# Tree(a=None, b=(2.0, None), c=None)\n\nprint(tree.at[1].at[0].set(10))\n# Tree(a=1, b=(10, 3.0), c=[4. 5. 6.])\n\nprint(tree.at[1].at[0].apply(lambda x: 10))\n# Tree(a=1, b=(10, 3.0), c=[4. 5. 6.])\n```\n\n</details>\n\n<details>\n\n<summary>\n\n## \ud83d\udcdc Stateful computations<a id=\"stateful_computation\"></a> </summary>\n\nFirst, [Under jax.jit jax requires states to be explicit](https://jax.readthedocs.io/en/latest/jax-101/07-state.html?highlight=state), this means that for any class instance; variables needs to be separated from the class and be passed explictly. However when using @pytc.treeclass no need to separate the instance variables ; instead the whole instance is passed as a state.\n\nUsing the following pattern,Updating state **functionally** can be achieved under `jax.jit`\n\n```python\nimport jax\nimport pytreeclass as pytc\n\n@pytc.treeclass\nclass Counter:\n    calls : int = 0\n\n    def increment(self):\n        self.calls += 1\ncounter = Counter() # Counter(calls=0)\n```\n\nHere, we define the update function. Since the increment method mutate the internal state, thus we need to use the functional approach to update the state by using `.at`. To achieve this we can use `.at[method_name].__call__(*args,**kwargs)`, this functional call will return the value of this call and a _new_ model instance with the update state.\n\n```python\n@jax.jit\ndef update(counter):\n    value, new_counter = counter.at[\"increment\"]()\n    return new_counter\n\nfor i in range(10):\n    counter = update(counter)\n\nprint(counter.calls) # 10\n```\n\n</details>\n\n## \u2795 More<a id=\"more\"></a>\n\n<details><summary>[Advanced] Register custom user-defined classes to work with visualization and indexing tools. </summary>\n\nSimilar to [`jax.tree_util.register_pytree_node`](https://jax.readthedocs.io/en/latest/pytrees.html#extending-pytrees), `PyTreeClass` register common data structures and `treeclass` wrapped classes to figure out how to define the names, types, index, and metadatas of certain leaf along its path.\n\nHere is an example of registering\n\n```python\n\nclass Tree:\n    def __init__(self, a, b):\n        self.a = a\n        self.b = b\n\n    def __repr__(self) -> str:\n        return f\"{self.__class__.__name__}(a={self.a}, b={self.b})\"\n\n\n# jax flatten rule\ndef tree_flatten(tree):\n    return (tree.a, tree.b), None\n\n# jax unflatten rule\ndef tree_unflatten(_, children):\n    return Tree(*children)\n\n# PyTreeClass flatten rule\ndef pytc_tree_flatten(tree):\n    names = (\"a\", \"b\")\n    types = (type(tree.a), type(tree.b))\n    indices = (0,1)\n    metadatas = (None, None)\n    return [*zip(names, types, indices, metadatas)]\n\n\n# Register with `jax`\njax.tree_util.register_pytree_node(Tree, tree_flatten, tree_unflatten)\n\n# Register the `Tree` class trace function to support indexing\npytc.register_pytree_node_trace(Tree, pytc_tree_flatten)\n\ntree = Tree(1, 2)\n\n# works with jax\njax.tree_util.tree_leaves(tree)  # [1, 2]\n\n# works with PyTreeClass viz tools\nprint(pytc.tree_summary(tree))\n\n# \u250c\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u2500\u2510\n# \u2502Name\u2502Type\u2502Count\u2502Size  \u2502\n# \u251c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n# \u2502a   \u2502int \u25021    \u250228.00B\u2502\n# \u251c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n# \u2502b   \u2502int \u25021    \u250228.00B\u2502\n# \u251c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n# \u2502\u03a3   \u2502Tree\u25022    \u250256.00B\u2502\n# \u2514\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2500\u2518\n\n```\n\nAfter registeration, you can use internal tools like\n\n- `pytc.tree_map_with_trace`\n- `pytc.tree_leaves_with_trace`\n- `pytc.tree_flatten_with_trace`\n\nMore details on that soon.\n\n</details>\n\n<details> <summary>Validate or convert inputs using callbacks</summary>\n\n`PyTreeClass` includes `callbacks` in the `field` to apply a sequence of functions on input at setting the attribute stage. The callback is quite useful in several cases, for instance, to ensure a certain input type within a valid range. See example:\n\n```python\nimport jax\nimport pytreeclass as pytc\n\ndef positive_int_callback(value):\n    if not isinstance(value, int):\n        raise TypeError(\"Value must be an integer\")\n    if value <= 0:\n        raise ValueError(\"Value must be positive\")\n    return value\n\n\n@pytc.treeclass\nclass Tree:\n    in_features:int = pytc.field(callbacks=[positive_int_callback])\n\n\ntree = Tree(1)\n# no error\n\ntree = Tree(0)\n# ValueError: Error for field=`in_features`:\n# Value must be positive\n\ntree = Tree(1.0)\n# TypeError: Error for field=`in_features`:\n# Value must be an integer\n```\n\n</details>\n\n<details>  <summary> Add leafwise math operations to PyTreeClass wrapped class</summary>\n\n```python\nimport functools as ft\nimport pytreeclass as pytc\n\n@ft.partial(pytc.treeclass, leafwise=True)\nclass Tree:\n    a:int = 1\n    b:tuple[float] = (2.,3.)\n    c:jax.Array = jnp.array([4.,5.,6.])\n\n    def __call__(self, x):\n        return self.a + self.b[0] + self.c + x\n\ntree = Tree()\n\ntree + 100\n# Tree(a=101, b=(102.0, 103.0), c=f32[3](\u03bc=105.00, \u03c3=0.82, \u2208[104.00,106.00]))\n\n@jax.grad\ndef loss_func(tree:Tree, x:jax.Array):\n    preds = jax.vmap(tree)(x)  # <--- vectorize the tree call over the leading axis\n    return jnp.mean(preds**2)  # <--- return the mean squared error\n\n@jax.jit\ndef train_step(tree:Tree, x:jax.Array):\n    grads = loss_func(tree, x)\n    return tree - grads*1e-3  # <--- eliminate `tree_map`\n\n# lets freeze the non-differentiable parts of the tree\n# in essence any non inexact type should be frozen to\n# make the tree differentiable and work with jax transformations\njaxable_tree = jax.tree_util.tree_map(lambda x: pytc.freeze(x) if pytc.is_nondiff(x) else x, tree)\n\nfor epoch in range(1_000):\n    jaxable_tree = train_step(jaxable_tree, jnp.ones([10,1]))\n\nprint(jaxable_tree)\n# **the `frozen` params have \"#\" prefix**\n# Tree(a=#1, b=(-4.7176366, 3.0), c=[2.4973059 2.760783  3.024264 ])\n\n\n# unfreeze the tree\ntree = jax.tree_util.tree_map(pytc.unfreeze, jaxable_tree, is_leaf=pytc.is_frozen)\nprint(tree)\n# Tree(a=1, b=(-4.7176366, 3.0), c=[2.4973059 2.760783  3.024264 ])\n```\n\n</details>\n\n<details> <summary>Eliminate tree_map using bcmap + treeclass(..., leafwise=True) </summary>\n\nTDLR\n\n```python\nimport functools as ft\nimport pytreeclass as pytc\nimport jax.numpy as jnp\n\n@ft.partial(pytc.treeclass, leafwise=True)\nclass Tree:\n    a:int = 1\n    b:tuple[float] = (2.,3.)\n    c:jax.Array = jnp.array([4.,5.,6.])\n\ntree = Tree()\n\nprint(pytc.bcmap(jnp.where)(tree>2, tree+100, 0))\n# Tree(a=0, b=(0.0, 103.0), c=[104. 105. 106.])\n\n```\n\n`bcmap(func, is_leaf)` maps a function over [PyTrees](https://jax.readthedocs.io/en/latest/pytrees.html) leaves with automatic broadcasting for scalar arguments.\n\n`bcmap` is function transformation that broadcast a scalar to match the first argument of the function this enables us to convert a function like `jnp.where` to work with arbitrary tree structures without the need to write a specific function for each broadcasting case\n\nFor example, lets say we want to use `jnp.where` to zeros out all values in an arbitrary tree structure that are less than 0\n\ntree = ([1], {\"a\":1, \"b\":2}, (1,), -1,)\n\nwe can use `jax.tree_util.tree_map` to apply `jnp.where` to the tree but we need to write a specific function for broadcasting the scalar to the tree\n\n```python\ndef map_func(leaf):\n    # here we encoded the scalar `0` inside the function\n    return jnp.where(leaf>0, leaf, 0)\n\njtu.tree_map(map_func, tree)\n# ([Array(1, dtype=int32, weak_type=True)],\n#  {'a': Array(1, dtype=int32, weak_type=True),\n#   'b': Array(2, dtype=int32, weak_type=True)},\n#  (Array(1, dtype=int32, weak_type=True),),\n#  Array(0, dtype=int32, weak_type=True))\n```\n\nHowever, lets say we want to use `jnp.where` to set a value to a leaf value from another tree that looks like this\n\n```python\ndef map_func2(lhs_leaf, rhs_leaf):\n    # here we encoded the scalar `0` inside the function\n    return jnp.where(lhs_leaf>0, lhs_leaf, rhs_leaf)\n\ntree2 = jtu.tree_map(lambda x: 1000, tree)\n\njtu.tree_map(map_func2, tree, tree2)\n# ([Array(1, dtype=int32, weak_type=True)],\n#  {'a': Array(1, dtype=int32, weak_type=True),\n#   'b': Array(2, dtype=int32, weak_type=True)},\n#  (Array(1, dtype=int32, weak_type=True),),\n#  Array(1000, dtype=int32, weak_type=True))\n```\n\nNow, `bcmap` makes this easier by figuring out the broadcasting case.\n\n```python\nbroadcastable_where = pytc.bcmap(jnp.where)\nmask = jtu.tree_map(lambda x: x>0, tree)\n```\n\ncase 1\n\n```python\nbroadcastable_where(mask, tree, 0)\n# ([Array(1, dtype=int32, weak_type=True)],\n#  {'a': Array(1, dtype=int32, weak_type=True),\n#   'b': Array(2, dtype=int32, weak_type=True)},\n#  (Array(1, dtype=int32, weak_type=True),),\n#  Array(0, dtype=int32, weak_type=True))\n```\n\ncase 2\n\n```python\nbroadcastable_where(mask, tree, tree2)\n# ([Array(1, dtype=int32, weak_type=True)],\n#  {'a': Array(1, dtype=int32, weak_type=True),\n#   'b': Array(2, dtype=int32, weak_type=True)},\n#  (Array(1, dtype=int32, weak_type=True),),\n#  Array(1000, dtype=int32, weak_type=True))\n```\n\nlets then take this a step further to eliminate `mask` from the equation\nby using `pytreeclass` with `leafwise=True `\n\n```python\n@ft.partial(pytc.treeclass, leafwise=True)\nclass Tree:\n    tree : tuple = ([1], {\"a\":1, \"b\":2}, (1,), -1,)\n\ntree = Tree()\n# Tree(tree=([1], {a:1, b:2}, (1), -1))\n```\n\ncase 1: broadcast scalar to tree\n\n````python\nprint(broadcastable_where(tree>0, tree, 0))\n# Tree(tree=([1], {a:1, b:2}, (1), 0))\n\ncase 2: broadcast tree to tree\n```python\nprint(broadcastable_where(tree>0, tree, tree+100))\n# Tree(tree=([1], {a:1, b:2}, (1), 99))\n````\n\n`bcmap` also works with all kind of arguments in the wrapped function\n\n```python\nprint(broadcastable_where(tree>0, x=tree, y=tree+100))\n# Tree(tree=([1], {a:1, b:2}, (1), 99))\n```\n\nin concolusion, `bcmap` is a function transformation that can be used to\nto make functions work with arbitrary tree structures without the need to write\na specific function for each broadcasting case\n\nMoreover, `bcmap` can be more powerful when used with `pytreeclass` to\nfacilitate operation of arbitrary functions on `PyTree` objects\nwithout the need to use `tree_map`\n\n</details>\n\n<details><summary>Use PyTreeClass vizualization tools with arbitrary PyTrees </summary>\n\n```python\nimport jax\nimport pytreeclass as pytc\n\ntree = [1, [2,3], 4]\n\nprint(pytc.tree_diagram(tree,depth=1))\n# list\n#     \u251c\u2500\u2500 [0]=1\n#     \u251c\u2500\u2500 [1]=[..., ...]\n#     \u2514\u2500\u2500 [2]=4\n\nprint(pytc.tree_diagram(tree,depth=2))\n# list\n#     \u251c\u2500\u2500 [0]=1\n#     \u251c\u2500\u2500 [1]:list\n#     \u2502   \u251c\u2500\u2500 [0]=2\n#     \u2502   \u2514\u2500\u2500 [1]=3\n#     \u2514\u2500\u2500 [2]=4\n\n\nprint(pytc.tree_summary(tree,depth=1))\n# \u250c\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2510\n# \u2502Name\u2502Type\u2502Count\u2502Size   \u2502\n# \u251c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n# \u2502[0] \u2502int \u25021    \u250228.00B \u2502\n# \u251c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n# \u2502[1] \u2502list\u25022    \u250256.00B \u2502\n# \u251c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n# \u2502[2] \u2502int \u25021    \u250228.00B \u2502\n# \u251c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n# \u2502\u03a3   \u2502list\u25024    \u2502112.00B\u2502\n# \u2514\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2518\n\nprint(pytc.tree_summary(tree,depth=2))\n# \u250c\u2500\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2510\n# \u2502Name  \u2502Type\u2502Count\u2502Size   \u2502\n# \u251c\u2500\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n# \u2502[0]   \u2502int \u25021    \u250228.00B \u2502\n# \u251c\u2500\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n# \u2502[1][0]\u2502int \u25021    \u250228.00B \u2502\n# \u251c\u2500\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n# \u2502[1][1]\u2502int \u25021    \u250228.00B \u2502\n# \u251c\u2500\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n# \u2502[2]   \u2502int \u25021    \u250228.00B \u2502\n# \u251c\u2500\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n# \u2502\u03a3     \u2502list\u25024    \u2502112.00B\u2502\n# \u2514\u2500\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2518\n```\n\n</details>\n\n<details><summary>Use PyTreeClass components with other libraries</summary>\n\n```python\nimport jax\nimport pytreeclass as pytc\nfrom flax import struct\n\n@struct.dataclass\nclass FlaxTree:\n    a:int = 1\n    b:tuple[float] = (2.,3.)\n    c:jax.Array = jax.numpy.array([4.,5.,6.])\n\n    def __repr__(self) -> str:\n        return pytc.tree_repr(self)\n    def __str__(self) -> str:\n        return pytc.tree_str(self)\n    @property\n    def at(self):\n        return pytc.tree_indexer(self)\n\ndef pytc_flatten_rule(tree):\n    names =(\"a\",\"b\",\"c\")\n    types = map(type, (tree.a, tree.b, tree.c))\n    indices = range(3)\n    metadatas= (None, None, None)\n    return [*zip(names, types, indices, metadatas)]\n\npytc.register_pytree_node_trace(FlaxTree, pytc_flatten_rule)\n\nflax_tree = FlaxTree()\n\nprint(f\"{flax_tree!r}\")\n# FlaxTree(a=1, b=(2.0, 3.0), c=f32[3](\u03bc=5.00, \u03c3=0.82, \u2208[4.00,6.00]))\n\nprint(f\"{flax_tree!s}\")\n# FlaxTree(a=1, b=(2.0, 3.0), c=[4. 5. 6.])\n\nprint(pytc.tree_diagram(flax_tree))\n# FlaxTree\n#     \u251c\u2500\u2500 a=1\n#     \u251c\u2500\u2500 b:tuple\n#     \u2502   \u251c\u2500\u2500 [0]=2.0\n#     \u2502   \u2514\u2500\u2500 [1]=3.0\n#     \u2514\u2500\u2500 c=f32[3](\u03bc=5.00, \u03c3=0.82, \u2208[4.00,6.00])\n\nprint(pytc.tree_summary(flax_tree))\n# \u250c\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u252c\u2500\u2500\u2500\u2500\u2500\u2500\u2510\n# \u2502Name\u2502Type    \u2502Count\u2502Size  \u2502\n# \u251c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n# \u2502a   \u2502int     \u25021    \u250228.00B\u2502\n# \u251c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n# \u2502b[0]\u2502float   \u25021    \u250224.00B\u2502\n# \u251c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n# \u2502b[1]\u2502float   \u25021    \u250224.00B\u2502\n# \u251c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n# \u2502c   \u2502f32[3]  \u25023    \u250212.00B\u2502\n# \u251c\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u253c\u2500\u2500\u2500\u2500\u2500\u2500\u2524\n# \u2502\u03a3   \u2502FlaxTree\u25026    \u250288.00B\u2502\n# \u2514\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2500\u2518\n\nflax_tree.at[0].get()\n# FlaxTree(a=1, b=(None, None), c=None)\n\nflax_tree.at[\"a\"].set(10)\n# FlaxTree(a=10, b=(2.0, 3.0), c=f32[3](\u03bc=5.00, \u03c3=0.82, \u2208[4.00,6.00]))\n```\n\n</details>\n\n<details>\n<summary>Benchmark flatten/unflatten compared to Flax and Equinox </summary>\n\n<a href=\"https://colab.research.google.com/github/ASEM000/PyTreeClass/blob/main/assets/benchmark_flatten_unflatten.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n\n<table>\n\n<tr><td align=\"center\">CPU</td><td align=\"center\">GPU</td></tr>\n\n<tr>\n\n<td><img src='assets/benchmark_cpu.png'></td>\n<td><img src='assets/benchmark_gpu.png'></td>\n\n</tr>\n\n</table>\n\n</details>\n\n## \ud83d\udcd9 Acknowledgements<a id=\"acknowledgements\"></a>\n\n- [Farid Talibli (for visualization link generation backend)](https://www.linkedin.com/in/frdt98)\n- [Treex](https://github.com/cgarciae/treex), [Equinox](https://github.com/patrick-kidger/equinox), [tree-math](https://github.com/google/tree-math), [Flax](https://github.com/google/flax), [TensorFlow](https://www.tensorflow.org), [PyTorch](https://pytorch.org)\n- [Lovely JAX](https://github.com/xl0/lovely-jax)\n",
    "bugtrack_url": null,
    "license": "Apache-2.0",
    "summary": "JAX compatible dataclass.",
    "version": "0.2.1",
    "split_keywords": [
        "python",
        "machine-learning",
        "pytorch",
        "jax"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "651df5b93409dde5fa7391ea39626fcab4bf25a19f7d50d22cde00e7403a99b9",
                "md5": "72a8c8e285cc4201ea6e3dd31d29486e",
                "sha256": "816cfee6cd856580c094565700d1744e8a7e67d8fbb9b9f3c0a2f71ce48ac4f8"
            },
            "downloads": -1,
            "filename": "pytreeclass-0.2.1-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "72a8c8e285cc4201ea6e3dd31d29486e",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.8",
            "size": 58384,
            "upload_time": "2023-03-19T03:37:07",
            "upload_time_iso_8601": "2023-03-19T03:37:07.636070Z",
            "url": "https://files.pythonhosted.org/packages/65/1d/f5b93409dde5fa7391ea39626fcab4bf25a19f7d50d22cde00e7403a99b9/pytreeclass-0.2.1-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "0ff7e7c371fed186f64e1462054d60244c45435afd4cacb0662b1dfe8f4e5908",
                "md5": "dda54f48459237246a6852b650d99802",
                "sha256": "3de22bd6eb3931f2d30a749fba2f656e997058ffb6aba8b30a918d44d47e2db6"
            },
            "downloads": -1,
            "filename": "pytreeclass-0.2.1.tar.gz",
            "has_sig": false,
            "md5_digest": "dda54f48459237246a6852b650d99802",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.8",
            "size": 58332,
            "upload_time": "2023-03-19T03:37:10",
            "upload_time_iso_8601": "2023-03-19T03:37:10.141936Z",
            "url": "https://files.pythonhosted.org/packages/0f/f7/e7c371fed186f64e1462054d60244c45435afd4cacb0662b1dfe8f4e5908/pytreeclass-0.2.1.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2023-03-19 03:37:10",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "github_user": "ASEM000",
    "github_project": "pytreeclass",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "lcname": "pytreeclass"
}
        
Elapsed time: 0.12592s