# Jaxpr-Viz
JAX Computation Graph Visualisation Tool
JAX has built-in functionality to visualise the
HLO graph generated by JAX, but I've found this rather
low-level for some use-cases.
The intention of this package is to visualise how
sub-functions are connected in JAX programs. It does
this by converting the [JaxPr](https://jax.readthedocs.io/en/latest/jaxpr.html)
representation into a pydot graph. See [here](.github/docs/gallery.md)
for examples.
> **NOTE:** This project is still at an early stage and may not
> support all JAX functionality (or permutations thereof). If you spot
> some strange behaviour please create a [Github issue](https://github.com/zombie-einstein/jaxpr-viz/issues).
## Installation
Install with pip:
```bash
pip install jpviz
```
Dependent on your system you may also need to install [Graphviz](https://www.graphviz.org/)
## Usage
Jaxpr-viz can be used to visualise jit compiled (and nested)
functions. It wraps jit compiled functions, which when called
with concrete values returns a [pydot](https://github.com/pydot/pydot)
graph.
For example this simple computation graph
```python
import jax
import jax.numpy as jnp
import jpviz
@jax.jit
def foo(x):
return 2 * x
@jax.jit
def bar(x):
x = foo(x)
return x - 1
# Wrap function and call with concrete arguments
# here dot_graph is a pydot object
dot_graph = jpviz.draw(bar)(jnp.arange(10))
# This renders the graph to a png file
dot_graph.write_png("computation_graph.png")
```
produces this image
![bar computation graph](.github/images/bar_collapsed.png)
Pydot has a number of options for rendering graphs, see
[here](https://github.com/pydot/pydot#output).
> **NOTE:** For sub-functions to show as nodes/sub-graphs they
> need to be marked with `@jax.jit`, otherwise they will just
> merged into thir parent graph.
### Jupyter Notebook
To show the rendered graph in a jupyter notebook you can use the
helper function `view_pydot`
```python
...
dot_graph = jpviz.draw(bar)(jnp.arange(10))
jpviz.view_pydot(dot)
```
### Visualisation Options
#### Collapse Nodes
By default, functions that are composed of only primitive functions
are collapsed into a single node (like `foo` in the above example).
The full computation graph can be rendered using the `collapse_primitives`
flag, setting it to `False` in the above example
```python
...
dot_graph = jpviz.draw(bar, collapse_primitives=False)(jnp.arange(10))
...
```
produces
![bar computation graph](.github/images/bar_expanded.png)
#### Show Types
By default, type information is included in the node labels, this
can be hidden using the `show_avals` flag, setting it to `False`
```python
...
dot_graph = jpviz.draw(bar, show_avals=False)(jnp.arange(10))
...
```
produces
![bar computation graph](.github/images/bar_no_types.png "Title")
> **NOTE:** The labels of the nodes don't currently correspond
> to argument/variable names in the original Python code. Since
> JAX unpacks arguments/outputs to tuples they do correspond
> to the positioning of arguments and outputs.
## Examples
See [here](.github/docs/gallery.md) for more examples of rendered computation graphs.
## Developers
Developer notes can be found [here](.github/docs/developers.md).
Raw data
{
"_id": null,
"home_page": "https://github.com/zombie-einstein/jaxpr-viz",
"name": "jpviz",
"maintainer": null,
"docs_url": null,
"requires_python": "<3.13,>=3.10",
"maintainer_email": null,
"keywords": "JAX, Computation Graph",
"author": "Zombie-Einstein",
"author_email": "zombie-einstein@proton.me",
"download_url": "https://files.pythonhosted.org/packages/c3/e2/892dd192945d2ba75b0a0437a0b637d5bed701e925f427303d0fb5ff5db1/jpviz-0.1.7.tar.gz",
"platform": null,
"description": "# Jaxpr-Viz\n\nJAX Computation Graph Visualisation Tool\n\nJAX has built-in functionality to visualise the\nHLO graph generated by JAX, but I've found this rather\nlow-level for some use-cases.\n\nThe intention of this package is to visualise how\nsub-functions are connected in JAX programs. It does\nthis by converting the [JaxPr](https://jax.readthedocs.io/en/latest/jaxpr.html)\nrepresentation into a pydot graph. See [here](.github/docs/gallery.md)\nfor examples.\n\n> **NOTE:** This project is still at an early stage and may not\n> support all JAX functionality (or permutations thereof). If you spot\n> some strange behaviour please create a [Github issue](https://github.com/zombie-einstein/jaxpr-viz/issues).\n\n## Installation\n\nInstall with pip:\n\n```bash\npip install jpviz\n```\n\nDependent on your system you may also need to install [Graphviz](https://www.graphviz.org/)\n\n## Usage\n\nJaxpr-viz can be used to visualise jit compiled (and nested)\nfunctions. It wraps jit compiled functions, which when called\nwith concrete values returns a [pydot](https://github.com/pydot/pydot)\ngraph.\n\nFor example this simple computation graph\n\n```python\nimport jax\nimport jax.numpy as jnp\n\nimport jpviz\n\n@jax.jit\ndef foo(x):\n return 2 * x\n\n@jax.jit\ndef bar(x):\n x = foo(x)\n return x - 1\n\n# Wrap function and call with concrete arguments\n# here dot_graph is a pydot object\ndot_graph = jpviz.draw(bar)(jnp.arange(10))\n# This renders the graph to a png file\ndot_graph.write_png(\"computation_graph.png\")\n```\n\nproduces this image\n\n![bar computation graph](.github/images/bar_collapsed.png)\n\nPydot has a number of options for rendering graphs, see\n[here](https://github.com/pydot/pydot#output).\n\n> **NOTE:** For sub-functions to show as nodes/sub-graphs they\n> need to be marked with `@jax.jit`, otherwise they will just\n> merged into thir parent graph.\n\n### Jupyter Notebook\n\nTo show the rendered graph in a jupyter notebook you can use the\nhelper function `view_pydot`\n\n```python\n...\ndot_graph = jpviz.draw(bar)(jnp.arange(10))\njpviz.view_pydot(dot)\n```\n\n### Visualisation Options\n\n#### Collapse Nodes\nBy default, functions that are composed of only primitive functions\nare collapsed into a single node (like `foo` in the above example).\nThe full computation graph can be rendered using the `collapse_primitives`\nflag, setting it to `False` in the above example\n\n```python\n...\ndot_graph = jpviz.draw(bar, collapse_primitives=False)(jnp.arange(10))\n...\n```\n\nproduces\n\n![bar computation graph](.github/images/bar_expanded.png)\n\n#### Show Types\n\nBy default, type information is included in the node labels, this\ncan be hidden using the `show_avals` flag, setting it to `False`\n\n```python\n...\ndot_graph = jpviz.draw(bar, show_avals=False)(jnp.arange(10))\n...\n```\n\nproduces\n\n![bar computation graph](.github/images/bar_no_types.png \"Title\")\n\n> **NOTE:** The labels of the nodes don't currently correspond\n> to argument/variable names in the original Python code. Since\n> JAX unpacks arguments/outputs to tuples they do correspond\n> to the positioning of arguments and outputs.\n\n## Examples\n\nSee [here](.github/docs/gallery.md) for more examples of rendered computation graphs.\n\n## Developers\n\nDeveloper notes can be found [here](.github/docs/developers.md).\n",
"bugtrack_url": null,
"license": "MIT",
"summary": "Jaxpr Visualisation Tool",
"version": "0.1.7",
"project_urls": {
"Homepage": "https://github.com/zombie-einstein/jaxpr-viz",
"Repository": "https://github.com/zombie-einstein/jaxpr-viz"
},
"split_keywords": [
"jax",
" computation graph"
],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "8ba58489c2b5971f96f05f1228dfc95f2cd2aac5b6603152b8218d5a11e5b47d",
"md5": "be3640fdded17396219b52cdc9dabd50",
"sha256": "4bec5935cd9e41fdc60a9daf0b4f50bb479906aa2205d012ee1a2efb5e02ebae"
},
"downloads": -1,
"filename": "jpviz-0.1.7-py3-none-any.whl",
"has_sig": false,
"md5_digest": "be3640fdded17396219b52cdc9dabd50",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": "<3.13,>=3.10",
"size": 12255,
"upload_time": "2024-12-20T00:31:02",
"upload_time_iso_8601": "2024-12-20T00:31:02.407223Z",
"url": "https://files.pythonhosted.org/packages/8b/a5/8489c2b5971f96f05f1228dfc95f2cd2aac5b6603152b8218d5a11e5b47d/jpviz-0.1.7-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": "",
"digests": {
"blake2b_256": "c3e2892dd192945d2ba75b0a0437a0b637d5bed701e925f427303d0fb5ff5db1",
"md5": "5bae2883a19c11c0df547c07afa7e3bb",
"sha256": "b4eefc11217d14a3a7167726bd1186681b2d9c4cd27e9dc1f386c91195a71b4f"
},
"downloads": -1,
"filename": "jpviz-0.1.7.tar.gz",
"has_sig": false,
"md5_digest": "5bae2883a19c11c0df547c07afa7e3bb",
"packagetype": "sdist",
"python_version": "source",
"requires_python": "<3.13,>=3.10",
"size": 11446,
"upload_time": "2024-12-20T00:31:04",
"upload_time_iso_8601": "2024-12-20T00:31:04.788389Z",
"url": "https://files.pythonhosted.org/packages/c3/e2/892dd192945d2ba75b0a0437a0b637d5bed701e925f427303d0fb5ff5db1/jpviz-0.1.7.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2024-12-20 00:31:04",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "zombie-einstein",
"github_project": "jaxpr-viz",
"travis_ci": false,
"coveralls": false,
"github_actions": true,
"lcname": "jpviz"
}