torchtyping


Nametorchtyping JSON
Version 0.1.5 PyPI version JSON
download
home_pagehttps://github.com/patrick-kidger/torchtyping
SummaryRuntime type annotations for the shape, dtype etc. of PyTorch Tensors.
upload_time2024-08-01 02:44:53
maintainerPatrick Kidger
docs_urlNone
authorPatrick Kidger
requires_python>=3.7.0
licenseApache-2.0
keywords
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            <h1 align='center'>torchtyping</h1>
<h2 align='center'>Type annotations for a tensor's shape, dtype, names, ...</h2>

*Welcome! For new projects I now **strongly** recommend using my newer [jaxtyping](https://github.com/google/jaxtyping) project instead. It supports PyTorch, doesn't actually depend on JAX, and unlike TorchTyping it is compatible with static type checkers. :)*

---

Turn this:
```python
def batch_outer_product(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    # x has shape (batch, x_channels)
    # y has shape (batch, y_channels)
    # return has shape (batch, x_channels, y_channels)

    return x.unsqueeze(-1) * y.unsqueeze(-2)
```
into this:
```python
def batch_outer_product(x:   TensorType["batch", "x_channels"],
                        y:   TensorType["batch", "y_channels"]
                        ) -> TensorType["batch", "x_channels", "y_channels"]:

    return x.unsqueeze(-1) * y.unsqueeze(-2)
```
**with programmatic checking that the shape (dtype, ...) specification is met.**

Bye-bye bugs! Say hello to enforced, clear documentation of your code.

If (like me) you find yourself littering your code with comments like `# x has shape (batch, hidden_state)` or statements like `assert x.shape == y.shape` , just to keep track of what shape everything is, **then this is for you.**

---

## Installation

```bash
pip install torchtyping
```

Requires Python >=3.7 and PyTorch >=1.7.0.

If using [`typeguard`](https://github.com/agronholm/typeguard) then it must be a version <3.0.0.

## Usage

`torchtyping` allows for type annotating:

- **shape**: size, number of dimensions;
- **dtype** (float, integer, etc.);
- **layout** (dense, sparse);
- **names** of dimensions as per [named tensors](https://pytorch.org/docs/stable/named_tensor.html);
- **arbitrary number of batch dimensions** with `...`;
- **...plus anything else you like**, as `torchtyping` is highly extensible.

If [`typeguard`](https://github.com/agronholm/typeguard) is (optionally) installed then **at runtime the types can be checked** to ensure that the tensors really are of the advertised shape, dtype, etc. 

```python
# EXAMPLE

from torch import rand
from torchtyping import TensorType, patch_typeguard
from typeguard import typechecked

patch_typeguard()  # use before @typechecked

@typechecked
def func(x: TensorType["batch"],
         y: TensorType["batch"]) -> TensorType["batch"]:
    return x + y

func(rand(3), rand(3))  # works
func(rand(3), rand(1))
# TypeError: Dimension 'batch' of inconsistent size. Got both 1 and 3.
```

`typeguard` also has an import hook that can be used to automatically test an entire module, without needing to manually add `@typeguard.typechecked` decorators.

If you're not using `typeguard` then `torchtyping.patch_typeguard()` can be omitted altogether, and `torchtyping` just used for documentation purposes. If you're not already using `typeguard` for your regular Python programming, then strongly consider using it. It's a great way to squash bugs. Both `typeguard` and `torchtyping` also integrate with `pytest`, so if you're concerned about any performance penalty then they can be enabled during tests only.

## API

```python
torchtyping.TensorType[shape, dtype, layout, details]
```

The core of the library.

Each of `shape`, `dtype`, `layout`, `details` are optional.

- The `shape` argument can be any of:
  - An `int`: the dimension must be of exactly this size. If it is `-1` then any size is allowed.
  - A `str`: the size of the dimension passed at runtime will be bound to this name, and all tensors checked that the sizes are consistent.
  - A `...`: An arbitrary number of dimensions of any sizes.
  - A `str: int` pair (technically it's a slice), combining both `str` and `int` behaviour. (Just a `str` on its own is equivalent to `str: -1`.)
  - A `str: str` pair, in which case the size of the dimension passed at runtime will be bound to _both_ names, and all dimensions with either name must have the same size. (Some people like to use this as a way to associate multiple names with a dimension, for extra documentation purposes.)
  - A `str: ...` pair, in which case the multiple dimensions corresponding to `...` will be bound to the name specified by `str`, and again checked for consistency between arguments.
  - `None`, which when used in conjunction with `is_named` below, indicates a dimension that must _not_ have a name in the sense of [named tensors](https://pytorch.org/docs/stable/named_tensor.html).
  - A `None: int` pair, combining both `None` and `int` behaviour. (Just a `None` on its own is equivalent to `None: -1`.)
  - A `None: str` pair, combining both `None` and `str` behaviour. (That is, it must not have a named dimension, but must be of a size consistent with other uses of the string.)
  - A `typing.Any`: Any size is allowed for this dimension (equivalent to `-1`).
  - Any tuple of the above. For example.`TensorType["batch": ..., "length": 10, "channels", -1]`. If you just want to specify the number of dimensions then use for example `TensorType[-1, -1, -1]` for a three-dimensional tensor.
- The `dtype` argument can be any of:
  - `torch.float32`, `torch.float64` etc.
  - `int`, `bool`, `float`, which are converted to their corresponding PyTorch types. `float` is specifically interpreted as `torch.get_default_dtype()`, which is usually `float32`.
- The `layout` argument can be either `torch.strided` or `torch.sparse_coo`, for dense and sparse tensors respectively.
- The `details` argument offers a way to pass an arbitrary number of additional flags that customise and extend `torchtyping`. Two flags are built-in by default. `torchtyping.is_named` causes the [names of tensor dimensions](https://pytorch.org/docs/stable/named_tensor.html) to be checked, and `torchtyping.is_float` can be used to check that arbitrary floating point types are passed in. (Rather than just a specific one as with e.g. `TensorType[torch.float32]`.) For discussion on how to customise `torchtyping` with your own `details`, see the [further documentation](https://github.com/patrick-kidger/torchtyping/blob/master/FURTHER-DOCUMENTATION.md#custom-extensions).
- Check multiple things at once by just putting them all together inside a single `[]`. For example `TensorType["batch": ..., "length", "channels", float, is_named]`.

```python
torchtyping.patch_typeguard()
```

`torchtyping` integrates with `typeguard` to perform runtime type checking. `torchtyping.patch_typeguard()` should be called at the global level, and will patch `typeguard` to check `TensorType`s.

This function is safe to run multiple times. (It does nothing after the first run). 

- If using `@typeguard.typechecked`, then `torchtyping.patch_typeguard()` should be called any time before using `@typeguard.typechecked`. For example you could call it at the start of each file using `torchtyping`.
- If using `typeguard.importhook.install_import_hook`, then `torchtyping.patch_typeguard()` should be called any time before defining the functions you want checked. For example you could call `torchtyping.patch_typeguard()` just once, at the same time as the `typeguard` import hook. (The order of the hook and the patch doesn't matter.)
- If you're not using `typeguard` then `torchtyping.patch_typeguard()` can be omitted altogether, and `torchtyping` just used for documentation purposes.

```bash
pytest --torchtyping-patch-typeguard
```

`torchtyping` offers a `pytest` plugin to automatically run `torchtyping.patch_typeguard()` before your tests. `pytest` will automatically discover the plugin, you just need to pass the `--torchtyping-patch-typeguard` flag to enable it. Packages can then be passed to `typeguard` as normal, either by using `@typeguard.typechecked`, `typeguard`'s import hook, or the `pytest` flag `--typeguard-packages="your_package_here"`.

## Further documentation

See the [further documentation](https://github.com/patrick-kidger/torchtyping/blob/master/FURTHER-DOCUMENTATION.md) for:

- FAQ;
  - Including `flake8` and `mypy` compatibility;
- How to write custom extensions to `torchtyping`;
- Resources and links to other libraries and materials on this topic;
- More examples.

            

Raw data

            {
    "_id": null,
    "home_page": "https://github.com/patrick-kidger/torchtyping",
    "name": "torchtyping",
    "maintainer": "Patrick Kidger",
    "docs_url": null,
    "requires_python": ">=3.7.0",
    "maintainer_email": "contact@kidger.site",
    "keywords": null,
    "author": "Patrick Kidger",
    "author_email": "contact@kidger.site",
    "download_url": "https://files.pythonhosted.org/packages/0c/b2/9099ac2d76ebe538d527ef73844bcfa1296a5c75bd9f2d489e862d47bb42/torchtyping-0.1.5.tar.gz",
    "platform": null,
    "description": "<h1 align='center'>torchtyping</h1>\n<h2 align='center'>Type annotations for a tensor's shape, dtype, names, ...</h2>\n\n*Welcome! For new projects I now **strongly** recommend using my newer [jaxtyping](https://github.com/google/jaxtyping) project instead. It supports PyTorch, doesn't actually depend on JAX, and unlike TorchTyping it is compatible with static type checkers. :)*\n\n---\n\nTurn this:\n```python\ndef batch_outer_product(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:\n    # x has shape (batch, x_channels)\n    # y has shape (batch, y_channels)\n    # return has shape (batch, x_channels, y_channels)\n\n    return x.unsqueeze(-1) * y.unsqueeze(-2)\n```\ninto this:\n```python\ndef batch_outer_product(x:   TensorType[\"batch\", \"x_channels\"],\n                        y:   TensorType[\"batch\", \"y_channels\"]\n                        ) -> TensorType[\"batch\", \"x_channels\", \"y_channels\"]:\n\n    return x.unsqueeze(-1) * y.unsqueeze(-2)\n```\n**with programmatic checking that the shape (dtype, ...) specification is met.**\n\nBye-bye bugs! Say hello to enforced, clear documentation of your code.\n\nIf (like me) you find yourself littering your code with comments like `# x has shape (batch, hidden_state)` or statements like `assert x.shape == y.shape` , just to keep track of what shape everything is, **then this is for you.**\n\n---\n\n## Installation\n\n```bash\npip install torchtyping\n```\n\nRequires Python >=3.7 and PyTorch >=1.7.0.\n\nIf using [`typeguard`](https://github.com/agronholm/typeguard) then it must be a version <3.0.0.\n\n## Usage\n\n`torchtyping` allows for type annotating:\n\n- **shape**: size, number of dimensions;\n- **dtype** (float, integer, etc.);\n- **layout** (dense, sparse);\n- **names** of dimensions as per [named tensors](https://pytorch.org/docs/stable/named_tensor.html);\n- **arbitrary number of batch dimensions** with `...`;\n- **...plus anything else you like**, as `torchtyping` is highly extensible.\n\nIf [`typeguard`](https://github.com/agronholm/typeguard) is (optionally) installed then **at runtime the types can be checked** to ensure that the tensors really are of the advertised shape, dtype, etc. \n\n```python\n# EXAMPLE\n\nfrom torch import rand\nfrom torchtyping import TensorType, patch_typeguard\nfrom typeguard import typechecked\n\npatch_typeguard()  # use before @typechecked\n\n@typechecked\ndef func(x: TensorType[\"batch\"],\n         y: TensorType[\"batch\"]) -> TensorType[\"batch\"]:\n    return x + y\n\nfunc(rand(3), rand(3))  # works\nfunc(rand(3), rand(1))\n# TypeError: Dimension 'batch' of inconsistent size. Got both 1 and 3.\n```\n\n`typeguard` also has an import hook that can be used to automatically test an entire module, without needing to manually add `@typeguard.typechecked` decorators.\n\nIf you're not using `typeguard` then `torchtyping.patch_typeguard()` can be omitted altogether, and `torchtyping` just used for documentation purposes. If you're not already using `typeguard` for your regular Python programming, then strongly consider using it. It's a great way to squash bugs. Both `typeguard` and `torchtyping` also integrate with `pytest`, so if you're concerned about any performance penalty then they can be enabled during tests only.\n\n## API\n\n```python\ntorchtyping.TensorType[shape, dtype, layout, details]\n```\n\nThe core of the library.\n\nEach of `shape`, `dtype`, `layout`, `details` are optional.\n\n- The `shape` argument can be any of:\n  - An `int`: the dimension must be of exactly this size. If it is `-1` then any size is allowed.\n  - A `str`: the size of the dimension passed at runtime will be bound to this name, and all tensors checked that the sizes are consistent.\n  - A `...`: An arbitrary number of dimensions of any sizes.\n  - A `str: int` pair (technically it's a slice), combining both `str` and `int` behaviour. (Just a `str` on its own is equivalent to `str: -1`.)\n  - A `str: str` pair, in which case the size of the dimension passed at runtime will be bound to _both_ names, and all dimensions with either name must have the same size. (Some people like to use this as a way to associate multiple names with a dimension, for extra documentation purposes.)\n  - A `str: ...` pair, in which case the multiple dimensions corresponding to `...` will be bound to the name specified by `str`, and again checked for consistency between arguments.\n  - `None`, which when used in conjunction with `is_named` below, indicates a dimension that must _not_ have a name in the sense of [named tensors](https://pytorch.org/docs/stable/named_tensor.html).\n  - A `None: int` pair, combining both `None` and `int` behaviour. (Just a `None` on its own is equivalent to `None: -1`.)\n  - A `None: str` pair, combining both `None` and `str` behaviour. (That is, it must not have a named dimension, but must be of a size consistent with other uses of the string.)\n  - A `typing.Any`: Any size is allowed for this dimension (equivalent to `-1`).\n  - Any tuple of the above. For example.`TensorType[\"batch\": ..., \"length\": 10, \"channels\", -1]`. If you just want to specify the number of dimensions then use for example `TensorType[-1, -1, -1]` for a three-dimensional tensor.\n- The `dtype` argument can be any of:\n  - `torch.float32`, `torch.float64` etc.\n  - `int`, `bool`, `float`, which are converted to their corresponding PyTorch types. `float` is specifically interpreted as `torch.get_default_dtype()`, which is usually `float32`.\n- The `layout` argument can be either `torch.strided` or `torch.sparse_coo`, for dense and sparse tensors respectively.\n- The `details` argument offers a way to pass an arbitrary number of additional flags that customise and extend `torchtyping`. Two flags are built-in by default. `torchtyping.is_named` causes the [names of tensor dimensions](https://pytorch.org/docs/stable/named_tensor.html) to be checked, and `torchtyping.is_float` can be used to check that arbitrary floating point types are passed in. (Rather than just a specific one as with e.g. `TensorType[torch.float32]`.) For discussion on how to customise `torchtyping` with your own `details`, see the [further documentation](https://github.com/patrick-kidger/torchtyping/blob/master/FURTHER-DOCUMENTATION.md#custom-extensions).\n- Check multiple things at once by just putting them all together inside a single `[]`. For example `TensorType[\"batch\": ..., \"length\", \"channels\", float, is_named]`.\n\n```python\ntorchtyping.patch_typeguard()\n```\n\n`torchtyping` integrates with `typeguard` to perform runtime type checking. `torchtyping.patch_typeguard()` should be called at the global level, and will patch `typeguard` to check `TensorType`s.\n\nThis function is safe to run multiple times. (It does nothing after the first run). \n\n- If using `@typeguard.typechecked`, then `torchtyping.patch_typeguard()` should be called any time before using `@typeguard.typechecked`. For example you could call it at the start of each file using `torchtyping`.\n- If using `typeguard.importhook.install_import_hook`, then `torchtyping.patch_typeguard()` should be called any time before defining the functions you want checked. For example you could call `torchtyping.patch_typeguard()` just once, at the same time as the `typeguard` import hook. (The order of the hook and the patch doesn't matter.)\n- If you're not using `typeguard` then `torchtyping.patch_typeguard()` can be omitted altogether, and `torchtyping` just used for documentation purposes.\n\n```bash\npytest --torchtyping-patch-typeguard\n```\n\n`torchtyping` offers a `pytest` plugin to automatically run `torchtyping.patch_typeguard()` before your tests. `pytest` will automatically discover the plugin, you just need to pass the `--torchtyping-patch-typeguard` flag to enable it. Packages can then be passed to `typeguard` as normal, either by using `@typeguard.typechecked`, `typeguard`'s import hook, or the `pytest` flag `--typeguard-packages=\"your_package_here\"`.\n\n## Further documentation\n\nSee the [further documentation](https://github.com/patrick-kidger/torchtyping/blob/master/FURTHER-DOCUMENTATION.md) for:\n\n- FAQ;\n  - Including `flake8` and `mypy` compatibility;\n- How to write custom extensions to `torchtyping`;\n- Resources and links to other libraries and materials on this topic;\n- More examples.\n",
    "bugtrack_url": null,
    "license": "Apache-2.0",
    "summary": "Runtime type annotations for the shape, dtype etc. of PyTorch Tensors.",
    "version": "0.1.5",
    "project_urls": {
        "Homepage": "https://github.com/patrick-kidger/torchtyping"
    },
    "split_keywords": [],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "77333915d3a985bf3914aadfc4e24438a8499c0cd91af0c6d251ad5b773bc310",
                "md5": "ce948ef77430ac4c3a698014ed7f15e7",
                "sha256": "429d76e16b08a2409226320565d557d4ba3d0d6f544db3e5dc4cf690ffa2a3bd"
            },
            "downloads": -1,
            "filename": "torchtyping-0.1.5-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "ce948ef77430ac4c3a698014ed7f15e7",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.7.0",
            "size": 17992,
            "upload_time": "2024-08-01T02:44:51",
            "upload_time_iso_8601": "2024-08-01T02:44:51.857403Z",
            "url": "https://files.pythonhosted.org/packages/77/33/3915d3a985bf3914aadfc4e24438a8499c0cd91af0c6d251ad5b773bc310/torchtyping-0.1.5-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "0cb29099ac2d76ebe538d527ef73844bcfa1296a5c75bd9f2d489e862d47bb42",
                "md5": "640bc05b52d020dbc510e5d073aa27cb",
                "sha256": "fb51e1792536223e2497b1e106ed0cbe681f7324044599cb718d03be3f2b3851"
            },
            "downloads": -1,
            "filename": "torchtyping-0.1.5.tar.gz",
            "has_sig": false,
            "md5_digest": "640bc05b52d020dbc510e5d073aa27cb",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.7.0",
            "size": 23067,
            "upload_time": "2024-08-01T02:44:53",
            "upload_time_iso_8601": "2024-08-01T02:44:53.579761Z",
            "url": "https://files.pythonhosted.org/packages/0c/b2/9099ac2d76ebe538d527ef73844bcfa1296a5c75bd9f2d489e862d47bb42/torchtyping-0.1.5.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2024-08-01 02:44:53",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "patrick-kidger",
    "github_project": "torchtyping",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "lcname": "torchtyping"
}
        
Elapsed time: 0.51762s