torchextractor


Nametorchextractor JSON
Version 0.3.0 PyPI version JSON
download
home_pagehttps://github.com/antoinebrl/torchextractor
SummaryPytorch feature extraction made simple
upload_time2021-03-07 21:26:58
maintainer
docs_urlNone
authorAntoine Broyelle
requires_python>=3.6
license
keywords pytorch torch feature extraction
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # `torchextractor`: PyTorch Intermediate Feature Extraction

[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/torchextractor)](https://pypi.org/project/torchextractor/)
[![PyPI](https://img.shields.io/pypi/v/torchextractor)](https://pypi.org/project/torchextractor/)
[![Read the Docs](https://img.shields.io/readthedocs/torchextractor)](https://torchextractor.readthedocs.io/en/latest/)
[![Upload Python Package](https://github.com/antoinebrl/torchextractor/actions/workflows/publish.yml/badge.svg)](https://github.com/antoinebrl/torchextractor/actions/workflows/publish.yml)
[![GitHub](https://img.shields.io/github/license/antoinebrl/torchextractor)](https://github.com/antoinebrl/torchextractor/blob/main/LICENSE)


## Introduction

Too many times some model definitions get remorselessly copy-pasted just because the
`forward` function does not return what the person expects. You provide module names
and `torchextractor` takes care of the extraction for you.It's never been easier to
extract feature, add an extra loss or plug another head to a network.
Ler us know what amazing things you build with `torchextractor`!

## Installation

```shell
pip install torchextractor  # stable
pip install git+https://github.com/antoinebrl/torchextractor.git  # latest
```

Requirements:
- Python >= 3.6+
- torch >= 1.4.0

## Usage

```python
import torch
import torchvision
import torchextractor as tx

model = torchvision.models.resnet18(pretrained=True)
model = tx.Extractor(model, ["layer1", "layer2", "layer3", "layer4"])
dummy_input = torch.rand(7, 3, 224, 224)
model_output, features = model(dummy_input)
feature_shapes = {name: f.shape for name, f in features.items()}
print(feature_shapes)

# {
#   'layer1': torch.Size([1, 64, 56, 56]),
#   'layer2': torch.Size([1, 128, 28, 28]),
#   'layer3': torch.Size([1, 256, 14, 14]),
#   'layer4': torch.Size([1, 512, 7, 7]),
# }
```

[See more examples](docs/source/examples.ipynb)
[![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/antoinebrl/torchextractor/HEAD?filepath=docs/source/examples.ipynb)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/antoinebrl/torchextractor/blob/master/docs/source/examples.ipynb)

[Read the documentation](https://torchextractor.readthedocs.io/en/latest/)

## FAQ

**• How do I know the names of the modules?**

You can print all module names like this:
```python
tx.list_module_names(model)

# OR

for name, module in model.named_modules():
    print(name)
```

**• Why do some operations not get listed?**

It is not possible to add hooks if operations are not defined as modules.
Therefore, `F.relu` cannot be captured but `nn.Relu()` can.

**• How can I avoid listing all relevant modules?**

You can specify a custom filtering function to hook the relevant modules:
```python
# Hook everything !
module_filter_fn = lambda module, name: True

# Capture of all modules inside first layer
module_filter_fn = lambda module, name: name.startswith("layer1")

# Focus on all convolutions
module_filter_fn = lambda module, name: isinstance(module, torch.nn.Conv2d)

model = tx.Extractor(model, module_filter_fn=module_filter_fn)
```

**• Is it compatible with ONNX?**

`tx.Extractor` is compatible with ONNX! This means you can also access intermediate features maps after the export.

Pro-tip: name the output nodes by using `output_names` when calling `torch.onnx.export`.

**• Is it compatible with TorchScript?**

Bad news, TorchScript cannot take variable number of arguments and keyword-only arguments.

Good news, there is a workaround! The solution is to overwrite the `forward` function
of `tx.Extractor` to replicate the interface of the model.

```python
import torch
import torchvision
import torchextractor as tx

class MyExtractor(tx.Extractor):
    def forward(self, x1, x2, x3):
        # Assuming the model takes x1, x2 and x3 as input
        output = self.model(x1, x2, x3)
        return output, self.feature_maps

model = torchvision.models.resnet18(pretrained=True)
model = MyExtractor(model, ["layer1", "layer2", "layer3", "layer4"])
model_traced = torch.jit.script(model)
```

**• "One more thing!" :wink:**
By default we capture the latest output of the relevant modules,
but you can specify your own custom operations.

For example, to accumulate features over 10 forward passes you
can do the following:
```python
import torch
import torchvision
import torchextractor as tx

model = torchvision.models.resnet18(pretrained=True)

def capture_fn(module, input, output, module_name, feature_maps):
    if module_name not in feature_maps:
        feature_maps[module_name] = []
    feature_maps[module_name].append(output)

extractor = tx.Extractor(model, ["layer3", "layer4"], capture_fn=capture_fn)

for i in range(20):
    for i in range(10):
        x = torch.rand(7, 3, 224, 224)
        model(x)
    feature_maps = extractor.collect()

    # Do your stuffs here

    # Discard collected elements
    extractor.clear_placeholder()
```

## Contributing

All feedbacks and contributions are welcomed. Feel free to report an issue or to create a pull request!

If you want to get hands-on:
1. (Fork and) clone the repo.
2. Create a virtual environment: `virtualenv -p python3 .venv && source .venv/bin/activate`
2. Install dependencies: `pip install -r requirements.txt && pip install -r requirements-dev.txt`
4. Hook auto-formatting tools: `pre-commit install`
5. Hack as much as you want!
6. Run tests: `python -m unittest discover -vs ./tests/`
7. Share your work and create a pull request.

To Build documentation:
```shell
cd docs
pip install requirements.txt
make html
```


            

Raw data

            {
    "_id": null,
    "home_page": "https://github.com/antoinebrl/torchextractor",
    "name": "torchextractor",
    "maintainer": "",
    "docs_url": null,
    "requires_python": ">=3.6",
    "maintainer_email": "",
    "keywords": "pytorch torch feature extraction",
    "author": "Antoine Broyelle",
    "author_email": "antoine.broyelle@pm.me",
    "download_url": "https://files.pythonhosted.org/packages/6b/07/9b4811b9571a35a021beae83d8abee2e669ad37056584cf24408de7c3ea0/torchextractor-0.3.0.tar.gz",
    "platform": "",
    "description": "# `torchextractor`: PyTorch Intermediate Feature Extraction\n\n[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/torchextractor)](https://pypi.org/project/torchextractor/)\n[![PyPI](https://img.shields.io/pypi/v/torchextractor)](https://pypi.org/project/torchextractor/)\n[![Read the Docs](https://img.shields.io/readthedocs/torchextractor)](https://torchextractor.readthedocs.io/en/latest/)\n[![Upload Python Package](https://github.com/antoinebrl/torchextractor/actions/workflows/publish.yml/badge.svg)](https://github.com/antoinebrl/torchextractor/actions/workflows/publish.yml)\n[![GitHub](https://img.shields.io/github/license/antoinebrl/torchextractor)](https://github.com/antoinebrl/torchextractor/blob/main/LICENSE)\n\n\n## Introduction\n\nToo many times some model definitions get remorselessly copy-pasted just because the\n`forward` function does not return what the person expects. You provide module names\nand `torchextractor` takes care of the extraction for you.It's never been easier to\nextract feature, add an extra loss or plug another head to a network.\nLer us know what amazing things you build with `torchextractor`!\n\n## Installation\n\n```shell\npip install torchextractor  # stable\npip install git+https://github.com/antoinebrl/torchextractor.git  # latest\n```\n\nRequirements:\n- Python >= 3.6+\n- torch >= 1.4.0\n\n## Usage\n\n```python\nimport torch\nimport torchvision\nimport torchextractor as tx\n\nmodel = torchvision.models.resnet18(pretrained=True)\nmodel = tx.Extractor(model, [\"layer1\", \"layer2\", \"layer3\", \"layer4\"])\ndummy_input = torch.rand(7, 3, 224, 224)\nmodel_output, features = model(dummy_input)\nfeature_shapes = {name: f.shape for name, f in features.items()}\nprint(feature_shapes)\n\n# {\n#   'layer1': torch.Size([1, 64, 56, 56]),\n#   'layer2': torch.Size([1, 128, 28, 28]),\n#   'layer3': torch.Size([1, 256, 14, 14]),\n#   'layer4': torch.Size([1, 512, 7, 7]),\n# }\n```\n\n[See more examples](docs/source/examples.ipynb)\n[![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/antoinebrl/torchextractor/HEAD?filepath=docs/source/examples.ipynb)\n[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/antoinebrl/torchextractor/blob/master/docs/source/examples.ipynb)\n\n[Read the documentation](https://torchextractor.readthedocs.io/en/latest/)\n\n## FAQ\n\n**\u2022 How do I know the names of the modules?**\n\nYou can print all module names like this:\n```python\ntx.list_module_names(model)\n\n# OR\n\nfor name, module in model.named_modules():\n    print(name)\n```\n\n**\u2022 Why do some operations not get listed?**\n\nIt is not possible to add hooks if operations are not defined as modules.\nTherefore, `F.relu` cannot be captured but `nn.Relu()` can.\n\n**\u2022 How can I avoid listing all relevant modules?**\n\nYou can specify a custom filtering function to hook the relevant modules:\n```python\n# Hook everything !\nmodule_filter_fn = lambda module, name: True\n\n# Capture of all modules inside first layer\nmodule_filter_fn = lambda module, name: name.startswith(\"layer1\")\n\n# Focus on all convolutions\nmodule_filter_fn = lambda module, name: isinstance(module, torch.nn.Conv2d)\n\nmodel = tx.Extractor(model, module_filter_fn=module_filter_fn)\n```\n\n**\u2022 Is it compatible with ONNX?**\n\n`tx.Extractor` is compatible with ONNX! This means you can also access intermediate features maps after the export.\n\nPro-tip: name the output nodes by using `output_names` when calling `torch.onnx.export`.\n\n**\u2022 Is it compatible with TorchScript?**\n\nBad news, TorchScript cannot take variable number of arguments and keyword-only arguments.\n\nGood news, there is a workaround! The solution is to overwrite the `forward` function\nof `tx.Extractor` to replicate the interface of the model.\n\n```python\nimport torch\nimport torchvision\nimport torchextractor as tx\n\nclass MyExtractor(tx.Extractor):\n    def forward(self, x1, x2, x3):\n        # Assuming the model takes x1, x2 and x3 as input\n        output = self.model(x1, x2, x3)\n        return output, self.feature_maps\n\nmodel = torchvision.models.resnet18(pretrained=True)\nmodel = MyExtractor(model, [\"layer1\", \"layer2\", \"layer3\", \"layer4\"])\nmodel_traced = torch.jit.script(model)\n```\n\n**\u2022 \"One more thing!\" :wink:**\nBy default we capture the latest output of the relevant modules,\nbut you can specify your own custom operations.\n\nFor example, to accumulate features over 10 forward passes you\ncan do the following:\n```python\nimport torch\nimport torchvision\nimport torchextractor as tx\n\nmodel = torchvision.models.resnet18(pretrained=True)\n\ndef capture_fn(module, input, output, module_name, feature_maps):\n    if module_name not in feature_maps:\n        feature_maps[module_name] = []\n    feature_maps[module_name].append(output)\n\nextractor = tx.Extractor(model, [\"layer3\", \"layer4\"], capture_fn=capture_fn)\n\nfor i in range(20):\n    for i in range(10):\n        x = torch.rand(7, 3, 224, 224)\n        model(x)\n    feature_maps = extractor.collect()\n\n    # Do your stuffs here\n\n    # Discard collected elements\n    extractor.clear_placeholder()\n```\n\n## Contributing\n\nAll feedbacks and contributions are welcomed. Feel free to report an issue or to create a pull request!\n\nIf you want to get hands-on:\n1. (Fork and) clone the repo.\n2. Create a virtual environment: `virtualenv -p python3 .venv && source .venv/bin/activate`\n2. Install dependencies: `pip install -r requirements.txt && pip install -r requirements-dev.txt`\n4. Hook auto-formatting tools: `pre-commit install`\n5. Hack as much as you want!\n6. Run tests: `python -m unittest discover -vs ./tests/`\n7. Share your work and create a pull request.\n\nTo Build documentation:\n```shell\ncd docs\npip install requirements.txt\nmake html\n```\n\n",
    "bugtrack_url": null,
    "license": "",
    "summary": "Pytorch feature extraction made simple",
    "version": "0.3.0",
    "project_urls": {
        "Bug Tracker": "https://github.com/antoinebrl/torchextractor/issues",
        "Homepage": "https://github.com/antoinebrl/torchextractor"
    },
    "split_keywords": [
        "pytorch",
        "torch",
        "feature",
        "extraction"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "cc94f14591882d0459a626d6aa8ed3699b08e6b79192c26cae87cbd6081cb835",
                "md5": "f036ed73387b252fe70c8daaba81410f",
                "sha256": "1bfd90eea59f69e375240326304d0091f77a0e536b997d3c64aba564890d4fa1"
            },
            "downloads": -1,
            "filename": "torchextractor-0.3.0-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "f036ed73387b252fe70c8daaba81410f",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.6",
            "size": 10823,
            "upload_time": "2021-03-07T21:26:57",
            "upload_time_iso_8601": "2021-03-07T21:26:57.444613Z",
            "url": "https://files.pythonhosted.org/packages/cc/94/f14591882d0459a626d6aa8ed3699b08e6b79192c26cae87cbd6081cb835/torchextractor-0.3.0-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "6b079b4811b9571a35a021beae83d8abee2e669ad37056584cf24408de7c3ea0",
                "md5": "35b55d4bf448822c1139b447863cac53",
                "sha256": "fd1bbc1f32c7db25aaee7e3c0fff7abbff48f22bf43acae95bb3e55efd0282f3"
            },
            "downloads": -1,
            "filename": "torchextractor-0.3.0.tar.gz",
            "has_sig": false,
            "md5_digest": "35b55d4bf448822c1139b447863cac53",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.6",
            "size": 6250,
            "upload_time": "2021-03-07T21:26:58",
            "upload_time_iso_8601": "2021-03-07T21:26:58.665765Z",
            "url": "https://files.pythonhosted.org/packages/6b/07/9b4811b9571a35a021beae83d8abee2e669ad37056584cf24408de7c3ea0/torchextractor-0.3.0.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2021-03-07 21:26:58",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "antoinebrl",
    "github_project": "torchextractor",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "requirements": [],
    "lcname": "torchextractor"
}
        
Elapsed time: 0.22496s