jpviz


Namejpviz JSON
Version 0.1.7 PyPI version JSON
download
home_pagehttps://github.com/zombie-einstein/jaxpr-viz
SummaryJaxpr Visualisation Tool
upload_time2024-12-20 00:31:04
maintainerNone
docs_urlNone
authorZombie-Einstein
requires_python<3.13,>=3.10
licenseMIT
keywords jax computation graph
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # Jaxpr-Viz

JAX Computation Graph Visualisation Tool

JAX has built-in functionality to visualise the
HLO graph generated by JAX, but I've found this rather
low-level for some use-cases.

The intention of this package is to visualise how
sub-functions are connected in JAX programs. It does
this by converting the [JaxPr](https://jax.readthedocs.io/en/latest/jaxpr.html)
representation into a pydot graph. See [here](.github/docs/gallery.md)
for examples.

> **NOTE:** This project is still at an early stage and may not
> support all JAX functionality (or permutations thereof). If you spot
> some strange behaviour please create a [Github issue](https://github.com/zombie-einstein/jaxpr-viz/issues).

## Installation

Install with pip:

```bash
pip install jpviz
```

Dependent on your system you may also need to install [Graphviz](https://www.graphviz.org/)

## Usage

Jaxpr-viz can be used to visualise jit compiled (and nested)
functions. It wraps jit compiled functions, which when called
with concrete values returns a [pydot](https://github.com/pydot/pydot)
graph.

For example this simple computation graph

```python
import jax
import jax.numpy as jnp

import jpviz

@jax.jit
def foo(x):
    return 2 * x

@jax.jit
def bar(x):
    x = foo(x)
    return x - 1

# Wrap function and call with concrete arguments
#  here dot_graph is a pydot object
dot_graph = jpviz.draw(bar)(jnp.arange(10))
# This renders the graph to a png file
dot_graph.write_png("computation_graph.png")
```

produces this image

![bar computation graph](.github/images/bar_collapsed.png)

Pydot has a number of options for rendering graphs, see
[here](https://github.com/pydot/pydot#output).

> **NOTE:** For sub-functions to show as nodes/sub-graphs they
> need to be marked with `@jax.jit`, otherwise they will just
> merged into thir parent graph.

### Jupyter Notebook

To show the rendered graph in a jupyter notebook you can use the
helper function `view_pydot`

```python
...
dot_graph = jpviz.draw(bar)(jnp.arange(10))
jpviz.view_pydot(dot)
```

### Visualisation Options

#### Collapse Nodes
By default, functions that are composed of only primitive functions
are collapsed into a single node (like `foo` in the above example).
The full computation graph can be rendered using the `collapse_primitives`
flag, setting it to `False` in the above example

```python
...
dot_graph = jpviz.draw(bar, collapse_primitives=False)(jnp.arange(10))
...
```

produces

![bar computation graph](.github/images/bar_expanded.png)

#### Show Types

By default, type information is included in the node labels, this
can be hidden using the `show_avals` flag, setting it to `False`

```python
...
dot_graph = jpviz.draw(bar, show_avals=False)(jnp.arange(10))
...
```

produces

![bar computation graph](.github/images/bar_no_types.png "Title")

> **NOTE:** The labels of the nodes don't currently correspond
> to argument/variable names in the original Python code. Since
> JAX unpacks arguments/outputs to tuples they do correspond
> to the positioning of arguments and outputs.

## Examples

See [here](.github/docs/gallery.md) for more examples of rendered computation graphs.

## Developers

Developer notes can be found [here](.github/docs/developers.md).

            

Raw data

            {
    "_id": null,
    "home_page": "https://github.com/zombie-einstein/jaxpr-viz",
    "name": "jpviz",
    "maintainer": null,
    "docs_url": null,
    "requires_python": "<3.13,>=3.10",
    "maintainer_email": null,
    "keywords": "JAX, Computation Graph",
    "author": "Zombie-Einstein",
    "author_email": "zombie-einstein@proton.me",
    "download_url": "https://files.pythonhosted.org/packages/c3/e2/892dd192945d2ba75b0a0437a0b637d5bed701e925f427303d0fb5ff5db1/jpviz-0.1.7.tar.gz",
    "platform": null,
    "description": "# Jaxpr-Viz\n\nJAX Computation Graph Visualisation Tool\n\nJAX has built-in functionality to visualise the\nHLO graph generated by JAX, but I've found this rather\nlow-level for some use-cases.\n\nThe intention of this package is to visualise how\nsub-functions are connected in JAX programs. It does\nthis by converting the [JaxPr](https://jax.readthedocs.io/en/latest/jaxpr.html)\nrepresentation into a pydot graph. See [here](.github/docs/gallery.md)\nfor examples.\n\n> **NOTE:** This project is still at an early stage and may not\n> support all JAX functionality (or permutations thereof). If you spot\n> some strange behaviour please create a [Github issue](https://github.com/zombie-einstein/jaxpr-viz/issues).\n\n## Installation\n\nInstall with pip:\n\n```bash\npip install jpviz\n```\n\nDependent on your system you may also need to install [Graphviz](https://www.graphviz.org/)\n\n## Usage\n\nJaxpr-viz can be used to visualise jit compiled (and nested)\nfunctions. It wraps jit compiled functions, which when called\nwith concrete values returns a [pydot](https://github.com/pydot/pydot)\ngraph.\n\nFor example this simple computation graph\n\n```python\nimport jax\nimport jax.numpy as jnp\n\nimport jpviz\n\n@jax.jit\ndef foo(x):\n    return 2 * x\n\n@jax.jit\ndef bar(x):\n    x = foo(x)\n    return x - 1\n\n# Wrap function and call with concrete arguments\n#  here dot_graph is a pydot object\ndot_graph = jpviz.draw(bar)(jnp.arange(10))\n# This renders the graph to a png file\ndot_graph.write_png(\"computation_graph.png\")\n```\n\nproduces this image\n\n![bar computation graph](.github/images/bar_collapsed.png)\n\nPydot has a number of options for rendering graphs, see\n[here](https://github.com/pydot/pydot#output).\n\n> **NOTE:** For sub-functions to show as nodes/sub-graphs they\n> need to be marked with `@jax.jit`, otherwise they will just\n> merged into thir parent graph.\n\n### Jupyter Notebook\n\nTo show the rendered graph in a jupyter notebook you can use the\nhelper function `view_pydot`\n\n```python\n...\ndot_graph = jpviz.draw(bar)(jnp.arange(10))\njpviz.view_pydot(dot)\n```\n\n### Visualisation Options\n\n#### Collapse Nodes\nBy default, functions that are composed of only primitive functions\nare collapsed into a single node (like `foo` in the above example).\nThe full computation graph can be rendered using the `collapse_primitives`\nflag, setting it to `False` in the above example\n\n```python\n...\ndot_graph = jpviz.draw(bar, collapse_primitives=False)(jnp.arange(10))\n...\n```\n\nproduces\n\n![bar computation graph](.github/images/bar_expanded.png)\n\n#### Show Types\n\nBy default, type information is included in the node labels, this\ncan be hidden using the `show_avals` flag, setting it to `False`\n\n```python\n...\ndot_graph = jpviz.draw(bar, show_avals=False)(jnp.arange(10))\n...\n```\n\nproduces\n\n![bar computation graph](.github/images/bar_no_types.png \"Title\")\n\n> **NOTE:** The labels of the nodes don't currently correspond\n> to argument/variable names in the original Python code. Since\n> JAX unpacks arguments/outputs to tuples they do correspond\n> to the positioning of arguments and outputs.\n\n## Examples\n\nSee [here](.github/docs/gallery.md) for more examples of rendered computation graphs.\n\n## Developers\n\nDeveloper notes can be found [here](.github/docs/developers.md).\n",
    "bugtrack_url": null,
    "license": "MIT",
    "summary": "Jaxpr Visualisation Tool",
    "version": "0.1.7",
    "project_urls": {
        "Homepage": "https://github.com/zombie-einstein/jaxpr-viz",
        "Repository": "https://github.com/zombie-einstein/jaxpr-viz"
    },
    "split_keywords": [
        "jax",
        " computation graph"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "8ba58489c2b5971f96f05f1228dfc95f2cd2aac5b6603152b8218d5a11e5b47d",
                "md5": "be3640fdded17396219b52cdc9dabd50",
                "sha256": "4bec5935cd9e41fdc60a9daf0b4f50bb479906aa2205d012ee1a2efb5e02ebae"
            },
            "downloads": -1,
            "filename": "jpviz-0.1.7-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "be3640fdded17396219b52cdc9dabd50",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": "<3.13,>=3.10",
            "size": 12255,
            "upload_time": "2024-12-20T00:31:02",
            "upload_time_iso_8601": "2024-12-20T00:31:02.407223Z",
            "url": "https://files.pythonhosted.org/packages/8b/a5/8489c2b5971f96f05f1228dfc95f2cd2aac5b6603152b8218d5a11e5b47d/jpviz-0.1.7-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "c3e2892dd192945d2ba75b0a0437a0b637d5bed701e925f427303d0fb5ff5db1",
                "md5": "5bae2883a19c11c0df547c07afa7e3bb",
                "sha256": "b4eefc11217d14a3a7167726bd1186681b2d9c4cd27e9dc1f386c91195a71b4f"
            },
            "downloads": -1,
            "filename": "jpviz-0.1.7.tar.gz",
            "has_sig": false,
            "md5_digest": "5bae2883a19c11c0df547c07afa7e3bb",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": "<3.13,>=3.10",
            "size": 11446,
            "upload_time": "2024-12-20T00:31:04",
            "upload_time_iso_8601": "2024-12-20T00:31:04.788389Z",
            "url": "https://files.pythonhosted.org/packages/c3/e2/892dd192945d2ba75b0a0437a0b637d5bed701e925f427303d0fb5ff5db1/jpviz-0.1.7.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2024-12-20 00:31:04",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "zombie-einstein",
    "github_project": "jaxpr-viz",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "lcname": "jpviz"
}
        
Elapsed time: 0.41200s