torchinfo


Nametorchinfo JSON
Version 1.8.0 PyPI version JSON
download
home_pagehttps://github.com/tyleryep/torchinfo
SummaryModel summary in PyTorch, based off of the original torchsummary.
upload_time2023-05-14 19:23:26
maintainer
docs_urlNone
authorTyler Yep @tyleryep
requires_python>=3.7
licenseMIT
keywords torch pytorch torchsummary torch-summary summary keras deep-learning ml torchinfo torch-info visualize model statistics layer stats
VCS
bugtrack_url
requirements torch torchvision numpy
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # torchinfo

[![Python 3.7+](https://img.shields.io/badge/python-3.7+-blue.svg)](https://www.python.org/downloads/release/python-370/)
[![PyPI version](https://badge.fury.io/py/torchinfo.svg)](https://badge.fury.io/py/torchinfo)
[![Conda version](https://img.shields.io/conda/vn/conda-forge/torchinfo)](https://anaconda.org/conda-forge/torchinfo)
[![Build Status](https://github.com/TylerYep/torchinfo/actions/workflows/test.yml/badge.svg)](https://github.com/TylerYep/torchinfo/actions/workflows/test.yml)
[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/TylerYep/torchinfo/main.svg)](https://results.pre-commit.ci/latest/github/TylerYep/torchinfo/main)
[![GitHub license](https://img.shields.io/github/license/TylerYep/torchinfo)](https://github.com/TylerYep/torchinfo/blob/main/LICENSE)
[![codecov](https://codecov.io/gh/TylerYep/torchinfo/branch/main/graph/badge.svg)](https://codecov.io/gh/TylerYep/torchinfo)
[![Downloads](https://pepy.tech/badge/torchinfo)](https://pepy.tech/project/torchinfo)

(formerly torch-summary)

Torchinfo provides information complementary to what is provided by `print(your_model)` in PyTorch, similar to Tensorflow's `model.summary()` API to view the visualization of the model, which is helpful while debugging your network. In this project, we implement a similar functionality in PyTorch and create a clean, simple interface to use in your projects.

This is a completely rewritten version of the original torchsummary and torchsummaryX projects by @sksq96 and @nmhkahn. This project addresses all of the issues and pull requests left on the original projects by introducing a completely new API.

Supports PyTorch versions 1.4.0+.

# Usage

```
pip install torchinfo
```

Alternatively, via conda:

```
conda install -c conda-forge torchinfo
```

# How To Use

```python
from torchinfo import summary

model = ConvNet()
batch_size = 16
summary(model, input_size=(batch_size, 1, 28, 28))
```

```
================================================================================================================
Layer (type:depth-idx)          Input Shape          Output Shape         Param #            Mult-Adds
================================================================================================================
SingleInputNet                  [7, 1, 28, 28]       [7, 10]              --                 --
├─Conv2d: 1-1                   [7, 1, 28, 28]       [7, 10, 24, 24]      260                1,048,320
├─Conv2d: 1-2                   [7, 10, 12, 12]      [7, 20, 8, 8]        5,020              2,248,960
├─Dropout2d: 1-3                [7, 20, 8, 8]        [7, 20, 8, 8]        --                 --
├─Linear: 1-4                   [7, 320]             [7, 50]              16,050             112,350
├─Linear: 1-5                   [7, 50]              [7, 10]              510                3,570
================================================================================================================
Total params: 21,840
Trainable params: 21,840
Non-trainable params: 0
Total mult-adds (M): 3.41
================================================================================================================
Input size (MB): 0.02
Forward/backward pass size (MB): 0.40
Params size (MB): 0.09
Estimated Total Size (MB): 0.51
================================================================================================================
```

<!-- single_input_all_cols.out -->

Note: if you are using a Jupyter Notebook or Google Colab, `summary(model, ...)` must be the returned value of the cell.
If it is not, you should wrap the summary in a print(), e.g. `print(summary(model, ...))`.
See `tests/jupyter_test.ipynb` for examples.

**This version now supports:**

- RNNs, LSTMs, and other recursive layers
- Branching output used to explore model layers using specified depths
- Returns ModelStatistics object containing all summary data fields
- Configurable rows/columns
- Jupyter Notebook / Google Colab

**Other new features:**

- Verbose mode to show weights and bias layers
- Accepts either input data or simply the input shape!
- Customizable line widths and batch dimension
- Comprehensive unit/output testing, linting, and code coverage testing

**Community Contributions:**

- Sequentials & ModuleLists (thanks to @roym899)
- Improved Mult-Add calculations (thanks to @TE-StefanUhlich, @zmzhang2000)
- Dict/Misc input data (thanks to @e-dorigatti)
- Pruned layer support (thanks to @MajorCarrot)

# Documentation

```python
def summary(
    model: nn.Module,
    input_size: Optional[INPUT_SIZE_TYPE] = None,
    input_data: Optional[INPUT_DATA_TYPE] = None,
    batch_dim: Optional[int] = None,
    cache_forward_pass: Optional[bool] = None,
    col_names: Optional[Iterable[str]] = None,
    col_width: int = 25,
    depth: int = 3,
    device: Optional[torch.device] = None,
    dtypes: Optional[List[torch.dtype]] = None,
    mode: str | None = None,
    row_settings: Optional[Iterable[str]] = None,
    verbose: int = 1,
    **kwargs: Any,
) -> ModelStatistics:
"""
Summarize the given PyTorch model. Summarized information includes:
    1) Layer names,
    2) input/output shapes,
    3) kernel shape,
    4) # of parameters,
    5) # of operations (Mult-Adds),
    6) whether layer is trainable

NOTE: If neither input_data or input_size are provided, no forward pass through the
network is performed, and the provided model information is limited to layer names.

Args:
    model (nn.Module):
            PyTorch model to summarize. The model should be fully in either train()
            or eval() mode. If layers are not all in the same mode, running summary
            may have side effects on batchnorm or dropout statistics. If you
            encounter an issue with this, please open a GitHub issue.

    input_size (Sequence of Sizes):
            Shape of input data as a List/Tuple/torch.Size
            (dtypes must match model input, default is FloatTensors).
            You should include batch size in the tuple.
            Default: None

    input_data (Sequence of Tensors):
            Arguments for the model's forward pass (dtypes inferred).
            If the forward() function takes several parameters, pass in a list of
            args or a dict of kwargs (if your forward() function takes in a dict
            as its only argument, wrap it in a list).
            Default: None

    batch_dim (int):
            Batch_dimension of input data. If batch_dim is None, assume
            input_data / input_size contains the batch dimension, which is used
            in all calculations. Else, expand all tensors to contain the batch_dim.
            Specifying batch_dim can be an runtime optimization, since if batch_dim
            is specified, torchinfo uses a batch size of 1 for the forward pass.
            Default: None

    cache_forward_pass (bool):
            If True, cache the run of the forward() function using the model
            class name as the key. If the forward pass is an expensive operation,
            this can make it easier to modify the formatting of your model
            summary, e.g. changing the depth or enabled column types, especially
            in Jupyter Notebooks.
            WARNING: Modifying the model architecture or input data/input size when
            this feature is enabled does not invalidate the cache or re-run the
            forward pass, and can cause incorrect summaries as a result.
            Default: False

    col_names (Iterable[str]):
            Specify which columns to show in the output. Currently supported: (
                "input_size",
                "output_size",
                "num_params",
                "params_percent",
                "kernel_size",
                "mult_adds",
                "trainable",
            )
            Default: ("output_size", "num_params")
            If input_data / input_size are not provided, only "num_params" is used.

    col_width (int):
            Width of each column.
            Default: 25

    depth (int):
            Depth of nested layers to display (e.g. Sequentials).
            Nested layers below this depth will not be displayed in the summary.
            Default: 3

    device (torch.Device):
            Uses this torch device for model and input_data.
            If not specified, uses the dtype of input_data if given, or the
            parameters of the model. Otherwise, uses the result of
            torch.cuda.is_available().
            Default: None

    dtypes (List[torch.dtype]):
            If you use input_size, torchinfo assumes your input uses FloatTensors.
            If your model use a different data type, specify that dtype.
            For multiple inputs, specify the size of both inputs, and
            also specify the types of each parameter here.
            Default: None

    mode (str)
            Either "train" or "eval", which determines whether we call
            model.train() or model.eval() before calling summary().
            Default: "eval".

    row_settings (Iterable[str]):
            Specify which features to show in a row. Currently supported: (
                "ascii_only",
                "depth",
                "var_names",
            )
            Default: ("depth",)

    verbose (int):
            0 (quiet): No output
            1 (default): Print model summary
            2 (verbose): Show weight and bias layers in full detail
            Default: 1
            If using a Juypter Notebook or Google Colab, the default is 0.

    **kwargs:
            Other arguments used in `model.forward` function. Passing *args is no
            longer supported.

Return:
    ModelStatistics object
            See torchinfo/model_statistics.py for more information.
"""
```

# Examples

## Get Model Summary as String

```python
from torchinfo import summary

model_stats = summary(your_model, (1, 3, 28, 28), verbose=0)
summary_str = str(model_stats)
# summary_str contains the string representation of the summary!
```

## Explore Different Configurations

```python
class LSTMNet(nn.Module):
    def __init__(self, vocab_size=20, embed_dim=300, hidden_dim=512, num_layers=2):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.encoder = nn.LSTM(embed_dim, hidden_dim, num_layers=num_layers, batch_first=True)
        self.decoder = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x):
        embed = self.embedding(x)
        out, hidden = self.encoder(embed)
        out = self.decoder(out)
        out = out.view(-1, out.size(2))
        return out, hidden

summary(
    LSTMNet(),
    (1, 100),
    dtypes=[torch.long],
    verbose=2,
    col_width=16,
    col_names=["kernel_size", "output_size", "num_params", "mult_adds"],
    row_settings=["var_names"],
)
```

```
========================================================================================================================
Layer (type (var_name))                  Kernel Shape         Output Shape         Param #              Mult-Adds
========================================================================================================================
LSTMNet (LSTMNet)                        --                   [100, 20]            --                   --
├─Embedding (embedding)                  --                   [1, 100, 300]        6,000                6,000
│    └─weight                            [300, 20]                                 └─6,000
├─LSTM (encoder)                         --                   [1, 100, 512]        3,768,320            376,832,000
│    └─weight_ih_l0                      [2048, 300]                               ├─614,400
│    └─weight_hh_l0                      [2048, 512]                               ├─1,048,576
│    └─bias_ih_l0                        [2048]                                    ├─2,048
│    └─bias_hh_l0                        [2048]                                    ├─2,048
│    └─weight_ih_l1                      [2048, 512]                               ├─1,048,576
│    └─weight_hh_l1                      [2048, 512]                               ├─1,048,576
│    └─bias_ih_l1                        [2048]                                    ├─2,048
│    └─bias_hh_l1                        [2048]                                    └─2,048
├─Linear (decoder)                       --                   [1, 100, 20]         10,260               10,260
│    └─weight                            [512, 20]                                 ├─10,240
│    └─bias                              [20]                                      └─20
========================================================================================================================
Total params: 3,784,580
Trainable params: 3,784,580
Non-trainable params: 0
Total mult-adds (M): 376.85
========================================================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.67
Params size (MB): 15.14
Estimated Total Size (MB): 15.80
========================================================================================================================

```

<!-- lstm.out -->

## ResNet

```python
import torchvision

model = torchvision.models.resnet152()
summary(model, (1, 3, 224, 224), depth=3)
```

```
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
ResNet                                   [1, 1000]                 --
├─Conv2d: 1-1                            [1, 64, 112, 112]         9,408
├─BatchNorm2d: 1-2                       [1, 64, 112, 112]         128
├─ReLU: 1-3                              [1, 64, 112, 112]         --
├─MaxPool2d: 1-4                         [1, 64, 56, 56]           --
├─Sequential: 1-5                        [1, 256, 56, 56]          --
│    └─Bottleneck: 2-1                   [1, 256, 56, 56]          --
│    │    └─Conv2d: 3-1                  [1, 64, 56, 56]           4,096
│    │    └─BatchNorm2d: 3-2             [1, 64, 56, 56]           128
│    │    └─ReLU: 3-3                    [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-4                  [1, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-5             [1, 64, 56, 56]           128
│    │    └─ReLU: 3-6                    [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-7                  [1, 256, 56, 56]          16,384
│    │    └─BatchNorm2d: 3-8             [1, 256, 56, 56]          512
│    │    └─Sequential: 3-9              [1, 256, 56, 56]          16,896
│    │    └─ReLU: 3-10                   [1, 256, 56, 56]          --
│    └─Bottleneck: 2-2                   [1, 256, 56, 56]          --

  ...
  ...
  ...

├─AdaptiveAvgPool2d: 1-9                 [1, 2048, 1, 1]           --
├─Linear: 1-10                           [1, 1000]                 2,049,000
==========================================================================================
Total params: 60,192,808
Trainable params: 60,192,808
Non-trainable params: 0
Total mult-adds (G): 11.51
==========================================================================================
Input size (MB): 0.60
Forward/backward pass size (MB): 360.87
Params size (MB): 240.77
Estimated Total Size (MB): 602.25
==========================================================================================
```

<!-- resnet152.out -->

## Multiple Inputs w/ Different Data Types

```python
class MultipleInputNetDifferentDtypes(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1a = nn.Linear(300, 50)
        self.fc1b = nn.Linear(50, 10)

        self.fc2a = nn.Linear(300, 50)
        self.fc2b = nn.Linear(50, 10)

    def forward(self, x1, x2):
        x1 = F.relu(self.fc1a(x1))
        x1 = self.fc1b(x1)
        x2 = x2.type(torch.float)
        x2 = F.relu(self.fc2a(x2))
        x2 = self.fc2b(x2)
        x = torch.cat((x1, x2), 0)
        return F.log_softmax(x, dim=1)

summary(model, [(1, 300), (1, 300)], dtypes=[torch.float, torch.long])
```

Alternatively, you can also pass in the input_data itself, and
torchinfo will automatically infer the data types.

```python
input_data = torch.randn(1, 300)
other_input_data = torch.randn(1, 300).long()
model = MultipleInputNetDifferentDtypes()

summary(model, input_data=[input_data, other_input_data, ...])
```

## Sequentials & ModuleLists

```python
class ContainerModule(nn.Module):

    def __init__(self):
        super().__init__()
        self._layers = nn.ModuleList()
        self._layers.append(nn.Linear(5, 5))
        self._layers.append(ContainerChildModule())
        self._layers.append(nn.Linear(5, 5))

    def forward(self, x):
        for layer in self._layers:
            x = layer(x)
        return x


class ContainerChildModule(nn.Module):

    def __init__(self):
        super().__init__()
        self._sequential = nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5))
        self._between = nn.Linear(5, 5)

    def forward(self, x):
        out = self._sequential(x)
        out = self._between(out)
        for l in self._sequential:
            out = l(out)

        out = self._sequential(x)
        for l in self._sequential:
            out = l(out)
        return out

summary(ContainerModule(), (1, 5))
```

```
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
ContainerModule                          [1, 5]                    --
├─ModuleList: 1-1                        --                        --
│    └─Linear: 2-1                       [1, 5]                    30
│    └─ContainerChildModule: 2-2         [1, 5]                    --
│    │    └─Sequential: 3-1              [1, 5]                    --
│    │    │    └─Linear: 4-1             [1, 5]                    30
│    │    │    └─Linear: 4-2             [1, 5]                    30
│    │    └─Linear: 3-2                  [1, 5]                    30
│    │    └─Sequential: 3-3              --                        (recursive)
│    │    │    └─Linear: 4-3             [1, 5]                    (recursive)
│    │    │    └─Linear: 4-4             [1, 5]                    (recursive)
│    │    └─Sequential: 3-4              [1, 5]                    (recursive)
│    │    │    └─Linear: 4-5             [1, 5]                    (recursive)
│    │    │    └─Linear: 4-6             [1, 5]                    (recursive)
│    │    │    └─Linear: 4-7             [1, 5]                    (recursive)
│    │    │    └─Linear: 4-8             [1, 5]                    (recursive)
│    └─Linear: 2-3                       [1, 5]                    30
==========================================================================================
Total params: 150
Trainable params: 150
Non-trainable params: 0
Total mult-adds (M): 0.00
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00
==========================================================================================
```

<!-- container.out -->

# Contributing

All issues and pull requests are much appreciated! If you are wondering how to build the project:

- torchinfo is actively developed using the lastest version of Python.
  - Changes should be backward compatible to Python 3.7, and will follow Python's End-of-Life guidance for old versions.
  - Run `pip install -r requirements-dev.txt`. We use the latest versions of all dev packages.
  - Run `pre-commit install`.
  - To use auto-formatting tools, use `pre-commit run -a`.
  - To run unit tests, run `pytest`.
  - To update the expected output files, run `pytest --overwrite`.
  - To skip output file tests, use `pytest --no-output`

# References

- Thanks to @sksq96, @nmhkahn, and @sangyx for providing the inspiration for this project.
- For Model Size Estimation @jacobkimmel ([details here](https://github.com/sksq96/pytorch-summary/pull/21))

            

Raw data

            {
    "_id": null,
    "home_page": "https://github.com/tyleryep/torchinfo",
    "name": "torchinfo",
    "maintainer": "",
    "docs_url": null,
    "requires_python": ">=3.7",
    "maintainer_email": "",
    "keywords": "torch pytorch torchsummary torch-summary summary keras deep-learning ml torchinfo torch-info visualize model statistics layer stats",
    "author": "Tyler Yep @tyleryep",
    "author_email": "tyep@cs.stanford.edu",
    "download_url": "https://files.pythonhosted.org/packages/53/d9/2b811d1c0812e9ef23e6cf2dbe022becbe6c5ab065e33fd80ee05c0cd996/torchinfo-1.8.0.tar.gz",
    "platform": null,
    "description": "# torchinfo\n\n[![Python 3.7+](https://img.shields.io/badge/python-3.7+-blue.svg)](https://www.python.org/downloads/release/python-370/)\n[![PyPI version](https://badge.fury.io/py/torchinfo.svg)](https://badge.fury.io/py/torchinfo)\n[![Conda version](https://img.shields.io/conda/vn/conda-forge/torchinfo)](https://anaconda.org/conda-forge/torchinfo)\n[![Build Status](https://github.com/TylerYep/torchinfo/actions/workflows/test.yml/badge.svg)](https://github.com/TylerYep/torchinfo/actions/workflows/test.yml)\n[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/TylerYep/torchinfo/main.svg)](https://results.pre-commit.ci/latest/github/TylerYep/torchinfo/main)\n[![GitHub license](https://img.shields.io/github/license/TylerYep/torchinfo)](https://github.com/TylerYep/torchinfo/blob/main/LICENSE)\n[![codecov](https://codecov.io/gh/TylerYep/torchinfo/branch/main/graph/badge.svg)](https://codecov.io/gh/TylerYep/torchinfo)\n[![Downloads](https://pepy.tech/badge/torchinfo)](https://pepy.tech/project/torchinfo)\n\n(formerly torch-summary)\n\nTorchinfo provides information complementary to what is provided by `print(your_model)` in PyTorch, similar to Tensorflow's `model.summary()` API to view the visualization of the model, which is helpful while debugging your network. In this project, we implement a similar functionality in PyTorch and create a clean, simple interface to use in your projects.\n\nThis is a completely rewritten version of the original torchsummary and torchsummaryX projects by @sksq96 and @nmhkahn. This project addresses all of the issues and pull requests left on the original projects by introducing a completely new API.\n\nSupports PyTorch versions 1.4.0+.\n\n# Usage\n\n```\npip install torchinfo\n```\n\nAlternatively, via conda:\n\n```\nconda install -c conda-forge torchinfo\n```\n\n# How To Use\n\n```python\nfrom torchinfo import summary\n\nmodel = ConvNet()\nbatch_size = 16\nsummary(model, input_size=(batch_size, 1, 28, 28))\n```\n\n```\n================================================================================================================\nLayer (type:depth-idx)          Input Shape          Output Shape         Param #            Mult-Adds\n================================================================================================================\nSingleInputNet                  [7, 1, 28, 28]       [7, 10]              --                 --\n\u251c\u2500Conv2d: 1-1                   [7, 1, 28, 28]       [7, 10, 24, 24]      260                1,048,320\n\u251c\u2500Conv2d: 1-2                   [7, 10, 12, 12]      [7, 20, 8, 8]        5,020              2,248,960\n\u251c\u2500Dropout2d: 1-3                [7, 20, 8, 8]        [7, 20, 8, 8]        --                 --\n\u251c\u2500Linear: 1-4                   [7, 320]             [7, 50]              16,050             112,350\n\u251c\u2500Linear: 1-5                   [7, 50]              [7, 10]              510                3,570\n================================================================================================================\nTotal params: 21,840\nTrainable params: 21,840\nNon-trainable params: 0\nTotal mult-adds (M): 3.41\n================================================================================================================\nInput size (MB): 0.02\nForward/backward pass size (MB): 0.40\nParams size (MB): 0.09\nEstimated Total Size (MB): 0.51\n================================================================================================================\n```\n\n<!-- single_input_all_cols.out -->\n\nNote: if you are using a Jupyter Notebook or Google Colab, `summary(model, ...)` must be the returned value of the cell.\nIf it is not, you should wrap the summary in a print(), e.g. `print(summary(model, ...))`.\nSee `tests/jupyter_test.ipynb` for examples.\n\n**This version now supports:**\n\n- RNNs, LSTMs, and other recursive layers\n- Branching output used to explore model layers using specified depths\n- Returns ModelStatistics object containing all summary data fields\n- Configurable rows/columns\n- Jupyter Notebook / Google Colab\n\n**Other new features:**\n\n- Verbose mode to show weights and bias layers\n- Accepts either input data or simply the input shape!\n- Customizable line widths and batch dimension\n- Comprehensive unit/output testing, linting, and code coverage testing\n\n**Community Contributions:**\n\n- Sequentials & ModuleLists (thanks to @roym899)\n- Improved Mult-Add calculations (thanks to @TE-StefanUhlich, @zmzhang2000)\n- Dict/Misc input data (thanks to @e-dorigatti)\n- Pruned layer support (thanks to @MajorCarrot)\n\n# Documentation\n\n```python\ndef summary(\n    model: nn.Module,\n    input_size: Optional[INPUT_SIZE_TYPE] = None,\n    input_data: Optional[INPUT_DATA_TYPE] = None,\n    batch_dim: Optional[int] = None,\n    cache_forward_pass: Optional[bool] = None,\n    col_names: Optional[Iterable[str]] = None,\n    col_width: int = 25,\n    depth: int = 3,\n    device: Optional[torch.device] = None,\n    dtypes: Optional[List[torch.dtype]] = None,\n    mode: str | None = None,\n    row_settings: Optional[Iterable[str]] = None,\n    verbose: int = 1,\n    **kwargs: Any,\n) -> ModelStatistics:\n\"\"\"\nSummarize the given PyTorch model. Summarized information includes:\n    1) Layer names,\n    2) input/output shapes,\n    3) kernel shape,\n    4) # of parameters,\n    5) # of operations (Mult-Adds),\n    6) whether layer is trainable\n\nNOTE: If neither input_data or input_size are provided, no forward pass through the\nnetwork is performed, and the provided model information is limited to layer names.\n\nArgs:\n    model (nn.Module):\n            PyTorch model to summarize. The model should be fully in either train()\n            or eval() mode. If layers are not all in the same mode, running summary\n            may have side effects on batchnorm or dropout statistics. If you\n            encounter an issue with this, please open a GitHub issue.\n\n    input_size (Sequence of Sizes):\n            Shape of input data as a List/Tuple/torch.Size\n            (dtypes must match model input, default is FloatTensors).\n            You should include batch size in the tuple.\n            Default: None\n\n    input_data (Sequence of Tensors):\n            Arguments for the model's forward pass (dtypes inferred).\n            If the forward() function takes several parameters, pass in a list of\n            args or a dict of kwargs (if your forward() function takes in a dict\n            as its only argument, wrap it in a list).\n            Default: None\n\n    batch_dim (int):\n            Batch_dimension of input data. If batch_dim is None, assume\n            input_data / input_size contains the batch dimension, which is used\n            in all calculations. Else, expand all tensors to contain the batch_dim.\n            Specifying batch_dim can be an runtime optimization, since if batch_dim\n            is specified, torchinfo uses a batch size of 1 for the forward pass.\n            Default: None\n\n    cache_forward_pass (bool):\n            If True, cache the run of the forward() function using the model\n            class name as the key. If the forward pass is an expensive operation,\n            this can make it easier to modify the formatting of your model\n            summary, e.g. changing the depth or enabled column types, especially\n            in Jupyter Notebooks.\n            WARNING: Modifying the model architecture or input data/input size when\n            this feature is enabled does not invalidate the cache or re-run the\n            forward pass, and can cause incorrect summaries as a result.\n            Default: False\n\n    col_names (Iterable[str]):\n            Specify which columns to show in the output. Currently supported: (\n                \"input_size\",\n                \"output_size\",\n                \"num_params\",\n                \"params_percent\",\n                \"kernel_size\",\n                \"mult_adds\",\n                \"trainable\",\n            )\n            Default: (\"output_size\", \"num_params\")\n            If input_data / input_size are not provided, only \"num_params\" is used.\n\n    col_width (int):\n            Width of each column.\n            Default: 25\n\n    depth (int):\n            Depth of nested layers to display (e.g. Sequentials).\n            Nested layers below this depth will not be displayed in the summary.\n            Default: 3\n\n    device (torch.Device):\n            Uses this torch device for model and input_data.\n            If not specified, uses the dtype of input_data if given, or the\n            parameters of the model. Otherwise, uses the result of\n            torch.cuda.is_available().\n            Default: None\n\n    dtypes (List[torch.dtype]):\n            If you use input_size, torchinfo assumes your input uses FloatTensors.\n            If your model use a different data type, specify that dtype.\n            For multiple inputs, specify the size of both inputs, and\n            also specify the types of each parameter here.\n            Default: None\n\n    mode (str)\n            Either \"train\" or \"eval\", which determines whether we call\n            model.train() or model.eval() before calling summary().\n            Default: \"eval\".\n\n    row_settings (Iterable[str]):\n            Specify which features to show in a row. Currently supported: (\n                \"ascii_only\",\n                \"depth\",\n                \"var_names\",\n            )\n            Default: (\"depth\",)\n\n    verbose (int):\n            0 (quiet): No output\n            1 (default): Print model summary\n            2 (verbose): Show weight and bias layers in full detail\n            Default: 1\n            If using a Juypter Notebook or Google Colab, the default is 0.\n\n    **kwargs:\n            Other arguments used in `model.forward` function. Passing *args is no\n            longer supported.\n\nReturn:\n    ModelStatistics object\n            See torchinfo/model_statistics.py for more information.\n\"\"\"\n```\n\n# Examples\n\n## Get Model Summary as String\n\n```python\nfrom torchinfo import summary\n\nmodel_stats = summary(your_model, (1, 3, 28, 28), verbose=0)\nsummary_str = str(model_stats)\n# summary_str contains the string representation of the summary!\n```\n\n## Explore Different Configurations\n\n```python\nclass LSTMNet(nn.Module):\n    def __init__(self, vocab_size=20, embed_dim=300, hidden_dim=512, num_layers=2):\n        super().__init__()\n        self.hidden_dim = hidden_dim\n        self.embedding = nn.Embedding(vocab_size, embed_dim)\n        self.encoder = nn.LSTM(embed_dim, hidden_dim, num_layers=num_layers, batch_first=True)\n        self.decoder = nn.Linear(hidden_dim, vocab_size)\n\n    def forward(self, x):\n        embed = self.embedding(x)\n        out, hidden = self.encoder(embed)\n        out = self.decoder(out)\n        out = out.view(-1, out.size(2))\n        return out, hidden\n\nsummary(\n    LSTMNet(),\n    (1, 100),\n    dtypes=[torch.long],\n    verbose=2,\n    col_width=16,\n    col_names=[\"kernel_size\", \"output_size\", \"num_params\", \"mult_adds\"],\n    row_settings=[\"var_names\"],\n)\n```\n\n```\n========================================================================================================================\nLayer (type (var_name))                  Kernel Shape         Output Shape         Param #              Mult-Adds\n========================================================================================================================\nLSTMNet (LSTMNet)                        --                   [100, 20]            --                   --\n\u251c\u2500Embedding (embedding)                  --                   [1, 100, 300]        6,000                6,000\n\u2502    \u2514\u2500weight                            [300, 20]                                 \u2514\u25006,000\n\u251c\u2500LSTM (encoder)                         --                   [1, 100, 512]        3,768,320            376,832,000\n\u2502    \u2514\u2500weight_ih_l0                      [2048, 300]                               \u251c\u2500614,400\n\u2502    \u2514\u2500weight_hh_l0                      [2048, 512]                               \u251c\u25001,048,576\n\u2502    \u2514\u2500bias_ih_l0                        [2048]                                    \u251c\u25002,048\n\u2502    \u2514\u2500bias_hh_l0                        [2048]                                    \u251c\u25002,048\n\u2502    \u2514\u2500weight_ih_l1                      [2048, 512]                               \u251c\u25001,048,576\n\u2502    \u2514\u2500weight_hh_l1                      [2048, 512]                               \u251c\u25001,048,576\n\u2502    \u2514\u2500bias_ih_l1                        [2048]                                    \u251c\u25002,048\n\u2502    \u2514\u2500bias_hh_l1                        [2048]                                    \u2514\u25002,048\n\u251c\u2500Linear (decoder)                       --                   [1, 100, 20]         10,260               10,260\n\u2502    \u2514\u2500weight                            [512, 20]                                 \u251c\u250010,240\n\u2502    \u2514\u2500bias                              [20]                                      \u2514\u250020\n========================================================================================================================\nTotal params: 3,784,580\nTrainable params: 3,784,580\nNon-trainable params: 0\nTotal mult-adds (M): 376.85\n========================================================================================================================\nInput size (MB): 0.00\nForward/backward pass size (MB): 0.67\nParams size (MB): 15.14\nEstimated Total Size (MB): 15.80\n========================================================================================================================\n\n```\n\n<!-- lstm.out -->\n\n## ResNet\n\n```python\nimport torchvision\n\nmodel = torchvision.models.resnet152()\nsummary(model, (1, 3, 224, 224), depth=3)\n```\n\n```\n==========================================================================================\nLayer (type:depth-idx)                   Output Shape              Param #\n==========================================================================================\nResNet                                   [1, 1000]                 --\n\u251c\u2500Conv2d: 1-1                            [1, 64, 112, 112]         9,408\n\u251c\u2500BatchNorm2d: 1-2                       [1, 64, 112, 112]         128\n\u251c\u2500ReLU: 1-3                              [1, 64, 112, 112]         --\n\u251c\u2500MaxPool2d: 1-4                         [1, 64, 56, 56]           --\n\u251c\u2500Sequential: 1-5                        [1, 256, 56, 56]          --\n\u2502    \u2514\u2500Bottleneck: 2-1                   [1, 256, 56, 56]          --\n\u2502    \u2502    \u2514\u2500Conv2d: 3-1                  [1, 64, 56, 56]           4,096\n\u2502    \u2502    \u2514\u2500BatchNorm2d: 3-2             [1, 64, 56, 56]           128\n\u2502    \u2502    \u2514\u2500ReLU: 3-3                    [1, 64, 56, 56]           --\n\u2502    \u2502    \u2514\u2500Conv2d: 3-4                  [1, 64, 56, 56]           36,864\n\u2502    \u2502    \u2514\u2500BatchNorm2d: 3-5             [1, 64, 56, 56]           128\n\u2502    \u2502    \u2514\u2500ReLU: 3-6                    [1, 64, 56, 56]           --\n\u2502    \u2502    \u2514\u2500Conv2d: 3-7                  [1, 256, 56, 56]          16,384\n\u2502    \u2502    \u2514\u2500BatchNorm2d: 3-8             [1, 256, 56, 56]          512\n\u2502    \u2502    \u2514\u2500Sequential: 3-9              [1, 256, 56, 56]          16,896\n\u2502    \u2502    \u2514\u2500ReLU: 3-10                   [1, 256, 56, 56]          --\n\u2502    \u2514\u2500Bottleneck: 2-2                   [1, 256, 56, 56]          --\n\n  ...\n  ...\n  ...\n\n\u251c\u2500AdaptiveAvgPool2d: 1-9                 [1, 2048, 1, 1]           --\n\u251c\u2500Linear: 1-10                           [1, 1000]                 2,049,000\n==========================================================================================\nTotal params: 60,192,808\nTrainable params: 60,192,808\nNon-trainable params: 0\nTotal mult-adds (G): 11.51\n==========================================================================================\nInput size (MB): 0.60\nForward/backward pass size (MB): 360.87\nParams size (MB): 240.77\nEstimated Total Size (MB): 602.25\n==========================================================================================\n```\n\n<!-- resnet152.out -->\n\n## Multiple Inputs w/ Different Data Types\n\n```python\nclass MultipleInputNetDifferentDtypes(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.fc1a = nn.Linear(300, 50)\n        self.fc1b = nn.Linear(50, 10)\n\n        self.fc2a = nn.Linear(300, 50)\n        self.fc2b = nn.Linear(50, 10)\n\n    def forward(self, x1, x2):\n        x1 = F.relu(self.fc1a(x1))\n        x1 = self.fc1b(x1)\n        x2 = x2.type(torch.float)\n        x2 = F.relu(self.fc2a(x2))\n        x2 = self.fc2b(x2)\n        x = torch.cat((x1, x2), 0)\n        return F.log_softmax(x, dim=1)\n\nsummary(model, [(1, 300), (1, 300)], dtypes=[torch.float, torch.long])\n```\n\nAlternatively, you can also pass in the input_data itself, and\ntorchinfo will automatically infer the data types.\n\n```python\ninput_data = torch.randn(1, 300)\nother_input_data = torch.randn(1, 300).long()\nmodel = MultipleInputNetDifferentDtypes()\n\nsummary(model, input_data=[input_data, other_input_data, ...])\n```\n\n## Sequentials & ModuleLists\n\n```python\nclass ContainerModule(nn.Module):\n\n    def __init__(self):\n        super().__init__()\n        self._layers = nn.ModuleList()\n        self._layers.append(nn.Linear(5, 5))\n        self._layers.append(ContainerChildModule())\n        self._layers.append(nn.Linear(5, 5))\n\n    def forward(self, x):\n        for layer in self._layers:\n            x = layer(x)\n        return x\n\n\nclass ContainerChildModule(nn.Module):\n\n    def __init__(self):\n        super().__init__()\n        self._sequential = nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5))\n        self._between = nn.Linear(5, 5)\n\n    def forward(self, x):\n        out = self._sequential(x)\n        out = self._between(out)\n        for l in self._sequential:\n            out = l(out)\n\n        out = self._sequential(x)\n        for l in self._sequential:\n            out = l(out)\n        return out\n\nsummary(ContainerModule(), (1, 5))\n```\n\n```\n==========================================================================================\nLayer (type:depth-idx)                   Output Shape              Param #\n==========================================================================================\nContainerModule                          [1, 5]                    --\n\u251c\u2500ModuleList: 1-1                        --                        --\n\u2502    \u2514\u2500Linear: 2-1                       [1, 5]                    30\n\u2502    \u2514\u2500ContainerChildModule: 2-2         [1, 5]                    --\n\u2502    \u2502    \u2514\u2500Sequential: 3-1              [1, 5]                    --\n\u2502    \u2502    \u2502    \u2514\u2500Linear: 4-1             [1, 5]                    30\n\u2502    \u2502    \u2502    \u2514\u2500Linear: 4-2             [1, 5]                    30\n\u2502    \u2502    \u2514\u2500Linear: 3-2                  [1, 5]                    30\n\u2502    \u2502    \u2514\u2500Sequential: 3-3              --                        (recursive)\n\u2502    \u2502    \u2502    \u2514\u2500Linear: 4-3             [1, 5]                    (recursive)\n\u2502    \u2502    \u2502    \u2514\u2500Linear: 4-4             [1, 5]                    (recursive)\n\u2502    \u2502    \u2514\u2500Sequential: 3-4              [1, 5]                    (recursive)\n\u2502    \u2502    \u2502    \u2514\u2500Linear: 4-5             [1, 5]                    (recursive)\n\u2502    \u2502    \u2502    \u2514\u2500Linear: 4-6             [1, 5]                    (recursive)\n\u2502    \u2502    \u2502    \u2514\u2500Linear: 4-7             [1, 5]                    (recursive)\n\u2502    \u2502    \u2502    \u2514\u2500Linear: 4-8             [1, 5]                    (recursive)\n\u2502    \u2514\u2500Linear: 2-3                       [1, 5]                    30\n==========================================================================================\nTotal params: 150\nTrainable params: 150\nNon-trainable params: 0\nTotal mult-adds (M): 0.00\n==========================================================================================\nInput size (MB): 0.00\nForward/backward pass size (MB): 0.00\nParams size (MB): 0.00\nEstimated Total Size (MB): 0.00\n==========================================================================================\n```\n\n<!-- container.out -->\n\n# Contributing\n\nAll issues and pull requests are much appreciated! If you are wondering how to build the project:\n\n- torchinfo is actively developed using the lastest version of Python.\n  - Changes should be backward compatible to Python 3.7, and will follow Python's End-of-Life guidance for old versions.\n  - Run `pip install -r requirements-dev.txt`. We use the latest versions of all dev packages.\n  - Run `pre-commit install`.\n  - To use auto-formatting tools, use `pre-commit run -a`.\n  - To run unit tests, run `pytest`.\n  - To update the expected output files, run `pytest --overwrite`.\n  - To skip output file tests, use `pytest --no-output`\n\n# References\n\n- Thanks to @sksq96, @nmhkahn, and @sangyx for providing the inspiration for this project.\n- For Model Size Estimation @jacobkimmel ([details here](https://github.com/sksq96/pytorch-summary/pull/21))\n",
    "bugtrack_url": null,
    "license": "MIT",
    "summary": "Model summary in PyTorch, based off of the original torchsummary.",
    "version": "1.8.0",
    "project_urls": {
        "Homepage": "https://github.com/tyleryep/torchinfo"
    },
    "split_keywords": [
        "torch",
        "pytorch",
        "torchsummary",
        "torch-summary",
        "summary",
        "keras",
        "deep-learning",
        "ml",
        "torchinfo",
        "torch-info",
        "visualize",
        "model",
        "statistics",
        "layer",
        "stats"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "7225973bd6128381951b23cdcd8a9870c6dcfc5606cb864df8eabd82e529f9c1",
                "md5": "62ab1041f930012f5a50d0f95c764b15",
                "sha256": "2e911c2918603f945c26ff21a3a838d12709223dc4ccf243407bce8b6e897b46"
            },
            "downloads": -1,
            "filename": "torchinfo-1.8.0-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "62ab1041f930012f5a50d0f95c764b15",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.7",
            "size": 23377,
            "upload_time": "2023-05-14T19:23:24",
            "upload_time_iso_8601": "2023-05-14T19:23:24.141215Z",
            "url": "https://files.pythonhosted.org/packages/72/25/973bd6128381951b23cdcd8a9870c6dcfc5606cb864df8eabd82e529f9c1/torchinfo-1.8.0-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "53d92b811d1c0812e9ef23e6cf2dbe022becbe6c5ab065e33fd80ee05c0cd996",
                "md5": "9e55abc36fa0ce929beefde5e4153cf1",
                "sha256": "72e94b0e9a3e64dc583a8e5b7940b8938a1ac0f033f795457f27e6f4e7afa2e9"
            },
            "downloads": -1,
            "filename": "torchinfo-1.8.0.tar.gz",
            "has_sig": false,
            "md5_digest": "9e55abc36fa0ce929beefde5e4153cf1",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.7",
            "size": 25880,
            "upload_time": "2023-05-14T19:23:26",
            "upload_time_iso_8601": "2023-05-14T19:23:26.377864Z",
            "url": "https://files.pythonhosted.org/packages/53/d9/2b811d1c0812e9ef23e6cf2dbe022becbe6c5ab065e33fd80ee05c0cd996/torchinfo-1.8.0.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2023-05-14 19:23:26",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "tyleryep",
    "github_project": "torchinfo",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "requirements": [
        {
            "name": "torch",
            "specs": []
        },
        {
            "name": "torchvision",
            "specs": []
        },
        {
            "name": "numpy",
            "specs": []
        }
    ],
    "lcname": "torchinfo"
}
        
Elapsed time: 0.07459s