torchtrail


Nametorchtrail JSON
Version 0.0.19 PyPI version JSON
download
home_pagehttps://github.com/arakhmati/torchtrail
SummaryA library for tracing the execution of Pytorch operations and modules
upload_time2024-04-10 16:30:28
maintainerNone
docs_urlNone
authorAkhmed Rakhmati
requires_pythonNone
licenseNone
keywords pytorch tracing tracing library tracing tool tracing module tracing operations visualize visualize pytorch visualize pytorch operations visualize pytorch modules
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # torchtrail

[![PyPI version](https://badge.fury.io/py/torchtrail.svg)](https://badge.fury.io/py/torchtrail)
[![Build Status](https://github.com/arakhmati/torchtrail/actions/workflows/python-package.yml/badge.svg)](https://github.com/arakhmati/torchtrail/actions/workflows/python-package.yml)
[![GitHub license](https://img.shields.io/github/license/arakhmati/torchtrail)](https://github.com/arakhmati/torchtrail/blob/main/LICENSE)

`torchtrail` provides an external API to trace pytorch models and extract the graph of torch functions and modules that were executed. The graphs can then be visualized or used for other purposes.

## Installation Instructions

### On MacOs
```bash
brew install graphviz
pip install torchtrail
```

### On Ubuntu
```bash
sudo apt-get install graphviz
pip install torchtrail
```

## Examples

### Tracing a function
```python
import torch
import torchtrail

with torchtrail.trace():
    input_tensor = torch.rand(1, 64)
    output_tensor = torch.exp(input_tensor)
torchtrail.visualize(output_tensor, file_name="exp.svg")
```
![](https://raw.githubusercontent.com/arakhmati/torchtrail/main/docs/images/exp.svg)

The graph could be obtained as a `networkx.MultiDiGraph` using `torchtrail.get_graph`:
```python
graph: "networkx.MultiDiGraph" = torchtrail.get_graph(output_tensor)
```


### Tracing a module

```python
import torch
import transformers

import torchtrail

model_name = "google/bert_uncased_L-4_H-256_A-4"
config = transformers.BertConfig.from_pretrained(model_name)
config.num_hidden_layers = 1
model = transformers.BertModel.from_pretrained(model_name, config=config).eval()

with torchtrail.trace():
    input_tensor = torch.randint(0, model.config.vocab_size, (1, 64))
    output = model(input_tensor).last_hidden_state

torchtrail.visualize(output, max_depth=1, file_name="bert_max_depth_1.svg")
```

![](https://raw.githubusercontent.com/arakhmati/torchtrail/main/docs/images/bert_max_depth_1.svg)


```python
torchtrail.visualize(output, max_depth=2, file_name="bert_max_depth_2.svg")
```

![](https://raw.githubusercontent.com/arakhmati/torchtrail/main/docs/images/bert_max_depth_2.svg)

The graph of the full module can be visualized by omitting `max_depth` argument

```python
torchtrail.visualize(output, file_name="bert.svg")
```

![](https://raw.githubusercontent.com/arakhmati/torchtrail/main/docs/images/bert.svg)

The graph could be obtained as a `networkx.MultiDiGraph` using `torchtrail.get_graph`:
```python
graph: "networkx.MultiDiGraph" = torchtrail.get_graph(output_tensor)
```

Alternatively, visualization of the modules can be turned off completely using `show_modules=False`

```python
torchtrail.visualize(output, show_modules=False, file_name="bert_show_modules_False.svg")
```

![](https://raw.githubusercontent.com/arakhmati/torchtrail/main/docs/images/bert_show_modules_False.svg)

The flattened graph could be obtained as a `networkx.MultiDiGraph` using `torchtrail.get_graph`:
```python
graph: "networkx.MultiDiGraph" = torchtrail.get_graph(output_tensor, flatten=True)
```


## Reference
- `torchtrail` was inspired by [torchview](https://github.com/mert-kurttutan/torchview). [mert-kurttutan](https://github.com/mert-kurttutan) did an amazing job with displaying torch graphs. However, one of the goals of `torchtrail` included producing [networkx](https://networkx.org)-compatible graph, therefore `torchtrail` was written.
- The idea to use persistent MultiDiGraph to trace torch operations was taken from [composit](https://github.com/arakhmati/composit)

            

Raw data

            {
    "_id": null,
    "home_page": "https://github.com/arakhmati/torchtrail",
    "name": "torchtrail",
    "maintainer": null,
    "docs_url": null,
    "requires_python": null,
    "maintainer_email": null,
    "keywords": "pytorch, tracing, tracing library, tracing tool, tracing module, tracing operations, visualize, visualize pytorch, visualize pytorch operations, visualize pytorch modules",
    "author": "Akhmed Rakhmati",
    "author_email": "akhmed.rakhmati@gmail.com",
    "download_url": "https://files.pythonhosted.org/packages/b5/b0/3ba522fcb2eb162d69c99a13625c02e27b3aca139f5ad9dc2464f79cf2f5/torchtrail-0.0.19.tar.gz",
    "platform": null,
    "description": "# torchtrail\n\n[![PyPI version](https://badge.fury.io/py/torchtrail.svg)](https://badge.fury.io/py/torchtrail)\n[![Build Status](https://github.com/arakhmati/torchtrail/actions/workflows/python-package.yml/badge.svg)](https://github.com/arakhmati/torchtrail/actions/workflows/python-package.yml)\n[![GitHub license](https://img.shields.io/github/license/arakhmati/torchtrail)](https://github.com/arakhmati/torchtrail/blob/main/LICENSE)\n\n`torchtrail` provides an external API to trace pytorch models and extract the graph of torch functions and modules that were executed. The graphs can then be visualized or used for other purposes.\n\n## Installation Instructions\n\n### On MacOs\n```bash\nbrew install graphviz\npip install torchtrail\n```\n\n### On Ubuntu\n```bash\nsudo apt-get install graphviz\npip install torchtrail\n```\n\n## Examples\n\n### Tracing a function\n```python\nimport torch\nimport torchtrail\n\nwith torchtrail.trace():\n    input_tensor = torch.rand(1, 64)\n    output_tensor = torch.exp(input_tensor)\ntorchtrail.visualize(output_tensor, file_name=\"exp.svg\")\n```\n![](https://raw.githubusercontent.com/arakhmati/torchtrail/main/docs/images/exp.svg)\n\nThe graph could be obtained as a `networkx.MultiDiGraph` using `torchtrail.get_graph`:\n```python\ngraph: \"networkx.MultiDiGraph\" = torchtrail.get_graph(output_tensor)\n```\n\n\n### Tracing a module\n\n```python\nimport torch\nimport transformers\n\nimport torchtrail\n\nmodel_name = \"google/bert_uncased_L-4_H-256_A-4\"\nconfig = transformers.BertConfig.from_pretrained(model_name)\nconfig.num_hidden_layers = 1\nmodel = transformers.BertModel.from_pretrained(model_name, config=config).eval()\n\nwith torchtrail.trace():\n    input_tensor = torch.randint(0, model.config.vocab_size, (1, 64))\n    output = model(input_tensor).last_hidden_state\n\ntorchtrail.visualize(output, max_depth=1, file_name=\"bert_max_depth_1.svg\")\n```\n\n![](https://raw.githubusercontent.com/arakhmati/torchtrail/main/docs/images/bert_max_depth_1.svg)\n\n\n```python\ntorchtrail.visualize(output, max_depth=2, file_name=\"bert_max_depth_2.svg\")\n```\n\n![](https://raw.githubusercontent.com/arakhmati/torchtrail/main/docs/images/bert_max_depth_2.svg)\n\nThe graph of the full module can be visualized by omitting `max_depth` argument\n\n```python\ntorchtrail.visualize(output, file_name=\"bert.svg\")\n```\n\n![](https://raw.githubusercontent.com/arakhmati/torchtrail/main/docs/images/bert.svg)\n\nThe graph could be obtained as a `networkx.MultiDiGraph` using `torchtrail.get_graph`:\n```python\ngraph: \"networkx.MultiDiGraph\" = torchtrail.get_graph(output_tensor)\n```\n\nAlternatively, visualization of the modules can be turned off completely using `show_modules=False`\n\n```python\ntorchtrail.visualize(output, show_modules=False, file_name=\"bert_show_modules_False.svg\")\n```\n\n![](https://raw.githubusercontent.com/arakhmati/torchtrail/main/docs/images/bert_show_modules_False.svg)\n\nThe flattened graph could be obtained as a `networkx.MultiDiGraph` using `torchtrail.get_graph`:\n```python\ngraph: \"networkx.MultiDiGraph\" = torchtrail.get_graph(output_tensor, flatten=True)\n```\n\n\n## Reference\n- `torchtrail` was inspired by [torchview](https://github.com/mert-kurttutan/torchview). [mert-kurttutan](https://github.com/mert-kurttutan) did an amazing job with displaying torch graphs. However, one of the goals of `torchtrail` included producing [networkx](https://networkx.org)-compatible graph, therefore `torchtrail` was written.\n- The idea to use persistent MultiDiGraph to trace torch operations was taken from [composit](https://github.com/arakhmati/composit)\n",
    "bugtrack_url": null,
    "license": null,
    "summary": "A library for tracing the execution of Pytorch operations and modules",
    "version": "0.0.19",
    "project_urls": {
        "Homepage": "https://github.com/arakhmati/torchtrail"
    },
    "split_keywords": [
        "pytorch",
        " tracing",
        " tracing library",
        " tracing tool",
        " tracing module",
        " tracing operations",
        " visualize",
        " visualize pytorch",
        " visualize pytorch operations",
        " visualize pytorch modules"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "b5b03ba522fcb2eb162d69c99a13625c02e27b3aca139f5ad9dc2464f79cf2f5",
                "md5": "7c1a1ce41a1a8ddd045c3a5050d65ff9",
                "sha256": "781a6c71b24a5d0d08ecd98d3e90a5e6a41f66a6dee1f66770459aa6d0629c93"
            },
            "downloads": -1,
            "filename": "torchtrail-0.0.19.tar.gz",
            "has_sig": false,
            "md5_digest": "7c1a1ce41a1a8ddd045c3a5050d65ff9",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": null,
            "size": 15310,
            "upload_time": "2024-04-10T16:30:28",
            "upload_time_iso_8601": "2024-04-10T16:30:28.723741Z",
            "url": "https://files.pythonhosted.org/packages/b5/b0/3ba522fcb2eb162d69c99a13625c02e27b3aca139f5ad9dc2464f79cf2f5/torchtrail-0.0.19.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2024-04-10 16:30:28",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "arakhmati",
    "github_project": "torchtrail",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "lcname": "torchtrail"
}
        
Elapsed time: 0.88575s