# RSProduction MachineLearning
This project provides some usefull machine learning functionality.
# Table of Contents
- [1 dataset](#1-dataset)
- [1.1 HMDB51 : torch.utils.data.dataset.Dataset](#11-hmdb51--torchutilsdatadatasetdataset)
- [1.1.1 \_\_init\_\_](#111-\_\_init\_\_)
- [1.2 Kinetics : torch.utils.data.dataset.Dataset](#12-kinetics--torchutilsdatadatasetdataset)
- [1.2.1 \_\_init\_\_](#121-\_\_init\_\_)
- [1.3 TUCHRI : torch.utils.data.dataset.Dataset](#13-tuchri--torchutilsdatadatasetdataset)
- [1.3.1 \_\_init\_\_](#131-\_\_init\_\_)
- [1.3.2 get\_uniform\_sampler](#132-get\_uniform\_sampler)
- [1.3.3 load\_backgrounds](#133-load\_backgrounds)
- [1.4 TUCRID : torch.utils.data.dataset.Dataset](#14-tucrid--torchutilsdatadatasetdataset)
- [1.4.1 \_\_init\_\_](#141-\_\_init\_\_)
- [1.4.2 get\_uniform\_sampler](#142-get\_uniform\_sampler)
- [1.4.3 load\_backgrounds](#143-load\_backgrounds)
- [1.5 UCF101 : torch.utils.data.dataset.Dataset](#15-ucf101--torchutilsdatadatasetdataset)
- [1.5.1 \_\_init\_\_](#151-\_\_init\_\_)
- [1.6 UTKinectAction3D : torch.utils.data.dataset.Dataset](#16-utkinectaction3d--torchutilsdatadatasetdataset)
- [1.6.1 \_\_init\_\_](#161-\_\_init\_\_)
- [2 metrics](#2-metrics)
- [2.1 AUROC](#21-auroc)
- [2.2 F1\_Score](#22-f1\_score)
- [2.3 FN](#23-fn)
- [2.4 FP](#24-fp)
- [2.5 FPR](#25-fpr)
- [2.6 ROC](#26-roc)
- [2.7 TN](#27-tn)
- [2.8 TP](#28-tp)
- [2.9 TPR](#29-tpr)
- [2.10 confusion\_matrix](#210-confusion\_matrix)
- [2.11 plot\_ROC](#211-plot\_roc)
- [2.12 plot\_confusion\_matrix](#212-plot\_confusion\_matrix)
- [2.13 precision](#213-precision)
- [2.14 recall](#214-recall)
- [2.15 top\_10\_accuracy](#215-top\_10\_accuracy)
- [2.16 top\_1\_accuracy](#216-top\_1\_accuracy)
- [2.17 top\_2\_accuracy](#217-top\_2\_accuracy)
- [2.18 top\_3\_accuracy](#218-top\_3\_accuracy)
- [2.19 top\_5\_accuracy](#219-top\_5\_accuracy)
- [2.20 top\_k\_accuracy](#220-top\_k\_accuracy)
- [3 model](#3-model)
- [3.1 MODELS : enum.Enum](#31-models--enumenum)
- [3.2 WEIGHTS : enum.Enum](#32-weights--enumenum)
- [3.3 list\_model\_weights](#33-list\_model\_weights)
- [3.4 load\_model](#34-load\_model)
- [3.5 publish\_model](#35-publish\_model)
- [4 module](#4-module)
- [4.1 MultiHeadSelfAttention : torch.nn.modules.module.Module](#41-multiheadselfattention--torchnnmodulesmodulemodule)
- [4.1.1 \_wrapped\_call\_impl](#411-\_wrapped\_call\_impl)
- [4.1.2 \_\_init\_\_](#412-\_\_init\_\_)
- [4.1.3 \_apply](#413-\_apply)
- [4.1.4 \_call\_impl](#414-\_call\_impl)
- [4.1.5 \_get\_backward\_hooks](#415-\_get\_backward\_hooks)
- [4.1.6 \_get\_backward\_pre\_hooks](#416-\_get\_backward\_pre\_hooks)
- [4.1.7 \_get\_name](#417-\_get\_name)
- [4.1.8 \_load\_from\_state\_dict](#418-\_load\_from\_state\_dict)
- [4.1.9 \_maybe\_warn\_non\_full\_backward\_hook](#419-\_maybe\_warn\_non\_full\_backward\_hook)
- [4.1.10 \_named\_members](#4110-\_named\_members)
- [4.1.11 \_register\_load\_state\_dict\_pre\_hook](#4111-\_register\_load\_state\_dict\_pre\_hook)
- [4.1.12 \_register\_state\_dict\_hook](#4112-\_register\_state\_dict\_hook)
- [4.1.13 \_replicate\_for\_data\_parallel](#4113-\_replicate\_for\_data\_parallel)
- [4.1.14 \_save\_to\_state\_dict](#4114-\_save\_to\_state\_dict)
- [4.1.15 \_slow\_forward](#4115-\_slow\_forward)
- [4.1.16 \_wrapped\_call\_impl](#4116-\_wrapped\_call\_impl)
- [4.1.17 add\_module](#4117-add\_module)
- [4.1.18 apply](#4118-apply)
- [4.1.19 bfloat16](#4119-bfloat16)
- [4.1.20 buffers](#4120-buffers)
- [ >>> # xdoctest: +SKIP("undefined vars")](#-->>>-#-xdoctest--+skip("undefined-vars"))
- [4.1.21 children](#4121-children)
- [4.1.22 compile](#4122-compile)
- [4.1.23 cpu](#4123-cpu)
- [4.1.24 cuda](#4124-cuda)
- [4.1.25 double](#4125-double)
- [4.1.26 eval](#4126-eval)
- [4.1.27 extra\_repr](#4127-extra\_repr)
- [4.1.28 float](#4128-float)
- [4.1.29 forward](#4129-forward)
- [4.1.30 get\_buffer](#4130-get\_buffer)
- [4.1.31 get\_extra\_state](#4131-get\_extra\_state)
- [4.1.32 get\_parameter](#4132-get\_parameter)
- [4.1.33 get\_submodule](#4133-get\_submodule)
- [4.1.34 half](#4134-half)
- [4.1.35 ipu](#4135-ipu)
- [4.1.36 load\_state\_dict](#4136-load\_state\_dict)
- [4.1.37 modules](#4137-modules)
- [4.1.38 mtia](#4138-mtia)
- [4.1.39 named\_buffers](#4139-named\_buffers)
- [ >>> # xdoctest: +SKIP("undefined vars")](#-->>>-#-xdoctest--+skip("undefined-vars"))
- [4.1.40 named\_children](#4140-named\_children)
- [ >>> # xdoctest: +SKIP("undefined vars")](#-->>>-#-xdoctest--+skip("undefined-vars"))
- [4.1.41 named\_modules](#4141-named\_modules)
- [4.1.42 named\_parameters](#4142-named\_parameters)
- [ >>> # xdoctest: +SKIP("undefined vars")](#-->>>-#-xdoctest--+skip("undefined-vars"))
- [4.1.43 parameters](#4143-parameters)
- [ >>> # xdoctest: +SKIP("undefined vars")](#-->>>-#-xdoctest--+skip("undefined-vars"))
- [4.1.44 register\_backward\_hook](#4144-register\_backward\_hook)
- [4.1.45 register\_buffer](#4145-register\_buffer)
- [ >>> # xdoctest: +SKIP("undefined vars")](#-->>>-#-xdoctest--+skip("undefined-vars"))
- [4.1.46 register\_forward\_hook](#4146-register\_forward\_hook)
- [4.1.47 register\_forward\_pre\_hook](#4147-register\_forward\_pre\_hook)
- [4.1.48 register\_full\_backward\_hook](#4148-register\_full\_backward\_hook)
- [4.1.49 register\_full\_backward\_pre\_hook](#4149-register\_full\_backward\_pre\_hook)
- [4.1.50 register\_load\_state\_dict\_post\_hook](#4150-register\_load\_state\_dict\_post\_hook)
- [4.1.51 register\_load\_state\_dict\_pre\_hook](#4151-register\_load\_state\_dict\_pre\_hook)
- [ hook(module, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -> None # noqa: B950](#--hook(module,-state_dict,-prefix,-local_metadata,-strict,-missing_keys,-unexpected_keys,-error_msgs)-->-none--#-noqa--b950)
- [4.1.52 register\_module](#4152-register\_module)
- [4.1.53 register\_parameter](#4153-register\_parameter)
- [4.1.54 register\_state\_dict\_post\_hook](#4154-register\_state\_dict\_post\_hook)
- [4.1.55 register\_state\_dict\_pre\_hook](#4155-register\_state\_dict\_pre\_hook)
- [4.1.56 requires\_grad\_](#4156-requires\_grad\_)
- [4.1.57 set\_extra\_state](#4157-set\_extra\_state)
- [4.1.58 set\_submodule](#4158-set\_submodule)
- [4.1.59 share\_memory](#4159-share\_memory)
- [4.1.60 state\_dict](#4160-state\_dict)
- [ >>> # xdoctest: +SKIP("undefined vars")](#-->>>-#-xdoctest--+skip("undefined-vars"))
- [4.1.61 to](#4161-to)
- [ >>> # xdoctest: +IGNORE_WANT("non-deterministic")](#-->>>-#-xdoctest--+ignore_want("non-deterministic"))
- [ >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1)](#-->>>-#-xdoctest--+requires(env-torch_doctest_cuda1))
- [4.1.62 to\_empty](#4162-to\_empty)
- [4.1.63 train](#4163-train)
- [4.1.64 type](#4164-type)
- [4.1.65 xpu](#4165-xpu)
- [4.1.66 zero\_grad](#4166-zero\_grad)
- [4.2 SelfAttention : torch.nn.modules.module.Module](#42-selfattention--torchnnmodulesmodulemodule)
- [4.2.1 \_wrapped\_call\_impl](#421-\_wrapped\_call\_impl)
- [4.2.2 \_\_init\_\_](#422-\_\_init\_\_)
- [4.2.3 \_apply](#423-\_apply)
- [4.2.4 \_call\_impl](#424-\_call\_impl)
- [4.2.5 \_get\_backward\_hooks](#425-\_get\_backward\_hooks)
- [4.2.6 \_get\_backward\_pre\_hooks](#426-\_get\_backward\_pre\_hooks)
- [4.2.7 \_get\_name](#427-\_get\_name)
- [4.2.8 \_load\_from\_state\_dict](#428-\_load\_from\_state\_dict)
- [4.2.9 \_maybe\_warn\_non\_full\_backward\_hook](#429-\_maybe\_warn\_non\_full\_backward\_hook)
- [4.2.10 \_named\_members](#4210-\_named\_members)
- [4.2.11 \_register\_load\_state\_dict\_pre\_hook](#4211-\_register\_load\_state\_dict\_pre\_hook)
- [4.2.12 \_register\_state\_dict\_hook](#4212-\_register\_state\_dict\_hook)
- [4.2.13 \_replicate\_for\_data\_parallel](#4213-\_replicate\_for\_data\_parallel)
- [4.2.14 \_save\_to\_state\_dict](#4214-\_save\_to\_state\_dict)
- [4.2.15 \_slow\_forward](#4215-\_slow\_forward)
- [4.2.16 \_wrapped\_call\_impl](#4216-\_wrapped\_call\_impl)
- [4.2.17 add\_module](#4217-add\_module)
- [4.2.18 apply](#4218-apply)
- [4.2.19 bfloat16](#4219-bfloat16)
- [4.2.20 buffers](#4220-buffers)
- [ >>> # xdoctest: +SKIP("undefined vars")](#-->>>-#-xdoctest--+skip("undefined-vars"))
- [4.2.21 children](#4221-children)
- [4.2.22 compile](#4222-compile)
- [4.2.23 cpu](#4223-cpu)
- [4.2.24 cuda](#4224-cuda)
- [4.2.25 double](#4225-double)
- [4.2.26 eval](#4226-eval)
- [4.2.27 extra\_repr](#4227-extra\_repr)
- [4.2.28 float](#4228-float)
- [4.2.29 forward](#4229-forward)
- [4.2.30 get\_buffer](#4230-get\_buffer)
- [4.2.31 get\_extra\_state](#4231-get\_extra\_state)
- [4.2.32 get\_parameter](#4232-get\_parameter)
- [4.2.33 get\_submodule](#4233-get\_submodule)
- [4.2.34 half](#4234-half)
- [4.2.35 ipu](#4235-ipu)
- [4.2.36 load\_state\_dict](#4236-load\_state\_dict)
- [4.2.37 modules](#4237-modules)
- [4.2.38 mtia](#4238-mtia)
- [4.2.39 named\_buffers](#4239-named\_buffers)
- [ >>> # xdoctest: +SKIP("undefined vars")](#-->>>-#-xdoctest--+skip("undefined-vars"))
- [4.2.40 named\_children](#4240-named\_children)
- [ >>> # xdoctest: +SKIP("undefined vars")](#-->>>-#-xdoctest--+skip("undefined-vars"))
- [4.2.41 named\_modules](#4241-named\_modules)
- [4.2.42 named\_parameters](#4242-named\_parameters)
- [ >>> # xdoctest: +SKIP("undefined vars")](#-->>>-#-xdoctest--+skip("undefined-vars"))
- [4.2.43 parameters](#4243-parameters)
- [ >>> # xdoctest: +SKIP("undefined vars")](#-->>>-#-xdoctest--+skip("undefined-vars"))
- [4.2.44 register\_backward\_hook](#4244-register\_backward\_hook)
- [4.2.45 register\_buffer](#4245-register\_buffer)
- [ >>> # xdoctest: +SKIP("undefined vars")](#-->>>-#-xdoctest--+skip("undefined-vars"))
- [4.2.46 register\_forward\_hook](#4246-register\_forward\_hook)
- [4.2.47 register\_forward\_pre\_hook](#4247-register\_forward\_pre\_hook)
- [4.2.48 register\_full\_backward\_hook](#4248-register\_full\_backward\_hook)
- [4.2.49 register\_full\_backward\_pre\_hook](#4249-register\_full\_backward\_pre\_hook)
- [4.2.50 register\_load\_state\_dict\_post\_hook](#4250-register\_load\_state\_dict\_post\_hook)
- [4.2.51 register\_load\_state\_dict\_pre\_hook](#4251-register\_load\_state\_dict\_pre\_hook)
- [ hook(module, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -> None # noqa: B950](#--hook(module,-state_dict,-prefix,-local_metadata,-strict,-missing_keys,-unexpected_keys,-error_msgs)-->-none--#-noqa--b950)
- [4.2.52 register\_module](#4252-register\_module)
- [4.2.53 register\_parameter](#4253-register\_parameter)
- [4.2.54 register\_state\_dict\_post\_hook](#4254-register\_state\_dict\_post\_hook)
- [4.2.55 register\_state\_dict\_pre\_hook](#4255-register\_state\_dict\_pre\_hook)
- [4.2.56 requires\_grad\_](#4256-requires\_grad\_)
- [4.2.57 set\_extra\_state](#4257-set\_extra\_state)
- [4.2.58 set\_submodule](#4258-set\_submodule)
- [4.2.59 share\_memory](#4259-share\_memory)
- [4.2.60 state\_dict](#4260-state\_dict)
- [ >>> # xdoctest: +SKIP("undefined vars")](#-->>>-#-xdoctest--+skip("undefined-vars"))
- [4.2.61 to](#4261-to)
- [ >>> # xdoctest: +IGNORE_WANT("non-deterministic")](#-->>>-#-xdoctest--+ignore_want("non-deterministic"))
- [ >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1)](#-->>>-#-xdoctest--+requires(env-torch_doctest_cuda1))
- [4.2.62 to\_empty](#4262-to\_empty)
- [4.2.63 train](#4263-train)
- [4.2.64 type](#4264-type)
- [4.2.65 xpu](#4265-xpu)
- [4.2.66 zero\_grad](#4266-zero\_grad)
- [5 multi\_transforms](#5-multi\_transforms)
- [5.1 BGR2GRAY : MultiTransform](#51-bgr2gray--multitransform)
- [5.1.1 \_\_call\_\_](#511-\_\_call\_\_)
- [5.1.2 \_\_init\_\_](#512-\_\_init\_\_)
- [5.2 BGR2RGB : MultiTransform](#52-bgr2rgb--multitransform)
- [5.2.1 \_\_call\_\_](#521-\_\_call\_\_)
- [5.2.2 \_\_init\_\_](#522-\_\_init\_\_)
- [5.3 Brightness : MultiTransform](#53-brightness--multitransform)
- [5.3.1 \_\_call\_\_](#531-\_\_call\_\_)
- [5.3.2 \_\_init\_\_](#532-\_\_init\_\_)
- [5.4 CenterCrop : MultiTransform](#54-centercrop--multitransform)
- [5.4.1 \_\_call\_\_](#541-\_\_call\_\_)
- [5.4.2 \_\_init\_\_](#542-\_\_init\_\_)
- [5.5 Color : MultiTransform](#55-color--multitransform)
- [5.5.1 \_\_call\_\_](#551-\_\_call\_\_)
- [5.5.2 \_\_init\_\_](#552-\_\_init\_\_)
- [5.6 Compose : builtins.object](#56-compose--builtinsobject)
- [5.6.1 \_\_call\_\_](#561-\_\_call\_\_)
- [5.6.2 \_\_init\_\_](#562-\_\_init\_\_)
- [5.7 GaussianNoise : MultiTransform](#57-gaussiannoise--multitransform)
- [5.7.1 \_\_call\_\_](#571-\_\_call\_\_)
- [5.7.2 \_\_init\_\_](#572-\_\_init\_\_)
- [5.8 MultiTransform : builtins.object](#58-multitransform--builtinsobject)
- [5.8.1 \_\_call\_\_](#581-\_\_call\_\_)
- [5.8.2 \_\_init\_\_](#582-\_\_init\_\_)
- [5.9 Normalize : MultiTransform](#59-normalize--multitransform)
- [5.9.1 \_\_call\_\_](#591-\_\_call\_\_)
- [5.9.2 \_\_init\_\_](#592-\_\_init\_\_)
- [5.10 RGB2BGR : BGR2RGB](#510-rgb2bgr--bgr2rgb)
- [5.10.1 \_\_call\_\_](#5101-\_\_call\_\_)
- [5.10.2 \_\_init\_\_](#5102-\_\_init\_\_)
- [5.11 RandomCrop : MultiTransform](#511-randomcrop--multitransform)
- [5.11.1 \_\_call\_\_](#5111-\_\_call\_\_)
- [5.11.2 \_\_init\_\_](#5112-\_\_init\_\_)
- [5.12 RandomHorizontalFlip : MultiTransform](#512-randomhorizontalflip--multitransform)
- [5.12.1 \_\_call\_\_](#5121-\_\_call\_\_)
- [5.12.2 \_\_init\_\_](#5122-\_\_init\_\_)
- [5.13 RandomVerticalFlip : MultiTransform](#513-randomverticalflip--multitransform)
- [5.13.1 \_\_call\_\_](#5131-\_\_call\_\_)
- [5.13.2 \_\_init\_\_](#5132-\_\_init\_\_)
- [5.14 RemoveBackgroundAI : MultiTransform](#514-removebackgroundai--multitransform)
- [5.14.1 \_\_call\_\_](#5141-\_\_call\_\_)
- [5.14.2 \_\_init\_\_](#5142-\_\_init\_\_)
- [5.15 ReplaceBackground : MultiTransform](#515-replacebackground--multitransform)
- [5.15.1 \_\_call\_\_](#5151-\_\_call\_\_)
- [5.15.2 \_\_init\_\_](#5152-\_\_init\_\_)
- [5.16 Resize : MultiTransform](#516-resize--multitransform)
- [5.16.1 \_\_call\_\_](#5161-\_\_call\_\_)
- [5.16.2 \_\_init\_\_](#5162-\_\_init\_\_)
- [5.17 Rotate : MultiTransform](#517-rotate--multitransform)
- [5.17.1 \_\_call\_\_](#5171-\_\_call\_\_)
- [5.17.2 \_\_init\_\_](#5172-\_\_init\_\_)
- [5.18 Satturation : MultiTransform](#518-satturation--multitransform)
- [5.18.1 \_\_call\_\_](#5181-\_\_call\_\_)
- [5.18.2 \_\_init\_\_](#5182-\_\_init\_\_)
- [5.19 Scale : MultiTransform](#519-scale--multitransform)
- [5.19.1 \_\_call\_\_](#5191-\_\_call\_\_)
- [5.19.2 \_\_init\_\_](#5192-\_\_init\_\_)
- [5.20 Stack : MultiTransform](#520-stack--multitransform)
- [5.20.1 \_\_call\_\_](#5201-\_\_call\_\_)
- [5.20.2 \_\_init\_\_](#5202-\_\_init\_\_)
- [5.21 ToCVImage : MultiTransform](#521-tocvimage--multitransform)
- [5.21.1 \_\_call\_\_](#5211-\_\_call\_\_)
- [5.21.2 \_\_init\_\_](#5212-\_\_init\_\_)
- [5.22 ToNumpy : MultiTransform](#522-tonumpy--multitransform)
- [5.22.1 \_\_call\_\_](#5221-\_\_call\_\_)
- [5.22.2 \_\_init\_\_](#5222-\_\_init\_\_)
- [5.23 ToPILImage : MultiTransform](#523-topilimage--multitransform)
- [5.23.1 \_\_call\_\_](#5231-\_\_call\_\_)
- [5.23.2 \_\_init\_\_](#5232-\_\_init\_\_)
- [5.24 ToTensor : MultiTransform](#524-totensor--multitransform)
- [5.24.1 \_\_call\_\_](#5241-\_\_call\_\_)
- [5.24.2 \_\_init\_\_](#5242-\_\_init\_\_)
- [6 run](#6-run)
- [6.1 Run : builtins.object](#61-run--builtinsobject)
- [6.1.1 \_\_init\_\_](#611-\_\_init\_\_)
- [6.1.2 append](#612-append)
- [6.1.3 get\_avg](#613-get\_avg)
- [6.1.4 get\_val](#614-get\_val)
- [6.1.5 len](#615-len)
- [6.1.6 load\_best\_state\_dict](#616-load\_best\_state\_dict)
- [6.1.7 load\_state\_dict](#617-load\_state\_dict)
- [6.1.8 pickle\_dump](#618-pickle\_dump)
- [6.1.9 pickle\_load](#619-pickle\_load)
- [6.1.10 plot](#6110-plot)
- [6.1.11 recalculate\_moving\_average](#6111-recalculate\_moving\_average)
- [6.1.12 save](#6112-save)
- [6.1.13 save\_best\_state\_dict](#6113-save\_best\_state\_dict)
- [6.1.14 save\_state\_dict](#6114-save\_state\_dict)
- [6.1.15 train\_epoch](#6115-train\_epoch)
- [6.1.16 validate\_epoch](#6116-validate\_epoch)
# 1 dataset
[TOC](#table-of-contents)
## 1.1 HMDB51 : torch.utils.data.dataset.Dataset
[TOC](#table-of-contents)
**Description**
Dataset class for HMDB51.
**Example**
```python
from rsp.ml.dataset import HMDB51
import rsp.ml.multi_transforms as multi_transforms
import cv2 as cv
transforms = multi_transforms.Compose([
multi_transforms.Color(1.5, p=0.5),
multi_transforms.Stack()
])
ds = HMDB51('train', fold=1, transforms=transforms)
for X, T in ds:
for x in X.permute(0, 2, 3, 1):
img_color = x[:, :, :3].numpy()
img_depth = x[:, :, 3].numpy()
cv.imshow('color', img_color)
cv.imshow('depth', img_depth)
cv.waitKey(30)
```
### 1.1.1 \_\_init\_\_
[TOC](#table-of-contents)
**Description**
Initializes a new instance.
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| split | str | Dataset split [train|val|test] |
| fold | int | Fold number. The dataset is split into 3 folds. If fold is None, all folds will be loaded. |
| cache_dir | str, default = None | Directory to store the downloaded files. If set to `None`, the default cache directory will be used |
| force_reload | bool, default = False | If set to `True`, the dataset will be reloaded |
| target_size | (int, int), default = (400, 400) | Size of the frames. The frames will be resized to this size. |
| sequence_length | int, default = 30 | Length of the sequences |
| transforms | rsp.ml.multi_transforms.Compose = default = rsp.ml.multi_transforms.Compose([]) | Transformations, that will be applied to each input sequence. See documentation of `rsp.ml.multi_transforms` for more details. |
| verbose | bool, default = False | If set to `True`, the progress will be printed. |
## 1.2 Kinetics : torch.utils.data.dataset.Dataset
[TOC](#table-of-contents)
**Description**
Dataset class for the Kinetics dataset.
**Example**
```python
from rsp.ml.dataset import Kinetics
ds = Kinetics(split='train', type=400)
for X, T in ds:
print(X)
```
### 1.2.1 \_\_init\_\_
[TOC](#table-of-contents)
**Description**
Initializes a new instance.
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| split | str | Dataset split [train|val] |
| sequence_length | int, default = 60 | Length of the sequences |
| type | int, default = 400 | Type of the kineticts dataset. Currently only 400 is supported. |
| frame_size | (int, int), default = (400, 400) | Size of the frames. The frames will be resized to this size. |
| transforms | rsp.ml.multi_transforms.Compose = default = rsp.ml.multi_transforms.Compose([]) | Transformations, that will be applied to each input sequence. See documentation of `rsp.ml.multi_transforms` for more details. |
| cache_dir | str, default = None | Directory to store the downloaded files. If set to `None`, the default cache directory will be used |
| num_threads | int, default = 0 | Number of threads to use for downloading the files. |
| verbose | bool, default = True | If set to `True`, the progress and additional information will be printed. |
## 1.3 TUCHRI : torch.utils.data.dataset.Dataset
[TOC](#table-of-contents)
**Description**
Dataset class for the Robot Interaction Dataset by University of Technology Chemnitz (TUCHRI).
### 1.3.1 \_\_init\_\_
[TOC](#table-of-contents)
**Description**
Initializes a new instance.
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| phase | str | Dataset phase [train|val] |
| load_depth_data | bool, default = True | Load depth data |
| sequence_length | int, default = 30 | Length of the sequences |
| num_classes | int, default = 10 | Number of classes |
| transforms | rsp.ml.multi_transforms.Compose = default = rsp.ml.multi_transforms.Compose([]) | Transformations, that will be applied to each input sequence. See documentation of `rsp.ml.multi_transforms` for more details. |
### 1.3.2 get\_uniform\_sampler
[TOC](#table-of-contents)
### 1.3.3 load\_backgrounds
[TOC](#table-of-contents)
**Description**
Loads the background images.
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| load_depth_data | bool, default = True | If set to `True`, the depth images will be loaded as well. |
## 1.4 TUCRID : torch.utils.data.dataset.Dataset
[TOC](#table-of-contents)
**Description**
Dataset class for the Robot Interaction Dataset by University of Technology Chemnitz (TUCRID).
**Example**
```python
from rsp.ml.dataset import TUCRID
from rsp.ml.dataset import ReplaceBackgroundRGBD
import rsp.ml.multi_transforms as multi_transforms
import cv2 as cv
backgrounds = TUCRID.load_backgrounds_color()
transforms = multi_transforms.Compose([
ReplaceBackgroundRGBD(backgrounds),
multi_transforms.Stack()
])
ds = TUCRID('train', transforms=transforms)
for X, T in ds:
for x in X.permute(0, 2, 3, 1):
img_color = x[:, :, :3].numpy()
img_depth = x[:, :, 3].numpy()
cv.imshow('color', img_color)
cv.imshow('depth', img_depth)
cv.waitKey(30)
```
### 1.4.1 \_\_init\_\_
[TOC](#table-of-contents)
**Description**
Initializes a new instance.
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| phase | str | Dataset phase [train|val] |
| load_depth_data | bool, default = True | Load depth data |
| sequence_length | int, default = 30 | Length of the sequences |
| num_classes | int, default = 10 | Number of classes |
| transforms | rsp.ml.multi_transforms.Compose = default = rsp.ml.multi_transforms.Compose([]) | Transformations, that will be applied to each input sequence. See documentation of `rsp.ml.multi_transforms` for more details. |
### 1.4.2 get\_uniform\_sampler
[TOC](#table-of-contents)
### 1.4.3 load\_backgrounds
[TOC](#table-of-contents)
**Description**
Loads the background images.
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| load_depth_data | bool, default = True | If set to `True`, the depth images will be loaded as well. |
## 1.5 UCF101 : torch.utils.data.dataset.Dataset
[TOC](#table-of-contents)
**Description**
An abstract class representing a :class:`Dataset`.
All datasets that represent a map from keys to data samples should subclass
it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
data sample for a given key. Subclasses could also optionally overwrite
:meth:`__len__`, which is expected to return the size of the dataset by many
:class:`~torch.utils.data.Sampler` implementations and the default options
of :class:`~torch.utils.data.DataLoader`. Subclasses could also
optionally implement :meth:`__getitems__`, for speedup batched samples
loading. This method accepts list of indices of samples of batch and returns
list of samples.
.. note::
:class:`~torch.utils.data.DataLoader` by default constructs an index
sampler that yields integral indices. To make it work with a map-style
dataset with non-integral indices/keys, a custom sampler must be provided.
**Example**
```python
from rsp.ml.dataset import UCF101
import rsp.ml.multi_transforms as multi_transforms
import cv2 as cv
transforms = multi_transforms.Compose([
multi_transforms.Color(1.5, p=0.5),
multi_transforms.Stack()
])
ds = UCF101('train', fold=1, transforms=transforms)
for X, T in ds:
for x in X.permute(0, 2, 3, 1):
img_color = x[:, :, :3].numpy()
img_depth = x[:, :, 3].numpy()
cv.imshow('color', img_color)
cv.imshow('depth', img_depth)
cv.waitKey(30)
```
### 1.5.1 \_\_init\_\_
[TOC](#table-of-contents)
**Description**
Initializes a new instance.
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| split | str | Dataset split [train|val|test] |
| fold | int | Fold number. The dataset is split into 3 folds. If fold is None, all folds will be loaded. |
| cache_dir | str, default = None | Directory to store the downloaded files. If set to `None`, the default cache directory will be used |
| force_reload | bool, default = False | If set to `True`, the dataset will be reloaded |
| target_size | (int, int), default = (400, 400) | Size of the frames. The frames will be resized to this size. |
| sequence_length | int, default = 30 | Length of the sequences |
| transforms | rsp.ml.multi_transforms.Compose = default = rsp.ml.multi_transforms.Compose([]) | Transformations, that will be applied to each input sequence. See documentation of `rsp.ml.multi_transforms` for more details. |
| verbose | bool, default = False | If set to `True`, the progress will be printed. |
## 1.6 UTKinectAction3D : torch.utils.data.dataset.Dataset
[TOC](#table-of-contents)
**Description**
Dataset class for the UTKinectAction3D dataset.
Parameters
----------
split : str
Dataset split [train|val]
cache_dir : str, default = None
Directory to store the downloaded files. If set to `None`, the default cache directory will be used
force_reload : bool, default = False
If set to `True`, the dataset will be reloaded
target_size : (int, int), default = (400, 400)
Size of the frames. The frames will be resized to this size.
sequence_length : int, default = 30
Length of the sequences
transforms : rsp.ml.multi_transforms.Compose = default = rsp.ml.multi_transforms.Compose([])
Transformations, that will be applied to each input sequence. See documentation of `rsp.ml.multi_transforms` for more details.
verbose : bool, default = False
If set to `True`, the progress will be printed.
**Example**
```python
from rsp.ml.dataset import UTKinectAction3D
import rsp.ml.multi_transforms as multi_transforms
import cv2 as cv
transforms = multi_transforms.Compose([
multi_transforms.Color(1.5, p=0.5),
multi_transforms.Stack()
])
ds = UTKinectAction3D('train', transforms=transforms)
for X, T in ds:
for x in X.permute(0, 2, 3, 1):
img_color = x[:, :, :3].numpy()
img_depth = x[:, :, 3].numpy()
cv.imshow('color', img_color)
cv.imshow('depth', img_depth)
cv.waitKey(30)
```
### 1.6.1 \_\_init\_\_
[TOC](#table-of-contents)
**Description**
Initialize self. See help(type(self)) for accurate signature.
# 2 metrics
[TOC](#table-of-contents)
The module `rsp.ml.metrics` provides some functionality to quantify the quality of predictions.
## 2.1 AUROC
[TOC](#table-of-contents)
**Description**
Calculates the Area under the Receiver Operation Chracteristic Curve.
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| Y | torch.Tensor | Prediction |
| T | torch.Tensor | True values |
| num_thresholds | int, default = 100 | Number of thresholds to compute. |
**Returns**
Receiver Operation Chracteristic Area under the Curve : float
## 2.2 F1\_Score
[TOC](#table-of-contents)
**Description**
F1 Score. Expected input shape: (batch_size, num_classes)
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| Y | torch.Tensor | Prediction |
| T | torch.Tensor | True values |
| threshold | float | All values that are greater than or equal to the threshold are considered a positive class. |
**Returns**
F1 Score : float
**Equations**
$precision = \frac{TP}{TP + FP}$
$recall = \frac{TP}{TP + FN}$
$F_1 = \frac{2 \cdot precision \cdot recall}{precision + recall} = \frac{2 \cdot TP}{2 \cdot TP + FP + FN}$
**Example**
```python
import rsp.ml.metrics as m
Y = torch.tensor([
[0.1, 0.1, 0.8],
[0.03, 0.95, 0.02],
[0.05, 0.9, 0.05],
[0.01, 0.87, 0.12],
[0.04, 0.03, 0.93],
[0.94, 0.02, 0.06]
])
T = torch.tensor([
[0, 0, 1],
[1, 0, 0],
[0, 1, 0],
[0, 1, 0],
[0, 0, 1],
[1, 0, 0]
])
f1score = m.F1_Score(Y, T)
print(f1score) --> 0.5
```
## 2.3 FN
[TOC](#table-of-contents)
**Description**
False negatives. Expected input shape: (batch_size, num_classes)
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| Y | torch.Tensor | Prediction |
| T | torch.Tensor | True values |
| threshold | float | All values that are greater than or equal to the threshold are considered a positive class. |
**Returns**
False negatives : int
**Example**
```python
import rsp.ml.metrics as m
import torch
Y = torch.tensor([
[0.1, 0.1, 0.8],
[0.03, 0.95, 0.02],
[0.05, 0.9, 0.05],
[0.01, 0.87, 0.12],
[0.04, 0.03, 0.93],
[0.94, 0.02, 0.06]
])
T = torch.tensor([
[0, 0, 1],
[1, 0, 0],
[0, 1, 0],
[0, 1, 0],
[0, 0, 1],
[1, 0, 0]
])
fn = m.FN(Y, T)
print(fn) -> 1
```
## 2.4 FP
[TOC](#table-of-contents)
**Description**
False positives. Expected input shape: (batch_size, num_classes)
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| Y | torch.Tensor | Prediction |
| T | torch.Tensor | True values |
| threshold | float | All values that are greater than or equal to the threshold are considered a positive class. |
**Returns**
False positives : int
**Example**
```python
import rsp.ml.metrics as m
import torch
Y = torch.tensor([
[0.1, 0.1, 0.8],
[0.03, 0.95, 0.02],
[0.05, 0.9, 0.05],
[0.01, 0.87, 0.12],
[0.04, 0.03, 0.93],
[0.94, 0.02, 0.06]
])
T = torch.tensor([
[0, 0, 1],
[1, 0, 0],
[0, 1, 0],
[0, 1, 0],
[0, 0, 1],
[1, 0, 0]
])
fp = m.FP(Y, T)
print(fp) -> 1
```
## 2.5 FPR
[TOC](#table-of-contents)
**Description**
False positive rate. Expected input shape: (batch_size, num_classes)
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| Y | torch.Tensor | Prediction |
| T | torch.Tensor | True values |
| threshold | float | All values that are greater than or equal to the threshold are considered a positive class. |
**Returns**
False positive rate : float
**Example**
```python
import rsp.ml.metrics as m
import torch
Y = torch.tensor([
[0.1, 0.1, 0.8],
[0.03, 0.95, 0.02],
[0.05, 0.9, 0.05],
[0.01, 0.87, 0.12],
[0.04, 0.03, 0.93],
[0.94, 0.02, 0.06]
])
T = torch.tensor([
[0, 0, 1],
[1, 0, 0],
[0, 1, 0],
[0, 1, 0],
[0, 0, 1],
[1, 0, 0]
])
fpr = m.FPR(Y, T)
print(fpr) -> 0.08333333333333333
```
## 2.6 ROC
[TOC](#table-of-contents)
**Description**
Calculates the receiver operating characteristic: computes False Positive Rates and True positive Rates for `num_thresholds` aligned between 0 and 1
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| Y | torch.Tensor | Prediction |
| T | torch.Tensor | True values |
| num_thresholds | int, default = 100 | Number of thresholds to compute. |
**Returns**
(False Positive Rates, True Positive Rates) for 100 different thresholds : (List[float], List[float])
**Example**
```python
import rsp.ml.metrics as m
import torch
import torch.nn.functional as F
num_elements = 100000
num_classes = 7
T = []
for i in range(num_elements):
true_class = torch.randint(0, num_classes, (1,))
t = F.one_hot(true_class, num_classes=num_classes)
T.append(t)
T = torch.cat(T)
dist = torch.normal(T.float(), 1.5)
Y = F.softmax(dist, dim = 1)
FPRs, TPRs = m.ROC(Y, T)
```
## 2.7 TN
[TOC](#table-of-contents)
**Description**
True negatives. Expected input shape: (batch_size, num_classes)
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| Y | torch.Tensor | Prediction |
| T | torch.Tensor | True values |
| threshold | float | All values that are greater than or equal to the threshold are considered a positive class. |
**Returns**
True negatives : int
**Example**
```python
import rsp.ml.metrics as m
import torch
Y = torch.tensor([
[0.1, 0.1, 0.8],
[0.03, 0.95, 0.02],
[0.05, 0.9, 0.05],
[0.01, 0.87, 0.12],
[0.04, 0.03, 0.93],
[0.94, 0.02, 0.06]
])
T = torch.tensor([
[0, 0, 1],
[1, 0, 0],
[0, 1, 0],
[0, 1, 0],
[0, 0, 1],
[1, 0, 0]
])
tn = m.TN(Y, T)
print(tn) -> 11
```
## 2.8 TP
[TOC](#table-of-contents)
**Description**
True positives. Expected input shape: (batch_size, num_classes)
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| Y | torch.Tensor | Prediction |
| T | torch.Tensor | True values |
| threshold | float | All values that are greater than or equal to the threshold are considered a positive class. |
**Returns**
True positives : int
**Example**
```python
import rsp.ml.metrics as m
import torch
Y = torch.tensor([
[0.1, 0.1, 0.8],
[0.03, 0.95, 0.02],
[0.05, 0.9, 0.05],
[0.01, 0.87, 0.12],
[0.04, 0.03, 0.93],
[0.94, 0.02, 0.06]
])
T = torch.tensor([
[0, 0, 1],
[1, 0, 0],
[0, 1, 0],
[0, 1, 0],
[0, 0, 1],
[1, 0, 0]
])
tp = m.TP(Y, T)
print(tp) -> 5
```
## 2.9 TPR
[TOC](#table-of-contents)
**Description**
True positive rate. Expected input shape: (batch_size, num_classes)
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| Y | torch.Tensor | Prediction |
| T | torch.Tensor | True values |
| threshold | float | All values that are greater than or equal to the threshold are considered a positive class. |
**Returns**
True positive rate : float
**Example**
```python
import rsp.ml.metrics as m
import torch
Y = torch.tensor([
[0.1, 0.1, 0.8],
[0.03, 0.95, 0.02],
[0.05, 0.9, 0.05],
[0.01, 0.87, 0.12],
[0.04, 0.03, 0.93],
[0.94, 0.02, 0.06]
])
T = torch.tensor([
[0, 0, 1],
[1, 0, 0],
[0, 1, 0],
[0, 1, 0],
[0, 0, 1],
[1, 0, 0]
])
tpr = m.TPR(Y, T)
print(tpr) -> 0.8333333333333334
```
## 2.10 confusion\_matrix
[TOC](#table-of-contents)
**Description**
Calculates the confusion matrix. Expected input shape: (batch_size, num_classes)
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| Y | torch.Tensor | Prediction |
| T | torch.Tensor | True values |
**Returns**
Confusion matrix : torch.Tensor
**Example**
```python
import rsp.ml.metrics as m
import torch
Y = torch.tensor([
[0.1, 0.1, 0.8],
[0.03, 0.95, 0.02],
[0.05, 0.9, 0.05],
[0.01, 0.87, 0.12],
[0.04, 0.03, 0.93],
[0.94, 0.02, 0.06]
])
T = torch.tensor([
[0, 0, 1],
[1, 0, 0],
[0, 1, 0],
[0, 1, 0],
[0, 0, 1],
[1, 0, 0]
])
conf_mat = m.confusion_matrix(Y, T)
print(conf_mat) -> tensor([
[1, 1, 0],
[0, 2, 0],
[0, 0, 2]
])
```
## 2.11 plot\_ROC
[TOC](#table-of-contents)
**Description**
Plot the receiver operating characteristic.
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| Y | torch.Tensor | Prediction |
| T | torch.Tensor | True values |
| num_thresholds | int, default = 100 | Number of thresholds to compute. |
| title | str, optional, default = 'Confusion Matrix' | Title of the plot |
| class_curves | bool, default = False | Plot ROC curve for each class |
| labels | str, optional, default = None | Class labels -> automatic labeling C000, ..., CXXX if labels is None |
| plt_show | bool, optional, default = False | Set to True to show the plot |
| save_file_name | str, optional, default = None | If not None, the plot is saved under the specified save_file_name. |
**Returns**
Image of the confusion matrix : np.array

## 2.12 plot\_confusion\_matrix
[TOC](#table-of-contents)
**Description**
Plot the confusion matrix
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| confusion_matrix | torch.Tensor | Confusion matrix |
| labels | str, optional, default = None | Class labels -> automatic labeling C000, ..., CXXX if labels is None |
| cmap | str, optional, default = 'Blues' | Seaborn cmap, see https://r02b.github.io/seaborn_palettes/ |
| xlabel | str, optional, default = 'Predicted label' | X-Axis label |
| ylabel | str, optional, default = 'True label' | Y-Axis label |
| title | str, optional, default = 'Confusion Matrix' | Title of the plot |
| plt_show | bool, optional, default = False | Set to True to show the plot |
| save_file_name | str, optional, default = None | If not None, the plot is saved under the specified save_file_name. |
**Returns**
Image of the confusion matrix : np.array

## 2.13 precision
[TOC](#table-of-contents)
**Description**
Precision. Expected input shape: (batch_size, num_classes)
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| Y | torch.Tensor | Prediction |
| T | torch.Tensor | True values |
| threshold | float | All values that are greater than or equal to the threshold are considered a positive class. |
**Returns**
Precision : float
**Equations**
$precision = \frac{TP}{TP + FP}$
**Example**
```python
import rsp.ml.metrics as m
import torch
Y = torch.tensor([
[0.1, 0.1, 0.8],
[0.03, 0.95, 0.02],
[0.05, 0.9, 0.05],
[0.01, 0.87, 0.12],
[0.04, 0.03, 0.93],
[0.94, 0.02, 0.06]
])
T = torch.tensor([
[0, 0, 1],
[1, 0, 0],
[0, 1, 0],
[0, 1, 0],
[0, 0, 1],
[1, 0, 0]
])
precision = m.precision(Y, T)
print(precision) -> 0.8333333333333334
```
## 2.14 recall
[TOC](#table-of-contents)
**Description**
Recall. Expected input shape: (batch_size, num_classes)
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| Y | torch.Tensor | Prediction |
| T | torch.Tensor | True values |
| threshold | float | All values that are greater than or equal to the threshold are considered a positive class. |
**Returns**
Recall : float
**Equations**
$recall = \frac{TP}{TP + FN}$
**Example**
```python
import rsp.ml.metrics as m
import torch
Y = torch.tensor([
[0.1, 0.1, 0.8],
[0.03, 0.95, 0.02],
[0.05, 0.9, 0.05],
[0.01, 0.87, 0.12],
[0.04, 0.03, 0.93],
[0.94, 0.02, 0.06]
])
T = torch.tensor([
[0, 0, 1],
[1, 0, 0],
[0, 1, 0],
[0, 1, 0],
[0, 0, 1],
[1, 0, 0]
])
recall = m.recall(Y, T)
print(recall) -> 0.8333333333333334
```
## 2.15 top\_10\_accuracy
[TOC](#table-of-contents)
**Description**
Top 10 accuracy. Expected input shape: (batch_size, num_classes)
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| Y | torch.Tensor | Prediction |
| T | torch.Tensor | True values |
**Returns**
Top 10 accuracy -> top k accuracy | k = 10 : float
**Example**
```python
import rsp.ml.metrics as m
Y = torch.tensor([
[0.1, 0.1, 0.8],
[0.03, 0.95, 0.02],
[0.05, 0.9, 0.05],
[0.01, 0.87, 0.12],
[0.04, 0.03, 0.93],
[0.94, 0.02, 0.06]
])
T = torch.tensor([
[0, 0, 1],
[1, 0, 0],
[0, 1, 0],
[0, 1, 0],
[0, 0, 1],
[1, 0, 0]
])
top_10_accuracy = m.top_10_accuracy(Y, T, k = 3)
print(top_10_accuracy) --> 1.0
```
## 2.16 top\_1\_accuracy
[TOC](#table-of-contents)
**Description**
Top 1 accuracy. Expected input shape: (batch_size, num_classes)
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| Y | torch.Tensor | Prediction |
| T | torch.Tensor | True values |
**Returns**
Top 1 accuracy -> top k accuracy | k = 1 : float
**Example**
```python
import rsp.ml.metrics as m
Y = torch.tensor([
[0.1, 0.1, 0.8],
[0.03, 0.95, 0.02],
[0.05, 0.9, 0.05],
[0.01, 0.87, 0.12],
[0.04, 0.03, 0.93],
[0.94, 0.02, 0.06]
])
T = torch.tensor([
[0, 0, 1],
[1, 0, 0],
[0, 1, 0],
[0, 1, 0],
[0, 0, 1],
[1, 0, 0]
])
top_1_accuracy = m.top_1_accuracy(Y, T, k = 3)
print(top_1_accuracy) --> 0.8333333333333334
```
## 2.17 top\_2\_accuracy
[TOC](#table-of-contents)
**Description**
Top 2 accuracy. Expected input shape: (batch_size, num_classes)
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| Y | torch.Tensor | Prediction |
| T | torch.Tensor | True values |
**Returns**
Top 2 accuracy -> top k accuracy | k = 2 : float
**Example**
```python
import rsp.ml.metrics as m
Y = torch.tensor([
[0.1, 0.1, 0.8],
[0.03, 0.95, 0.02],
[0.05, 0.9, 0.05],
[0.01, 0.87, 0.12],
[0.04, 0.03, 0.93],
[0.94, 0.02, 0.06]
])
T = torch.tensor([
[0, 0, 1],
[1, 0, 0],
[0, 1, 0],
[0, 1, 0],
[0, 0, 1],
[1, 0, 0]
])
top_2_accuracy = m.top_2_accuracy(Y, T, k = 3)
print(top_2_accuracy) --> 1.0
```
## 2.18 top\_3\_accuracy
[TOC](#table-of-contents)
**Description**
Top 3 accuracy. Expected input shape: (batch_size, num_classes)
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| Y | torch.Tensor | Prediction |
| T | torch.Tensor | True values |
**Returns**
Top 3 accuracy -> top k accuracy | k = 3 : float
**Example**
```python
import rsp.ml.metrics as m
Y = torch.tensor([
[0.1, 0.1, 0.8],
[0.03, 0.95, 0.02],
[0.05, 0.9, 0.05],
[0.01, 0.87, 0.12],
[0.04, 0.03, 0.93],
[0.94, 0.02, 0.06]
])
T = torch.tensor([
[0, 0, 1],
[1, 0, 0],
[0, 1, 0],
[0, 1, 0],
[0, 0, 1],
[1, 0, 0]
])
top_3_accuracy = m.top_3_accuracy(Y, T, k = 3)
print(top_3_accuracy) --> 1.0
```
## 2.19 top\_5\_accuracy
[TOC](#table-of-contents)
**Description**
Top 5 accuracy. Expected input shape: (batch_size, num_classes)
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| Y | torch.Tensor | Prediction |
| T | torch.Tensor | True values |
**Returns**
Top 5 accuracy -> top k accuracy | k = 5 : float
**Example**
```python
import rsp.ml.metrics as m
Y = torch.tensor([
[0.1, 0.1, 0.8],
[0.03, 0.95, 0.02],
[0.05, 0.9, 0.05],
[0.01, 0.87, 0.12],
[0.04, 0.03, 0.93],
[0.94, 0.02, 0.06]
])
T = torch.tensor([
[0, 0, 1],
[1, 0, 0],
[0, 1, 0],
[0, 1, 0],
[0, 0, 1],
[1, 0, 0]
])
top_5_accuracy = m.top_5_accuracy(Y, T, k = 3)
print(top_5_accuracy) --> 1.0
```
## 2.20 top\_k\_accuracy
[TOC](#table-of-contents)
**Description**
Top k accuracy. Expected input shape: (batch_size, num_classes)
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| Y | torch.Tensor | Prediction |
| T | torch.Tensor | True values |
**Returns**
Top k accuracy : float
**Example**
```python
import rsp.ml.metrics as m
Y = torch.tensor([
[0.1, 0.1, 0.8],
[0.03, 0.95, 0.02],
[0.05, 0.9, 0.05],
[0.01, 0.87, 0.12],
[0.04, 0.03, 0.93],
[0.94, 0.02, 0.06]
])
T = torch.tensor([
[0, 0, 1],
[1, 0, 0],
[0, 1, 0],
[0, 1, 0],
[0, 0, 1],
[1, 0, 0]
])
top_k_accuracy = m.top_k_accuracy(Y, T, k = 3)
print(top_k_accuracy) --> 1.0
```
# 3 model
[TOC](#table-of-contents)
The module `rsp.ml.model` provides some usefull functionality to store and load pytorch models.
## 3.1 MODELS : enum.Enum
[TOC](#table-of-contents)
**Description**
Create a collection of name/value pairs.
Example enumeration:
>>> class Color(Enum):
... RED = 1
... BLUE = 2
... GREEN = 3
Access them by:
- attribute access::
>>> Color.RED
<Color.RED: 1>
- value lookup:
>>> Color(1)
<Color.RED: 1>
- name lookup:
>>> Color['RED']
<Color.RED: 1>
Enumerations can be iterated over, and know how many members they have:
>>> len(Color)
3
>>> list(Color)
[<Color.RED: 1>, <Color.BLUE: 2>, <Color.GREEN: 3>]
Methods can be added to enumerations, and members can have their own
attributes -- see the documentation for details.
## 3.2 WEIGHTS : enum.Enum
[TOC](#table-of-contents)
**Description**
Create a collection of name/value pairs.
Example enumeration:
>>> class Color(Enum):
... RED = 1
... BLUE = 2
... GREEN = 3
Access them by:
- attribute access::
>>> Color.RED
<Color.RED: 1>
- value lookup:
>>> Color(1)
<Color.RED: 1>
- name lookup:
>>> Color['RED']
<Color.RED: 1>
Enumerations can be iterated over, and know how many members they have:
>>> len(Color)
3
>>> list(Color)
[<Color.RED: 1>, <Color.BLUE: 2>, <Color.GREEN: 3>]
Methods can be added to enumerations, and members can have their own
attributes -- see the documentation for details.
## 3.3 list\_model\_weights
[TOC](#table-of-contents)
**Description**
Lists all available weight files.
**Returns**
List of (MODEL:str, WEIGHT:str) : List[Tuple(str, str)]
**Example**
```python
import rsp.ml.model as model
model_weight_files = model.list_model_weights()
```
## 3.4 load\_model
[TOC](#table-of-contents)
**Description**
Loads a pretrained PyTorch model from HuggingFace.
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| model | MODELS | ID of the model |
| weights | WEIGHTS | ID of the weights |
**Returns**
Pretrained PyTorch model : torch.nn.Module
**Example**
```python
import rsp.ml.model as model
action_recognition_model = model.load_model(MODEL.TUCARC3D, WEIGHTS.TUCAR)
```
## 3.5 publish\_model
[TOC](#table-of-contents)
# 4 module
[TOC](#table-of-contents)
## 4.1 MultiHeadSelfAttention : torch.nn.modules.module.Module
[TOC](#table-of-contents)
**Description**
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in
a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call :meth:`to`, etc.
.. note::
As per the example above, an ``__init__()`` call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or
evaluation mode.
:vartype training: bool
##### 4.1.1 \_wrapped\_call\_impl
[TOC](#table-of-contents)
### 4.1.2 \_\_init\_\_
[TOC](#table-of-contents)
**Description**
Initialize internal Module state, shared by both nn.Module and ScriptModule.
##### 4.1.3 \_apply
[TOC](#table-of-contents)
##### 4.1.4 \_call\_impl
[TOC](#table-of-contents)
##### 4.1.5 \_get\_backward\_hooks
[TOC](#table-of-contents)
**Description**
Return the backward hooks for use in the call function.
It returns two lists, one with the full backward hooks and one with the non-full
backward hooks.
##### 4.1.6 \_get\_backward\_pre\_hooks
[TOC](#table-of-contents)
##### 4.1.7 \_get\_name
[TOC](#table-of-contents)
##### 4.1.8 \_load\_from\_state\_dict
[TOC](#table-of-contents)
**Description**
Copy parameters and buffers from :attr:`state_dict` into only this module, but not its descendants.
This is called on every submodule
in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
module in input :attr:`state_dict` is provided as :attr:`local_metadata`.
For state dicts without metadata, :attr:`local_metadata` is empty.
Subclasses can achieve class-specific backward compatible loading using
the version number at `local_metadata.get("version", None)`.
Additionally, :attr:`local_metadata` can also contain the key
`assign_to_params_buffers` that indicates whether keys should be
assigned their corresponding tensor in the state_dict.
.. note::
:attr:`state_dict` is not the same object as the input
:attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So
it can be modified.
Args:
state_dict (dict): a dict containing parameters and
persistent buffers.
prefix (str): the prefix for parameters and buffers used in this
module
local_metadata (dict): a dict containing the metadata for this module.
See
strict (bool): whether to strictly enforce that the keys in
:attr:`state_dict` with :attr:`prefix` match the names of
parameters and buffers in this module
missing_keys (list of str): if ``strict=True``, add missing keys to
this list
unexpected_keys (list of str): if ``strict=True``, add unexpected
keys to this list
error_msgs (list of str): error messages should be added to this
list, and will be reported together in
:meth:`~torch.nn.Module.load_state_dict`
##### 4.1.9 \_maybe\_warn\_non\_full\_backward\_hook
[TOC](#table-of-contents)
##### 4.1.10 \_named\_members
[TOC](#table-of-contents)
**Description**
Help yield various names + members of modules.
##### 4.1.11 \_register\_load\_state\_dict\_pre\_hook
[TOC](#table-of-contents)
**Description**
See :meth:`~torch.nn.Module.register_load_state_dict_pre_hook` for details.
A subtle difference is that if ``with_module`` is set to ``False``, then the
hook will not take the ``module`` as the first argument whereas
:meth:`~torch.nn.Module.register_load_state_dict_pre_hook` always takes the
``module`` as the first argument.
Arguments:
hook (Callable): Callable hook that will be invoked before
loading the state dict.
with_module (bool, optional): Whether or not to pass the module
instance to the hook as the first parameter.
##### 4.1.12 \_register\_state\_dict\_hook
[TOC](#table-of-contents)
**Description**
Register a post-hook for the :meth:`~torch.nn.Module.state_dict` method.
It should have the following signature::
hook(module, state_dict, prefix, local_metadata) -> None or state_dict
The registered hooks can modify the ``state_dict`` inplace or return a new one.
If a new ``state_dict`` is returned, it will only be respected if it is the root
module that :meth:`~nn.Module.state_dict` is called from.
##### 4.1.13 \_replicate\_for\_data\_parallel
[TOC](#table-of-contents)
##### 4.1.14 \_save\_to\_state\_dict
[TOC](#table-of-contents)
**Description**
Save module state to the `destination` dictionary.
The `destination` dictionary will contain the state
of the module, but not its descendants. This is called on every
submodule in :meth:`~torch.nn.Module.state_dict`.
In rare cases, subclasses can achieve class-specific behavior by
overriding this method with custom logic.
Args:
destination (dict): a dict where state will be stored
prefix (str): the prefix for parameters and buffers used in this
module
##### 4.1.15 \_slow\_forward
[TOC](#table-of-contents)
##### 4.1.16 \_wrapped\_call\_impl
[TOC](#table-of-contents)
##### 4.1.17 add\_module
[TOC](#table-of-contents)
**Description**
Add a child module to the current module.
The module can be accessed as an attribute using the given name.
Args:
name (str): name of the child module. The child module can be
accessed from this module using the given name
module (Module): child module to be added to the module.
##### 4.1.18 apply
[TOC](#table-of-contents)
**Description**
Apply ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self.
Typical use includes initializing the parameters of a model
(see also :ref:`nn-init-doc`).
Args:
fn (:class:`Module` -> None): function to be applied to each submodule
Returns:
Module: self
Example::
>>> @torch.no_grad()
>>> def init_weights(m):
>>> print(m)
>>> if type(m) == nn.Linear:
>>> m.weight.fill_(1.0)
>>> print(m.weight)
>>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
>>> net.apply(init_weights)
Linear(in_features=2, out_features=2, bias=True)
Parameter containing:
tensor([[1., 1.],
[1., 1.]], requires_grad=True)
Linear(in_features=2, out_features=2, bias=True)
Parameter containing:
tensor([[1., 1.],
[1., 1.]], requires_grad=True)
Sequential(
(0): Linear(in_features=2, out_features=2, bias=True)
(1): Linear(in_features=2, out_features=2, bias=True)
)
##### 4.1.19 bfloat16
[TOC](#table-of-contents)
**Description**
Casts all floating point parameters and buffers to ``bfloat16`` datatype.
.. note::
This method modifies the module in-place.
Returns:
Module: self
##### 4.1.20 buffers
[TOC](#table-of-contents)
**Description**
Return an iterator over module buffers.
Args:
recurse (bool): if True, then yields buffers of this module
and all submodules. Otherwise, yields only buffers that
are direct members of this module.
Yields:
torch.Tensor: module buffer
Example::
>>> # xdoctest: +SKIP("undefined vars")
>>> for buf in model.buffers():
>>> print(type(buf), buf.size())
<class 'torch.Tensor'> (20L,)
<class 'torch.Tensor'> (20L, 1L, 5L, 5L)
##### 4.1.21 children
[TOC](#table-of-contents)
**Description**
Return an iterator over immediate children modules.
Yields:
Module: a child module
##### 4.1.22 compile
[TOC](#table-of-contents)
**Description**
Compile this Module's forward using :func:`torch.compile`.
This Module's `__call__` method is compiled and all arguments are passed as-is
to :func:`torch.compile`.
See :func:`torch.compile` for details on the arguments for this function.
##### 4.1.23 cpu
[TOC](#table-of-contents)
**Description**
Move all model parameters and buffers to the CPU.
.. note::
This method modifies the module in-place.
Returns:
Module: self
##### 4.1.24 cuda
[TOC](#table-of-contents)
**Description**
Move all model parameters and buffers to the GPU.
This also makes associated parameters and buffers different objects. So
it should be called before constructing the optimizer if the module will
live on GPU while being optimized.
.. note::
This method modifies the module in-place.
Args:
device (int, optional): if specified, all parameters will be
copied to that device
Returns:
Module: self
##### 4.1.25 double
[TOC](#table-of-contents)
**Description**
Casts all floating point parameters and buffers to ``double`` datatype.
.. note::
This method modifies the module in-place.
Returns:
Module: self
##### 4.1.26 eval
[TOC](#table-of-contents)
**Description**
Set the module in evaluation mode.
This has an effect only on certain modules. See the documentation of
particular modules for details of their behaviors in training/evaluation
mode, i.e. whether they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
etc.
This is equivalent with :meth:`self.train(False) <torch.nn.Module.train>`.
See :ref:`locally-disable-grad-doc` for a comparison between
`.eval()` and several similar mechanisms that may be confused with it.
Returns:
Module: self
##### 4.1.27 extra\_repr
[TOC](#table-of-contents)
**Description**
Return the extra representation of the module.
To print customized extra information, you should re-implement
this method in your own modules. Both single-line and multi-line
strings are acceptable.
##### 4.1.28 float
[TOC](#table-of-contents)
**Description**
Casts all floating point parameters and buffers to ``float`` datatype.
.. note::
This method modifies the module in-place.
Returns:
Module: self
### 4.1.29 forward
[TOC](#table-of-contents)
**Description**
Define the computation performed at every call.
Should be overridden by all subclasses.
.. note::
Although the recipe for forward pass needs to be defined within
this function, one should call the :class:`Module` instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
##### 4.1.30 get\_buffer
[TOC](#table-of-contents)
**Description**
Return the buffer given by ``target`` if it exists, otherwise throw an error.
See the docstring for ``get_submodule`` for a more detailed
explanation of this method's functionality as well as how to
correctly specify ``target``.
Args:
target: The fully-qualified string name of the buffer
to look for. (See ``get_submodule`` for how to specify a
fully-qualified string.)
Returns:
torch.Tensor: The buffer referenced by ``target``
Raises:
AttributeError: If the target string references an invalid
path or resolves to something that is not a
buffer
##### 4.1.31 get\_extra\_state
[TOC](#table-of-contents)
**Description**
Return any extra state to include in the module's state_dict.
Implement this and a corresponding :func:`set_extra_state` for your module
if you need to store extra state. This function is called when building the
module's `state_dict()`.
Note that extra state should be picklable to ensure working serialization
of the state_dict. We only provide backwards compatibility guarantees
for serializing Tensors; other objects may break backwards compatibility if
their serialized pickled form changes.
Returns:
object: Any extra state to store in the module's state_dict
##### 4.1.32 get\_parameter
[TOC](#table-of-contents)
**Description**
Return the parameter given by ``target`` if it exists, otherwise throw an error.
See the docstring for ``get_submodule`` for a more detailed
explanation of this method's functionality as well as how to
correctly specify ``target``.
Args:
target: The fully-qualified string name of the Parameter
to look for. (See ``get_submodule`` for how to specify a
fully-qualified string.)
Returns:
torch.nn.Parameter: The Parameter referenced by ``target``
Raises:
AttributeError: If the target string references an invalid
path or resolves to something that is not an
``nn.Parameter``
##### 4.1.33 get\_submodule
[TOC](#table-of-contents)
**Description**
Return the submodule given by ``target`` if it exists, otherwise throw an error.
For example, let's say you have an ``nn.Module`` ``A`` that
looks like this:
.. code-block:: text
A(
(net_b): Module(
(net_c): Module(
(conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2))
)
(linear): Linear(in_features=100, out_features=200, bias=True)
)
)
(The diagram shows an ``nn.Module`` ``A``. ``A`` which has a nested
submodule ``net_b``, which itself has two submodules ``net_c``
and ``linear``. ``net_c`` then has a submodule ``conv``.)
To check whether or not we have the ``linear`` submodule, we
would call ``get_submodule("net_b.linear")``. To check whether
we have the ``conv`` submodule, we would call
``get_submodule("net_b.net_c.conv")``.
The runtime of ``get_submodule`` is bounded by the degree
of module nesting in ``target``. A query against
``named_modules`` achieves the same result, but it is O(N) in
the number of transitive modules. So, for a simple check to see
if some submodule exists, ``get_submodule`` should always be
used.
Args:
target: The fully-qualified string name of the submodule
to look for. (See above example for how to specify a
fully-qualified string.)
Returns:
torch.nn.Module: The submodule referenced by ``target``
Raises:
AttributeError: If the target string references an invalid
path or resolves to something that is not an
``nn.Module``
##### 4.1.34 half
[TOC](#table-of-contents)
**Description**
Casts all floating point parameters and buffers to ``half`` datatype.
.. note::
This method modifies the module in-place.
Returns:
Module: self
##### 4.1.35 ipu
[TOC](#table-of-contents)
**Description**
Move all model parameters and buffers to the IPU.
This also makes associated parameters and buffers different objects. So
it should be called before constructing the optimizer if the module will
live on IPU while being optimized.
.. note::
This method modifies the module in-place.
Arguments:
device (int, optional): if specified, all parameters will be
copied to that device
Returns:
Module: self
##### 4.1.36 load\_state\_dict
[TOC](#table-of-contents)
**Description**
Copy parameters and buffers from :attr:`state_dict` into this module and its descendants.
If :attr:`strict` is ``True``, then
the keys of :attr:`state_dict` must exactly match the keys returned
by this module's :meth:`~torch.nn.Module.state_dict` function.
.. warning::
If :attr:`assign` is ``True`` the optimizer must be created after
the call to :attr:`load_state_dict` unless
:func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``.
Args:
state_dict (dict): a dict containing parameters and
persistent buffers.
strict (bool, optional): whether to strictly enforce that the keys
in :attr:`state_dict` match the keys returned by this module's
:meth:`~torch.nn.Module.state_dict` function. Default: ``True``
assign (bool, optional): When set to ``False``, the properties of the tensors
in the current module are preserved whereas setting it to ``True`` preserves
properties of the Tensors in the state dict. The only
exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s
for which the value from the module is preserved.
Default: ``False``
Returns:
``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
* **missing_keys** is a list of str containing any keys that are expected
by this module but missing from the provided ``state_dict``.
* **unexpected_keys** is a list of str containing the keys that are not
expected by this module but present in the provided ``state_dict``.
Note:
If a parameter or buffer is registered as ``None`` and its corresponding key
exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a
``RuntimeError``.
##### 4.1.37 modules
[TOC](#table-of-contents)
**Description**
Return an iterator over all modules in the network.
Yields:
Module: a module in the network
Note:
Duplicate modules are returned only once. In the following
example, ``l`` will be returned only once.
Example::
>>> l = nn.Linear(2, 2)
>>> net = nn.Sequential(l, l)
>>> for idx, m in enumerate(net.modules()):
... print(idx, '->', m)
0 -> Sequential(
(0): Linear(in_features=2, out_features=2, bias=True)
(1): Linear(in_features=2, out_features=2, bias=True)
)
1 -> Linear(in_features=2, out_features=2, bias=True)
##### 4.1.38 mtia
[TOC](#table-of-contents)
**Description**
Move all model parameters and buffers to the MTIA.
This also makes associated parameters and buffers different objects. So
it should be called before constructing the optimizer if the module will
live on MTIA while being optimized.
.. note::
This method modifies the module in-place.
Arguments:
device (int, optional): if specified, all parameters will be
copied to that device
Returns:
Module: self
##### 4.1.39 named\_buffers
[TOC](#table-of-contents)
**Description**
Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
Args:
prefix (str): prefix to prepend to all buffer names.
recurse (bool, optional): if True, then yields buffers of this module
and all submodules. Otherwise, yields only buffers that
are direct members of this module. Defaults to True.
remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Defaults to True.
Yields:
(str, torch.Tensor): Tuple containing the name and buffer
Example::
>>> # xdoctest: +SKIP("undefined vars")
>>> for name, buf in self.named_buffers():
>>> if name in ['running_var']:
>>> print(buf.size())
##### 4.1.40 named\_children
[TOC](#table-of-contents)
**Description**
Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.
Yields:
(str, Module): Tuple containing a name and child module
Example::
>>> # xdoctest: +SKIP("undefined vars")
>>> for name, module in model.named_children():
>>> if name in ['conv4', 'conv5']:
>>> print(module)
##### 4.1.41 named\_modules
[TOC](#table-of-contents)
**Description**
Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.
Args:
memo: a memo to store the set of modules already added to the result
prefix: a prefix that will be added to the name of the module
remove_duplicate: whether to remove the duplicated module instances in the result
or not
Yields:
(str, Module): Tuple of name and module
Note:
Duplicate modules are returned only once. In the following
example, ``l`` will be returned only once.
Example::
>>> l = nn.Linear(2, 2)
>>> net = nn.Sequential(l, l)
>>> for idx, m in enumerate(net.named_modules()):
... print(idx, '->', m)
0 -> ('', Sequential(
(0): Linear(in_features=2, out_features=2, bias=True)
(1): Linear(in_features=2, out_features=2, bias=True)
))
1 -> ('0', Linear(in_features=2, out_features=2, bias=True))
##### 4.1.42 named\_parameters
[TOC](#table-of-contents)
**Description**
Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
Args:
prefix (str): prefix to prepend to all parameter names.
recurse (bool): if True, then yields parameters of this module
and all submodules. Otherwise, yields only parameters that
are direct members of this module.
remove_duplicate (bool, optional): whether to remove the duplicated
parameters in the result. Defaults to True.
Yields:
(str, Parameter): Tuple containing the name and parameter
Example::
>>> # xdoctest: +SKIP("undefined vars")
>>> for name, param in self.named_parameters():
>>> if name in ['bias']:
>>> print(param.size())
##### 4.1.43 parameters
[TOC](#table-of-contents)
**Description**
Return an iterator over module parameters.
This is typically passed to an optimizer.
Args:
recurse (bool): if True, then yields parameters of this module
and all submodules. Otherwise, yields only parameters that
are direct members of this module.
Yields:
Parameter: module parameter
Example::
>>> # xdoctest: +SKIP("undefined vars")
>>> for param in model.parameters():
>>> print(type(param), param.size())
<class 'torch.Tensor'> (20L,)
<class 'torch.Tensor'> (20L, 1L, 5L, 5L)
##### 4.1.44 register\_backward\_hook
[TOC](#table-of-contents)
**Description**
Register a backward hook on the module.
This function is deprecated in favor of :meth:`~torch.nn.Module.register_full_backward_hook` and
the behavior of this function will change in future versions.
Returns:
:class:`torch.utils.hooks.RemovableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
##### 4.1.45 register\_buffer
[TOC](#table-of-contents)
**Description**
Add a buffer to the module.
This is typically used to register a buffer that should not to be
considered a model parameter. For example, BatchNorm's ``running_mean``
is not a parameter, but is part of the module's state. Buffers, by
default, are persistent and will be saved alongside parameters. This
behavior can be changed by setting :attr:`persistent` to ``False``. The
only difference between a persistent buffer and a non-persistent buffer
is that the latter will not be a part of this module's
:attr:`state_dict`.
Buffers can be accessed as attributes using given names.
Args:
name (str): name of the buffer. The buffer can be accessed
from this module using the given name
tensor (Tensor or None): buffer to be registered. If ``None``, then operations
that run on buffers, such as :attr:`cuda`, are ignored. If ``None``,
the buffer is **not** included in the module's :attr:`state_dict`.
persistent (bool): whether the buffer is part of this module's
:attr:`state_dict`.
Example::
>>> # xdoctest: +SKIP("undefined vars")
>>> self.register_buffer('running_mean', torch.zeros(num_features))
##### 4.1.46 register\_forward\_hook
[TOC](#table-of-contents)
**Description**
Register a forward hook on the module.
The hook will be called every time after :func:`forward` has computed an output.
If ``with_kwargs`` is ``False`` or not specified, the input contains only
the positional arguments given to the module. Keyword arguments won't be
passed to the hooks and only to the ``forward``. The hook can modify the
output. It can modify the input inplace but it will not have effect on
forward since this is called after :func:`forward` is called. The hook
should have the following signature::
hook(module, args, output) -> None or modified output
If ``with_kwargs`` is ``True``, the forward hook will be passed the
``kwargs`` given to the forward function and be expected to return the
output possibly modified. The hook should have the following signature::
hook(module, args, kwargs, output) -> None or modified output
Args:
hook (Callable): The user defined hook to be registered.
prepend (bool): If ``True``, the provided ``hook`` will be fired
before all existing ``forward`` hooks on this
:class:`torch.nn.modules.Module`. Otherwise, the provided
``hook`` will be fired after all existing ``forward`` hooks on
this :class:`torch.nn.modules.Module`. Note that global
``forward`` hooks registered with
:func:`register_module_forward_hook` will fire before all hooks
registered by this method.
Default: ``False``
with_kwargs (bool): If ``True``, the ``hook`` will be passed the
kwargs given to the forward function.
Default: ``False``
always_call (bool): If ``True`` the ``hook`` will be run regardless of
whether an exception is raised while calling the Module.
Default: ``False``
Returns:
:class:`torch.utils.hooks.RemovableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
##### 4.1.47 register\_forward\_pre\_hook
[TOC](#table-of-contents)
**Description**
Register a forward pre-hook on the module.
The hook will be called every time before :func:`forward` is invoked.
If ``with_kwargs`` is false or not specified, the input contains only
the positional arguments given to the module. Keyword arguments won't be
passed to the hooks and only to the ``forward``. The hook can modify the
input. User can either return a tuple or a single modified value in the
hook. We will wrap the value into a tuple if a single value is returned
(unless that value is already a tuple). The hook should have the
following signature::
hook(module, args) -> None or modified input
If ``with_kwargs`` is true, the forward pre-hook will be passed the
kwargs given to the forward function. And if the hook modifies the
input, both the args and kwargs should be returned. The hook should have
the following signature::
hook(module, args, kwargs) -> None or a tuple of modified input and kwargs
Args:
hook (Callable): The user defined hook to be registered.
prepend (bool): If true, the provided ``hook`` will be fired before
all existing ``forward_pre`` hooks on this
:class:`torch.nn.modules.Module`. Otherwise, the provided
``hook`` will be fired after all existing ``forward_pre`` hooks
on this :class:`torch.nn.modules.Module`. Note that global
``forward_pre`` hooks registered with
:func:`register_module_forward_pre_hook` will fire before all
hooks registered by this method.
Default: ``False``
with_kwargs (bool): If true, the ``hook`` will be passed the kwargs
given to the forward function.
Default: ``False``
Returns:
:class:`torch.utils.hooks.RemovableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
##### 4.1.48 register\_full\_backward\_hook
[TOC](#table-of-contents)
**Description**
Register a backward hook on the module.
The hook will be called every time the gradients with respect to a module
are computed, i.e. the hook will execute if and only if the gradients with
respect to module outputs are computed. The hook should have the following
signature::
hook(module, grad_input, grad_output) -> tuple(Tensor) or None
The :attr:`grad_input` and :attr:`grad_output` are tuples that contain the gradients
with respect to the inputs and outputs respectively. The hook should
not modify its arguments, but it can optionally return a new gradient with
respect to the input that will be used in place of :attr:`grad_input` in
subsequent computations. :attr:`grad_input` will only correspond to the inputs given
as positional arguments and all kwarg arguments are ignored. Entries
in :attr:`grad_input` and :attr:`grad_output` will be ``None`` for all non-Tensor
arguments.
For technical reasons, when this hook is applied to a Module, its forward function will
receive a view of each Tensor passed to the Module. Similarly the caller will receive a view
of each Tensor returned by the Module's forward function.
.. warning ::
Modifying inputs or outputs inplace is not allowed when using backward hooks and
will raise an error.
Args:
hook (Callable): The user-defined hook to be registered.
prepend (bool): If true, the provided ``hook`` will be fired before
all existing ``backward`` hooks on this
:class:`torch.nn.modules.Module`. Otherwise, the provided
``hook`` will be fired after all existing ``backward`` hooks on
this :class:`torch.nn.modules.Module`. Note that global
``backward`` hooks registered with
:func:`register_module_full_backward_hook` will fire before
all hooks registered by this method.
Returns:
:class:`torch.utils.hooks.RemovableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
##### 4.1.49 register\_full\_backward\_pre\_hook
[TOC](#table-of-contents)
**Description**
Register a backward pre-hook on the module.
The hook will be called every time the gradients for the module are computed.
The hook should have the following signature::
hook(module, grad_output) -> tuple[Tensor] or None
The :attr:`grad_output` is a tuple. The hook should
not modify its arguments, but it can optionally return a new gradient with
respect to the output that will be used in place of :attr:`grad_output` in
subsequent computations. Entries in :attr:`grad_output` will be ``None`` for
all non-Tensor arguments.
For technical reasons, when this hook is applied to a Module, its forward function will
receive a view of each Tensor passed to the Module. Similarly the caller will receive a view
of each Tensor returned by the Module's forward function.
.. warning ::
Modifying inputs inplace is not allowed when using backward hooks and
will raise an error.
Args:
hook (Callable): The user-defined hook to be registered.
prepend (bool): If true, the provided ``hook`` will be fired before
all existing ``backward_pre`` hooks on this
:class:`torch.nn.modules.Module`. Otherwise, the provided
``hook`` will be fired after all existing ``backward_pre`` hooks
on this :class:`torch.nn.modules.Module`. Note that global
``backward_pre`` hooks registered with
:func:`register_module_full_backward_pre_hook` will fire before
all hooks registered by this method.
Returns:
:class:`torch.utils.hooks.RemovableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
##### 4.1.50 register\_load\_state\_dict\_post\_hook
[TOC](#table-of-contents)
**Description**
Register a post-hook to be run after module's :meth:`~nn.Module.load_state_dict` is called.
It should have the following signature::
hook(module, incompatible_keys) -> None
The ``module`` argument is the current module that this hook is registered
on, and the ``incompatible_keys`` argument is a ``NamedTuple`` consisting
of attributes ``missing_keys`` and ``unexpected_keys``. ``missing_keys``
is a ``list`` of ``str`` containing the missing keys and
``unexpected_keys`` is a ``list`` of ``str`` containing the unexpected keys.
The given incompatible_keys can be modified inplace if needed.
Note that the checks performed when calling :func:`load_state_dict` with
``strict=True`` are affected by modifications the hook makes to
``missing_keys`` or ``unexpected_keys``, as expected. Additions to either
set of keys will result in an error being thrown when ``strict=True``, and
clearing out both missing and unexpected keys will avoid an error.
Returns:
:class:`torch.utils.hooks.RemovableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
##### 4.1.51 register\_load\_state\_dict\_pre\_hook
[TOC](#table-of-contents)
**Description**
Register a pre-hook to be run before module's :meth:`~nn.Module.load_state_dict` is called.
It should have the following signature::
hook(module, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -> None # noqa: B950
Arguments:
hook (Callable): Callable hook that will be invoked before
loading the state dict.
##### 4.1.52 register\_module
[TOC](#table-of-contents)
**Description**
Alias for :func:`add_module`.
##### 4.1.53 register\_parameter
[TOC](#table-of-contents)
**Description**
Add a parameter to the module.
The parameter can be accessed as an attribute using given name.
Args:
name (str): name of the parameter. The parameter can be accessed
from this module using the given name
param (Parameter or None): parameter to be added to the module. If
``None``, then operations that run on parameters, such as :attr:`cuda`,
are ignored. If ``None``, the parameter is **not** included in the
module's :attr:`state_dict`.
##### 4.1.54 register\_state\_dict\_post\_hook
[TOC](#table-of-contents)
**Description**
Register a post-hook for the :meth:`~torch.nn.Module.state_dict` method.
It should have the following signature::
hook(module, state_dict, prefix, local_metadata) -> None
The registered hooks can modify the ``state_dict`` inplace.
##### 4.1.55 register\_state\_dict\_pre\_hook
[TOC](#table-of-contents)
**Description**
Register a pre-hook for the :meth:`~torch.nn.Module.state_dict` method.
It should have the following signature::
hook(module, prefix, keep_vars) -> None
The registered hooks can be used to perform pre-processing before the ``state_dict``
call is made.
##### 4.1.56 requires\_grad\_
[TOC](#table-of-contents)
**Description**
Change if autograd should record operations on parameters in this module.
This method sets the parameters' :attr:`requires_grad` attributes
in-place.
This method is helpful for freezing part of the module for finetuning
or training parts of a model individually (e.g., GAN training).
See :ref:`locally-disable-grad-doc` for a comparison between
`.requires_grad_()` and several similar mechanisms that may be confused with it.
Args:
requires_grad (bool): whether autograd should record operations on
parameters in this module. Default: ``True``.
Returns:
Module: self
##### 4.1.57 set\_extra\_state
[TOC](#table-of-contents)
**Description**
Set extra state contained in the loaded `state_dict`.
This function is called from :func:`load_state_dict` to handle any extra state
found within the `state_dict`. Implement this function and a corresponding
:func:`get_extra_state` for your module if you need to store extra state within its
`state_dict`.
Args:
state (dict): Extra state from the `state_dict`
##### 4.1.58 set\_submodule
[TOC](#table-of-contents)
**Description**
Set the submodule given by ``target`` if it exists, otherwise throw an error.
For example, let's say you have an ``nn.Module`` ``A`` that
looks like this:
.. code-block:: text
A(
(net_b): Module(
(net_c): Module(
(conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2))
)
(linear): Linear(in_features=100, out_features=200, bias=True)
)
)
(The diagram shows an ``nn.Module`` ``A``. ``A`` has a nested
submodule ``net_b``, which itself has two submodules ``net_c``
and ``linear``. ``net_c`` then has a submodule ``conv``.)
To overide the ``Conv2d`` with a new submodule ``Linear``, you
would call
``set_submodule("net_b.net_c.conv", nn.Linear(33, 16))``.
Args:
target: The fully-qualified string name of the submodule
to look for. (See above example for how to specify a
fully-qualified string.)
module: The module to set the submodule to.
Raises:
ValueError: If the target string is empty
AttributeError: If the target string references an invalid
path or resolves to something that is not an
``nn.Module``
##### 4.1.59 share\_memory
[TOC](#table-of-contents)
**Description**
See :meth:`torch.Tensor.share_memory_`.
##### 4.1.60 state\_dict
[TOC](#table-of-contents)
**Description**
Return a dictionary containing references to the whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are
included. Keys are corresponding parameter and buffer names.
Parameters and buffers set to ``None`` are not included.
.. note::
The returned object is a shallow copy. It contains references
to the module's parameters and buffers.
.. warning::
Currently ``state_dict()`` also accepts positional arguments for
``destination``, ``prefix`` and ``keep_vars`` in order. However,
this is being deprecated and keyword arguments will be enforced in
future releases.
.. warning::
Please avoid the use of argument ``destination`` as it is not
designed for end-users.
Args:
destination (dict, optional): If provided, the state of module will
be updated into the dict and the same object is returned.
Otherwise, an ``OrderedDict`` will be created and returned.
Default: ``None``.
prefix (str, optional): a prefix added to parameter and buffer
names to compose the keys in state_dict. Default: ``''``.
keep_vars (bool, optional): by default the :class:`~torch.Tensor` s
returned in the state dict are detached from autograd. If it's
set to ``True``, detaching will not be performed.
Default: ``False``.
Returns:
dict:
a dictionary containing a whole state of the module
Example::
>>> # xdoctest: +SKIP("undefined vars")
>>> module.state_dict().keys()
['bias', 'weight']
##### 4.1.61 to
[TOC](#table-of-contents)
**Description**
Move and/or cast the parameters and buffers.
This can be called as
.. function:: to(device=None, dtype=None, non_blocking=False)
:noindex:
.. function:: to(dtype, non_blocking=False)
:noindex:
.. function:: to(tensor, non_blocking=False)
:noindex:
.. function:: to(memory_format=torch.channels_last)
:noindex:
Its signature is similar to :meth:`torch.Tensor.to`, but only accepts
floating point or complex :attr:`dtype`\ s. In addition, this method will
only cast the floating point or complex parameters and buffers to :attr:`dtype`
(if given). The integral parameters and buffers will be moved
:attr:`device`, if that is given, but with dtypes unchanged. When
:attr:`non_blocking` is set, it tries to convert/move asynchronously
with respect to the host if possible, e.g., moving CPU Tensors with
pinned memory to CUDA devices.
See below for examples.
.. note::
This method modifies the module in-place.
Args:
device (:class:`torch.device`): the desired device of the parameters
and buffers in this module
dtype (:class:`torch.dtype`): the desired floating point or complex dtype of
the parameters and buffers in this module
tensor (torch.Tensor): Tensor whose dtype and device are the desired
dtype and device for all parameters and buffers in this module
memory_format (:class:`torch.memory_format`): the desired memory
format for 4D parameters and buffers in this module (keyword
only argument)
Returns:
Module: self
Examples::
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
>>> linear = nn.Linear(2, 2)
>>> linear.weight
Parameter containing:
tensor([[ 0.1913, -0.3420],
[-0.5113, -0.2325]])
>>> linear.to(torch.double)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1913, -0.3420],
[-0.5113, -0.2325]], dtype=torch.float64)
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1)
>>> gpu1 = torch.device("cuda:1")
>>> linear.to(gpu1, dtype=torch.half, non_blocking=True)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1914, -0.3420],
[-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1')
>>> cpu = torch.device("cpu")
>>> linear.to(cpu)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1914, -0.3420],
[-0.5112, -0.2324]], dtype=torch.float16)
>>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble)
>>> linear.weight
Parameter containing:
tensor([[ 0.3741+0.j, 0.2382+0.j],
[ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128)
>>> linear(torch.ones(3, 2, dtype=torch.cdouble))
tensor([[0.6122+0.j, 0.1150+0.j],
[0.6122+0.j, 0.1150+0.j],
[0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128)
##### 4.1.62 to\_empty
[TOC](#table-of-contents)
**Description**
Move the parameters and buffers to the specified device without copying storage.
Args:
device (:class:`torch.device`): The desired device of the parameters
and buffers in this module.
recurse (bool): Whether parameters and buffers of submodules should
be recursively moved to the specified device.
Returns:
Module: self
##### 4.1.63 train
[TOC](#table-of-contents)
**Description**
Set the module in training mode.
This has an effect only on certain modules. See the documentation of
particular modules for details of their behaviors in training/evaluation
mode, i.e., whether they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
etc.
Args:
mode (bool): whether to set training mode (``True``) or evaluation
mode (``False``). Default: ``True``.
Returns:
Module: self
##### 4.1.64 type
[TOC](#table-of-contents)
**Description**
Casts all parameters and buffers to :attr:`dst_type`.
.. note::
This method modifies the module in-place.
Args:
dst_type (type or string): the desired type
Returns:
Module: self
##### 4.1.65 xpu
[TOC](#table-of-contents)
**Description**
Move all model parameters and buffers to the XPU.
This also makes associated parameters and buffers different objects. So
it should be called before constructing optimizer if the module will
live on XPU while being optimized.
.. note::
This method modifies the module in-place.
Arguments:
device (int, optional): if specified, all parameters will be
copied to that device
Returns:
Module: self
##### 4.1.66 zero\_grad
[TOC](#table-of-contents)
**Description**
Reset gradients of all model parameters.
See similar function under :class:`torch.optim.Optimizer` for more context.
Args:
set_to_none (bool): instead of setting to zero, set the grads to None.
See :meth:`torch.optim.Optimizer.zero_grad` for details.
## 4.2 SelfAttention : torch.nn.modules.module.Module
[TOC](#table-of-contents)
**Description**
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in
a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call :meth:`to`, etc.
.. note::
As per the example above, an ``__init__()`` call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or
evaluation mode.
:vartype training: bool
##### 4.2.1 \_wrapped\_call\_impl
[TOC](#table-of-contents)
### 4.2.2 \_\_init\_\_
[TOC](#table-of-contents)
**Description**
Initialize internal Module state, shared by both nn.Module and ScriptModule.
##### 4.2.3 \_apply
[TOC](#table-of-contents)
##### 4.2.4 \_call\_impl
[TOC](#table-of-contents)
##### 4.2.5 \_get\_backward\_hooks
[TOC](#table-of-contents)
**Description**
Return the backward hooks for use in the call function.
It returns two lists, one with the full backward hooks and one with the non-full
backward hooks.
##### 4.2.6 \_get\_backward\_pre\_hooks
[TOC](#table-of-contents)
##### 4.2.7 \_get\_name
[TOC](#table-of-contents)
##### 4.2.8 \_load\_from\_state\_dict
[TOC](#table-of-contents)
**Description**
Copy parameters and buffers from :attr:`state_dict` into only this module, but not its descendants.
This is called on every submodule
in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
module in input :attr:`state_dict` is provided as :attr:`local_metadata`.
For state dicts without metadata, :attr:`local_metadata` is empty.
Subclasses can achieve class-specific backward compatible loading using
the version number at `local_metadata.get("version", None)`.
Additionally, :attr:`local_metadata` can also contain the key
`assign_to_params_buffers` that indicates whether keys should be
assigned their corresponding tensor in the state_dict.
.. note::
:attr:`state_dict` is not the same object as the input
:attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So
it can be modified.
Args:
state_dict (dict): a dict containing parameters and
persistent buffers.
prefix (str): the prefix for parameters and buffers used in this
module
local_metadata (dict): a dict containing the metadata for this module.
See
strict (bool): whether to strictly enforce that the keys in
:attr:`state_dict` with :attr:`prefix` match the names of
parameters and buffers in this module
missing_keys (list of str): if ``strict=True``, add missing keys to
this list
unexpected_keys (list of str): if ``strict=True``, add unexpected
keys to this list
error_msgs (list of str): error messages should be added to this
list, and will be reported together in
:meth:`~torch.nn.Module.load_state_dict`
##### 4.2.9 \_maybe\_warn\_non\_full\_backward\_hook
[TOC](#table-of-contents)
##### 4.2.10 \_named\_members
[TOC](#table-of-contents)
**Description**
Help yield various names + members of modules.
##### 4.2.11 \_register\_load\_state\_dict\_pre\_hook
[TOC](#table-of-contents)
**Description**
See :meth:`~torch.nn.Module.register_load_state_dict_pre_hook` for details.
A subtle difference is that if ``with_module`` is set to ``False``, then the
hook will not take the ``module`` as the first argument whereas
:meth:`~torch.nn.Module.register_load_state_dict_pre_hook` always takes the
``module`` as the first argument.
Arguments:
hook (Callable): Callable hook that will be invoked before
loading the state dict.
with_module (bool, optional): Whether or not to pass the module
instance to the hook as the first parameter.
##### 4.2.12 \_register\_state\_dict\_hook
[TOC](#table-of-contents)
**Description**
Register a post-hook for the :meth:`~torch.nn.Module.state_dict` method.
It should have the following signature::
hook(module, state_dict, prefix, local_metadata) -> None or state_dict
The registered hooks can modify the ``state_dict`` inplace or return a new one.
If a new ``state_dict`` is returned, it will only be respected if it is the root
module that :meth:`~nn.Module.state_dict` is called from.
##### 4.2.13 \_replicate\_for\_data\_parallel
[TOC](#table-of-contents)
##### 4.2.14 \_save\_to\_state\_dict
[TOC](#table-of-contents)
**Description**
Save module state to the `destination` dictionary.
The `destination` dictionary will contain the state
of the module, but not its descendants. This is called on every
submodule in :meth:`~torch.nn.Module.state_dict`.
In rare cases, subclasses can achieve class-specific behavior by
overriding this method with custom logic.
Args:
destination (dict): a dict where state will be stored
prefix (str): the prefix for parameters and buffers used in this
module
##### 4.2.15 \_slow\_forward
[TOC](#table-of-contents)
##### 4.2.16 \_wrapped\_call\_impl
[TOC](#table-of-contents)
##### 4.2.17 add\_module
[TOC](#table-of-contents)
**Description**
Add a child module to the current module.
The module can be accessed as an attribute using the given name.
Args:
name (str): name of the child module. The child module can be
accessed from this module using the given name
module (Module): child module to be added to the module.
##### 4.2.18 apply
[TOC](#table-of-contents)
**Description**
Apply ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self.
Typical use includes initializing the parameters of a model
(see also :ref:`nn-init-doc`).
Args:
fn (:class:`Module` -> None): function to be applied to each submodule
Returns:
Module: self
Example::
>>> @torch.no_grad()
>>> def init_weights(m):
>>> print(m)
>>> if type(m) == nn.Linear:
>>> m.weight.fill_(1.0)
>>> print(m.weight)
>>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
>>> net.apply(init_weights)
Linear(in_features=2, out_features=2, bias=True)
Parameter containing:
tensor([[1., 1.],
[1., 1.]], requires_grad=True)
Linear(in_features=2, out_features=2, bias=True)
Parameter containing:
tensor([[1., 1.],
[1., 1.]], requires_grad=True)
Sequential(
(0): Linear(in_features=2, out_features=2, bias=True)
(1): Linear(in_features=2, out_features=2, bias=True)
)
##### 4.2.19 bfloat16
[TOC](#table-of-contents)
**Description**
Casts all floating point parameters and buffers to ``bfloat16`` datatype.
.. note::
This method modifies the module in-place.
Returns:
Module: self
##### 4.2.20 buffers
[TOC](#table-of-contents)
**Description**
Return an iterator over module buffers.
Args:
recurse (bool): if True, then yields buffers of this module
and all submodules. Otherwise, yields only buffers that
are direct members of this module.
Yields:
torch.Tensor: module buffer
Example::
>>> # xdoctest: +SKIP("undefined vars")
>>> for buf in model.buffers():
>>> print(type(buf), buf.size())
<class 'torch.Tensor'> (20L,)
<class 'torch.Tensor'> (20L, 1L, 5L, 5L)
##### 4.2.21 children
[TOC](#table-of-contents)
**Description**
Return an iterator over immediate children modules.
Yields:
Module: a child module
##### 4.2.22 compile
[TOC](#table-of-contents)
**Description**
Compile this Module's forward using :func:`torch.compile`.
This Module's `__call__` method is compiled and all arguments are passed as-is
to :func:`torch.compile`.
See :func:`torch.compile` for details on the arguments for this function.
##### 4.2.23 cpu
[TOC](#table-of-contents)
**Description**
Move all model parameters and buffers to the CPU.
.. note::
This method modifies the module in-place.
Returns:
Module: self
##### 4.2.24 cuda
[TOC](#table-of-contents)
**Description**
Move all model parameters and buffers to the GPU.
This also makes associated parameters and buffers different objects. So
it should be called before constructing the optimizer if the module will
live on GPU while being optimized.
.. note::
This method modifies the module in-place.
Args:
device (int, optional): if specified, all parameters will be
copied to that device
Returns:
Module: self
##### 4.2.25 double
[TOC](#table-of-contents)
**Description**
Casts all floating point parameters and buffers to ``double`` datatype.
.. note::
This method modifies the module in-place.
Returns:
Module: self
##### 4.2.26 eval
[TOC](#table-of-contents)
**Description**
Set the module in evaluation mode.
This has an effect only on certain modules. See the documentation of
particular modules for details of their behaviors in training/evaluation
mode, i.e. whether they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
etc.
This is equivalent with :meth:`self.train(False) <torch.nn.Module.train>`.
See :ref:`locally-disable-grad-doc` for a comparison between
`.eval()` and several similar mechanisms that may be confused with it.
Returns:
Module: self
##### 4.2.27 extra\_repr
[TOC](#table-of-contents)
**Description**
Return the extra representation of the module.
To print customized extra information, you should re-implement
this method in your own modules. Both single-line and multi-line
strings are acceptable.
##### 4.2.28 float
[TOC](#table-of-contents)
**Description**
Casts all floating point parameters and buffers to ``float`` datatype.
.. note::
This method modifies the module in-place.
Returns:
Module: self
### 4.2.29 forward
[TOC](#table-of-contents)
**Description**
Define the computation performed at every call.
Should be overridden by all subclasses.
.. note::
Although the recipe for forward pass needs to be defined within
this function, one should call the :class:`Module` instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
##### 4.2.30 get\_buffer
[TOC](#table-of-contents)
**Description**
Return the buffer given by ``target`` if it exists, otherwise throw an error.
See the docstring for ``get_submodule`` for a more detailed
explanation of this method's functionality as well as how to
correctly specify ``target``.
Args:
target: The fully-qualified string name of the buffer
to look for. (See ``get_submodule`` for how to specify a
fully-qualified string.)
Returns:
torch.Tensor: The buffer referenced by ``target``
Raises:
AttributeError: If the target string references an invalid
path or resolves to something that is not a
buffer
##### 4.2.31 get\_extra\_state
[TOC](#table-of-contents)
**Description**
Return any extra state to include in the module's state_dict.
Implement this and a corresponding :func:`set_extra_state` for your module
if you need to store extra state. This function is called when building the
module's `state_dict()`.
Note that extra state should be picklable to ensure working serialization
of the state_dict. We only provide backwards compatibility guarantees
for serializing Tensors; other objects may break backwards compatibility if
their serialized pickled form changes.
Returns:
object: Any extra state to store in the module's state_dict
##### 4.2.32 get\_parameter
[TOC](#table-of-contents)
**Description**
Return the parameter given by ``target`` if it exists, otherwise throw an error.
See the docstring for ``get_submodule`` for a more detailed
explanation of this method's functionality as well as how to
correctly specify ``target``.
Args:
target: The fully-qualified string name of the Parameter
to look for. (See ``get_submodule`` for how to specify a
fully-qualified string.)
Returns:
torch.nn.Parameter: The Parameter referenced by ``target``
Raises:
AttributeError: If the target string references an invalid
path or resolves to something that is not an
``nn.Parameter``
##### 4.2.33 get\_submodule
[TOC](#table-of-contents)
**Description**
Return the submodule given by ``target`` if it exists, otherwise throw an error.
For example, let's say you have an ``nn.Module`` ``A`` that
looks like this:
.. code-block:: text
A(
(net_b): Module(
(net_c): Module(
(conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2))
)
(linear): Linear(in_features=100, out_features=200, bias=True)
)
)
(The diagram shows an ``nn.Module`` ``A``. ``A`` which has a nested
submodule ``net_b``, which itself has two submodules ``net_c``
and ``linear``. ``net_c`` then has a submodule ``conv``.)
To check whether or not we have the ``linear`` submodule, we
would call ``get_submodule("net_b.linear")``. To check whether
we have the ``conv`` submodule, we would call
``get_submodule("net_b.net_c.conv")``.
The runtime of ``get_submodule`` is bounded by the degree
of module nesting in ``target``. A query against
``named_modules`` achieves the same result, but it is O(N) in
the number of transitive modules. So, for a simple check to see
if some submodule exists, ``get_submodule`` should always be
used.
Args:
target: The fully-qualified string name of the submodule
to look for. (See above example for how to specify a
fully-qualified string.)
Returns:
torch.nn.Module: The submodule referenced by ``target``
Raises:
AttributeError: If the target string references an invalid
path or resolves to something that is not an
``nn.Module``
##### 4.2.34 half
[TOC](#table-of-contents)
**Description**
Casts all floating point parameters and buffers to ``half`` datatype.
.. note::
This method modifies the module in-place.
Returns:
Module: self
##### 4.2.35 ipu
[TOC](#table-of-contents)
**Description**
Move all model parameters and buffers to the IPU.
This also makes associated parameters and buffers different objects. So
it should be called before constructing the optimizer if the module will
live on IPU while being optimized.
.. note::
This method modifies the module in-place.
Arguments:
device (int, optional): if specified, all parameters will be
copied to that device
Returns:
Module: self
##### 4.2.36 load\_state\_dict
[TOC](#table-of-contents)
**Description**
Copy parameters and buffers from :attr:`state_dict` into this module and its descendants.
If :attr:`strict` is ``True``, then
the keys of :attr:`state_dict` must exactly match the keys returned
by this module's :meth:`~torch.nn.Module.state_dict` function.
.. warning::
If :attr:`assign` is ``True`` the optimizer must be created after
the call to :attr:`load_state_dict` unless
:func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``.
Args:
state_dict (dict): a dict containing parameters and
persistent buffers.
strict (bool, optional): whether to strictly enforce that the keys
in :attr:`state_dict` match the keys returned by this module's
:meth:`~torch.nn.Module.state_dict` function. Default: ``True``
assign (bool, optional): When set to ``False``, the properties of the tensors
in the current module are preserved whereas setting it to ``True`` preserves
properties of the Tensors in the state dict. The only
exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s
for which the value from the module is preserved.
Default: ``False``
Returns:
``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
* **missing_keys** is a list of str containing any keys that are expected
by this module but missing from the provided ``state_dict``.
* **unexpected_keys** is a list of str containing the keys that are not
expected by this module but present in the provided ``state_dict``.
Note:
If a parameter or buffer is registered as ``None`` and its corresponding key
exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a
``RuntimeError``.
##### 4.2.37 modules
[TOC](#table-of-contents)
**Description**
Return an iterator over all modules in the network.
Yields:
Module: a module in the network
Note:
Duplicate modules are returned only once. In the following
example, ``l`` will be returned only once.
Example::
>>> l = nn.Linear(2, 2)
>>> net = nn.Sequential(l, l)
>>> for idx, m in enumerate(net.modules()):
... print(idx, '->', m)
0 -> Sequential(
(0): Linear(in_features=2, out_features=2, bias=True)
(1): Linear(in_features=2, out_features=2, bias=True)
)
1 -> Linear(in_features=2, out_features=2, bias=True)
##### 4.2.38 mtia
[TOC](#table-of-contents)
**Description**
Move all model parameters and buffers to the MTIA.
This also makes associated parameters and buffers different objects. So
it should be called before constructing the optimizer if the module will
live on MTIA while being optimized.
.. note::
This method modifies the module in-place.
Arguments:
device (int, optional): if specified, all parameters will be
copied to that device
Returns:
Module: self
##### 4.2.39 named\_buffers
[TOC](#table-of-contents)
**Description**
Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
Args:
prefix (str): prefix to prepend to all buffer names.
recurse (bool, optional): if True, then yields buffers of this module
and all submodules. Otherwise, yields only buffers that
are direct members of this module. Defaults to True.
remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Defaults to True.
Yields:
(str, torch.Tensor): Tuple containing the name and buffer
Example::
>>> # xdoctest: +SKIP("undefined vars")
>>> for name, buf in self.named_buffers():
>>> if name in ['running_var']:
>>> print(buf.size())
##### 4.2.40 named\_children
[TOC](#table-of-contents)
**Description**
Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.
Yields:
(str, Module): Tuple containing a name and child module
Example::
>>> # xdoctest: +SKIP("undefined vars")
>>> for name, module in model.named_children():
>>> if name in ['conv4', 'conv5']:
>>> print(module)
##### 4.2.41 named\_modules
[TOC](#table-of-contents)
**Description**
Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.
Args:
memo: a memo to store the set of modules already added to the result
prefix: a prefix that will be added to the name of the module
remove_duplicate: whether to remove the duplicated module instances in the result
or not
Yields:
(str, Module): Tuple of name and module
Note:
Duplicate modules are returned only once. In the following
example, ``l`` will be returned only once.
Example::
>>> l = nn.Linear(2, 2)
>>> net = nn.Sequential(l, l)
>>> for idx, m in enumerate(net.named_modules()):
... print(idx, '->', m)
0 -> ('', Sequential(
(0): Linear(in_features=2, out_features=2, bias=True)
(1): Linear(in_features=2, out_features=2, bias=True)
))
1 -> ('0', Linear(in_features=2, out_features=2, bias=True))
##### 4.2.42 named\_parameters
[TOC](#table-of-contents)
**Description**
Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
Args:
prefix (str): prefix to prepend to all parameter names.
recurse (bool): if True, then yields parameters of this module
and all submodules. Otherwise, yields only parameters that
are direct members of this module.
remove_duplicate (bool, optional): whether to remove the duplicated
parameters in the result. Defaults to True.
Yields:
(str, Parameter): Tuple containing the name and parameter
Example::
>>> # xdoctest: +SKIP("undefined vars")
>>> for name, param in self.named_parameters():
>>> if name in ['bias']:
>>> print(param.size())
##### 4.2.43 parameters
[TOC](#table-of-contents)
**Description**
Return an iterator over module parameters.
This is typically passed to an optimizer.
Args:
recurse (bool): if True, then yields parameters of this module
and all submodules. Otherwise, yields only parameters that
are direct members of this module.
Yields:
Parameter: module parameter
Example::
>>> # xdoctest: +SKIP("undefined vars")
>>> for param in model.parameters():
>>> print(type(param), param.size())
<class 'torch.Tensor'> (20L,)
<class 'torch.Tensor'> (20L, 1L, 5L, 5L)
##### 4.2.44 register\_backward\_hook
[TOC](#table-of-contents)
**Description**
Register a backward hook on the module.
This function is deprecated in favor of :meth:`~torch.nn.Module.register_full_backward_hook` and
the behavior of this function will change in future versions.
Returns:
:class:`torch.utils.hooks.RemovableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
##### 4.2.45 register\_buffer
[TOC](#table-of-contents)
**Description**
Add a buffer to the module.
This is typically used to register a buffer that should not to be
considered a model parameter. For example, BatchNorm's ``running_mean``
is not a parameter, but is part of the module's state. Buffers, by
default, are persistent and will be saved alongside parameters. This
behavior can be changed by setting :attr:`persistent` to ``False``. The
only difference between a persistent buffer and a non-persistent buffer
is that the latter will not be a part of this module's
:attr:`state_dict`.
Buffers can be accessed as attributes using given names.
Args:
name (str): name of the buffer. The buffer can be accessed
from this module using the given name
tensor (Tensor or None): buffer to be registered. If ``None``, then operations
that run on buffers, such as :attr:`cuda`, are ignored. If ``None``,
the buffer is **not** included in the module's :attr:`state_dict`.
persistent (bool): whether the buffer is part of this module's
:attr:`state_dict`.
Example::
>>> # xdoctest: +SKIP("undefined vars")
>>> self.register_buffer('running_mean', torch.zeros(num_features))
##### 4.2.46 register\_forward\_hook
[TOC](#table-of-contents)
**Description**
Register a forward hook on the module.
The hook will be called every time after :func:`forward` has computed an output.
If ``with_kwargs`` is ``False`` or not specified, the input contains only
the positional arguments given to the module. Keyword arguments won't be
passed to the hooks and only to the ``forward``. The hook can modify the
output. It can modify the input inplace but it will not have effect on
forward since this is called after :func:`forward` is called. The hook
should have the following signature::
hook(module, args, output) -> None or modified output
If ``with_kwargs`` is ``True``, the forward hook will be passed the
``kwargs`` given to the forward function and be expected to return the
output possibly modified. The hook should have the following signature::
hook(module, args, kwargs, output) -> None or modified output
Args:
hook (Callable): The user defined hook to be registered.
prepend (bool): If ``True``, the provided ``hook`` will be fired
before all existing ``forward`` hooks on this
:class:`torch.nn.modules.Module`. Otherwise, the provided
``hook`` will be fired after all existing ``forward`` hooks on
this :class:`torch.nn.modules.Module`. Note that global
``forward`` hooks registered with
:func:`register_module_forward_hook` will fire before all hooks
registered by this method.
Default: ``False``
with_kwargs (bool): If ``True``, the ``hook`` will be passed the
kwargs given to the forward function.
Default: ``False``
always_call (bool): If ``True`` the ``hook`` will be run regardless of
whether an exception is raised while calling the Module.
Default: ``False``
Returns:
:class:`torch.utils.hooks.RemovableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
##### 4.2.47 register\_forward\_pre\_hook
[TOC](#table-of-contents)
**Description**
Register a forward pre-hook on the module.
The hook will be called every time before :func:`forward` is invoked.
If ``with_kwargs`` is false or not specified, the input contains only
the positional arguments given to the module. Keyword arguments won't be
passed to the hooks and only to the ``forward``. The hook can modify the
input. User can either return a tuple or a single modified value in the
hook. We will wrap the value into a tuple if a single value is returned
(unless that value is already a tuple). The hook should have the
following signature::
hook(module, args) -> None or modified input
If ``with_kwargs`` is true, the forward pre-hook will be passed the
kwargs given to the forward function. And if the hook modifies the
input, both the args and kwargs should be returned. The hook should have
the following signature::
hook(module, args, kwargs) -> None or a tuple of modified input and kwargs
Args:
hook (Callable): The user defined hook to be registered.
prepend (bool): If true, the provided ``hook`` will be fired before
all existing ``forward_pre`` hooks on this
:class:`torch.nn.modules.Module`. Otherwise, the provided
``hook`` will be fired after all existing ``forward_pre`` hooks
on this :class:`torch.nn.modules.Module`. Note that global
``forward_pre`` hooks registered with
:func:`register_module_forward_pre_hook` will fire before all
hooks registered by this method.
Default: ``False``
with_kwargs (bool): If true, the ``hook`` will be passed the kwargs
given to the forward function.
Default: ``False``
Returns:
:class:`torch.utils.hooks.RemovableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
##### 4.2.48 register\_full\_backward\_hook
[TOC](#table-of-contents)
**Description**
Register a backward hook on the module.
The hook will be called every time the gradients with respect to a module
are computed, i.e. the hook will execute if and only if the gradients with
respect to module outputs are computed. The hook should have the following
signature::
hook(module, grad_input, grad_output) -> tuple(Tensor) or None
The :attr:`grad_input` and :attr:`grad_output` are tuples that contain the gradients
with respect to the inputs and outputs respectively. The hook should
not modify its arguments, but it can optionally return a new gradient with
respect to the input that will be used in place of :attr:`grad_input` in
subsequent computations. :attr:`grad_input` will only correspond to the inputs given
as positional arguments and all kwarg arguments are ignored. Entries
in :attr:`grad_input` and :attr:`grad_output` will be ``None`` for all non-Tensor
arguments.
For technical reasons, when this hook is applied to a Module, its forward function will
receive a view of each Tensor passed to the Module. Similarly the caller will receive a view
of each Tensor returned by the Module's forward function.
.. warning ::
Modifying inputs or outputs inplace is not allowed when using backward hooks and
will raise an error.
Args:
hook (Callable): The user-defined hook to be registered.
prepend (bool): If true, the provided ``hook`` will be fired before
all existing ``backward`` hooks on this
:class:`torch.nn.modules.Module`. Otherwise, the provided
``hook`` will be fired after all existing ``backward`` hooks on
this :class:`torch.nn.modules.Module`. Note that global
``backward`` hooks registered with
:func:`register_module_full_backward_hook` will fire before
all hooks registered by this method.
Returns:
:class:`torch.utils.hooks.RemovableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
##### 4.2.49 register\_full\_backward\_pre\_hook
[TOC](#table-of-contents)
**Description**
Register a backward pre-hook on the module.
The hook will be called every time the gradients for the module are computed.
The hook should have the following signature::
hook(module, grad_output) -> tuple[Tensor] or None
The :attr:`grad_output` is a tuple. The hook should
not modify its arguments, but it can optionally return a new gradient with
respect to the output that will be used in place of :attr:`grad_output` in
subsequent computations. Entries in :attr:`grad_output` will be ``None`` for
all non-Tensor arguments.
For technical reasons, when this hook is applied to a Module, its forward function will
receive a view of each Tensor passed to the Module. Similarly the caller will receive a view
of each Tensor returned by the Module's forward function.
.. warning ::
Modifying inputs inplace is not allowed when using backward hooks and
will raise an error.
Args:
hook (Callable): The user-defined hook to be registered.
prepend (bool): If true, the provided ``hook`` will be fired before
all existing ``backward_pre`` hooks on this
:class:`torch.nn.modules.Module`. Otherwise, the provided
``hook`` will be fired after all existing ``backward_pre`` hooks
on this :class:`torch.nn.modules.Module`. Note that global
``backward_pre`` hooks registered with
:func:`register_module_full_backward_pre_hook` will fire before
all hooks registered by this method.
Returns:
:class:`torch.utils.hooks.RemovableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
##### 4.2.50 register\_load\_state\_dict\_post\_hook
[TOC](#table-of-contents)
**Description**
Register a post-hook to be run after module's :meth:`~nn.Module.load_state_dict` is called.
It should have the following signature::
hook(module, incompatible_keys) -> None
The ``module`` argument is the current module that this hook is registered
on, and the ``incompatible_keys`` argument is a ``NamedTuple`` consisting
of attributes ``missing_keys`` and ``unexpected_keys``. ``missing_keys``
is a ``list`` of ``str`` containing the missing keys and
``unexpected_keys`` is a ``list`` of ``str`` containing the unexpected keys.
The given incompatible_keys can be modified inplace if needed.
Note that the checks performed when calling :func:`load_state_dict` with
``strict=True`` are affected by modifications the hook makes to
``missing_keys`` or ``unexpected_keys``, as expected. Additions to either
set of keys will result in an error being thrown when ``strict=True``, and
clearing out both missing and unexpected keys will avoid an error.
Returns:
:class:`torch.utils.hooks.RemovableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
##### 4.2.51 register\_load\_state\_dict\_pre\_hook
[TOC](#table-of-contents)
**Description**
Register a pre-hook to be run before module's :meth:`~nn.Module.load_state_dict` is called.
It should have the following signature::
hook(module, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -> None # noqa: B950
Arguments:
hook (Callable): Callable hook that will be invoked before
loading the state dict.
##### 4.2.52 register\_module
[TOC](#table-of-contents)
**Description**
Alias for :func:`add_module`.
##### 4.2.53 register\_parameter
[TOC](#table-of-contents)
**Description**
Add a parameter to the module.
The parameter can be accessed as an attribute using given name.
Args:
name (str): name of the parameter. The parameter can be accessed
from this module using the given name
param (Parameter or None): parameter to be added to the module. If
``None``, then operations that run on parameters, such as :attr:`cuda`,
are ignored. If ``None``, the parameter is **not** included in the
module's :attr:`state_dict`.
##### 4.2.54 register\_state\_dict\_post\_hook
[TOC](#table-of-contents)
**Description**
Register a post-hook for the :meth:`~torch.nn.Module.state_dict` method.
It should have the following signature::
hook(module, state_dict, prefix, local_metadata) -> None
The registered hooks can modify the ``state_dict`` inplace.
##### 4.2.55 register\_state\_dict\_pre\_hook
[TOC](#table-of-contents)
**Description**
Register a pre-hook for the :meth:`~torch.nn.Module.state_dict` method.
It should have the following signature::
hook(module, prefix, keep_vars) -> None
The registered hooks can be used to perform pre-processing before the ``state_dict``
call is made.
##### 4.2.56 requires\_grad\_
[TOC](#table-of-contents)
**Description**
Change if autograd should record operations on parameters in this module.
This method sets the parameters' :attr:`requires_grad` attributes
in-place.
This method is helpful for freezing part of the module for finetuning
or training parts of a model individually (e.g., GAN training).
See :ref:`locally-disable-grad-doc` for a comparison between
`.requires_grad_()` and several similar mechanisms that may be confused with it.
Args:
requires_grad (bool): whether autograd should record operations on
parameters in this module. Default: ``True``.
Returns:
Module: self
##### 4.2.57 set\_extra\_state
[TOC](#table-of-contents)
**Description**
Set extra state contained in the loaded `state_dict`.
This function is called from :func:`load_state_dict` to handle any extra state
found within the `state_dict`. Implement this function and a corresponding
:func:`get_extra_state` for your module if you need to store extra state within its
`state_dict`.
Args:
state (dict): Extra state from the `state_dict`
##### 4.2.58 set\_submodule
[TOC](#table-of-contents)
**Description**
Set the submodule given by ``target`` if it exists, otherwise throw an error.
For example, let's say you have an ``nn.Module`` ``A`` that
looks like this:
.. code-block:: text
A(
(net_b): Module(
(net_c): Module(
(conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2))
)
(linear): Linear(in_features=100, out_features=200, bias=True)
)
)
(The diagram shows an ``nn.Module`` ``A``. ``A`` has a nested
submodule ``net_b``, which itself has two submodules ``net_c``
and ``linear``. ``net_c`` then has a submodule ``conv``.)
To overide the ``Conv2d`` with a new submodule ``Linear``, you
would call
``set_submodule("net_b.net_c.conv", nn.Linear(33, 16))``.
Args:
target: The fully-qualified string name of the submodule
to look for. (See above example for how to specify a
fully-qualified string.)
module: The module to set the submodule to.
Raises:
ValueError: If the target string is empty
AttributeError: If the target string references an invalid
path or resolves to something that is not an
``nn.Module``
##### 4.2.59 share\_memory
[TOC](#table-of-contents)
**Description**
See :meth:`torch.Tensor.share_memory_`.
##### 4.2.60 state\_dict
[TOC](#table-of-contents)
**Description**
Return a dictionary containing references to the whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are
included. Keys are corresponding parameter and buffer names.
Parameters and buffers set to ``None`` are not included.
.. note::
The returned object is a shallow copy. It contains references
to the module's parameters and buffers.
.. warning::
Currently ``state_dict()`` also accepts positional arguments for
``destination``, ``prefix`` and ``keep_vars`` in order. However,
this is being deprecated and keyword arguments will be enforced in
future releases.
.. warning::
Please avoid the use of argument ``destination`` as it is not
designed for end-users.
Args:
destination (dict, optional): If provided, the state of module will
be updated into the dict and the same object is returned.
Otherwise, an ``OrderedDict`` will be created and returned.
Default: ``None``.
prefix (str, optional): a prefix added to parameter and buffer
names to compose the keys in state_dict. Default: ``''``.
keep_vars (bool, optional): by default the :class:`~torch.Tensor` s
returned in the state dict are detached from autograd. If it's
set to ``True``, detaching will not be performed.
Default: ``False``.
Returns:
dict:
a dictionary containing a whole state of the module
Example::
>>> # xdoctest: +SKIP("undefined vars")
>>> module.state_dict().keys()
['bias', 'weight']
##### 4.2.61 to
[TOC](#table-of-contents)
**Description**
Move and/or cast the parameters and buffers.
This can be called as
.. function:: to(device=None, dtype=None, non_blocking=False)
:noindex:
.. function:: to(dtype, non_blocking=False)
:noindex:
.. function:: to(tensor, non_blocking=False)
:noindex:
.. function:: to(memory_format=torch.channels_last)
:noindex:
Its signature is similar to :meth:`torch.Tensor.to`, but only accepts
floating point or complex :attr:`dtype`\ s. In addition, this method will
only cast the floating point or complex parameters and buffers to :attr:`dtype`
(if given). The integral parameters and buffers will be moved
:attr:`device`, if that is given, but with dtypes unchanged. When
:attr:`non_blocking` is set, it tries to convert/move asynchronously
with respect to the host if possible, e.g., moving CPU Tensors with
pinned memory to CUDA devices.
See below for examples.
.. note::
This method modifies the module in-place.
Args:
device (:class:`torch.device`): the desired device of the parameters
and buffers in this module
dtype (:class:`torch.dtype`): the desired floating point or complex dtype of
the parameters and buffers in this module
tensor (torch.Tensor): Tensor whose dtype and device are the desired
dtype and device for all parameters and buffers in this module
memory_format (:class:`torch.memory_format`): the desired memory
format for 4D parameters and buffers in this module (keyword
only argument)
Returns:
Module: self
Examples::
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
>>> linear = nn.Linear(2, 2)
>>> linear.weight
Parameter containing:
tensor([[ 0.1913, -0.3420],
[-0.5113, -0.2325]])
>>> linear.to(torch.double)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1913, -0.3420],
[-0.5113, -0.2325]], dtype=torch.float64)
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1)
>>> gpu1 = torch.device("cuda:1")
>>> linear.to(gpu1, dtype=torch.half, non_blocking=True)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1914, -0.3420],
[-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1')
>>> cpu = torch.device("cpu")
>>> linear.to(cpu)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1914, -0.3420],
[-0.5112, -0.2324]], dtype=torch.float16)
>>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble)
>>> linear.weight
Parameter containing:
tensor([[ 0.3741+0.j, 0.2382+0.j],
[ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128)
>>> linear(torch.ones(3, 2, dtype=torch.cdouble))
tensor([[0.6122+0.j, 0.1150+0.j],
[0.6122+0.j, 0.1150+0.j],
[0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128)
##### 4.2.62 to\_empty
[TOC](#table-of-contents)
**Description**
Move the parameters and buffers to the specified device without copying storage.
Args:
device (:class:`torch.device`): The desired device of the parameters
and buffers in this module.
recurse (bool): Whether parameters and buffers of submodules should
be recursively moved to the specified device.
Returns:
Module: self
##### 4.2.63 train
[TOC](#table-of-contents)
**Description**
Set the module in training mode.
This has an effect only on certain modules. See the documentation of
particular modules for details of their behaviors in training/evaluation
mode, i.e., whether they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
etc.
Args:
mode (bool): whether to set training mode (``True``) or evaluation
mode (``False``). Default: ``True``.
Returns:
Module: self
##### 4.2.64 type
[TOC](#table-of-contents)
**Description**
Casts all parameters and buffers to :attr:`dst_type`.
.. note::
This method modifies the module in-place.
Args:
dst_type (type or string): the desired type
Returns:
Module: self
##### 4.2.65 xpu
[TOC](#table-of-contents)
**Description**
Move all model parameters and buffers to the XPU.
This also makes associated parameters and buffers different objects. So
it should be called before constructing optimizer if the module will
live on XPU while being optimized.
.. note::
This method modifies the module in-place.
Arguments:
device (int, optional): if specified, all parameters will be
copied to that device
Returns:
Module: self
##### 4.2.66 zero\_grad
[TOC](#table-of-contents)
**Description**
Reset gradients of all model parameters.
See similar function under :class:`torch.optim.Optimizer` for more context.
Args:
set_to_none (bool): instead of setting to zero, set the grads to None.
See :meth:`torch.optim.Optimizer.zero_grad` for details.
# 5 multi\_transforms
[TOC](#table-of-contents)
The module `rsp.ml.multi_transforms` is based on `torchvision.transforms`, which is made for single images. `rsp.ml.multi_transforms` extends this functionality by providing transformations for sequences of images, which could be usefull for video augmentation.
## 5.1 BGR2GRAY : MultiTransform
[TOC](#table-of-contents)
**Description**
Converts a sequence of BGR images to grayscale images.
### 5.1.1 \_\_call\_\_
[TOC](#table-of-contents)
**Description**
Call self as a function.
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| input | torch.Tensor<br>List[PIL.Image]<br>List[numpy.array] | Sequence of images |
### 5.1.2 \_\_init\_\_
[TOC](#table-of-contents)
**Description**
Initializes a new instance.
## 5.2 BGR2RGB : MultiTransform
[TOC](#table-of-contents)
**Description**
Converts sequence of BGR images to RGB images.
### 5.2.1 \_\_call\_\_
[TOC](#table-of-contents)
**Description**
Call self as a function.
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| input | torch.Tensor<br>List[PIL.Image]<br>List[numpy.array] | Sequence of images |
### 5.2.2 \_\_init\_\_
[TOC](#table-of-contents)
**Description**
Initializes a new instance.
## 5.3 Brightness : MultiTransform
[TOC](#table-of-contents)
**Description**
MultiTransform is an extension to keep the same transformation over a sequence of images instead of initializing a new transformation for every single image. It is inspired by `torchvision.transforms` and could be used for video augmentation. Use `rsp.ml.multi_transforms.Compose`to combine multiple image sequence transformations.
> **Note** `rsp.ml.multi_transforms.MultiTransform` is a base class and should be inherited.
### 5.3.1 \_\_call\_\_
[TOC](#table-of-contents)
**Description**
Call self as a function.
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| input | torch.Tensor<br>List[PIL.Image]<br>List[numpy.array] | Sequence of images |
### 5.3.2 \_\_init\_\_
[TOC](#table-of-contents)
**Description**
Initializes a new instance.
## 5.4 CenterCrop : MultiTransform
[TOC](#table-of-contents)
**Description**
Crops Images at the center after upscaling them. Dimensions kept the same.

### 5.4.1 \_\_call\_\_
[TOC](#table-of-contents)
**Description**
Call self as a function.
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| input | torch.Tensor<br>List[PIL.Image]<br>List[numpy.array] | Sequence of images |
### 5.4.2 \_\_init\_\_
[TOC](#table-of-contents)
**Description**
Initializes a new instance.
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| max_scale | float | Images are scaled randomly between 1. and max_scale before cropping to original size. |
## 5.5 Color : MultiTransform
[TOC](#table-of-contents)
**Description**
MultiTransform is an extension to keep the same transformation over a sequence of images instead of initializing a new transformation for every single image. It is inspired by `torchvision.transforms` and could be used for video augmentation. Use `rsp.ml.multi_transforms.Compose`to combine multiple image sequence transformations.
> **Note** `rsp.ml.multi_transforms.MultiTransform` is a base class and should be inherited.
### 5.5.1 \_\_call\_\_
[TOC](#table-of-contents)
**Description**
Call self as a function.
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| input | torch.Tensor<br>List[PIL.Image]<br>List[numpy.array] | Sequence of images |
### 5.5.2 \_\_init\_\_
[TOC](#table-of-contents)
**Description**
Initializes a new instance.
## 5.6 Compose : builtins.object
[TOC](#table-of-contents)
**Description**
Composes several MultiTransforms together.
**Example**
```python
import rsp.ml.multi_transforms as t
transforms = t.Compose([
t.BGR2GRAY(),
t.Scale(0.5)
])
```
### 5.6.1 \_\_call\_\_
[TOC](#table-of-contents)
**Description**
Call self as a function.
### 5.6.2 \_\_init\_\_
[TOC](#table-of-contents)
**Description**
Initializes a new instance.
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| children | List[MultiTransform] | List of MultiTransforms to compose. |
## 5.7 GaussianNoise : MultiTransform
[TOC](#table-of-contents)
**Description**
MultiTransform is an extension to keep the same transformation over a sequence of images instead of initializing a new transformation for every single image. It is inspired by `torchvision.transforms` and could be used for video augmentation. Use `rsp.ml.multi_transforms.Compose`to combine multiple image sequence transformations.
> **Note** `rsp.ml.multi_transforms.MultiTransform` is a base class and should be inherited.
### 5.7.1 \_\_call\_\_
[TOC](#table-of-contents)
**Description**
Call self as a function.
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| input | torch.Tensor<br>List[PIL.Image]<br>List[numpy.array] | Sequence of images |
### 5.7.2 \_\_init\_\_
[TOC](#table-of-contents)
**Description**
Initializes a new instance.
## 5.8 MultiTransform : builtins.object
[TOC](#table-of-contents)
**Description**
MultiTransform is an extension to keep the same transformation over a sequence of images instead of initializing a new transformation for every single image. It is inspired by `torchvision.transforms` and could be used for video augmentation. Use `rsp.ml.multi_transforms.Compose`to combine multiple image sequence transformations.
> **Note** `rsp.ml.multi_transforms.MultiTransform` is a base class and should be inherited.
### 5.8.1 \_\_call\_\_
[TOC](#table-of-contents)
**Description**
Call self as a function.
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| input | torch.Tensor<br>List[PIL.Image]<br>List[numpy.array] | Sequence of images |
### 5.8.2 \_\_init\_\_
[TOC](#table-of-contents)
**Description**
Initializes a new instance.
## 5.9 Normalize : MultiTransform
[TOC](#table-of-contents)
**Description**
Normalize images with mean and standard deviation. Given mean: (mean[1],...,mean[n]) and std: (std[1],..,std[n]) for n channels, this transform will normalize each channel of the input torch.*Tensor i.e., output[channel] = (input[channel] - mean[channel]) / std[channel]
> Based on torchvision.transforms.Normalize
### 5.9.1 \_\_call\_\_
[TOC](#table-of-contents)
**Description**
Call self as a function.
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| input | torch.Tensor<br>List[PIL.Image]<br>List[numpy.array] | Sequence of images |
### 5.9.2 \_\_init\_\_
[TOC](#table-of-contents)
**Description**
Initializes a new instance.
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| mean | List[float] | Sequence of means for each channel. |
| std | List[float] | Sequence of standard deviations for each channel. |
| inplace | bool | Set to True make this operation in-place. |
## 5.10 RGB2BGR : BGR2RGB
[TOC](#table-of-contents)
**Description**
Converts sequence of RGB images to BGR images.
### 5.10.1 \_\_call\_\_
[TOC](#table-of-contents)
**Description**
Call self as a function.
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| input | torch.Tensor<br>List[PIL.Image]<br>List[numpy.array] | Sequence of images |
### 5.10.2 \_\_init\_\_
[TOC](#table-of-contents)
**Description**
Initializes a new instance.
## 5.11 RandomCrop : MultiTransform
[TOC](#table-of-contents)
**Description**
Crops Images at a random location after upscaling them. Dimensions kept the same.

### 5.11.1 \_\_call\_\_
[TOC](#table-of-contents)
**Description**
Call self as a function.
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| input | torch.Tensor<br>List[PIL.Image]<br>List[numpy.array] | Sequence of images |
### 5.11.2 \_\_init\_\_
[TOC](#table-of-contents)
**Description**
Initializes a new instance.
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| max_scale | float | Images are scaled randomly between 1. and max_scale before cropping to original size. |
## 5.12 RandomHorizontalFlip : MultiTransform
[TOC](#table-of-contents)
**Description**
MultiTransform is an extension to keep the same transformation over a sequence of images instead of initializing a new transformation for every single image. It is inspired by `torchvision.transforms` and could be used for video augmentation. Use `rsp.ml.multi_transforms.Compose`to combine multiple image sequence transformations.
> **Note** `rsp.ml.multi_transforms.MultiTransform` is a base class and should be inherited.
### 5.12.1 \_\_call\_\_
[TOC](#table-of-contents)
**Description**
Call self as a function.
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| input | torch.Tensor<br>List[PIL.Image]<br>List[numpy.array] | Sequence of images |
### 5.12.2 \_\_init\_\_
[TOC](#table-of-contents)
**Description**
Initializes a new instance.
## 5.13 RandomVerticalFlip : MultiTransform
[TOC](#table-of-contents)
**Description**
MultiTransform is an extension to keep the same transformation over a sequence of images instead of initializing a new transformation for every single image. It is inspired by `torchvision.transforms` and could be used for video augmentation. Use `rsp.ml.multi_transforms.Compose`to combine multiple image sequence transformations.
> **Note** `rsp.ml.multi_transforms.MultiTransform` is a base class and should be inherited.
### 5.13.1 \_\_call\_\_
[TOC](#table-of-contents)
**Description**
Call self as a function.
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| input | torch.Tensor<br>List[PIL.Image]<br>List[numpy.array] | Sequence of images |
### 5.13.2 \_\_init\_\_
[TOC](#table-of-contents)
**Description**
Initializes a new instance.
## 5.14 RemoveBackgroundAI : MultiTransform
[TOC](#table-of-contents)
**Description**
MultiTransform is an extension to keep the same transformation over a sequence of images instead of initializing a new transformation for every single image. It is inspired by `torchvision.transforms` and could be used for video augmentation. Use `rsp.ml.multi_transforms.Compose`to combine multiple image sequence transformations.
> **Note** `rsp.ml.multi_transforms.MultiTransform` is a base class and should be inherited.
### 5.14.1 \_\_call\_\_
[TOC](#table-of-contents)
**Description**
Call self as a function.
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| input | torch.Tensor<br>List[PIL.Image]<br>List[numpy.array] | Sequence of images |
### 5.14.2 \_\_init\_\_
[TOC](#table-of-contents)
**Description**
Initializes a new instance.
## 5.15 ReplaceBackground : MultiTransform
[TOC](#table-of-contents)
**Description**
Transformation for background replacement based on HSV values. Supports depth background replacement. backgrounds have to be passed as list of tuples of rgb and depth images.
**Example**
```python
from rsp.nl.dataset import TUCRID
import rsp.ml.multi_transforms as multi_transforms
USE_DEPTH_DATA = False
backgrounds = TUCRID.load_backgrounds(USE_DEPTH_DATA)
tranforms_train = multi_transforms.Compose([
multi_transforms.ReplaceBackground(
backgrounds = backgrounds,
hsv_filter=[(69, 87, 139, 255, 52, 255)],
p = 0.8
),
multi_transforms.Stack()
])
tucrid = TUCRID('train', load_depth_data=USE_DEPTH_DATA, transforms=tranforms_train)
for X, T in tucrid:
for x in X:
img = x.permute(1, 2, 0).numpy()
cv.imshow('img', img)
cv.waitKey(30)
```
### 5.15.1 \_\_call\_\_
[TOC](#table-of-contents)
**Description**
Call self as a function.
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| input | torch.Tensor<br>List[PIL.Image]<br>List[numpy.array] | Sequence of images |
### 5.15.2 \_\_init\_\_
[TOC](#table-of-contents)
**Description**
Transformation for background replacement based on HSV values. Supports depth background replacement. backgrounds have to be passed as list of tuples of rgb and depth images.
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| backgrounds | List[np.array] | List of background images |
| hsv_filter | List[tuple[int, int, int, int, int, int]] | List of HSV filters |
| p | float, default = 1. | Probability of applying the transformation |
| rotate | float, default = 5 | Maximum rotation angle |
| max_scale | float, default = 2 | Maximum scaling factor |
| max_noise | float, default = 0.002 | Maximum noise level |
## 5.16 Resize : MultiTransform
[TOC](#table-of-contents)
**Description**
MultiTransform is an extension to keep the same transformation over a sequence of images instead of initializing a new transformation for every single image. It is inspired by `torchvision.transforms` and could be used for video augmentation. Use `rsp.ml.multi_transforms.Compose`to combine multiple image sequence transformations.
> **Note** `rsp.ml.multi_transforms.MultiTransform` is a base class and should be inherited.
### 5.16.1 \_\_call\_\_
[TOC](#table-of-contents)
**Description**
Call self as a function.
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| input | torch.Tensor<br>List[PIL.Image]<br>List[numpy.array] | Sequence of images |
### 5.16.2 \_\_init\_\_
[TOC](#table-of-contents)
**Description**
Initializes a new instance.
## 5.17 Rotate : MultiTransform
[TOC](#table-of-contents)
**Description**
Randomly rotates images.
**Equations**
$angle = -max\_angle + 2 \cdot random() \cdot max\_angle$

### 5.17.1 \_\_call\_\_
[TOC](#table-of-contents)
**Description**
Call self as a function.
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| input | torch.Tensor<br>List[PIL.Image]<br>List[numpy.array] | Sequence of images |
### 5.17.2 \_\_init\_\_
[TOC](#table-of-contents)
**Description**
Iitializes a new instance.
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| max_angle | float | Maximal rotation in degrees | -max_angle <= rotate <= max_angle |
| auto_scale | bool, default = True | Image will be resized when auto scale is activated to avoid black margins. |
## 5.18 Satturation : MultiTransform
[TOC](#table-of-contents)
**Description**
MultiTransform is an extension to keep the same transformation over a sequence of images instead of initializing a new transformation for every single image. It is inspired by `torchvision.transforms` and could be used for video augmentation. Use `rsp.ml.multi_transforms.Compose`to combine multiple image sequence transformations.
> **Note** `rsp.ml.multi_transforms.MultiTransform` is a base class and should be inherited.
### 5.18.1 \_\_call\_\_
[TOC](#table-of-contents)
**Description**
Call self as a function.
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| input | torch.Tensor<br>List[PIL.Image]<br>List[numpy.array] | Sequence of images |
### 5.18.2 \_\_init\_\_
[TOC](#table-of-contents)
**Description**
Initializes a new instance.
## 5.19 Scale : MultiTransform
[TOC](#table-of-contents)
**Description**
MultiTransform is an extension to keep the same transformation over a sequence of images instead of initializing a new transformation for every single image. It is inspired by `torchvision.transforms` and could be used for video augmentation. Use `rsp.ml.multi_transforms.Compose`to combine multiple image sequence transformations.
> **Note** `rsp.ml.multi_transforms.MultiTransform` is a base class and should be inherited.
### 5.19.1 \_\_call\_\_
[TOC](#table-of-contents)
**Description**
Call self as a function.
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| input | torch.Tensor<br>List[PIL.Image]<br>List[numpy.array] | Sequence of images |
### 5.19.2 \_\_init\_\_
[TOC](#table-of-contents)
**Description**
Initializes a new instance.
## 5.20 Stack : MultiTransform
[TOC](#table-of-contents)
**Description**
MultiTransform is an extension to keep the same transformation over a sequence of images instead of initializing a new transformation for every single image. It is inspired by `torchvision.transforms` and could be used for video augmentation. Use `rsp.ml.multi_transforms.Compose`to combine multiple image sequence transformations.
> **Note** `rsp.ml.multi_transforms.MultiTransform` is a base class and should be inherited.
### 5.20.1 \_\_call\_\_
[TOC](#table-of-contents)
**Description**
Call self as a function.
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| input | torch.Tensor<br>List[PIL.Image]<br>List[numpy.array] | Sequence of images |
### 5.20.2 \_\_init\_\_
[TOC](#table-of-contents)
**Description**
Initializes a new instance.
## 5.21 ToCVImage : MultiTransform
[TOC](#table-of-contents)
**Description**
Converts a `torch.Tensor`to Open CV image by changing dimensions (d0, d1, d2) -> (d1, d2, d0) and converting `torch.Tensor` to `numpy`.
### 5.21.1 \_\_call\_\_
[TOC](#table-of-contents)
**Description**
Call self as a function.
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| input | torch.Tensor<br>List[PIL.Image]<br>List[numpy.array] | Sequence of images |
### 5.21.2 \_\_init\_\_
[TOC](#table-of-contents)
**Description**
Initializes a new instance.
## 5.22 ToNumpy : MultiTransform
[TOC](#table-of-contents)
**Description**
Converts a `torch.Tensor`to `numpy`
### 5.22.1 \_\_call\_\_
[TOC](#table-of-contents)
**Description**
Call self as a function.
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| input | torch.Tensor<br>List[PIL.Image]<br>List[numpy.array] | Sequence of images |
### 5.22.2 \_\_init\_\_
[TOC](#table-of-contents)
**Description**
Initializes a new instance.
## 5.23 ToPILImage : MultiTransform
[TOC](#table-of-contents)
**Description**
Converts sequence of images to sequence of `PIL.Image`.
### 5.23.1 \_\_call\_\_
[TOC](#table-of-contents)
**Description**
Call self as a function.
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| input | torch.Tensor<br>List[PIL.Image]<br>List[numpy.array] | Sequence of images |
### 5.23.2 \_\_init\_\_
[TOC](#table-of-contents)
**Description**
Initializes a new instance.
## 5.24 ToTensor : MultiTransform
[TOC](#table-of-contents)
**Description**
Converts a sequence of images to torch.Tensor.
### 5.24.1 \_\_call\_\_
[TOC](#table-of-contents)
**Description**
Call self as a function.
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| input | torch.Tensor<br>List[PIL.Image]<br>List[numpy.array] | Sequence of images |
### 5.24.2 \_\_init\_\_
[TOC](#table-of-contents)
**Description**
Initializes a new instance.
# 6 run
[TOC](#table-of-contents)
The module `rsp.ml.run` provides some tools for storing, loading and visualizing data during training of models using PyTorch.
## 6.1 Run : builtins.object
[TOC](#table-of-contents)
**Description**
Run class to store and manage training
**Example**
```python
from rsp.ml.run import Run
import rsp.ml.metrics as m
metrics = [
m.top_1_accuracy
]
config = {
m.top_1_accuracy.__name__: {
'ymin': 0,
'ymax': 1
}
}
run = Run(id='run0001', metrics=metrics, config=config, ignore_outliers_in_chart_scaling=True)
for epoch in range(100):
"""here goes some training code, giving us inputs, predictions and targets"""
acc = m.top_1_accuracy(predictions, targets)
run.append(m.top_1_accuracy.__name__, 'train', acc)
```
### 6.1.1 \_\_init\_\_
[TOC](#table-of-contents)
**Description**
Run class to store and manage training
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| id | str, default = None | Id of the run. If None, a new id is generated |
| moving_average_epochs | int, default = 1 | Number of epochs to average over |
| metrics | list, default = None | List of metrics to compute. Each metric should be a function that takes Y and T as input. |
| device | str, default = None | torch device to run on |
| ignore_outliers_in_chart_scaling | bool, default = False | Ignore outliers when scaling charts |
| config | dict, default = {} | Configuration dictionary. Keys are metric names and values are dictionaries with keys 'ymin' and 'ymax' |
### 6.1.2 append
[TOC](#table-of-contents)
**Description**
Append value to key in phase.
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| key | str | Key to append to |
| phase | str | Phase to append to |
| value | float | Value to append |
### 6.1.3 get\_avg
[TOC](#table-of-contents)
**Description**
Get last average value of key in phase
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| key | str | Key to get |
| phase | str | Phase to get from |
**Returns**
Last average value of key in phase. If key is not in data, returns np.nan : value : float
### 6.1.4 get\_val
[TOC](#table-of-contents)
**Description**
Get last value of key in phase
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| key | str | Key to get |
| phase | str | Phase to get from |
**Returns**
Last value of key in phase. If key is not in data, returns np.nan : value : float
### 6.1.5 len
[TOC](#table-of-contents)
**Description**
Get length of longest phase
### 6.1.6 load\_best\_state\_dict
[TOC](#table-of-contents)
**Description**
Load best state_dict from runs/{id}/{fname}
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| model | torch.nn.Module | Model to load state_dict into |
| fname | str, default = 'state_dict.pt' | Filename to load from |
| verbose | bool, default = False | Print loaded file |
### 6.1.7 load\_state\_dict
[TOC](#table-of-contents)
**Description**
Load state_dict from runs/{id}/{fname}
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| model | torch.nn.Module | Model to load state_dict into |
| fname | str, default = None | Filename to load from |
### 6.1.8 pickle\_dump
[TOC](#table-of-contents)
**Description**
Pickle model to runs/{id}/{fname}
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| model | torch.nn.Module | Model to pickle |
| fname | str, default = 'model.pkl' | Filename to save to |
### 6.1.9 pickle\_load
[TOC](#table-of-contents)
**Description**
Load model from runs/{id}/{fname}
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| fname | str, default = 'model.pkl' | Filename to load from |
### 6.1.10 plot
[TOC](#table-of-contents)
**Description**
Plot all keys to runs/{id}/plot/{key}.jpg
### 6.1.11 recalculate\_moving\_average
[TOC](#table-of-contents)
**Description**
Recalculate moving average
### 6.1.12 save
[TOC](#table-of-contents)
**Description**
Save data to runs/{id}/data.json
### 6.1.13 save\_best\_state\_dict
[TOC](#table-of-contents)
**Description**
Save state_dict if new_acc is better than previous best
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| state_dict | dict | State dict to save |
| new_acc | float | New accuracy |
| epoch | int, default = None | Epoch to save |
| fname | str, default = 'state_dict.pt' | Filename to save to |
### 6.1.14 save\_state\_dict
[TOC](#table-of-contents)
**Description**
Save state_dict to runs/{id}/{fname}
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| state_dict | dict | State dict to save |
| fname | str, default = 'state_dict.pt' | Filename to save to |
### 6.1.15 train\_epoch
[TOC](#table-of-contents)
**Description**
Train one epoch.
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| dataloader | DataLoader | DataLoader to train on |
| model | torch.nn.Module | Model to train |
| optimizer | torch.optim.Optimizer | Optimizer to use |
| criterion | torch.nn.Module | Criterion to use |
| num_batches | int, default = None | Number of batches to train on. If None, train on all batches |
| return_YT | bool, default = False | Append Y and T to results |
**Returns**
Dictionary with results : results : dict
### 6.1.16 validate\_epoch
[TOC](#table-of-contents)
**Description**
Validate one epoch.
**Parameters**
| Name | Type | Description |
|------|------|-------------|
| dataloader | DataLoader | DataLoader to validate on |
| model | torch.nn.Module | Model to validate |
| optimizer | torch.optim.Optimizer | Optimizer to use |
| criterion | torch.nn.Module | Criterion to use |
| num_batches | int, default = None | Number of batches to validate on. If None, validate on all batches |
| return_YT | bool, default = False | Append Y and T to results |
**Returns**
Dictionary with results : results : dict
Raw data
{
"_id": null,
"home_page": "https://github.com/SchulzR97/rsp-ml",
"name": "rsp-ml",
"maintainer": null,
"docs_url": null,
"requires_python": null,
"maintainer_email": null,
"keywords": "python, Machine Learning",
"author": "Robert Schulz",
"author_email": "schulzr256@gmail.com",
"download_url": "https://files.pythonhosted.org/packages/27/51/7ccf46a8f2464d7cca4688e6c9bdf31b542a818b905524dd8c55201f29b4/rsp_ml-0.0.133.tar.gz",
"platform": null,
"description": "# RSProduction MachineLearning\n\nThis project provides some usefull machine learning functionality.\n\n# Table of Contents\n\n- [1 dataset](#1-dataset)\n - [1.1 HMDB51 : torch.utils.data.dataset.Dataset](#11-hmdb51--torchutilsdatadatasetdataset)\n - [1.1.1 \\_\\_init\\_\\_](#111-\\_\\_init\\_\\_)\n - [1.2 Kinetics : torch.utils.data.dataset.Dataset](#12-kinetics--torchutilsdatadatasetdataset)\n - [1.2.1 \\_\\_init\\_\\_](#121-\\_\\_init\\_\\_)\n - [1.3 TUCHRI : torch.utils.data.dataset.Dataset](#13-tuchri--torchutilsdatadatasetdataset)\n - [1.3.1 \\_\\_init\\_\\_](#131-\\_\\_init\\_\\_)\n - [1.3.2 get\\_uniform\\_sampler](#132-get\\_uniform\\_sampler)\n - [1.3.3 load\\_backgrounds](#133-load\\_backgrounds)\n - [1.4 TUCRID : torch.utils.data.dataset.Dataset](#14-tucrid--torchutilsdatadatasetdataset)\n - [1.4.1 \\_\\_init\\_\\_](#141-\\_\\_init\\_\\_)\n - [1.4.2 get\\_uniform\\_sampler](#142-get\\_uniform\\_sampler)\n - [1.4.3 load\\_backgrounds](#143-load\\_backgrounds)\n - [1.5 UCF101 : torch.utils.data.dataset.Dataset](#15-ucf101--torchutilsdatadatasetdataset)\n - [1.5.1 \\_\\_init\\_\\_](#151-\\_\\_init\\_\\_)\n - [1.6 UTKinectAction3D : torch.utils.data.dataset.Dataset](#16-utkinectaction3d--torchutilsdatadatasetdataset)\n - [1.6.1 \\_\\_init\\_\\_](#161-\\_\\_init\\_\\_)\n- [2 metrics](#2-metrics)\n - [2.1 AUROC](#21-auroc)\n - [2.2 F1\\_Score](#22-f1\\_score)\n - [2.3 FN](#23-fn)\n - [2.4 FP](#24-fp)\n - [2.5 FPR](#25-fpr)\n - [2.6 ROC](#26-roc)\n - [2.7 TN](#27-tn)\n - [2.8 TP](#28-tp)\n - [2.9 TPR](#29-tpr)\n - [2.10 confusion\\_matrix](#210-confusion\\_matrix)\n - [2.11 plot\\_ROC](#211-plot\\_roc)\n - [2.12 plot\\_confusion\\_matrix](#212-plot\\_confusion\\_matrix)\n - [2.13 precision](#213-precision)\n - [2.14 recall](#214-recall)\n - [2.15 top\\_10\\_accuracy](#215-top\\_10\\_accuracy)\n - [2.16 top\\_1\\_accuracy](#216-top\\_1\\_accuracy)\n - [2.17 top\\_2\\_accuracy](#217-top\\_2\\_accuracy)\n - [2.18 top\\_3\\_accuracy](#218-top\\_3\\_accuracy)\n - [2.19 top\\_5\\_accuracy](#219-top\\_5\\_accuracy)\n - [2.20 top\\_k\\_accuracy](#220-top\\_k\\_accuracy)\n- [3 model](#3-model)\n - [3.1 MODELS : enum.Enum](#31-models--enumenum)\n - [3.2 WEIGHTS : enum.Enum](#32-weights--enumenum)\n - [3.3 list\\_model\\_weights](#33-list\\_model\\_weights)\n - [3.4 load\\_model](#34-load\\_model)\n - [3.5 publish\\_model](#35-publish\\_model)\n- [4 module](#4-module)\n - [4.1 MultiHeadSelfAttention : torch.nn.modules.module.Module](#41-multiheadselfattention--torchnnmodulesmodulemodule)\n - [4.1.1 \\_wrapped\\_call\\_impl](#411-\\_wrapped\\_call\\_impl)\n - [4.1.2 \\_\\_init\\_\\_](#412-\\_\\_init\\_\\_)\n - [4.1.3 \\_apply](#413-\\_apply)\n - [4.1.4 \\_call\\_impl](#414-\\_call\\_impl)\n - [4.1.5 \\_get\\_backward\\_hooks](#415-\\_get\\_backward\\_hooks)\n - [4.1.6 \\_get\\_backward\\_pre\\_hooks](#416-\\_get\\_backward\\_pre\\_hooks)\n - [4.1.7 \\_get\\_name](#417-\\_get\\_name)\n - [4.1.8 \\_load\\_from\\_state\\_dict](#418-\\_load\\_from\\_state\\_dict)\n - [4.1.9 \\_maybe\\_warn\\_non\\_full\\_backward\\_hook](#419-\\_maybe\\_warn\\_non\\_full\\_backward\\_hook)\n - [4.1.10 \\_named\\_members](#4110-\\_named\\_members)\n - [4.1.11 \\_register\\_load\\_state\\_dict\\_pre\\_hook](#4111-\\_register\\_load\\_state\\_dict\\_pre\\_hook)\n - [4.1.12 \\_register\\_state\\_dict\\_hook](#4112-\\_register\\_state\\_dict\\_hook)\n - [4.1.13 \\_replicate\\_for\\_data\\_parallel](#4113-\\_replicate\\_for\\_data\\_parallel)\n - [4.1.14 \\_save\\_to\\_state\\_dict](#4114-\\_save\\_to\\_state\\_dict)\n - [4.1.15 \\_slow\\_forward](#4115-\\_slow\\_forward)\n - [4.1.16 \\_wrapped\\_call\\_impl](#4116-\\_wrapped\\_call\\_impl)\n - [4.1.17 add\\_module](#4117-add\\_module)\n - [4.1.18 apply](#4118-apply)\n - [4.1.19 bfloat16](#4119-bfloat16)\n - [4.1.20 buffers](#4120-buffers)\n- [ >>> # xdoctest: +SKIP(\"undefined vars\")](#-->>>-#-xdoctest--+skip(\"undefined-vars\"))\n - [4.1.21 children](#4121-children)\n - [4.1.22 compile](#4122-compile)\n - [4.1.23 cpu](#4123-cpu)\n - [4.1.24 cuda](#4124-cuda)\n - [4.1.25 double](#4125-double)\n - [4.1.26 eval](#4126-eval)\n - [4.1.27 extra\\_repr](#4127-extra\\_repr)\n - [4.1.28 float](#4128-float)\n - [4.1.29 forward](#4129-forward)\n - [4.1.30 get\\_buffer](#4130-get\\_buffer)\n - [4.1.31 get\\_extra\\_state](#4131-get\\_extra\\_state)\n - [4.1.32 get\\_parameter](#4132-get\\_parameter)\n - [4.1.33 get\\_submodule](#4133-get\\_submodule)\n - [4.1.34 half](#4134-half)\n - [4.1.35 ipu](#4135-ipu)\n - [4.1.36 load\\_state\\_dict](#4136-load\\_state\\_dict)\n - [4.1.37 modules](#4137-modules)\n - [4.1.38 mtia](#4138-mtia)\n - [4.1.39 named\\_buffers](#4139-named\\_buffers)\n- [ >>> # xdoctest: +SKIP(\"undefined vars\")](#-->>>-#-xdoctest--+skip(\"undefined-vars\"))\n - [4.1.40 named\\_children](#4140-named\\_children)\n- [ >>> # xdoctest: +SKIP(\"undefined vars\")](#-->>>-#-xdoctest--+skip(\"undefined-vars\"))\n - [4.1.41 named\\_modules](#4141-named\\_modules)\n - [4.1.42 named\\_parameters](#4142-named\\_parameters)\n- [ >>> # xdoctest: +SKIP(\"undefined vars\")](#-->>>-#-xdoctest--+skip(\"undefined-vars\"))\n - [4.1.43 parameters](#4143-parameters)\n- [ >>> # xdoctest: +SKIP(\"undefined vars\")](#-->>>-#-xdoctest--+skip(\"undefined-vars\"))\n - [4.1.44 register\\_backward\\_hook](#4144-register\\_backward\\_hook)\n - [4.1.45 register\\_buffer](#4145-register\\_buffer)\n- [ >>> # xdoctest: +SKIP(\"undefined vars\")](#-->>>-#-xdoctest--+skip(\"undefined-vars\"))\n - [4.1.46 register\\_forward\\_hook](#4146-register\\_forward\\_hook)\n - [4.1.47 register\\_forward\\_pre\\_hook](#4147-register\\_forward\\_pre\\_hook)\n - [4.1.48 register\\_full\\_backward\\_hook](#4148-register\\_full\\_backward\\_hook)\n - [4.1.49 register\\_full\\_backward\\_pre\\_hook](#4149-register\\_full\\_backward\\_pre\\_hook)\n - [4.1.50 register\\_load\\_state\\_dict\\_post\\_hook](#4150-register\\_load\\_state\\_dict\\_post\\_hook)\n - [4.1.51 register\\_load\\_state\\_dict\\_pre\\_hook](#4151-register\\_load\\_state\\_dict\\_pre\\_hook)\n- [ hook(module, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -> None # noqa: B950](#--hook(module,-state_dict,-prefix,-local_metadata,-strict,-missing_keys,-unexpected_keys,-error_msgs)-->-none--#-noqa--b950)\n - [4.1.52 register\\_module](#4152-register\\_module)\n - [4.1.53 register\\_parameter](#4153-register\\_parameter)\n - [4.1.54 register\\_state\\_dict\\_post\\_hook](#4154-register\\_state\\_dict\\_post\\_hook)\n - [4.1.55 register\\_state\\_dict\\_pre\\_hook](#4155-register\\_state\\_dict\\_pre\\_hook)\n - [4.1.56 requires\\_grad\\_](#4156-requires\\_grad\\_)\n - [4.1.57 set\\_extra\\_state](#4157-set\\_extra\\_state)\n - [4.1.58 set\\_submodule](#4158-set\\_submodule)\n - [4.1.59 share\\_memory](#4159-share\\_memory)\n - [4.1.60 state\\_dict](#4160-state\\_dict)\n- [ >>> # xdoctest: +SKIP(\"undefined vars\")](#-->>>-#-xdoctest--+skip(\"undefined-vars\"))\n - [4.1.61 to](#4161-to)\n- [ >>> # xdoctest: +IGNORE_WANT(\"non-deterministic\")](#-->>>-#-xdoctest--+ignore_want(\"non-deterministic\"))\n- [ >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1)](#-->>>-#-xdoctest--+requires(env-torch_doctest_cuda1))\n - [4.1.62 to\\_empty](#4162-to\\_empty)\n - [4.1.63 train](#4163-train)\n - [4.1.64 type](#4164-type)\n - [4.1.65 xpu](#4165-xpu)\n - [4.1.66 zero\\_grad](#4166-zero\\_grad)\n - [4.2 SelfAttention : torch.nn.modules.module.Module](#42-selfattention--torchnnmodulesmodulemodule)\n - [4.2.1 \\_wrapped\\_call\\_impl](#421-\\_wrapped\\_call\\_impl)\n - [4.2.2 \\_\\_init\\_\\_](#422-\\_\\_init\\_\\_)\n - [4.2.3 \\_apply](#423-\\_apply)\n - [4.2.4 \\_call\\_impl](#424-\\_call\\_impl)\n - [4.2.5 \\_get\\_backward\\_hooks](#425-\\_get\\_backward\\_hooks)\n - [4.2.6 \\_get\\_backward\\_pre\\_hooks](#426-\\_get\\_backward\\_pre\\_hooks)\n - [4.2.7 \\_get\\_name](#427-\\_get\\_name)\n - [4.2.8 \\_load\\_from\\_state\\_dict](#428-\\_load\\_from\\_state\\_dict)\n - [4.2.9 \\_maybe\\_warn\\_non\\_full\\_backward\\_hook](#429-\\_maybe\\_warn\\_non\\_full\\_backward\\_hook)\n - [4.2.10 \\_named\\_members](#4210-\\_named\\_members)\n - [4.2.11 \\_register\\_load\\_state\\_dict\\_pre\\_hook](#4211-\\_register\\_load\\_state\\_dict\\_pre\\_hook)\n - [4.2.12 \\_register\\_state\\_dict\\_hook](#4212-\\_register\\_state\\_dict\\_hook)\n - [4.2.13 \\_replicate\\_for\\_data\\_parallel](#4213-\\_replicate\\_for\\_data\\_parallel)\n - [4.2.14 \\_save\\_to\\_state\\_dict](#4214-\\_save\\_to\\_state\\_dict)\n - [4.2.15 \\_slow\\_forward](#4215-\\_slow\\_forward)\n - [4.2.16 \\_wrapped\\_call\\_impl](#4216-\\_wrapped\\_call\\_impl)\n - [4.2.17 add\\_module](#4217-add\\_module)\n - [4.2.18 apply](#4218-apply)\n - [4.2.19 bfloat16](#4219-bfloat16)\n - [4.2.20 buffers](#4220-buffers)\n- [ >>> # xdoctest: +SKIP(\"undefined vars\")](#-->>>-#-xdoctest--+skip(\"undefined-vars\"))\n - [4.2.21 children](#4221-children)\n - [4.2.22 compile](#4222-compile)\n - [4.2.23 cpu](#4223-cpu)\n - [4.2.24 cuda](#4224-cuda)\n - [4.2.25 double](#4225-double)\n - [4.2.26 eval](#4226-eval)\n - [4.2.27 extra\\_repr](#4227-extra\\_repr)\n - [4.2.28 float](#4228-float)\n - [4.2.29 forward](#4229-forward)\n - [4.2.30 get\\_buffer](#4230-get\\_buffer)\n - [4.2.31 get\\_extra\\_state](#4231-get\\_extra\\_state)\n - [4.2.32 get\\_parameter](#4232-get\\_parameter)\n - [4.2.33 get\\_submodule](#4233-get\\_submodule)\n - [4.2.34 half](#4234-half)\n - [4.2.35 ipu](#4235-ipu)\n - [4.2.36 load\\_state\\_dict](#4236-load\\_state\\_dict)\n - [4.2.37 modules](#4237-modules)\n - [4.2.38 mtia](#4238-mtia)\n - [4.2.39 named\\_buffers](#4239-named\\_buffers)\n- [ >>> # xdoctest: +SKIP(\"undefined vars\")](#-->>>-#-xdoctest--+skip(\"undefined-vars\"))\n - [4.2.40 named\\_children](#4240-named\\_children)\n- [ >>> # xdoctest: +SKIP(\"undefined vars\")](#-->>>-#-xdoctest--+skip(\"undefined-vars\"))\n - [4.2.41 named\\_modules](#4241-named\\_modules)\n - [4.2.42 named\\_parameters](#4242-named\\_parameters)\n- [ >>> # xdoctest: +SKIP(\"undefined vars\")](#-->>>-#-xdoctest--+skip(\"undefined-vars\"))\n - [4.2.43 parameters](#4243-parameters)\n- [ >>> # xdoctest: +SKIP(\"undefined vars\")](#-->>>-#-xdoctest--+skip(\"undefined-vars\"))\n - [4.2.44 register\\_backward\\_hook](#4244-register\\_backward\\_hook)\n - [4.2.45 register\\_buffer](#4245-register\\_buffer)\n- [ >>> # xdoctest: +SKIP(\"undefined vars\")](#-->>>-#-xdoctest--+skip(\"undefined-vars\"))\n - [4.2.46 register\\_forward\\_hook](#4246-register\\_forward\\_hook)\n - [4.2.47 register\\_forward\\_pre\\_hook](#4247-register\\_forward\\_pre\\_hook)\n - [4.2.48 register\\_full\\_backward\\_hook](#4248-register\\_full\\_backward\\_hook)\n - [4.2.49 register\\_full\\_backward\\_pre\\_hook](#4249-register\\_full\\_backward\\_pre\\_hook)\n - [4.2.50 register\\_load\\_state\\_dict\\_post\\_hook](#4250-register\\_load\\_state\\_dict\\_post\\_hook)\n - [4.2.51 register\\_load\\_state\\_dict\\_pre\\_hook](#4251-register\\_load\\_state\\_dict\\_pre\\_hook)\n- [ hook(module, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -> None # noqa: B950](#--hook(module,-state_dict,-prefix,-local_metadata,-strict,-missing_keys,-unexpected_keys,-error_msgs)-->-none--#-noqa--b950)\n - [4.2.52 register\\_module](#4252-register\\_module)\n - [4.2.53 register\\_parameter](#4253-register\\_parameter)\n - [4.2.54 register\\_state\\_dict\\_post\\_hook](#4254-register\\_state\\_dict\\_post\\_hook)\n - [4.2.55 register\\_state\\_dict\\_pre\\_hook](#4255-register\\_state\\_dict\\_pre\\_hook)\n - [4.2.56 requires\\_grad\\_](#4256-requires\\_grad\\_)\n - [4.2.57 set\\_extra\\_state](#4257-set\\_extra\\_state)\n - [4.2.58 set\\_submodule](#4258-set\\_submodule)\n - [4.2.59 share\\_memory](#4259-share\\_memory)\n - [4.2.60 state\\_dict](#4260-state\\_dict)\n- [ >>> # xdoctest: +SKIP(\"undefined vars\")](#-->>>-#-xdoctest--+skip(\"undefined-vars\"))\n - [4.2.61 to](#4261-to)\n- [ >>> # xdoctest: +IGNORE_WANT(\"non-deterministic\")](#-->>>-#-xdoctest--+ignore_want(\"non-deterministic\"))\n- [ >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1)](#-->>>-#-xdoctest--+requires(env-torch_doctest_cuda1))\n - [4.2.62 to\\_empty](#4262-to\\_empty)\n - [4.2.63 train](#4263-train)\n - [4.2.64 type](#4264-type)\n - [4.2.65 xpu](#4265-xpu)\n - [4.2.66 zero\\_grad](#4266-zero\\_grad)\n- [5 multi\\_transforms](#5-multi\\_transforms)\n - [5.1 BGR2GRAY : MultiTransform](#51-bgr2gray--multitransform)\n - [5.1.1 \\_\\_call\\_\\_](#511-\\_\\_call\\_\\_)\n - [5.1.2 \\_\\_init\\_\\_](#512-\\_\\_init\\_\\_)\n - [5.2 BGR2RGB : MultiTransform](#52-bgr2rgb--multitransform)\n - [5.2.1 \\_\\_call\\_\\_](#521-\\_\\_call\\_\\_)\n - [5.2.2 \\_\\_init\\_\\_](#522-\\_\\_init\\_\\_)\n - [5.3 Brightness : MultiTransform](#53-brightness--multitransform)\n - [5.3.1 \\_\\_call\\_\\_](#531-\\_\\_call\\_\\_)\n - [5.3.2 \\_\\_init\\_\\_](#532-\\_\\_init\\_\\_)\n - [5.4 CenterCrop : MultiTransform](#54-centercrop--multitransform)\n - [5.4.1 \\_\\_call\\_\\_](#541-\\_\\_call\\_\\_)\n - [5.4.2 \\_\\_init\\_\\_](#542-\\_\\_init\\_\\_)\n - [5.5 Color : MultiTransform](#55-color--multitransform)\n - [5.5.1 \\_\\_call\\_\\_](#551-\\_\\_call\\_\\_)\n - [5.5.2 \\_\\_init\\_\\_](#552-\\_\\_init\\_\\_)\n - [5.6 Compose : builtins.object](#56-compose--builtinsobject)\n - [5.6.1 \\_\\_call\\_\\_](#561-\\_\\_call\\_\\_)\n - [5.6.2 \\_\\_init\\_\\_](#562-\\_\\_init\\_\\_)\n - [5.7 GaussianNoise : MultiTransform](#57-gaussiannoise--multitransform)\n - [5.7.1 \\_\\_call\\_\\_](#571-\\_\\_call\\_\\_)\n - [5.7.2 \\_\\_init\\_\\_](#572-\\_\\_init\\_\\_)\n - [5.8 MultiTransform : builtins.object](#58-multitransform--builtinsobject)\n - [5.8.1 \\_\\_call\\_\\_](#581-\\_\\_call\\_\\_)\n - [5.8.2 \\_\\_init\\_\\_](#582-\\_\\_init\\_\\_)\n - [5.9 Normalize : MultiTransform](#59-normalize--multitransform)\n - [5.9.1 \\_\\_call\\_\\_](#591-\\_\\_call\\_\\_)\n - [5.9.2 \\_\\_init\\_\\_](#592-\\_\\_init\\_\\_)\n - [5.10 RGB2BGR : BGR2RGB](#510-rgb2bgr--bgr2rgb)\n - [5.10.1 \\_\\_call\\_\\_](#5101-\\_\\_call\\_\\_)\n - [5.10.2 \\_\\_init\\_\\_](#5102-\\_\\_init\\_\\_)\n - [5.11 RandomCrop : MultiTransform](#511-randomcrop--multitransform)\n - [5.11.1 \\_\\_call\\_\\_](#5111-\\_\\_call\\_\\_)\n - [5.11.2 \\_\\_init\\_\\_](#5112-\\_\\_init\\_\\_)\n - [5.12 RandomHorizontalFlip : MultiTransform](#512-randomhorizontalflip--multitransform)\n - [5.12.1 \\_\\_call\\_\\_](#5121-\\_\\_call\\_\\_)\n - [5.12.2 \\_\\_init\\_\\_](#5122-\\_\\_init\\_\\_)\n - [5.13 RandomVerticalFlip : MultiTransform](#513-randomverticalflip--multitransform)\n - [5.13.1 \\_\\_call\\_\\_](#5131-\\_\\_call\\_\\_)\n - [5.13.2 \\_\\_init\\_\\_](#5132-\\_\\_init\\_\\_)\n - [5.14 RemoveBackgroundAI : MultiTransform](#514-removebackgroundai--multitransform)\n - [5.14.1 \\_\\_call\\_\\_](#5141-\\_\\_call\\_\\_)\n - [5.14.2 \\_\\_init\\_\\_](#5142-\\_\\_init\\_\\_)\n - [5.15 ReplaceBackground : MultiTransform](#515-replacebackground--multitransform)\n - [5.15.1 \\_\\_call\\_\\_](#5151-\\_\\_call\\_\\_)\n - [5.15.2 \\_\\_init\\_\\_](#5152-\\_\\_init\\_\\_)\n - [5.16 Resize : MultiTransform](#516-resize--multitransform)\n - [5.16.1 \\_\\_call\\_\\_](#5161-\\_\\_call\\_\\_)\n - [5.16.2 \\_\\_init\\_\\_](#5162-\\_\\_init\\_\\_)\n - [5.17 Rotate : MultiTransform](#517-rotate--multitransform)\n - [5.17.1 \\_\\_call\\_\\_](#5171-\\_\\_call\\_\\_)\n - [5.17.2 \\_\\_init\\_\\_](#5172-\\_\\_init\\_\\_)\n - [5.18 Satturation : MultiTransform](#518-satturation--multitransform)\n - [5.18.1 \\_\\_call\\_\\_](#5181-\\_\\_call\\_\\_)\n - [5.18.2 \\_\\_init\\_\\_](#5182-\\_\\_init\\_\\_)\n - [5.19 Scale : MultiTransform](#519-scale--multitransform)\n - [5.19.1 \\_\\_call\\_\\_](#5191-\\_\\_call\\_\\_)\n - [5.19.2 \\_\\_init\\_\\_](#5192-\\_\\_init\\_\\_)\n - [5.20 Stack : MultiTransform](#520-stack--multitransform)\n - [5.20.1 \\_\\_call\\_\\_](#5201-\\_\\_call\\_\\_)\n - [5.20.2 \\_\\_init\\_\\_](#5202-\\_\\_init\\_\\_)\n - [5.21 ToCVImage : MultiTransform](#521-tocvimage--multitransform)\n - [5.21.1 \\_\\_call\\_\\_](#5211-\\_\\_call\\_\\_)\n - [5.21.2 \\_\\_init\\_\\_](#5212-\\_\\_init\\_\\_)\n - [5.22 ToNumpy : MultiTransform](#522-tonumpy--multitransform)\n - [5.22.1 \\_\\_call\\_\\_](#5221-\\_\\_call\\_\\_)\n - [5.22.2 \\_\\_init\\_\\_](#5222-\\_\\_init\\_\\_)\n - [5.23 ToPILImage : MultiTransform](#523-topilimage--multitransform)\n - [5.23.1 \\_\\_call\\_\\_](#5231-\\_\\_call\\_\\_)\n - [5.23.2 \\_\\_init\\_\\_](#5232-\\_\\_init\\_\\_)\n - [5.24 ToTensor : MultiTransform](#524-totensor--multitransform)\n - [5.24.1 \\_\\_call\\_\\_](#5241-\\_\\_call\\_\\_)\n - [5.24.2 \\_\\_init\\_\\_](#5242-\\_\\_init\\_\\_)\n- [6 run](#6-run)\n - [6.1 Run : builtins.object](#61-run--builtinsobject)\n - [6.1.1 \\_\\_init\\_\\_](#611-\\_\\_init\\_\\_)\n - [6.1.2 append](#612-append)\n - [6.1.3 get\\_avg](#613-get\\_avg)\n - [6.1.4 get\\_val](#614-get\\_val)\n - [6.1.5 len](#615-len)\n - [6.1.6 load\\_best\\_state\\_dict](#616-load\\_best\\_state\\_dict)\n - [6.1.7 load\\_state\\_dict](#617-load\\_state\\_dict)\n - [6.1.8 pickle\\_dump](#618-pickle\\_dump)\n - [6.1.9 pickle\\_load](#619-pickle\\_load)\n - [6.1.10 plot](#6110-plot)\n - [6.1.11 recalculate\\_moving\\_average](#6111-recalculate\\_moving\\_average)\n - [6.1.12 save](#6112-save)\n - [6.1.13 save\\_best\\_state\\_dict](#6113-save\\_best\\_state\\_dict)\n - [6.1.14 save\\_state\\_dict](#6114-save\\_state\\_dict)\n - [6.1.15 train\\_epoch](#6115-train\\_epoch)\n - [6.1.16 validate\\_epoch](#6116-validate\\_epoch)\n\n\n\n\n# 1 dataset\n\n[TOC](#table-of-contents)\n\n\n\n## 1.1 HMDB51 : torch.utils.data.dataset.Dataset\n\n[TOC](#table-of-contents)\n\n**Description**\n\nDataset class for HMDB51.\n\n**Example**\n\n```python\nfrom rsp.ml.dataset import HMDB51\nimport rsp.ml.multi_transforms as multi_transforms\nimport cv2 as cv\n\ntransforms = multi_transforms.Compose([\n multi_transforms.Color(1.5, p=0.5),\n multi_transforms.Stack()\n])\nds = HMDB51('train', fold=1, transforms=transforms)\n\nfor X, T in ds:\n for x in X.permute(0, 2, 3, 1):\n img_color = x[:, :, :3].numpy()\n img_depth = x[:, :, 3].numpy()\n\n cv.imshow('color', img_color)\n cv.imshow('depth', img_depth)\n\n cv.waitKey(30)\n```\n### 1.1.1 \\_\\_init\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nInitializes a new instance.\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| split | str | Dataset split [train|val|test] |\n| fold | int | Fold number. The dataset is split into 3 folds. If fold is None, all folds will be loaded. |\n| cache_dir | str, default = None | Directory to store the downloaded files. If set to `None`, the default cache directory will be used |\n| force_reload | bool, default = False | If set to `True`, the dataset will be reloaded |\n| target_size | (int, int), default = (400, 400) | Size of the frames. The frames will be resized to this size. |\n| sequence_length | int, default = 30 | Length of the sequences |\n| transforms | rsp.ml.multi_transforms.Compose = default = rsp.ml.multi_transforms.Compose([]) | Transformations, that will be applied to each input sequence. See documentation of `rsp.ml.multi_transforms` for more details. |\n| verbose | bool, default = False | If set to `True`, the progress will be printed. |\n## 1.2 Kinetics : torch.utils.data.dataset.Dataset\n\n[TOC](#table-of-contents)\n\n**Description**\n\nDataset class for the Kinetics dataset.\n\n**Example**\n\n```python\nfrom rsp.ml.dataset import Kinetics\n\nds = Kinetics(split='train', type=400)\n\nfor X, T in ds:\n print(X)\n```\n### 1.2.1 \\_\\_init\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nInitializes a new instance.\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| split | str | Dataset split [train|val] |\n| sequence_length | int, default = 60 | Length of the sequences |\n| type | int, default = 400 | Type of the kineticts dataset. Currently only 400 is supported. |\n| frame_size | (int, int), default = (400, 400) | Size of the frames. The frames will be resized to this size. |\n| transforms | rsp.ml.multi_transforms.Compose = default = rsp.ml.multi_transforms.Compose([]) | Transformations, that will be applied to each input sequence. See documentation of `rsp.ml.multi_transforms` for more details. |\n| cache_dir | str, default = None | Directory to store the downloaded files. If set to `None`, the default cache directory will be used |\n| num_threads | int, default = 0 | Number of threads to use for downloading the files. |\n| verbose | bool, default = True | If set to `True`, the progress and additional information will be printed. |\n## 1.3 TUCHRI : torch.utils.data.dataset.Dataset\n\n[TOC](#table-of-contents)\n\n**Description**\n\nDataset class for the Robot Interaction Dataset by University of Technology Chemnitz (TUCHRI).\n\n\n### 1.3.1 \\_\\_init\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nInitializes a new instance.\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| phase | str | Dataset phase [train|val] |\n| load_depth_data | bool, default = True | Load depth data |\n| sequence_length | int, default = 30 | Length of the sequences |\n| num_classes | int, default = 10 | Number of classes |\n| transforms | rsp.ml.multi_transforms.Compose = default = rsp.ml.multi_transforms.Compose([]) | Transformations, that will be applied to each input sequence. See documentation of `rsp.ml.multi_transforms` for more details. |\n### 1.3.2 get\\_uniform\\_sampler\n\n[TOC](#table-of-contents)\n\n### 1.3.3 load\\_backgrounds\n\n[TOC](#table-of-contents)\n\n**Description**\n\nLoads the background images.\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| load_depth_data | bool, default = True | If set to `True`, the depth images will be loaded as well. |\n## 1.4 TUCRID : torch.utils.data.dataset.Dataset\n\n[TOC](#table-of-contents)\n\n**Description**\n\nDataset class for the Robot Interaction Dataset by University of Technology Chemnitz (TUCRID).\n\n**Example**\n\n```python\nfrom rsp.ml.dataset import TUCRID\nfrom rsp.ml.dataset import ReplaceBackgroundRGBD\nimport rsp.ml.multi_transforms as multi_transforms\nimport cv2 as cv\n\nbackgrounds = TUCRID.load_backgrounds_color()\ntransforms = multi_transforms.Compose([\n ReplaceBackgroundRGBD(backgrounds),\n multi_transforms.Stack()\n])\n\nds = TUCRID('train', transforms=transforms)\n\nfor X, T in ds:\n for x in X.permute(0, 2, 3, 1):\n img_color = x[:, :, :3].numpy()\n img_depth = x[:, :, 3].numpy()\n\n cv.imshow('color', img_color)\n cv.imshow('depth', img_depth)\n\n cv.waitKey(30)\n```\n### 1.4.1 \\_\\_init\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nInitializes a new instance.\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| phase | str | Dataset phase [train|val] |\n| load_depth_data | bool, default = True | Load depth data |\n| sequence_length | int, default = 30 | Length of the sequences |\n| num_classes | int, default = 10 | Number of classes |\n| transforms | rsp.ml.multi_transforms.Compose = default = rsp.ml.multi_transforms.Compose([]) | Transformations, that will be applied to each input sequence. See documentation of `rsp.ml.multi_transforms` for more details. |\n### 1.4.2 get\\_uniform\\_sampler\n\n[TOC](#table-of-contents)\n\n### 1.4.3 load\\_backgrounds\n\n[TOC](#table-of-contents)\n\n**Description**\n\nLoads the background images.\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| load_depth_data | bool, default = True | If set to `True`, the depth images will be loaded as well. |\n## 1.5 UCF101 : torch.utils.data.dataset.Dataset\n\n[TOC](#table-of-contents)\n\n**Description**\n\nAn abstract class representing a :class:`Dataset`.\n\nAll datasets that represent a map from keys to data samples should subclass\nit. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a\ndata sample for a given key. Subclasses could also optionally overwrite\n:meth:`__len__`, which is expected to return the size of the dataset by many\n:class:`~torch.utils.data.Sampler` implementations and the default options\nof :class:`~torch.utils.data.DataLoader`. Subclasses could also\noptionally implement :meth:`__getitems__`, for speedup batched samples\nloading. This method accepts list of indices of samples of batch and returns\nlist of samples.\n\n.. note::\n :class:`~torch.utils.data.DataLoader` by default constructs an index\n sampler that yields integral indices. To make it work with a map-style\n dataset with non-integral indices/keys, a custom sampler must be provided.\n\n**Example**\n\n```python\nfrom rsp.ml.dataset import UCF101\nimport rsp.ml.multi_transforms as multi_transforms\nimport cv2 as cv\n\ntransforms = multi_transforms.Compose([\n multi_transforms.Color(1.5, p=0.5),\n multi_transforms.Stack()\n])\nds = UCF101('train', fold=1, transforms=transforms)\n\nfor X, T in ds:\n for x in X.permute(0, 2, 3, 1):\n img_color = x[:, :, :3].numpy()\n img_depth = x[:, :, 3].numpy()\n\n cv.imshow('color', img_color)\n cv.imshow('depth', img_depth)\n\n cv.waitKey(30)\n```\n### 1.5.1 \\_\\_init\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nInitializes a new instance.\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| split | str | Dataset split [train|val|test] |\n| fold | int | Fold number. The dataset is split into 3 folds. If fold is None, all folds will be loaded. |\n| cache_dir | str, default = None | Directory to store the downloaded files. If set to `None`, the default cache directory will be used |\n| force_reload | bool, default = False | If set to `True`, the dataset will be reloaded |\n| target_size | (int, int), default = (400, 400) | Size of the frames. The frames will be resized to this size. |\n| sequence_length | int, default = 30 | Length of the sequences |\n| transforms | rsp.ml.multi_transforms.Compose = default = rsp.ml.multi_transforms.Compose([]) | Transformations, that will be applied to each input sequence. See documentation of `rsp.ml.multi_transforms` for more details. |\n| verbose | bool, default = False | If set to `True`, the progress will be printed. |\n## 1.6 UTKinectAction3D : torch.utils.data.dataset.Dataset\n\n[TOC](#table-of-contents)\n\n**Description**\n\nDataset class for the UTKinectAction3D dataset.\n\nParameters\n----------\nsplit : str\n Dataset split [train|val]\ncache_dir : str, default = None\n Directory to store the downloaded files. If set to `None`, the default cache directory will be used\nforce_reload : bool, default = False\n If set to `True`, the dataset will be reloaded\ntarget_size : (int, int), default = (400, 400)\n Size of the frames. The frames will be resized to this size.\nsequence_length : int, default = 30\n Length of the sequences\ntransforms : rsp.ml.multi_transforms.Compose = default = rsp.ml.multi_transforms.Compose([])\n Transformations, that will be applied to each input sequence. See documentation of `rsp.ml.multi_transforms` for more details.\nverbose : bool, default = False\n If set to `True`, the progress will be printed.\n\n**Example**\n\n```python\nfrom rsp.ml.dataset import UTKinectAction3D\nimport rsp.ml.multi_transforms as multi_transforms\nimport cv2 as cv\n\ntransforms = multi_transforms.Compose([\n multi_transforms.Color(1.5, p=0.5),\n multi_transforms.Stack()\n])\nds = UTKinectAction3D('train', transforms=transforms)\n\nfor X, T in ds:\n for x in X.permute(0, 2, 3, 1):\n img_color = x[:, :, :3].numpy()\n img_depth = x[:, :, 3].numpy()\n\n cv.imshow('color', img_color)\n cv.imshow('depth', img_depth)\n\n cv.waitKey(30)\n```\n### 1.6.1 \\_\\_init\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nInitialize self. See help(type(self)) for accurate signature.\n\n# 2 metrics\n\n[TOC](#table-of-contents)\n\nThe module `rsp.ml.metrics` provides some functionality to quantify the quality of predictions.\n\n## 2.1 AUROC\n\n[TOC](#table-of-contents)\n\n**Description**\n\nCalculates the Area under the Receiver Operation Chracteristic Curve.\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| Y | torch.Tensor | Prediction |\n| T | torch.Tensor | True values |\n| num_thresholds | int, default = 100 | Number of thresholds to compute. |\n\n**Returns**\n\nReceiver Operation Chracteristic Area under the Curve : float\n\n## 2.2 F1\\_Score\n\n[TOC](#table-of-contents)\n\n**Description**\n\nF1 Score. Expected input shape: (batch_size, num_classes)\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| Y | torch.Tensor | Prediction |\n| T | torch.Tensor | True values |\n| threshold | float | All values that are greater than or equal to the threshold are considered a positive class. |\n\n**Returns**\n\nF1 Score : float\n\n**Equations**\n\n$precision = \\frac{TP}{TP + FP}$\n\n$recall = \\frac{TP}{TP + FN}$\n\n$F_1 = \\frac{2 \\cdot precision \\cdot recall}{precision + recall} = \\frac{2 \\cdot TP}{2 \\cdot TP + FP + FN}$\n\n\n\n**Example**\n\n```python\nimport rsp.ml.metrics as m\n\nY = torch.tensor([\n [0.1, 0.1, 0.8],\n [0.03, 0.95, 0.02],\n [0.05, 0.9, 0.05],\n [0.01, 0.87, 0.12],\n [0.04, 0.03, 0.93],\n [0.94, 0.02, 0.06]\n])\nT = torch.tensor([\n [0, 0, 1],\n [1, 0, 0],\n [0, 1, 0],\n [0, 1, 0],\n [0, 0, 1],\n [1, 0, 0]\n])\n\nf1score = m.F1_Score(Y, T)\n\nprint(f1score) --> 0.5\n```\n\n## 2.3 FN\n\n[TOC](#table-of-contents)\n\n**Description**\n\nFalse negatives. Expected input shape: (batch_size, num_classes)\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| Y | torch.Tensor | Prediction |\n| T | torch.Tensor | True values |\n| threshold | float | All values that are greater than or equal to the threshold are considered a positive class. |\n\n**Returns**\n\nFalse negatives : int\n\n**Example**\n\n```python\nimport rsp.ml.metrics as m\nimport torch\n\nY = torch.tensor([\n [0.1, 0.1, 0.8],\n [0.03, 0.95, 0.02],\n [0.05, 0.9, 0.05],\n [0.01, 0.87, 0.12],\n [0.04, 0.03, 0.93],\n [0.94, 0.02, 0.06]\n])\nT = torch.tensor([\n [0, 0, 1],\n [1, 0, 0],\n [0, 1, 0],\n [0, 1, 0],\n [0, 0, 1],\n [1, 0, 0]\n])\n\nfn = m.FN(Y, T)\nprint(fn) -> 1\n```\n\n## 2.4 FP\n\n[TOC](#table-of-contents)\n\n**Description**\n\nFalse positives. Expected input shape: (batch_size, num_classes)\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| Y | torch.Tensor | Prediction |\n| T | torch.Tensor | True values |\n| threshold | float | All values that are greater than or equal to the threshold are considered a positive class. |\n\n**Returns**\n\nFalse positives : int\n\n**Example**\n\n```python\nimport rsp.ml.metrics as m\nimport torch\n\nY = torch.tensor([\n [0.1, 0.1, 0.8],\n [0.03, 0.95, 0.02],\n [0.05, 0.9, 0.05],\n [0.01, 0.87, 0.12],\n [0.04, 0.03, 0.93],\n [0.94, 0.02, 0.06]\n])\nT = torch.tensor([\n [0, 0, 1],\n [1, 0, 0],\n [0, 1, 0],\n [0, 1, 0],\n [0, 0, 1],\n [1, 0, 0]\n])\n\nfp = m.FP(Y, T)\nprint(fp) -> 1\n```\n\n## 2.5 FPR\n\n[TOC](#table-of-contents)\n\n**Description**\n\nFalse positive rate. Expected input shape: (batch_size, num_classes)\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| Y | torch.Tensor | Prediction |\n| T | torch.Tensor | True values |\n| threshold | float | All values that are greater than or equal to the threshold are considered a positive class. |\n\n**Returns**\n\nFalse positive rate : float\n\n**Example**\n\n```python\nimport rsp.ml.metrics as m\nimport torch\n\nY = torch.tensor([\n [0.1, 0.1, 0.8],\n [0.03, 0.95, 0.02],\n [0.05, 0.9, 0.05],\n [0.01, 0.87, 0.12],\n [0.04, 0.03, 0.93],\n [0.94, 0.02, 0.06]\n])\nT = torch.tensor([\n [0, 0, 1],\n [1, 0, 0],\n [0, 1, 0],\n [0, 1, 0],\n [0, 0, 1],\n [1, 0, 0]\n])\n\nfpr = m.FPR(Y, T)\nprint(fpr) -> 0.08333333333333333\n```\n\n## 2.6 ROC\n\n[TOC](#table-of-contents)\n\n**Description**\n\nCalculates the receiver operating characteristic: computes False Positive Rates and True positive Rates for `num_thresholds` aligned between 0 and 1\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| Y | torch.Tensor | Prediction |\n| T | torch.Tensor | True values |\n| num_thresholds | int, default = 100 | Number of thresholds to compute. |\n\n**Returns**\n\n(False Positive Rates, True Positive Rates) for 100 different thresholds : (List[float], List[float])\n\n**Example**\n\n```python\nimport rsp.ml.metrics as m\nimport torch\nimport torch.nn.functional as F\n\nnum_elements = 100000\nnum_classes = 7\n\nT = []\nfor i in range(num_elements):\n true_class = torch.randint(0, num_classes, (1,))\n t = F.one_hot(true_class, num_classes=num_classes)\n T.append(t)\nT = torch.cat(T)\n\ndist = torch.normal(T.float(), 1.5)\nY = F.softmax(dist, dim = 1)\nFPRs, TPRs = m.ROC(Y, T)\n```\n\n## 2.7 TN\n\n[TOC](#table-of-contents)\n\n**Description**\n\nTrue negatives. Expected input shape: (batch_size, num_classes)\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| Y | torch.Tensor | Prediction |\n| T | torch.Tensor | True values |\n| threshold | float | All values that are greater than or equal to the threshold are considered a positive class. |\n\n**Returns**\n\nTrue negatives : int\n\n**Example**\n\n```python\nimport rsp.ml.metrics as m\nimport torch\n\nY = torch.tensor([\n [0.1, 0.1, 0.8],\n [0.03, 0.95, 0.02],\n [0.05, 0.9, 0.05],\n [0.01, 0.87, 0.12],\n [0.04, 0.03, 0.93],\n [0.94, 0.02, 0.06]\n])\nT = torch.tensor([\n [0, 0, 1],\n [1, 0, 0],\n [0, 1, 0],\n [0, 1, 0],\n [0, 0, 1],\n [1, 0, 0]\n])\n\ntn = m.TN(Y, T)\nprint(tn) -> 11\n```\n\n## 2.8 TP\n\n[TOC](#table-of-contents)\n\n**Description**\n\nTrue positives. Expected input shape: (batch_size, num_classes)\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| Y | torch.Tensor | Prediction |\n| T | torch.Tensor | True values |\n| threshold | float | All values that are greater than or equal to the threshold are considered a positive class. |\n\n**Returns**\n\nTrue positives : int\n\n**Example**\n\n```python\nimport rsp.ml.metrics as m\nimport torch\n\nY = torch.tensor([\n [0.1, 0.1, 0.8],\n [0.03, 0.95, 0.02],\n [0.05, 0.9, 0.05],\n [0.01, 0.87, 0.12],\n [0.04, 0.03, 0.93],\n [0.94, 0.02, 0.06]\n])\nT = torch.tensor([\n [0, 0, 1],\n [1, 0, 0],\n [0, 1, 0],\n [0, 1, 0],\n [0, 0, 1],\n [1, 0, 0]\n])\n\ntp = m.TP(Y, T)\nprint(tp) -> 5\n```\n\n## 2.9 TPR\n\n[TOC](#table-of-contents)\n\n**Description**\n\nTrue positive rate. Expected input shape: (batch_size, num_classes)\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| Y | torch.Tensor | Prediction |\n| T | torch.Tensor | True values |\n| threshold | float | All values that are greater than or equal to the threshold are considered a positive class. |\n\n**Returns**\n\nTrue positive rate : float\n\n**Example**\n\n```python\nimport rsp.ml.metrics as m\nimport torch\n\nY = torch.tensor([\n [0.1, 0.1, 0.8],\n [0.03, 0.95, 0.02],\n [0.05, 0.9, 0.05],\n [0.01, 0.87, 0.12],\n [0.04, 0.03, 0.93],\n [0.94, 0.02, 0.06]\n])\nT = torch.tensor([\n [0, 0, 1],\n [1, 0, 0],\n [0, 1, 0],\n [0, 1, 0],\n [0, 0, 1],\n [1, 0, 0]\n])\n\ntpr = m.TPR(Y, T)\nprint(tpr) -> 0.8333333333333334\n```\n\n## 2.10 confusion\\_matrix\n\n[TOC](#table-of-contents)\n\n**Description**\n\nCalculates the confusion matrix. Expected input shape: (batch_size, num_classes)\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| Y | torch.Tensor | Prediction |\n| T | torch.Tensor | True values |\n\n**Returns**\n\nConfusion matrix : torch.Tensor\n\n**Example**\n\n```python\nimport rsp.ml.metrics as m\nimport torch\n\nY = torch.tensor([\n [0.1, 0.1, 0.8],\n [0.03, 0.95, 0.02],\n [0.05, 0.9, 0.05],\n [0.01, 0.87, 0.12],\n [0.04, 0.03, 0.93],\n [0.94, 0.02, 0.06]\n])\nT = torch.tensor([\n [0, 0, 1],\n [1, 0, 0],\n [0, 1, 0],\n [0, 1, 0],\n [0, 0, 1],\n [1, 0, 0]\n])\n\nconf_mat = m.confusion_matrix(Y, T)\nprint(conf_mat) -> tensor([\n [1, 1, 0],\n [0, 2, 0],\n [0, 0, 2]\n])\n```\n\n## 2.11 plot\\_ROC\n\n[TOC](#table-of-contents)\n\n**Description**\n\nPlot the receiver operating characteristic.\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| Y | torch.Tensor | Prediction |\n| T | torch.Tensor | True values |\n| num_thresholds | int, default = 100 | Number of thresholds to compute. |\n| title | str, optional, default = 'Confusion Matrix' | Title of the plot |\n| class_curves | bool, default = False | Plot ROC curve for each class |\n| labels | str, optional, default = None | Class labels -> automatic labeling C000, ..., CXXX if labels is None |\n| plt_show | bool, optional, default = False | Set to True to show the plot |\n| save_file_name | str, optional, default = None | If not None, the plot is saved under the specified save_file_name. |\n\n**Returns**\n\nImage of the confusion matrix : np.array\n\n\n## 2.12 plot\\_confusion\\_matrix\n\n[TOC](#table-of-contents)\n\n**Description**\n\nPlot the confusion matrix\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| confusion_matrix | torch.Tensor | Confusion matrix |\n| labels | str, optional, default = None | Class labels -> automatic labeling C000, ..., CXXX if labels is None |\n| cmap | str, optional, default = 'Blues' | Seaborn cmap, see https://r02b.github.io/seaborn_palettes/ |\n| xlabel | str, optional, default = 'Predicted label' | X-Axis label |\n| ylabel | str, optional, default = 'True label' | Y-Axis label |\n| title | str, optional, default = 'Confusion Matrix' | Title of the plot |\n| plt_show | bool, optional, default = False | Set to True to show the plot |\n| save_file_name | str, optional, default = None | If not None, the plot is saved under the specified save_file_name. |\n\n**Returns**\n\nImage of the confusion matrix : np.array\n\n\n## 2.13 precision\n\n[TOC](#table-of-contents)\n\n**Description**\n\nPrecision. Expected input shape: (batch_size, num_classes)\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| Y | torch.Tensor | Prediction |\n| T | torch.Tensor | True values |\n| threshold | float | All values that are greater than or equal to the threshold are considered a positive class. |\n\n**Returns**\n\nPrecision : float\n\n**Equations**\n\n$precision = \\frac{TP}{TP + FP}$\n\n\n\n**Example**\n\n```python\nimport rsp.ml.metrics as m\nimport torch\n\nY = torch.tensor([\n [0.1, 0.1, 0.8],\n [0.03, 0.95, 0.02],\n [0.05, 0.9, 0.05],\n [0.01, 0.87, 0.12],\n [0.04, 0.03, 0.93],\n [0.94, 0.02, 0.06]\n])\nT = torch.tensor([\n [0, 0, 1],\n [1, 0, 0],\n [0, 1, 0],\n [0, 1, 0],\n [0, 0, 1],\n [1, 0, 0]\n])\n\nprecision = m.precision(Y, T)\nprint(precision) -> 0.8333333333333334\n```\n\n## 2.14 recall\n\n[TOC](#table-of-contents)\n\n**Description**\n\nRecall. Expected input shape: (batch_size, num_classes)\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| Y | torch.Tensor | Prediction |\n| T | torch.Tensor | True values |\n| threshold | float | All values that are greater than or equal to the threshold are considered a positive class. |\n\n**Returns**\n\nRecall : float\n\n**Equations**\n\n$recall = \\frac{TP}{TP + FN}$\n\n\n\n**Example**\n\n```python\nimport rsp.ml.metrics as m\nimport torch\n\nY = torch.tensor([\n [0.1, 0.1, 0.8],\n [0.03, 0.95, 0.02],\n [0.05, 0.9, 0.05],\n [0.01, 0.87, 0.12],\n [0.04, 0.03, 0.93],\n [0.94, 0.02, 0.06]\n])\nT = torch.tensor([\n [0, 0, 1],\n [1, 0, 0],\n [0, 1, 0],\n [0, 1, 0],\n [0, 0, 1],\n [1, 0, 0]\n])\n\nrecall = m.recall(Y, T)\nprint(recall) -> 0.8333333333333334\n```\n\n## 2.15 top\\_10\\_accuracy\n\n[TOC](#table-of-contents)\n\n**Description**\n\nTop 10 accuracy. Expected input shape: (batch_size, num_classes)\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| Y | torch.Tensor | Prediction |\n| T | torch.Tensor | True values |\n\n**Returns**\n\nTop 10 accuracy -> top k accuracy | k = 10 : float\n\n**Example**\n\n```python\nimport rsp.ml.metrics as m\n\nY = torch.tensor([\n [0.1, 0.1, 0.8],\n [0.03, 0.95, 0.02],\n [0.05, 0.9, 0.05],\n [0.01, 0.87, 0.12],\n [0.04, 0.03, 0.93],\n [0.94, 0.02, 0.06]\n])\nT = torch.tensor([\n [0, 0, 1],\n [1, 0, 0],\n [0, 1, 0],\n [0, 1, 0],\n [0, 0, 1],\n [1, 0, 0]\n])\n\ntop_10_accuracy = m.top_10_accuracy(Y, T, k = 3)\n\nprint(top_10_accuracy) --> 1.0\n```\n\n## 2.16 top\\_1\\_accuracy\n\n[TOC](#table-of-contents)\n\n**Description**\n\nTop 1 accuracy. Expected input shape: (batch_size, num_classes)\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| Y | torch.Tensor | Prediction |\n| T | torch.Tensor | True values |\n\n**Returns**\n\nTop 1 accuracy -> top k accuracy | k = 1 : float\n\n**Example**\n\n```python\nimport rsp.ml.metrics as m\n\nY = torch.tensor([\n [0.1, 0.1, 0.8],\n [0.03, 0.95, 0.02],\n [0.05, 0.9, 0.05],\n [0.01, 0.87, 0.12],\n [0.04, 0.03, 0.93],\n [0.94, 0.02, 0.06]\n])\nT = torch.tensor([\n [0, 0, 1],\n [1, 0, 0],\n [0, 1, 0],\n [0, 1, 0],\n [0, 0, 1],\n [1, 0, 0]\n])\n\ntop_1_accuracy = m.top_1_accuracy(Y, T, k = 3)\n\nprint(top_1_accuracy) --> 0.8333333333333334\n```\n\n## 2.17 top\\_2\\_accuracy\n\n[TOC](#table-of-contents)\n\n**Description**\n\nTop 2 accuracy. Expected input shape: (batch_size, num_classes)\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| Y | torch.Tensor | Prediction |\n| T | torch.Tensor | True values |\n\n**Returns**\n\nTop 2 accuracy -> top k accuracy | k = 2 : float\n\n**Example**\n\n```python\nimport rsp.ml.metrics as m\n\nY = torch.tensor([\n [0.1, 0.1, 0.8],\n [0.03, 0.95, 0.02],\n [0.05, 0.9, 0.05],\n [0.01, 0.87, 0.12],\n [0.04, 0.03, 0.93],\n [0.94, 0.02, 0.06]\n])\nT = torch.tensor([\n [0, 0, 1],\n [1, 0, 0],\n [0, 1, 0],\n [0, 1, 0],\n [0, 0, 1],\n [1, 0, 0]\n])\n\ntop_2_accuracy = m.top_2_accuracy(Y, T, k = 3)\n\nprint(top_2_accuracy) --> 1.0\n```\n\n## 2.18 top\\_3\\_accuracy\n\n[TOC](#table-of-contents)\n\n**Description**\n\nTop 3 accuracy. Expected input shape: (batch_size, num_classes)\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| Y | torch.Tensor | Prediction |\n| T | torch.Tensor | True values |\n\n**Returns**\n\nTop 3 accuracy -> top k accuracy | k = 3 : float\n\n**Example**\n\n```python\nimport rsp.ml.metrics as m\n\nY = torch.tensor([\n [0.1, 0.1, 0.8],\n [0.03, 0.95, 0.02],\n [0.05, 0.9, 0.05],\n [0.01, 0.87, 0.12],\n [0.04, 0.03, 0.93],\n [0.94, 0.02, 0.06]\n])\nT = torch.tensor([\n [0, 0, 1],\n [1, 0, 0],\n [0, 1, 0],\n [0, 1, 0],\n [0, 0, 1],\n [1, 0, 0]\n])\n\ntop_3_accuracy = m.top_3_accuracy(Y, T, k = 3)\n\nprint(top_3_accuracy) --> 1.0\n```\n\n## 2.19 top\\_5\\_accuracy\n\n[TOC](#table-of-contents)\n\n**Description**\n\nTop 5 accuracy. Expected input shape: (batch_size, num_classes)\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| Y | torch.Tensor | Prediction |\n| T | torch.Tensor | True values |\n\n**Returns**\n\nTop 5 accuracy -> top k accuracy | k = 5 : float\n\n**Example**\n\n```python\nimport rsp.ml.metrics as m\n\nY = torch.tensor([\n [0.1, 0.1, 0.8],\n [0.03, 0.95, 0.02],\n [0.05, 0.9, 0.05],\n [0.01, 0.87, 0.12],\n [0.04, 0.03, 0.93],\n [0.94, 0.02, 0.06]\n])\nT = torch.tensor([\n [0, 0, 1],\n [1, 0, 0],\n [0, 1, 0],\n [0, 1, 0],\n [0, 0, 1],\n [1, 0, 0]\n])\n\ntop_5_accuracy = m.top_5_accuracy(Y, T, k = 3)\n\nprint(top_5_accuracy) --> 1.0\n```\n\n## 2.20 top\\_k\\_accuracy\n\n[TOC](#table-of-contents)\n\n**Description**\n\nTop k accuracy. Expected input shape: (batch_size, num_classes)\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| Y | torch.Tensor | Prediction |\n| T | torch.Tensor | True values |\n\n**Returns**\n\nTop k accuracy : float\n\n**Example**\n\n```python\nimport rsp.ml.metrics as m\n\nY = torch.tensor([\n [0.1, 0.1, 0.8],\n [0.03, 0.95, 0.02],\n [0.05, 0.9, 0.05],\n [0.01, 0.87, 0.12],\n [0.04, 0.03, 0.93],\n [0.94, 0.02, 0.06]\n])\nT = torch.tensor([\n [0, 0, 1],\n [1, 0, 0],\n [0, 1, 0],\n [0, 1, 0],\n [0, 0, 1],\n [1, 0, 0]\n])\n\ntop_k_accuracy = m.top_k_accuracy(Y, T, k = 3)\n\nprint(top_k_accuracy) --> 1.0\n```\n\n# 3 model\n\n[TOC](#table-of-contents)\n\nThe module `rsp.ml.model` provides some usefull functionality to store and load pytorch models.\n\n## 3.1 MODELS : enum.Enum\n\n[TOC](#table-of-contents)\n\n**Description**\n\nCreate a collection of name/value pairs.\n\nExample enumeration:\n\n>>> class Color(Enum):\n... RED = 1\n... BLUE = 2\n... GREEN = 3\n\nAccess them by:\n\n- attribute access::\n\n>>> Color.RED\n<Color.RED: 1>\n\n- value lookup:\n\n>>> Color(1)\n<Color.RED: 1>\n\n- name lookup:\n\n>>> Color['RED']\n<Color.RED: 1>\n\nEnumerations can be iterated over, and know how many members they have:\n\n>>> len(Color)\n3\n\n>>> list(Color)\n[<Color.RED: 1>, <Color.BLUE: 2>, <Color.GREEN: 3>]\n\nMethods can be added to enumerations, and members can have their own\nattributes -- see the documentation for details.\n\n\n## 3.2 WEIGHTS : enum.Enum\n\n[TOC](#table-of-contents)\n\n**Description**\n\nCreate a collection of name/value pairs.\n\nExample enumeration:\n\n>>> class Color(Enum):\n... RED = 1\n... BLUE = 2\n... GREEN = 3\n\nAccess them by:\n\n- attribute access::\n\n>>> Color.RED\n<Color.RED: 1>\n\n- value lookup:\n\n>>> Color(1)\n<Color.RED: 1>\n\n- name lookup:\n\n>>> Color['RED']\n<Color.RED: 1>\n\nEnumerations can be iterated over, and know how many members they have:\n\n>>> len(Color)\n3\n\n>>> list(Color)\n[<Color.RED: 1>, <Color.BLUE: 2>, <Color.GREEN: 3>]\n\nMethods can be added to enumerations, and members can have their own\nattributes -- see the documentation for details.\n\n\n## 3.3 list\\_model\\_weights\n\n[TOC](#table-of-contents)\n\n**Description**\n\nLists all available weight files.\n\n\n**Returns**\n\nList of (MODEL:str, WEIGHT:str) : List[Tuple(str, str)]\n\n**Example**\n\n```python\nimport rsp.ml.model as model\n\nmodel_weight_files = model.list_model_weights()\n```\n\n## 3.4 load\\_model\n\n[TOC](#table-of-contents)\n\n**Description**\n\nLoads a pretrained PyTorch model from HuggingFace.\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| model | MODELS | ID of the model |\n| weights | WEIGHTS | ID of the weights |\n\n**Returns**\n\nPretrained PyTorch model : torch.nn.Module\n\n**Example**\n\n```python\nimport rsp.ml.model as model\n\naction_recognition_model = model.load_model(MODEL.TUCARC3D, WEIGHTS.TUCAR)\n```\n\n## 3.5 publish\\_model\n\n[TOC](#table-of-contents)\n\n# 4 module\n\n[TOC](#table-of-contents)\n\n\n\n## 4.1 MultiHeadSelfAttention : torch.nn.modules.module.Module\n\n[TOC](#table-of-contents)\n\n**Description**\n\nBase class for all neural network modules.\n\nYour models should also subclass this class.\n\nModules can also contain other Modules, allowing them to be nested in\na tree structure. You can assign the submodules as regular attributes::\n\n import torch.nn as nn\n import torch.nn.functional as F\n\n class Model(nn.Module):\n def __init__(self) -> None:\n super().__init__()\n self.conv1 = nn.Conv2d(1, 20, 5)\n self.conv2 = nn.Conv2d(20, 20, 5)\n\n def forward(self, x):\n x = F.relu(self.conv1(x))\n return F.relu(self.conv2(x))\n\nSubmodules assigned in this way will be registered, and will also have their\nparameters converted when you call :meth:`to`, etc.\n\n.. note::\n As per the example above, an ``__init__()`` call to the parent class\n must be made before assignment on the child.\n\n:ivar training: Boolean represents whether this module is in training or\n evaluation mode.\n:vartype training: bool\n\n\n##### 4.1.1 \\_wrapped\\_call\\_impl\n\n[TOC](#table-of-contents)\n\n### 4.1.2 \\_\\_init\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nInitialize internal Module state, shared by both nn.Module and ScriptModule.\n\n##### 4.1.3 \\_apply\n\n[TOC](#table-of-contents)\n\n##### 4.1.4 \\_call\\_impl\n\n[TOC](#table-of-contents)\n\n##### 4.1.5 \\_get\\_backward\\_hooks\n\n[TOC](#table-of-contents)\n\n**Description**\n\nReturn the backward hooks for use in the call function.\n\nIt returns two lists, one with the full backward hooks and one with the non-full\n\nbackward hooks.\n\n##### 4.1.6 \\_get\\_backward\\_pre\\_hooks\n\n[TOC](#table-of-contents)\n\n##### 4.1.7 \\_get\\_name\n\n[TOC](#table-of-contents)\n\n##### 4.1.8 \\_load\\_from\\_state\\_dict\n\n[TOC](#table-of-contents)\n\n**Description**\n\nCopy parameters and buffers from :attr:`state_dict` into only this module, but not its descendants.\n\nThis is called on every submodule\n\nin :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this\n\nmodule in input :attr:`state_dict` is provided as :attr:`local_metadata`.\n\nFor state dicts without metadata, :attr:`local_metadata` is empty.\n\nSubclasses can achieve class-specific backward compatible loading using\n\nthe version number at `local_metadata.get(\"version\", None)`.\n\nAdditionally, :attr:`local_metadata` can also contain the key\n\n`assign_to_params_buffers` that indicates whether keys should be\n\nassigned their corresponding tensor in the state_dict.\n\n.. note::\n\n :attr:`state_dict` is not the same object as the input\n\n :attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So\n\n it can be modified.\n\nArgs:\n\n state_dict (dict): a dict containing parameters and\n\n persistent buffers.\n\n prefix (str): the prefix for parameters and buffers used in this\n\n module\n\n local_metadata (dict): a dict containing the metadata for this module.\n\n See\n\n strict (bool): whether to strictly enforce that the keys in\n\n :attr:`state_dict` with :attr:`prefix` match the names of\n\n parameters and buffers in this module\n\n missing_keys (list of str): if ``strict=True``, add missing keys to\n\n this list\n\n unexpected_keys (list of str): if ``strict=True``, add unexpected\n\n keys to this list\n\n error_msgs (list of str): error messages should be added to this\n\n list, and will be reported together in\n\n :meth:`~torch.nn.Module.load_state_dict`\n\n##### 4.1.9 \\_maybe\\_warn\\_non\\_full\\_backward\\_hook\n\n[TOC](#table-of-contents)\n\n##### 4.1.10 \\_named\\_members\n\n[TOC](#table-of-contents)\n\n**Description**\n\nHelp yield various names + members of modules.\n\n##### 4.1.11 \\_register\\_load\\_state\\_dict\\_pre\\_hook\n\n[TOC](#table-of-contents)\n\n**Description**\n\nSee :meth:`~torch.nn.Module.register_load_state_dict_pre_hook` for details.\n\nA subtle difference is that if ``with_module`` is set to ``False``, then the\n\nhook will not take the ``module`` as the first argument whereas\n\n:meth:`~torch.nn.Module.register_load_state_dict_pre_hook` always takes the\n\n``module`` as the first argument.\n\nArguments:\n\n hook (Callable): Callable hook that will be invoked before\n\n loading the state dict.\n\n with_module (bool, optional): Whether or not to pass the module\n\n instance to the hook as the first parameter.\n\n##### 4.1.12 \\_register\\_state\\_dict\\_hook\n\n[TOC](#table-of-contents)\n\n**Description**\n\nRegister a post-hook for the :meth:`~torch.nn.Module.state_dict` method.\n\nIt should have the following signature::\n\n hook(module, state_dict, prefix, local_metadata) -> None or state_dict\n\nThe registered hooks can modify the ``state_dict`` inplace or return a new one.\n\nIf a new ``state_dict`` is returned, it will only be respected if it is the root\n\nmodule that :meth:`~nn.Module.state_dict` is called from.\n\n##### 4.1.13 \\_replicate\\_for\\_data\\_parallel\n\n[TOC](#table-of-contents)\n\n##### 4.1.14 \\_save\\_to\\_state\\_dict\n\n[TOC](#table-of-contents)\n\n**Description**\n\nSave module state to the `destination` dictionary.\n\nThe `destination` dictionary will contain the state\n\nof the module, but not its descendants. This is called on every\n\nsubmodule in :meth:`~torch.nn.Module.state_dict`.\n\nIn rare cases, subclasses can achieve class-specific behavior by\n\noverriding this method with custom logic.\n\nArgs:\n\n destination (dict): a dict where state will be stored\n\n prefix (str): the prefix for parameters and buffers used in this\n\n module\n\n##### 4.1.15 \\_slow\\_forward\n\n[TOC](#table-of-contents)\n\n##### 4.1.16 \\_wrapped\\_call\\_impl\n\n[TOC](#table-of-contents)\n\n##### 4.1.17 add\\_module\n\n[TOC](#table-of-contents)\n\n**Description**\n\nAdd a child module to the current module.\n\nThe module can be accessed as an attribute using the given name.\n\nArgs:\n\n name (str): name of the child module. The child module can be\n\n accessed from this module using the given name\n\n module (Module): child module to be added to the module.\n\n##### 4.1.18 apply\n\n[TOC](#table-of-contents)\n\n**Description**\n\nApply ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self.\n\nTypical use includes initializing the parameters of a model\n\n(see also :ref:`nn-init-doc`).\n\nArgs:\n\n fn (:class:`Module` -> None): function to be applied to each submodule\n\nReturns:\n\n Module: self\n\nExample::\n\n >>> @torch.no_grad()\n\n >>> def init_weights(m):\n\n >>> print(m)\n\n >>> if type(m) == nn.Linear:\n\n >>> m.weight.fill_(1.0)\n\n >>> print(m.weight)\n\n >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))\n\n >>> net.apply(init_weights)\n\n Linear(in_features=2, out_features=2, bias=True)\n\n Parameter containing:\n\n tensor([[1., 1.],\n\n [1., 1.]], requires_grad=True)\n\n Linear(in_features=2, out_features=2, bias=True)\n\n Parameter containing:\n\n tensor([[1., 1.],\n\n [1., 1.]], requires_grad=True)\n\n Sequential(\n\n (0): Linear(in_features=2, out_features=2, bias=True)\n\n (1): Linear(in_features=2, out_features=2, bias=True)\n\n )\n\n##### 4.1.19 bfloat16\n\n[TOC](#table-of-contents)\n\n**Description**\n\nCasts all floating point parameters and buffers to ``bfloat16`` datatype.\n\n.. note::\n\n This method modifies the module in-place.\n\nReturns:\n\n Module: self\n\n##### 4.1.20 buffers\n\n[TOC](#table-of-contents)\n\n**Description**\n\nReturn an iterator over module buffers.\n\nArgs:\n\n recurse (bool): if True, then yields buffers of this module\n\n and all submodules. Otherwise, yields only buffers that\n\n are direct members of this module.\n\nYields:\n\n torch.Tensor: module buffer\n\nExample::\n\n >>> # xdoctest: +SKIP(\"undefined vars\")\n\n >>> for buf in model.buffers():\n\n >>> print(type(buf), buf.size())\n\n <class 'torch.Tensor'> (20L,)\n\n <class 'torch.Tensor'> (20L, 1L, 5L, 5L)\n\n##### 4.1.21 children\n\n[TOC](#table-of-contents)\n\n**Description**\n\nReturn an iterator over immediate children modules.\n\nYields:\n\n Module: a child module\n\n##### 4.1.22 compile\n\n[TOC](#table-of-contents)\n\n**Description**\n\nCompile this Module's forward using :func:`torch.compile`.\n\nThis Module's `__call__` method is compiled and all arguments are passed as-is\n\nto :func:`torch.compile`.\n\nSee :func:`torch.compile` for details on the arguments for this function.\n\n##### 4.1.23 cpu\n\n[TOC](#table-of-contents)\n\n**Description**\n\nMove all model parameters and buffers to the CPU.\n\n.. note::\n\n This method modifies the module in-place.\n\nReturns:\n\n Module: self\n\n##### 4.1.24 cuda\n\n[TOC](#table-of-contents)\n\n**Description**\n\nMove all model parameters and buffers to the GPU.\n\nThis also makes associated parameters and buffers different objects. So\n\nit should be called before constructing the optimizer if the module will\n\nlive on GPU while being optimized.\n\n.. note::\n\n This method modifies the module in-place.\n\nArgs:\n\n device (int, optional): if specified, all parameters will be\n\n copied to that device\n\nReturns:\n\n Module: self\n\n##### 4.1.25 double\n\n[TOC](#table-of-contents)\n\n**Description**\n\nCasts all floating point parameters and buffers to ``double`` datatype.\n\n.. note::\n\n This method modifies the module in-place.\n\nReturns:\n\n Module: self\n\n##### 4.1.26 eval\n\n[TOC](#table-of-contents)\n\n**Description**\n\nSet the module in evaluation mode.\n\nThis has an effect only on certain modules. See the documentation of\n\nparticular modules for details of their behaviors in training/evaluation\n\nmode, i.e. whether they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,\n\netc.\n\nThis is equivalent with :meth:`self.train(False) <torch.nn.Module.train>`.\n\nSee :ref:`locally-disable-grad-doc` for a comparison between\n\n`.eval()` and several similar mechanisms that may be confused with it.\n\nReturns:\n\n Module: self\n\n##### 4.1.27 extra\\_repr\n\n[TOC](#table-of-contents)\n\n**Description**\n\nReturn the extra representation of the module.\n\nTo print customized extra information, you should re-implement\n\nthis method in your own modules. Both single-line and multi-line\n\nstrings are acceptable.\n\n##### 4.1.28 float\n\n[TOC](#table-of-contents)\n\n**Description**\n\nCasts all floating point parameters and buffers to ``float`` datatype.\n\n.. note::\n\n This method modifies the module in-place.\n\nReturns:\n\n Module: self\n\n### 4.1.29 forward\n\n[TOC](#table-of-contents)\n\n**Description**\n\nDefine the computation performed at every call.\n\nShould be overridden by all subclasses.\n\n.. note::\n\n Although the recipe for forward pass needs to be defined within\n\n this function, one should call the :class:`Module` instance afterwards\n\n instead of this since the former takes care of running the\n\n registered hooks while the latter silently ignores them.\n\n##### 4.1.30 get\\_buffer\n\n[TOC](#table-of-contents)\n\n**Description**\n\nReturn the buffer given by ``target`` if it exists, otherwise throw an error.\n\nSee the docstring for ``get_submodule`` for a more detailed\n\nexplanation of this method's functionality as well as how to\n\ncorrectly specify ``target``.\n\nArgs:\n\n target: The fully-qualified string name of the buffer\n\n to look for. (See ``get_submodule`` for how to specify a\n\n fully-qualified string.)\n\nReturns:\n\n torch.Tensor: The buffer referenced by ``target``\n\nRaises:\n\n AttributeError: If the target string references an invalid\n\n path or resolves to something that is not a\n\n buffer\n\n##### 4.1.31 get\\_extra\\_state\n\n[TOC](#table-of-contents)\n\n**Description**\n\nReturn any extra state to include in the module's state_dict.\n\nImplement this and a corresponding :func:`set_extra_state` for your module\n\nif you need to store extra state. This function is called when building the\n\nmodule's `state_dict()`.\n\nNote that extra state should be picklable to ensure working serialization\n\nof the state_dict. We only provide backwards compatibility guarantees\n\nfor serializing Tensors; other objects may break backwards compatibility if\n\ntheir serialized pickled form changes.\n\nReturns:\n\n object: Any extra state to store in the module's state_dict\n\n##### 4.1.32 get\\_parameter\n\n[TOC](#table-of-contents)\n\n**Description**\n\nReturn the parameter given by ``target`` if it exists, otherwise throw an error.\n\nSee the docstring for ``get_submodule`` for a more detailed\n\nexplanation of this method's functionality as well as how to\n\ncorrectly specify ``target``.\n\nArgs:\n\n target: The fully-qualified string name of the Parameter\n\n to look for. (See ``get_submodule`` for how to specify a\n\n fully-qualified string.)\n\nReturns:\n\n torch.nn.Parameter: The Parameter referenced by ``target``\n\nRaises:\n\n AttributeError: If the target string references an invalid\n\n path or resolves to something that is not an\n\n ``nn.Parameter``\n\n##### 4.1.33 get\\_submodule\n\n[TOC](#table-of-contents)\n\n**Description**\n\nReturn the submodule given by ``target`` if it exists, otherwise throw an error.\n\nFor example, let's say you have an ``nn.Module`` ``A`` that\n\nlooks like this:\n\n.. code-block:: text\n\n A(\n\n (net_b): Module(\n\n (net_c): Module(\n\n (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2))\n\n )\n\n (linear): Linear(in_features=100, out_features=200, bias=True)\n\n )\n\n )\n\n(The diagram shows an ``nn.Module`` ``A``. ``A`` which has a nested\n\nsubmodule ``net_b``, which itself has two submodules ``net_c``\n\nand ``linear``. ``net_c`` then has a submodule ``conv``.)\n\nTo check whether or not we have the ``linear`` submodule, we\n\nwould call ``get_submodule(\"net_b.linear\")``. To check whether\n\nwe have the ``conv`` submodule, we would call\n\n``get_submodule(\"net_b.net_c.conv\")``.\n\nThe runtime of ``get_submodule`` is bounded by the degree\n\nof module nesting in ``target``. A query against\n\n``named_modules`` achieves the same result, but it is O(N) in\n\nthe number of transitive modules. So, for a simple check to see\n\nif some submodule exists, ``get_submodule`` should always be\n\nused.\n\nArgs:\n\n target: The fully-qualified string name of the submodule\n\n to look for. (See above example for how to specify a\n\n fully-qualified string.)\n\nReturns:\n\n torch.nn.Module: The submodule referenced by ``target``\n\nRaises:\n\n AttributeError: If the target string references an invalid\n\n path or resolves to something that is not an\n\n ``nn.Module``\n\n##### 4.1.34 half\n\n[TOC](#table-of-contents)\n\n**Description**\n\nCasts all floating point parameters and buffers to ``half`` datatype.\n\n.. note::\n\n This method modifies the module in-place.\n\nReturns:\n\n Module: self\n\n##### 4.1.35 ipu\n\n[TOC](#table-of-contents)\n\n**Description**\n\nMove all model parameters and buffers to the IPU.\n\nThis also makes associated parameters and buffers different objects. So\n\nit should be called before constructing the optimizer if the module will\n\nlive on IPU while being optimized.\n\n.. note::\n\n This method modifies the module in-place.\n\nArguments:\n\n device (int, optional): if specified, all parameters will be\n\n copied to that device\n\nReturns:\n\n Module: self\n\n##### 4.1.36 load\\_state\\_dict\n\n[TOC](#table-of-contents)\n\n**Description**\n\nCopy parameters and buffers from :attr:`state_dict` into this module and its descendants.\n\nIf :attr:`strict` is ``True``, then\n\nthe keys of :attr:`state_dict` must exactly match the keys returned\n\nby this module's :meth:`~torch.nn.Module.state_dict` function.\n\n.. warning::\n\n If :attr:`assign` is ``True`` the optimizer must be created after\n\n the call to :attr:`load_state_dict` unless\n\n :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``.\n\nArgs:\n\n state_dict (dict): a dict containing parameters and\n\n persistent buffers.\n\n strict (bool, optional): whether to strictly enforce that the keys\n\n in :attr:`state_dict` match the keys returned by this module's\n\n :meth:`~torch.nn.Module.state_dict` function. Default: ``True``\n\n assign (bool, optional): When set to ``False``, the properties of the tensors\n\n in the current module are preserved whereas setting it to ``True`` preserves\n\n properties of the Tensors in the state dict. The only\n\n exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s\n\n for which the value from the module is preserved.\n\n Default: ``False``\n\nReturns:\n\n ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:\n\n * **missing_keys** is a list of str containing any keys that are expected\n\n by this module but missing from the provided ``state_dict``.\n\n * **unexpected_keys** is a list of str containing the keys that are not\n\n expected by this module but present in the provided ``state_dict``.\n\nNote:\n\n If a parameter or buffer is registered as ``None`` and its corresponding key\n\n exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a\n\n ``RuntimeError``.\n\n##### 4.1.37 modules\n\n[TOC](#table-of-contents)\n\n**Description**\n\nReturn an iterator over all modules in the network.\n\nYields:\n\n Module: a module in the network\n\nNote:\n\n Duplicate modules are returned only once. In the following\n\n example, ``l`` will be returned only once.\n\nExample::\n\n >>> l = nn.Linear(2, 2)\n\n >>> net = nn.Sequential(l, l)\n\n >>> for idx, m in enumerate(net.modules()):\n\n ... print(idx, '->', m)\n\n 0 -> Sequential(\n\n (0): Linear(in_features=2, out_features=2, bias=True)\n\n (1): Linear(in_features=2, out_features=2, bias=True)\n\n )\n\n 1 -> Linear(in_features=2, out_features=2, bias=True)\n\n##### 4.1.38 mtia\n\n[TOC](#table-of-contents)\n\n**Description**\n\nMove all model parameters and buffers to the MTIA.\n\nThis also makes associated parameters and buffers different objects. So\n\nit should be called before constructing the optimizer if the module will\n\nlive on MTIA while being optimized.\n\n.. note::\n\n This method modifies the module in-place.\n\nArguments:\n\n device (int, optional): if specified, all parameters will be\n\n copied to that device\n\nReturns:\n\n Module: self\n\n##### 4.1.39 named\\_buffers\n\n[TOC](#table-of-contents)\n\n**Description**\n\nReturn an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.\n\nArgs:\n\n prefix (str): prefix to prepend to all buffer names.\n\n recurse (bool, optional): if True, then yields buffers of this module\n\n and all submodules. Otherwise, yields only buffers that\n\n are direct members of this module. Defaults to True.\n\n remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Defaults to True.\n\nYields:\n\n (str, torch.Tensor): Tuple containing the name and buffer\n\nExample::\n\n >>> # xdoctest: +SKIP(\"undefined vars\")\n\n >>> for name, buf in self.named_buffers():\n\n >>> if name in ['running_var']:\n\n >>> print(buf.size())\n\n##### 4.1.40 named\\_children\n\n[TOC](#table-of-contents)\n\n**Description**\n\nReturn an iterator over immediate children modules, yielding both the name of the module as well as the module itself.\n\nYields:\n\n (str, Module): Tuple containing a name and child module\n\nExample::\n\n >>> # xdoctest: +SKIP(\"undefined vars\")\n\n >>> for name, module in model.named_children():\n\n >>> if name in ['conv4', 'conv5']:\n\n >>> print(module)\n\n##### 4.1.41 named\\_modules\n\n[TOC](#table-of-contents)\n\n**Description**\n\nReturn an iterator over all modules in the network, yielding both the name of the module as well as the module itself.\n\nArgs:\n\n memo: a memo to store the set of modules already added to the result\n\n prefix: a prefix that will be added to the name of the module\n\n remove_duplicate: whether to remove the duplicated module instances in the result\n\n or not\n\nYields:\n\n (str, Module): Tuple of name and module\n\nNote:\n\n Duplicate modules are returned only once. In the following\n\n example, ``l`` will be returned only once.\n\nExample::\n\n >>> l = nn.Linear(2, 2)\n\n >>> net = nn.Sequential(l, l)\n\n >>> for idx, m in enumerate(net.named_modules()):\n\n ... print(idx, '->', m)\n\n 0 -> ('', Sequential(\n\n (0): Linear(in_features=2, out_features=2, bias=True)\n\n (1): Linear(in_features=2, out_features=2, bias=True)\n\n ))\n\n 1 -> ('0', Linear(in_features=2, out_features=2, bias=True))\n\n##### 4.1.42 named\\_parameters\n\n[TOC](#table-of-contents)\n\n**Description**\n\nReturn an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.\n\nArgs:\n\n prefix (str): prefix to prepend to all parameter names.\n\n recurse (bool): if True, then yields parameters of this module\n\n and all submodules. Otherwise, yields only parameters that\n\n are direct members of this module.\n\n remove_duplicate (bool, optional): whether to remove the duplicated\n\n parameters in the result. Defaults to True.\n\nYields:\n\n (str, Parameter): Tuple containing the name and parameter\n\nExample::\n\n >>> # xdoctest: +SKIP(\"undefined vars\")\n\n >>> for name, param in self.named_parameters():\n\n >>> if name in ['bias']:\n\n >>> print(param.size())\n\n##### 4.1.43 parameters\n\n[TOC](#table-of-contents)\n\n**Description**\n\nReturn an iterator over module parameters.\n\nThis is typically passed to an optimizer.\n\nArgs:\n\n recurse (bool): if True, then yields parameters of this module\n\n and all submodules. Otherwise, yields only parameters that\n\n are direct members of this module.\n\nYields:\n\n Parameter: module parameter\n\nExample::\n\n >>> # xdoctest: +SKIP(\"undefined vars\")\n\n >>> for param in model.parameters():\n\n >>> print(type(param), param.size())\n\n <class 'torch.Tensor'> (20L,)\n\n <class 'torch.Tensor'> (20L, 1L, 5L, 5L)\n\n##### 4.1.44 register\\_backward\\_hook\n\n[TOC](#table-of-contents)\n\n**Description**\n\nRegister a backward hook on the module.\n\nThis function is deprecated in favor of :meth:`~torch.nn.Module.register_full_backward_hook` and\n\nthe behavior of this function will change in future versions.\n\nReturns:\n\n :class:`torch.utils.hooks.RemovableHandle`:\n\n a handle that can be used to remove the added hook by calling\n\n ``handle.remove()``\n\n##### 4.1.45 register\\_buffer\n\n[TOC](#table-of-contents)\n\n**Description**\n\nAdd a buffer to the module.\n\nThis is typically used to register a buffer that should not to be\n\nconsidered a model parameter. For example, BatchNorm's ``running_mean``\n\nis not a parameter, but is part of the module's state. Buffers, by\n\ndefault, are persistent and will be saved alongside parameters. This\n\nbehavior can be changed by setting :attr:`persistent` to ``False``. The\n\nonly difference between a persistent buffer and a non-persistent buffer\n\nis that the latter will not be a part of this module's\n\n:attr:`state_dict`.\n\nBuffers can be accessed as attributes using given names.\n\nArgs:\n\n name (str): name of the buffer. The buffer can be accessed\n\n from this module using the given name\n\n tensor (Tensor or None): buffer to be registered. If ``None``, then operations\n\n that run on buffers, such as :attr:`cuda`, are ignored. If ``None``,\n\n the buffer is **not** included in the module's :attr:`state_dict`.\n\n persistent (bool): whether the buffer is part of this module's\n\n :attr:`state_dict`.\n\nExample::\n\n >>> # xdoctest: +SKIP(\"undefined vars\")\n\n >>> self.register_buffer('running_mean', torch.zeros(num_features))\n\n##### 4.1.46 register\\_forward\\_hook\n\n[TOC](#table-of-contents)\n\n**Description**\n\nRegister a forward hook on the module.\n\nThe hook will be called every time after :func:`forward` has computed an output.\n\nIf ``with_kwargs`` is ``False`` or not specified, the input contains only\n\nthe positional arguments given to the module. Keyword arguments won't be\n\npassed to the hooks and only to the ``forward``. The hook can modify the\n\noutput. It can modify the input inplace but it will not have effect on\n\nforward since this is called after :func:`forward` is called. The hook\n\nshould have the following signature::\n\n hook(module, args, output) -> None or modified output\n\nIf ``with_kwargs`` is ``True``, the forward hook will be passed the\n\n``kwargs`` given to the forward function and be expected to return the\n\noutput possibly modified. The hook should have the following signature::\n\n hook(module, args, kwargs, output) -> None or modified output\n\nArgs:\n\n hook (Callable): The user defined hook to be registered.\n\n prepend (bool): If ``True``, the provided ``hook`` will be fired\n\n before all existing ``forward`` hooks on this\n\n :class:`torch.nn.modules.Module`. Otherwise, the provided\n\n ``hook`` will be fired after all existing ``forward`` hooks on\n\n this :class:`torch.nn.modules.Module`. Note that global\n\n ``forward`` hooks registered with\n\n :func:`register_module_forward_hook` will fire before all hooks\n\n registered by this method.\n\n Default: ``False``\n\n with_kwargs (bool): If ``True``, the ``hook`` will be passed the\n\n kwargs given to the forward function.\n\n Default: ``False``\n\n always_call (bool): If ``True`` the ``hook`` will be run regardless of\n\n whether an exception is raised while calling the Module.\n\n Default: ``False``\n\nReturns:\n\n :class:`torch.utils.hooks.RemovableHandle`:\n\n a handle that can be used to remove the added hook by calling\n\n ``handle.remove()``\n\n##### 4.1.47 register\\_forward\\_pre\\_hook\n\n[TOC](#table-of-contents)\n\n**Description**\n\nRegister a forward pre-hook on the module.\n\nThe hook will be called every time before :func:`forward` is invoked.\n\nIf ``with_kwargs`` is false or not specified, the input contains only\n\nthe positional arguments given to the module. Keyword arguments won't be\n\npassed to the hooks and only to the ``forward``. The hook can modify the\n\ninput. User can either return a tuple or a single modified value in the\n\nhook. We will wrap the value into a tuple if a single value is returned\n\n(unless that value is already a tuple). The hook should have the\n\nfollowing signature::\n\n hook(module, args) -> None or modified input\n\nIf ``with_kwargs`` is true, the forward pre-hook will be passed the\n\nkwargs given to the forward function. And if the hook modifies the\n\ninput, both the args and kwargs should be returned. The hook should have\n\nthe following signature::\n\n hook(module, args, kwargs) -> None or a tuple of modified input and kwargs\n\nArgs:\n\n hook (Callable): The user defined hook to be registered.\n\n prepend (bool): If true, the provided ``hook`` will be fired before\n\n all existing ``forward_pre`` hooks on this\n\n :class:`torch.nn.modules.Module`. Otherwise, the provided\n\n ``hook`` will be fired after all existing ``forward_pre`` hooks\n\n on this :class:`torch.nn.modules.Module`. Note that global\n\n ``forward_pre`` hooks registered with\n\n :func:`register_module_forward_pre_hook` will fire before all\n\n hooks registered by this method.\n\n Default: ``False``\n\n with_kwargs (bool): If true, the ``hook`` will be passed the kwargs\n\n given to the forward function.\n\n Default: ``False``\n\nReturns:\n\n :class:`torch.utils.hooks.RemovableHandle`:\n\n a handle that can be used to remove the added hook by calling\n\n ``handle.remove()``\n\n##### 4.1.48 register\\_full\\_backward\\_hook\n\n[TOC](#table-of-contents)\n\n**Description**\n\nRegister a backward hook on the module.\n\nThe hook will be called every time the gradients with respect to a module\n\nare computed, i.e. the hook will execute if and only if the gradients with\n\nrespect to module outputs are computed. The hook should have the following\n\nsignature::\n\n hook(module, grad_input, grad_output) -> tuple(Tensor) or None\n\nThe :attr:`grad_input` and :attr:`grad_output` are tuples that contain the gradients\n\nwith respect to the inputs and outputs respectively. The hook should\n\nnot modify its arguments, but it can optionally return a new gradient with\n\nrespect to the input that will be used in place of :attr:`grad_input` in\n\nsubsequent computations. :attr:`grad_input` will only correspond to the inputs given\n\nas positional arguments and all kwarg arguments are ignored. Entries\n\nin :attr:`grad_input` and :attr:`grad_output` will be ``None`` for all non-Tensor\n\narguments.\n\nFor technical reasons, when this hook is applied to a Module, its forward function will\n\nreceive a view of each Tensor passed to the Module. Similarly the caller will receive a view\n\nof each Tensor returned by the Module's forward function.\n\n.. warning ::\n\n Modifying inputs or outputs inplace is not allowed when using backward hooks and\n\n will raise an error.\n\nArgs:\n\n hook (Callable): The user-defined hook to be registered.\n\n prepend (bool): If true, the provided ``hook`` will be fired before\n\n all existing ``backward`` hooks on this\n\n :class:`torch.nn.modules.Module`. Otherwise, the provided\n\n ``hook`` will be fired after all existing ``backward`` hooks on\n\n this :class:`torch.nn.modules.Module`. Note that global\n\n ``backward`` hooks registered with\n\n :func:`register_module_full_backward_hook` will fire before\n\n all hooks registered by this method.\n\nReturns:\n\n :class:`torch.utils.hooks.RemovableHandle`:\n\n a handle that can be used to remove the added hook by calling\n\n ``handle.remove()``\n\n##### 4.1.49 register\\_full\\_backward\\_pre\\_hook\n\n[TOC](#table-of-contents)\n\n**Description**\n\nRegister a backward pre-hook on the module.\n\nThe hook will be called every time the gradients for the module are computed.\n\nThe hook should have the following signature::\n\n hook(module, grad_output) -> tuple[Tensor] or None\n\nThe :attr:`grad_output` is a tuple. The hook should\n\nnot modify its arguments, but it can optionally return a new gradient with\n\nrespect to the output that will be used in place of :attr:`grad_output` in\n\nsubsequent computations. Entries in :attr:`grad_output` will be ``None`` for\n\nall non-Tensor arguments.\n\nFor technical reasons, when this hook is applied to a Module, its forward function will\n\nreceive a view of each Tensor passed to the Module. Similarly the caller will receive a view\n\nof each Tensor returned by the Module's forward function.\n\n.. warning ::\n\n Modifying inputs inplace is not allowed when using backward hooks and\n\n will raise an error.\n\nArgs:\n\n hook (Callable): The user-defined hook to be registered.\n\n prepend (bool): If true, the provided ``hook`` will be fired before\n\n all existing ``backward_pre`` hooks on this\n\n :class:`torch.nn.modules.Module`. Otherwise, the provided\n\n ``hook`` will be fired after all existing ``backward_pre`` hooks\n\n on this :class:`torch.nn.modules.Module`. Note that global\n\n ``backward_pre`` hooks registered with\n\n :func:`register_module_full_backward_pre_hook` will fire before\n\n all hooks registered by this method.\n\nReturns:\n\n :class:`torch.utils.hooks.RemovableHandle`:\n\n a handle that can be used to remove the added hook by calling\n\n ``handle.remove()``\n\n##### 4.1.50 register\\_load\\_state\\_dict\\_post\\_hook\n\n[TOC](#table-of-contents)\n\n**Description**\n\nRegister a post-hook to be run after module's :meth:`~nn.Module.load_state_dict` is called.\n\nIt should have the following signature::\n\n hook(module, incompatible_keys) -> None\n\nThe ``module`` argument is the current module that this hook is registered\n\non, and the ``incompatible_keys`` argument is a ``NamedTuple`` consisting\n\nof attributes ``missing_keys`` and ``unexpected_keys``. ``missing_keys``\n\nis a ``list`` of ``str`` containing the missing keys and\n\n``unexpected_keys`` is a ``list`` of ``str`` containing the unexpected keys.\n\nThe given incompatible_keys can be modified inplace if needed.\n\nNote that the checks performed when calling :func:`load_state_dict` with\n\n``strict=True`` are affected by modifications the hook makes to\n\n``missing_keys`` or ``unexpected_keys``, as expected. Additions to either\n\nset of keys will result in an error being thrown when ``strict=True``, and\n\nclearing out both missing and unexpected keys will avoid an error.\n\nReturns:\n\n :class:`torch.utils.hooks.RemovableHandle`:\n\n a handle that can be used to remove the added hook by calling\n\n ``handle.remove()``\n\n##### 4.1.51 register\\_load\\_state\\_dict\\_pre\\_hook\n\n[TOC](#table-of-contents)\n\n**Description**\n\nRegister a pre-hook to be run before module's :meth:`~nn.Module.load_state_dict` is called.\n\nIt should have the following signature::\n\n hook(module, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -> None # noqa: B950\n\nArguments:\n\n hook (Callable): Callable hook that will be invoked before\n\n loading the state dict.\n\n##### 4.1.52 register\\_module\n\n[TOC](#table-of-contents)\n\n**Description**\n\nAlias for :func:`add_module`.\n\n##### 4.1.53 register\\_parameter\n\n[TOC](#table-of-contents)\n\n**Description**\n\nAdd a parameter to the module.\n\nThe parameter can be accessed as an attribute using given name.\n\nArgs:\n\n name (str): name of the parameter. The parameter can be accessed\n\n from this module using the given name\n\n param (Parameter or None): parameter to be added to the module. If\n\n ``None``, then operations that run on parameters, such as :attr:`cuda`,\n\n are ignored. If ``None``, the parameter is **not** included in the\n\n module's :attr:`state_dict`.\n\n##### 4.1.54 register\\_state\\_dict\\_post\\_hook\n\n[TOC](#table-of-contents)\n\n**Description**\n\nRegister a post-hook for the :meth:`~torch.nn.Module.state_dict` method.\n\nIt should have the following signature::\n\n hook(module, state_dict, prefix, local_metadata) -> None\n\nThe registered hooks can modify the ``state_dict`` inplace.\n\n##### 4.1.55 register\\_state\\_dict\\_pre\\_hook\n\n[TOC](#table-of-contents)\n\n**Description**\n\nRegister a pre-hook for the :meth:`~torch.nn.Module.state_dict` method.\n\nIt should have the following signature::\n\n hook(module, prefix, keep_vars) -> None\n\nThe registered hooks can be used to perform pre-processing before the ``state_dict``\n\ncall is made.\n\n##### 4.1.56 requires\\_grad\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nChange if autograd should record operations on parameters in this module.\n\nThis method sets the parameters' :attr:`requires_grad` attributes\n\nin-place.\n\nThis method is helpful for freezing part of the module for finetuning\n\nor training parts of a model individually (e.g., GAN training).\n\nSee :ref:`locally-disable-grad-doc` for a comparison between\n\n`.requires_grad_()` and several similar mechanisms that may be confused with it.\n\nArgs:\n\n requires_grad (bool): whether autograd should record operations on\n\n parameters in this module. Default: ``True``.\n\nReturns:\n\n Module: self\n\n##### 4.1.57 set\\_extra\\_state\n\n[TOC](#table-of-contents)\n\n**Description**\n\nSet extra state contained in the loaded `state_dict`.\n\nThis function is called from :func:`load_state_dict` to handle any extra state\n\nfound within the `state_dict`. Implement this function and a corresponding\n\n:func:`get_extra_state` for your module if you need to store extra state within its\n\n`state_dict`.\n\nArgs:\n\n state (dict): Extra state from the `state_dict`\n\n##### 4.1.58 set\\_submodule\n\n[TOC](#table-of-contents)\n\n**Description**\n\nSet the submodule given by ``target`` if it exists, otherwise throw an error.\n\nFor example, let's say you have an ``nn.Module`` ``A`` that\n\nlooks like this:\n\n.. code-block:: text\n\n A(\n\n (net_b): Module(\n\n (net_c): Module(\n\n (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2))\n\n )\n\n (linear): Linear(in_features=100, out_features=200, bias=True)\n\n )\n\n )\n\n(The diagram shows an ``nn.Module`` ``A``. ``A`` has a nested\n\nsubmodule ``net_b``, which itself has two submodules ``net_c``\n\nand ``linear``. ``net_c`` then has a submodule ``conv``.)\n\nTo overide the ``Conv2d`` with a new submodule ``Linear``, you\n\nwould call\n\n``set_submodule(\"net_b.net_c.conv\", nn.Linear(33, 16))``.\n\nArgs:\n\n target: The fully-qualified string name of the submodule\n\n to look for. (See above example for how to specify a\n\n fully-qualified string.)\n\n module: The module to set the submodule to.\n\nRaises:\n\n ValueError: If the target string is empty\n\n AttributeError: If the target string references an invalid\n\n path or resolves to something that is not an\n\n ``nn.Module``\n\n##### 4.1.59 share\\_memory\n\n[TOC](#table-of-contents)\n\n**Description**\n\nSee :meth:`torch.Tensor.share_memory_`.\n\n##### 4.1.60 state\\_dict\n\n[TOC](#table-of-contents)\n\n**Description**\n\nReturn a dictionary containing references to the whole state of the module.\n\nBoth parameters and persistent buffers (e.g. running averages) are\n\nincluded. Keys are corresponding parameter and buffer names.\n\nParameters and buffers set to ``None`` are not included.\n\n.. note::\n\n The returned object is a shallow copy. It contains references\n\n to the module's parameters and buffers.\n\n.. warning::\n\n Currently ``state_dict()`` also accepts positional arguments for\n\n ``destination``, ``prefix`` and ``keep_vars`` in order. However,\n\n this is being deprecated and keyword arguments will be enforced in\n\n future releases.\n\n.. warning::\n\n Please avoid the use of argument ``destination`` as it is not\n\n designed for end-users.\n\nArgs:\n\n destination (dict, optional): If provided, the state of module will\n\n be updated into the dict and the same object is returned.\n\n Otherwise, an ``OrderedDict`` will be created and returned.\n\n Default: ``None``.\n\n prefix (str, optional): a prefix added to parameter and buffer\n\n names to compose the keys in state_dict. Default: ``''``.\n\n keep_vars (bool, optional): by default the :class:`~torch.Tensor` s\n\n returned in the state dict are detached from autograd. If it's\n\n set to ``True``, detaching will not be performed.\n\n Default: ``False``.\n\nReturns:\n\n dict:\n\n a dictionary containing a whole state of the module\n\nExample::\n\n >>> # xdoctest: +SKIP(\"undefined vars\")\n\n >>> module.state_dict().keys()\n\n ['bias', 'weight']\n\n##### 4.1.61 to\n\n[TOC](#table-of-contents)\n\n**Description**\n\nMove and/or cast the parameters and buffers.\n\nThis can be called as\n\n.. function:: to(device=None, dtype=None, non_blocking=False)\n\n :noindex:\n\n.. function:: to(dtype, non_blocking=False)\n\n :noindex:\n\n.. function:: to(tensor, non_blocking=False)\n\n :noindex:\n\n.. function:: to(memory_format=torch.channels_last)\n\n :noindex:\n\nIts signature is similar to :meth:`torch.Tensor.to`, but only accepts\n\nfloating point or complex :attr:`dtype`\\ s. In addition, this method will\n\nonly cast the floating point or complex parameters and buffers to :attr:`dtype`\n\n(if given). The integral parameters and buffers will be moved\n\n:attr:`device`, if that is given, but with dtypes unchanged. When\n\n:attr:`non_blocking` is set, it tries to convert/move asynchronously\n\nwith respect to the host if possible, e.g., moving CPU Tensors with\n\npinned memory to CUDA devices.\n\nSee below for examples.\n\n.. note::\n\n This method modifies the module in-place.\n\nArgs:\n\n device (:class:`torch.device`): the desired device of the parameters\n\n and buffers in this module\n\n dtype (:class:`torch.dtype`): the desired floating point or complex dtype of\n\n the parameters and buffers in this module\n\n tensor (torch.Tensor): Tensor whose dtype and device are the desired\n\n dtype and device for all parameters and buffers in this module\n\n memory_format (:class:`torch.memory_format`): the desired memory\n\n format for 4D parameters and buffers in this module (keyword\n\n only argument)\n\nReturns:\n\n Module: self\n\nExamples::\n\n >>> # xdoctest: +IGNORE_WANT(\"non-deterministic\")\n\n >>> linear = nn.Linear(2, 2)\n\n >>> linear.weight\n\n Parameter containing:\n\n tensor([[ 0.1913, -0.3420],\n\n [-0.5113, -0.2325]])\n\n >>> linear.to(torch.double)\n\n Linear(in_features=2, out_features=2, bias=True)\n\n >>> linear.weight\n\n Parameter containing:\n\n tensor([[ 0.1913, -0.3420],\n\n [-0.5113, -0.2325]], dtype=torch.float64)\n\n >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1)\n\n >>> gpu1 = torch.device(\"cuda:1\")\n\n >>> linear.to(gpu1, dtype=torch.half, non_blocking=True)\n\n Linear(in_features=2, out_features=2, bias=True)\n\n >>> linear.weight\n\n Parameter containing:\n\n tensor([[ 0.1914, -0.3420],\n\n [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1')\n\n >>> cpu = torch.device(\"cpu\")\n\n >>> linear.to(cpu)\n\n Linear(in_features=2, out_features=2, bias=True)\n\n >>> linear.weight\n\n Parameter containing:\n\n tensor([[ 0.1914, -0.3420],\n\n [-0.5112, -0.2324]], dtype=torch.float16)\n\n >>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble)\n\n >>> linear.weight\n\n Parameter containing:\n\n tensor([[ 0.3741+0.j, 0.2382+0.j],\n\n [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128)\n\n >>> linear(torch.ones(3, 2, dtype=torch.cdouble))\n\n tensor([[0.6122+0.j, 0.1150+0.j],\n\n [0.6122+0.j, 0.1150+0.j],\n\n [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128)\n\n##### 4.1.62 to\\_empty\n\n[TOC](#table-of-contents)\n\n**Description**\n\nMove the parameters and buffers to the specified device without copying storage.\n\nArgs:\n\n device (:class:`torch.device`): The desired device of the parameters\n\n and buffers in this module.\n\n recurse (bool): Whether parameters and buffers of submodules should\n\n be recursively moved to the specified device.\n\nReturns:\n\n Module: self\n\n##### 4.1.63 train\n\n[TOC](#table-of-contents)\n\n**Description**\n\nSet the module in training mode.\n\nThis has an effect only on certain modules. See the documentation of\n\nparticular modules for details of their behaviors in training/evaluation\n\nmode, i.e., whether they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,\n\netc.\n\nArgs:\n\n mode (bool): whether to set training mode (``True``) or evaluation\n\n mode (``False``). Default: ``True``.\n\nReturns:\n\n Module: self\n\n##### 4.1.64 type\n\n[TOC](#table-of-contents)\n\n**Description**\n\nCasts all parameters and buffers to :attr:`dst_type`.\n\n.. note::\n\n This method modifies the module in-place.\n\nArgs:\n\n dst_type (type or string): the desired type\n\nReturns:\n\n Module: self\n\n##### 4.1.65 xpu\n\n[TOC](#table-of-contents)\n\n**Description**\n\nMove all model parameters and buffers to the XPU.\n\nThis also makes associated parameters and buffers different objects. So\n\nit should be called before constructing optimizer if the module will\n\nlive on XPU while being optimized.\n\n.. note::\n\n This method modifies the module in-place.\n\nArguments:\n\n device (int, optional): if specified, all parameters will be\n\n copied to that device\n\nReturns:\n\n Module: self\n\n##### 4.1.66 zero\\_grad\n\n[TOC](#table-of-contents)\n\n**Description**\n\nReset gradients of all model parameters.\n\nSee similar function under :class:`torch.optim.Optimizer` for more context.\n\nArgs:\n\n set_to_none (bool): instead of setting to zero, set the grads to None.\n\n See :meth:`torch.optim.Optimizer.zero_grad` for details.\n\n## 4.2 SelfAttention : torch.nn.modules.module.Module\n\n[TOC](#table-of-contents)\n\n**Description**\n\nBase class for all neural network modules.\n\nYour models should also subclass this class.\n\nModules can also contain other Modules, allowing them to be nested in\na tree structure. You can assign the submodules as regular attributes::\n\n import torch.nn as nn\n import torch.nn.functional as F\n\n class Model(nn.Module):\n def __init__(self) -> None:\n super().__init__()\n self.conv1 = nn.Conv2d(1, 20, 5)\n self.conv2 = nn.Conv2d(20, 20, 5)\n\n def forward(self, x):\n x = F.relu(self.conv1(x))\n return F.relu(self.conv2(x))\n\nSubmodules assigned in this way will be registered, and will also have their\nparameters converted when you call :meth:`to`, etc.\n\n.. note::\n As per the example above, an ``__init__()`` call to the parent class\n must be made before assignment on the child.\n\n:ivar training: Boolean represents whether this module is in training or\n evaluation mode.\n:vartype training: bool\n\n\n##### 4.2.1 \\_wrapped\\_call\\_impl\n\n[TOC](#table-of-contents)\n\n### 4.2.2 \\_\\_init\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nInitialize internal Module state, shared by both nn.Module and ScriptModule.\n\n##### 4.2.3 \\_apply\n\n[TOC](#table-of-contents)\n\n##### 4.2.4 \\_call\\_impl\n\n[TOC](#table-of-contents)\n\n##### 4.2.5 \\_get\\_backward\\_hooks\n\n[TOC](#table-of-contents)\n\n**Description**\n\nReturn the backward hooks for use in the call function.\n\nIt returns two lists, one with the full backward hooks and one with the non-full\n\nbackward hooks.\n\n##### 4.2.6 \\_get\\_backward\\_pre\\_hooks\n\n[TOC](#table-of-contents)\n\n##### 4.2.7 \\_get\\_name\n\n[TOC](#table-of-contents)\n\n##### 4.2.8 \\_load\\_from\\_state\\_dict\n\n[TOC](#table-of-contents)\n\n**Description**\n\nCopy parameters and buffers from :attr:`state_dict` into only this module, but not its descendants.\n\nThis is called on every submodule\n\nin :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this\n\nmodule in input :attr:`state_dict` is provided as :attr:`local_metadata`.\n\nFor state dicts without metadata, :attr:`local_metadata` is empty.\n\nSubclasses can achieve class-specific backward compatible loading using\n\nthe version number at `local_metadata.get(\"version\", None)`.\n\nAdditionally, :attr:`local_metadata` can also contain the key\n\n`assign_to_params_buffers` that indicates whether keys should be\n\nassigned their corresponding tensor in the state_dict.\n\n.. note::\n\n :attr:`state_dict` is not the same object as the input\n\n :attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So\n\n it can be modified.\n\nArgs:\n\n state_dict (dict): a dict containing parameters and\n\n persistent buffers.\n\n prefix (str): the prefix for parameters and buffers used in this\n\n module\n\n local_metadata (dict): a dict containing the metadata for this module.\n\n See\n\n strict (bool): whether to strictly enforce that the keys in\n\n :attr:`state_dict` with :attr:`prefix` match the names of\n\n parameters and buffers in this module\n\n missing_keys (list of str): if ``strict=True``, add missing keys to\n\n this list\n\n unexpected_keys (list of str): if ``strict=True``, add unexpected\n\n keys to this list\n\n error_msgs (list of str): error messages should be added to this\n\n list, and will be reported together in\n\n :meth:`~torch.nn.Module.load_state_dict`\n\n##### 4.2.9 \\_maybe\\_warn\\_non\\_full\\_backward\\_hook\n\n[TOC](#table-of-contents)\n\n##### 4.2.10 \\_named\\_members\n\n[TOC](#table-of-contents)\n\n**Description**\n\nHelp yield various names + members of modules.\n\n##### 4.2.11 \\_register\\_load\\_state\\_dict\\_pre\\_hook\n\n[TOC](#table-of-contents)\n\n**Description**\n\nSee :meth:`~torch.nn.Module.register_load_state_dict_pre_hook` for details.\n\nA subtle difference is that if ``with_module`` is set to ``False``, then the\n\nhook will not take the ``module`` as the first argument whereas\n\n:meth:`~torch.nn.Module.register_load_state_dict_pre_hook` always takes the\n\n``module`` as the first argument.\n\nArguments:\n\n hook (Callable): Callable hook that will be invoked before\n\n loading the state dict.\n\n with_module (bool, optional): Whether or not to pass the module\n\n instance to the hook as the first parameter.\n\n##### 4.2.12 \\_register\\_state\\_dict\\_hook\n\n[TOC](#table-of-contents)\n\n**Description**\n\nRegister a post-hook for the :meth:`~torch.nn.Module.state_dict` method.\n\nIt should have the following signature::\n\n hook(module, state_dict, prefix, local_metadata) -> None or state_dict\n\nThe registered hooks can modify the ``state_dict`` inplace or return a new one.\n\nIf a new ``state_dict`` is returned, it will only be respected if it is the root\n\nmodule that :meth:`~nn.Module.state_dict` is called from.\n\n##### 4.2.13 \\_replicate\\_for\\_data\\_parallel\n\n[TOC](#table-of-contents)\n\n##### 4.2.14 \\_save\\_to\\_state\\_dict\n\n[TOC](#table-of-contents)\n\n**Description**\n\nSave module state to the `destination` dictionary.\n\nThe `destination` dictionary will contain the state\n\nof the module, but not its descendants. This is called on every\n\nsubmodule in :meth:`~torch.nn.Module.state_dict`.\n\nIn rare cases, subclasses can achieve class-specific behavior by\n\noverriding this method with custom logic.\n\nArgs:\n\n destination (dict): a dict where state will be stored\n\n prefix (str): the prefix for parameters and buffers used in this\n\n module\n\n##### 4.2.15 \\_slow\\_forward\n\n[TOC](#table-of-contents)\n\n##### 4.2.16 \\_wrapped\\_call\\_impl\n\n[TOC](#table-of-contents)\n\n##### 4.2.17 add\\_module\n\n[TOC](#table-of-contents)\n\n**Description**\n\nAdd a child module to the current module.\n\nThe module can be accessed as an attribute using the given name.\n\nArgs:\n\n name (str): name of the child module. The child module can be\n\n accessed from this module using the given name\n\n module (Module): child module to be added to the module.\n\n##### 4.2.18 apply\n\n[TOC](#table-of-contents)\n\n**Description**\n\nApply ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self.\n\nTypical use includes initializing the parameters of a model\n\n(see also :ref:`nn-init-doc`).\n\nArgs:\n\n fn (:class:`Module` -> None): function to be applied to each submodule\n\nReturns:\n\n Module: self\n\nExample::\n\n >>> @torch.no_grad()\n\n >>> def init_weights(m):\n\n >>> print(m)\n\n >>> if type(m) == nn.Linear:\n\n >>> m.weight.fill_(1.0)\n\n >>> print(m.weight)\n\n >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))\n\n >>> net.apply(init_weights)\n\n Linear(in_features=2, out_features=2, bias=True)\n\n Parameter containing:\n\n tensor([[1., 1.],\n\n [1., 1.]], requires_grad=True)\n\n Linear(in_features=2, out_features=2, bias=True)\n\n Parameter containing:\n\n tensor([[1., 1.],\n\n [1., 1.]], requires_grad=True)\n\n Sequential(\n\n (0): Linear(in_features=2, out_features=2, bias=True)\n\n (1): Linear(in_features=2, out_features=2, bias=True)\n\n )\n\n##### 4.2.19 bfloat16\n\n[TOC](#table-of-contents)\n\n**Description**\n\nCasts all floating point parameters and buffers to ``bfloat16`` datatype.\n\n.. note::\n\n This method modifies the module in-place.\n\nReturns:\n\n Module: self\n\n##### 4.2.20 buffers\n\n[TOC](#table-of-contents)\n\n**Description**\n\nReturn an iterator over module buffers.\n\nArgs:\n\n recurse (bool): if True, then yields buffers of this module\n\n and all submodules. Otherwise, yields only buffers that\n\n are direct members of this module.\n\nYields:\n\n torch.Tensor: module buffer\n\nExample::\n\n >>> # xdoctest: +SKIP(\"undefined vars\")\n\n >>> for buf in model.buffers():\n\n >>> print(type(buf), buf.size())\n\n <class 'torch.Tensor'> (20L,)\n\n <class 'torch.Tensor'> (20L, 1L, 5L, 5L)\n\n##### 4.2.21 children\n\n[TOC](#table-of-contents)\n\n**Description**\n\nReturn an iterator over immediate children modules.\n\nYields:\n\n Module: a child module\n\n##### 4.2.22 compile\n\n[TOC](#table-of-contents)\n\n**Description**\n\nCompile this Module's forward using :func:`torch.compile`.\n\nThis Module's `__call__` method is compiled and all arguments are passed as-is\n\nto :func:`torch.compile`.\n\nSee :func:`torch.compile` for details on the arguments for this function.\n\n##### 4.2.23 cpu\n\n[TOC](#table-of-contents)\n\n**Description**\n\nMove all model parameters and buffers to the CPU.\n\n.. note::\n\n This method modifies the module in-place.\n\nReturns:\n\n Module: self\n\n##### 4.2.24 cuda\n\n[TOC](#table-of-contents)\n\n**Description**\n\nMove all model parameters and buffers to the GPU.\n\nThis also makes associated parameters and buffers different objects. So\n\nit should be called before constructing the optimizer if the module will\n\nlive on GPU while being optimized.\n\n.. note::\n\n This method modifies the module in-place.\n\nArgs:\n\n device (int, optional): if specified, all parameters will be\n\n copied to that device\n\nReturns:\n\n Module: self\n\n##### 4.2.25 double\n\n[TOC](#table-of-contents)\n\n**Description**\n\nCasts all floating point parameters and buffers to ``double`` datatype.\n\n.. note::\n\n This method modifies the module in-place.\n\nReturns:\n\n Module: self\n\n##### 4.2.26 eval\n\n[TOC](#table-of-contents)\n\n**Description**\n\nSet the module in evaluation mode.\n\nThis has an effect only on certain modules. See the documentation of\n\nparticular modules for details of their behaviors in training/evaluation\n\nmode, i.e. whether they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,\n\netc.\n\nThis is equivalent with :meth:`self.train(False) <torch.nn.Module.train>`.\n\nSee :ref:`locally-disable-grad-doc` for a comparison between\n\n`.eval()` and several similar mechanisms that may be confused with it.\n\nReturns:\n\n Module: self\n\n##### 4.2.27 extra\\_repr\n\n[TOC](#table-of-contents)\n\n**Description**\n\nReturn the extra representation of the module.\n\nTo print customized extra information, you should re-implement\n\nthis method in your own modules. Both single-line and multi-line\n\nstrings are acceptable.\n\n##### 4.2.28 float\n\n[TOC](#table-of-contents)\n\n**Description**\n\nCasts all floating point parameters and buffers to ``float`` datatype.\n\n.. note::\n\n This method modifies the module in-place.\n\nReturns:\n\n Module: self\n\n### 4.2.29 forward\n\n[TOC](#table-of-contents)\n\n**Description**\n\nDefine the computation performed at every call.\n\nShould be overridden by all subclasses.\n\n.. note::\n\n Although the recipe for forward pass needs to be defined within\n\n this function, one should call the :class:`Module` instance afterwards\n\n instead of this since the former takes care of running the\n\n registered hooks while the latter silently ignores them.\n\n##### 4.2.30 get\\_buffer\n\n[TOC](#table-of-contents)\n\n**Description**\n\nReturn the buffer given by ``target`` if it exists, otherwise throw an error.\n\nSee the docstring for ``get_submodule`` for a more detailed\n\nexplanation of this method's functionality as well as how to\n\ncorrectly specify ``target``.\n\nArgs:\n\n target: The fully-qualified string name of the buffer\n\n to look for. (See ``get_submodule`` for how to specify a\n\n fully-qualified string.)\n\nReturns:\n\n torch.Tensor: The buffer referenced by ``target``\n\nRaises:\n\n AttributeError: If the target string references an invalid\n\n path or resolves to something that is not a\n\n buffer\n\n##### 4.2.31 get\\_extra\\_state\n\n[TOC](#table-of-contents)\n\n**Description**\n\nReturn any extra state to include in the module's state_dict.\n\nImplement this and a corresponding :func:`set_extra_state` for your module\n\nif you need to store extra state. This function is called when building the\n\nmodule's `state_dict()`.\n\nNote that extra state should be picklable to ensure working serialization\n\nof the state_dict. We only provide backwards compatibility guarantees\n\nfor serializing Tensors; other objects may break backwards compatibility if\n\ntheir serialized pickled form changes.\n\nReturns:\n\n object: Any extra state to store in the module's state_dict\n\n##### 4.2.32 get\\_parameter\n\n[TOC](#table-of-contents)\n\n**Description**\n\nReturn the parameter given by ``target`` if it exists, otherwise throw an error.\n\nSee the docstring for ``get_submodule`` for a more detailed\n\nexplanation of this method's functionality as well as how to\n\ncorrectly specify ``target``.\n\nArgs:\n\n target: The fully-qualified string name of the Parameter\n\n to look for. (See ``get_submodule`` for how to specify a\n\n fully-qualified string.)\n\nReturns:\n\n torch.nn.Parameter: The Parameter referenced by ``target``\n\nRaises:\n\n AttributeError: If the target string references an invalid\n\n path or resolves to something that is not an\n\n ``nn.Parameter``\n\n##### 4.2.33 get\\_submodule\n\n[TOC](#table-of-contents)\n\n**Description**\n\nReturn the submodule given by ``target`` if it exists, otherwise throw an error.\n\nFor example, let's say you have an ``nn.Module`` ``A`` that\n\nlooks like this:\n\n.. code-block:: text\n\n A(\n\n (net_b): Module(\n\n (net_c): Module(\n\n (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2))\n\n )\n\n (linear): Linear(in_features=100, out_features=200, bias=True)\n\n )\n\n )\n\n(The diagram shows an ``nn.Module`` ``A``. ``A`` which has a nested\n\nsubmodule ``net_b``, which itself has two submodules ``net_c``\n\nand ``linear``. ``net_c`` then has a submodule ``conv``.)\n\nTo check whether or not we have the ``linear`` submodule, we\n\nwould call ``get_submodule(\"net_b.linear\")``. To check whether\n\nwe have the ``conv`` submodule, we would call\n\n``get_submodule(\"net_b.net_c.conv\")``.\n\nThe runtime of ``get_submodule`` is bounded by the degree\n\nof module nesting in ``target``. A query against\n\n``named_modules`` achieves the same result, but it is O(N) in\n\nthe number of transitive modules. So, for a simple check to see\n\nif some submodule exists, ``get_submodule`` should always be\n\nused.\n\nArgs:\n\n target: The fully-qualified string name of the submodule\n\n to look for. (See above example for how to specify a\n\n fully-qualified string.)\n\nReturns:\n\n torch.nn.Module: The submodule referenced by ``target``\n\nRaises:\n\n AttributeError: If the target string references an invalid\n\n path or resolves to something that is not an\n\n ``nn.Module``\n\n##### 4.2.34 half\n\n[TOC](#table-of-contents)\n\n**Description**\n\nCasts all floating point parameters and buffers to ``half`` datatype.\n\n.. note::\n\n This method modifies the module in-place.\n\nReturns:\n\n Module: self\n\n##### 4.2.35 ipu\n\n[TOC](#table-of-contents)\n\n**Description**\n\nMove all model parameters and buffers to the IPU.\n\nThis also makes associated parameters and buffers different objects. So\n\nit should be called before constructing the optimizer if the module will\n\nlive on IPU while being optimized.\n\n.. note::\n\n This method modifies the module in-place.\n\nArguments:\n\n device (int, optional): if specified, all parameters will be\n\n copied to that device\n\nReturns:\n\n Module: self\n\n##### 4.2.36 load\\_state\\_dict\n\n[TOC](#table-of-contents)\n\n**Description**\n\nCopy parameters and buffers from :attr:`state_dict` into this module and its descendants.\n\nIf :attr:`strict` is ``True``, then\n\nthe keys of :attr:`state_dict` must exactly match the keys returned\n\nby this module's :meth:`~torch.nn.Module.state_dict` function.\n\n.. warning::\n\n If :attr:`assign` is ``True`` the optimizer must be created after\n\n the call to :attr:`load_state_dict` unless\n\n :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``.\n\nArgs:\n\n state_dict (dict): a dict containing parameters and\n\n persistent buffers.\n\n strict (bool, optional): whether to strictly enforce that the keys\n\n in :attr:`state_dict` match the keys returned by this module's\n\n :meth:`~torch.nn.Module.state_dict` function. Default: ``True``\n\n assign (bool, optional): When set to ``False``, the properties of the tensors\n\n in the current module are preserved whereas setting it to ``True`` preserves\n\n properties of the Tensors in the state dict. The only\n\n exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s\n\n for which the value from the module is preserved.\n\n Default: ``False``\n\nReturns:\n\n ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:\n\n * **missing_keys** is a list of str containing any keys that are expected\n\n by this module but missing from the provided ``state_dict``.\n\n * **unexpected_keys** is a list of str containing the keys that are not\n\n expected by this module but present in the provided ``state_dict``.\n\nNote:\n\n If a parameter or buffer is registered as ``None`` and its corresponding key\n\n exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a\n\n ``RuntimeError``.\n\n##### 4.2.37 modules\n\n[TOC](#table-of-contents)\n\n**Description**\n\nReturn an iterator over all modules in the network.\n\nYields:\n\n Module: a module in the network\n\nNote:\n\n Duplicate modules are returned only once. In the following\n\n example, ``l`` will be returned only once.\n\nExample::\n\n >>> l = nn.Linear(2, 2)\n\n >>> net = nn.Sequential(l, l)\n\n >>> for idx, m in enumerate(net.modules()):\n\n ... print(idx, '->', m)\n\n 0 -> Sequential(\n\n (0): Linear(in_features=2, out_features=2, bias=True)\n\n (1): Linear(in_features=2, out_features=2, bias=True)\n\n )\n\n 1 -> Linear(in_features=2, out_features=2, bias=True)\n\n##### 4.2.38 mtia\n\n[TOC](#table-of-contents)\n\n**Description**\n\nMove all model parameters and buffers to the MTIA.\n\nThis also makes associated parameters and buffers different objects. So\n\nit should be called before constructing the optimizer if the module will\n\nlive on MTIA while being optimized.\n\n.. note::\n\n This method modifies the module in-place.\n\nArguments:\n\n device (int, optional): if specified, all parameters will be\n\n copied to that device\n\nReturns:\n\n Module: self\n\n##### 4.2.39 named\\_buffers\n\n[TOC](#table-of-contents)\n\n**Description**\n\nReturn an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.\n\nArgs:\n\n prefix (str): prefix to prepend to all buffer names.\n\n recurse (bool, optional): if True, then yields buffers of this module\n\n and all submodules. Otherwise, yields only buffers that\n\n are direct members of this module. Defaults to True.\n\n remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Defaults to True.\n\nYields:\n\n (str, torch.Tensor): Tuple containing the name and buffer\n\nExample::\n\n >>> # xdoctest: +SKIP(\"undefined vars\")\n\n >>> for name, buf in self.named_buffers():\n\n >>> if name in ['running_var']:\n\n >>> print(buf.size())\n\n##### 4.2.40 named\\_children\n\n[TOC](#table-of-contents)\n\n**Description**\n\nReturn an iterator over immediate children modules, yielding both the name of the module as well as the module itself.\n\nYields:\n\n (str, Module): Tuple containing a name and child module\n\nExample::\n\n >>> # xdoctest: +SKIP(\"undefined vars\")\n\n >>> for name, module in model.named_children():\n\n >>> if name in ['conv4', 'conv5']:\n\n >>> print(module)\n\n##### 4.2.41 named\\_modules\n\n[TOC](#table-of-contents)\n\n**Description**\n\nReturn an iterator over all modules in the network, yielding both the name of the module as well as the module itself.\n\nArgs:\n\n memo: a memo to store the set of modules already added to the result\n\n prefix: a prefix that will be added to the name of the module\n\n remove_duplicate: whether to remove the duplicated module instances in the result\n\n or not\n\nYields:\n\n (str, Module): Tuple of name and module\n\nNote:\n\n Duplicate modules are returned only once. In the following\n\n example, ``l`` will be returned only once.\n\nExample::\n\n >>> l = nn.Linear(2, 2)\n\n >>> net = nn.Sequential(l, l)\n\n >>> for idx, m in enumerate(net.named_modules()):\n\n ... print(idx, '->', m)\n\n 0 -> ('', Sequential(\n\n (0): Linear(in_features=2, out_features=2, bias=True)\n\n (1): Linear(in_features=2, out_features=2, bias=True)\n\n ))\n\n 1 -> ('0', Linear(in_features=2, out_features=2, bias=True))\n\n##### 4.2.42 named\\_parameters\n\n[TOC](#table-of-contents)\n\n**Description**\n\nReturn an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.\n\nArgs:\n\n prefix (str): prefix to prepend to all parameter names.\n\n recurse (bool): if True, then yields parameters of this module\n\n and all submodules. Otherwise, yields only parameters that\n\n are direct members of this module.\n\n remove_duplicate (bool, optional): whether to remove the duplicated\n\n parameters in the result. Defaults to True.\n\nYields:\n\n (str, Parameter): Tuple containing the name and parameter\n\nExample::\n\n >>> # xdoctest: +SKIP(\"undefined vars\")\n\n >>> for name, param in self.named_parameters():\n\n >>> if name in ['bias']:\n\n >>> print(param.size())\n\n##### 4.2.43 parameters\n\n[TOC](#table-of-contents)\n\n**Description**\n\nReturn an iterator over module parameters.\n\nThis is typically passed to an optimizer.\n\nArgs:\n\n recurse (bool): if True, then yields parameters of this module\n\n and all submodules. Otherwise, yields only parameters that\n\n are direct members of this module.\n\nYields:\n\n Parameter: module parameter\n\nExample::\n\n >>> # xdoctest: +SKIP(\"undefined vars\")\n\n >>> for param in model.parameters():\n\n >>> print(type(param), param.size())\n\n <class 'torch.Tensor'> (20L,)\n\n <class 'torch.Tensor'> (20L, 1L, 5L, 5L)\n\n##### 4.2.44 register\\_backward\\_hook\n\n[TOC](#table-of-contents)\n\n**Description**\n\nRegister a backward hook on the module.\n\nThis function is deprecated in favor of :meth:`~torch.nn.Module.register_full_backward_hook` and\n\nthe behavior of this function will change in future versions.\n\nReturns:\n\n :class:`torch.utils.hooks.RemovableHandle`:\n\n a handle that can be used to remove the added hook by calling\n\n ``handle.remove()``\n\n##### 4.2.45 register\\_buffer\n\n[TOC](#table-of-contents)\n\n**Description**\n\nAdd a buffer to the module.\n\nThis is typically used to register a buffer that should not to be\n\nconsidered a model parameter. For example, BatchNorm's ``running_mean``\n\nis not a parameter, but is part of the module's state. Buffers, by\n\ndefault, are persistent and will be saved alongside parameters. This\n\nbehavior can be changed by setting :attr:`persistent` to ``False``. The\n\nonly difference between a persistent buffer and a non-persistent buffer\n\nis that the latter will not be a part of this module's\n\n:attr:`state_dict`.\n\nBuffers can be accessed as attributes using given names.\n\nArgs:\n\n name (str): name of the buffer. The buffer can be accessed\n\n from this module using the given name\n\n tensor (Tensor or None): buffer to be registered. If ``None``, then operations\n\n that run on buffers, such as :attr:`cuda`, are ignored. If ``None``,\n\n the buffer is **not** included in the module's :attr:`state_dict`.\n\n persistent (bool): whether the buffer is part of this module's\n\n :attr:`state_dict`.\n\nExample::\n\n >>> # xdoctest: +SKIP(\"undefined vars\")\n\n >>> self.register_buffer('running_mean', torch.zeros(num_features))\n\n##### 4.2.46 register\\_forward\\_hook\n\n[TOC](#table-of-contents)\n\n**Description**\n\nRegister a forward hook on the module.\n\nThe hook will be called every time after :func:`forward` has computed an output.\n\nIf ``with_kwargs`` is ``False`` or not specified, the input contains only\n\nthe positional arguments given to the module. Keyword arguments won't be\n\npassed to the hooks and only to the ``forward``. The hook can modify the\n\noutput. It can modify the input inplace but it will not have effect on\n\nforward since this is called after :func:`forward` is called. The hook\n\nshould have the following signature::\n\n hook(module, args, output) -> None or modified output\n\nIf ``with_kwargs`` is ``True``, the forward hook will be passed the\n\n``kwargs`` given to the forward function and be expected to return the\n\noutput possibly modified. The hook should have the following signature::\n\n hook(module, args, kwargs, output) -> None or modified output\n\nArgs:\n\n hook (Callable): The user defined hook to be registered.\n\n prepend (bool): If ``True``, the provided ``hook`` will be fired\n\n before all existing ``forward`` hooks on this\n\n :class:`torch.nn.modules.Module`. Otherwise, the provided\n\n ``hook`` will be fired after all existing ``forward`` hooks on\n\n this :class:`torch.nn.modules.Module`. Note that global\n\n ``forward`` hooks registered with\n\n :func:`register_module_forward_hook` will fire before all hooks\n\n registered by this method.\n\n Default: ``False``\n\n with_kwargs (bool): If ``True``, the ``hook`` will be passed the\n\n kwargs given to the forward function.\n\n Default: ``False``\n\n always_call (bool): If ``True`` the ``hook`` will be run regardless of\n\n whether an exception is raised while calling the Module.\n\n Default: ``False``\n\nReturns:\n\n :class:`torch.utils.hooks.RemovableHandle`:\n\n a handle that can be used to remove the added hook by calling\n\n ``handle.remove()``\n\n##### 4.2.47 register\\_forward\\_pre\\_hook\n\n[TOC](#table-of-contents)\n\n**Description**\n\nRegister a forward pre-hook on the module.\n\nThe hook will be called every time before :func:`forward` is invoked.\n\nIf ``with_kwargs`` is false or not specified, the input contains only\n\nthe positional arguments given to the module. Keyword arguments won't be\n\npassed to the hooks and only to the ``forward``. The hook can modify the\n\ninput. User can either return a tuple or a single modified value in the\n\nhook. We will wrap the value into a tuple if a single value is returned\n\n(unless that value is already a tuple). The hook should have the\n\nfollowing signature::\n\n hook(module, args) -> None or modified input\n\nIf ``with_kwargs`` is true, the forward pre-hook will be passed the\n\nkwargs given to the forward function. And if the hook modifies the\n\ninput, both the args and kwargs should be returned. The hook should have\n\nthe following signature::\n\n hook(module, args, kwargs) -> None or a tuple of modified input and kwargs\n\nArgs:\n\n hook (Callable): The user defined hook to be registered.\n\n prepend (bool): If true, the provided ``hook`` will be fired before\n\n all existing ``forward_pre`` hooks on this\n\n :class:`torch.nn.modules.Module`. Otherwise, the provided\n\n ``hook`` will be fired after all existing ``forward_pre`` hooks\n\n on this :class:`torch.nn.modules.Module`. Note that global\n\n ``forward_pre`` hooks registered with\n\n :func:`register_module_forward_pre_hook` will fire before all\n\n hooks registered by this method.\n\n Default: ``False``\n\n with_kwargs (bool): If true, the ``hook`` will be passed the kwargs\n\n given to the forward function.\n\n Default: ``False``\n\nReturns:\n\n :class:`torch.utils.hooks.RemovableHandle`:\n\n a handle that can be used to remove the added hook by calling\n\n ``handle.remove()``\n\n##### 4.2.48 register\\_full\\_backward\\_hook\n\n[TOC](#table-of-contents)\n\n**Description**\n\nRegister a backward hook on the module.\n\nThe hook will be called every time the gradients with respect to a module\n\nare computed, i.e. the hook will execute if and only if the gradients with\n\nrespect to module outputs are computed. The hook should have the following\n\nsignature::\n\n hook(module, grad_input, grad_output) -> tuple(Tensor) or None\n\nThe :attr:`grad_input` and :attr:`grad_output` are tuples that contain the gradients\n\nwith respect to the inputs and outputs respectively. The hook should\n\nnot modify its arguments, but it can optionally return a new gradient with\n\nrespect to the input that will be used in place of :attr:`grad_input` in\n\nsubsequent computations. :attr:`grad_input` will only correspond to the inputs given\n\nas positional arguments and all kwarg arguments are ignored. Entries\n\nin :attr:`grad_input` and :attr:`grad_output` will be ``None`` for all non-Tensor\n\narguments.\n\nFor technical reasons, when this hook is applied to a Module, its forward function will\n\nreceive a view of each Tensor passed to the Module. Similarly the caller will receive a view\n\nof each Tensor returned by the Module's forward function.\n\n.. warning ::\n\n Modifying inputs or outputs inplace is not allowed when using backward hooks and\n\n will raise an error.\n\nArgs:\n\n hook (Callable): The user-defined hook to be registered.\n\n prepend (bool): If true, the provided ``hook`` will be fired before\n\n all existing ``backward`` hooks on this\n\n :class:`torch.nn.modules.Module`. Otherwise, the provided\n\n ``hook`` will be fired after all existing ``backward`` hooks on\n\n this :class:`torch.nn.modules.Module`. Note that global\n\n ``backward`` hooks registered with\n\n :func:`register_module_full_backward_hook` will fire before\n\n all hooks registered by this method.\n\nReturns:\n\n :class:`torch.utils.hooks.RemovableHandle`:\n\n a handle that can be used to remove the added hook by calling\n\n ``handle.remove()``\n\n##### 4.2.49 register\\_full\\_backward\\_pre\\_hook\n\n[TOC](#table-of-contents)\n\n**Description**\n\nRegister a backward pre-hook on the module.\n\nThe hook will be called every time the gradients for the module are computed.\n\nThe hook should have the following signature::\n\n hook(module, grad_output) -> tuple[Tensor] or None\n\nThe :attr:`grad_output` is a tuple. The hook should\n\nnot modify its arguments, but it can optionally return a new gradient with\n\nrespect to the output that will be used in place of :attr:`grad_output` in\n\nsubsequent computations. Entries in :attr:`grad_output` will be ``None`` for\n\nall non-Tensor arguments.\n\nFor technical reasons, when this hook is applied to a Module, its forward function will\n\nreceive a view of each Tensor passed to the Module. Similarly the caller will receive a view\n\nof each Tensor returned by the Module's forward function.\n\n.. warning ::\n\n Modifying inputs inplace is not allowed when using backward hooks and\n\n will raise an error.\n\nArgs:\n\n hook (Callable): The user-defined hook to be registered.\n\n prepend (bool): If true, the provided ``hook`` will be fired before\n\n all existing ``backward_pre`` hooks on this\n\n :class:`torch.nn.modules.Module`. Otherwise, the provided\n\n ``hook`` will be fired after all existing ``backward_pre`` hooks\n\n on this :class:`torch.nn.modules.Module`. Note that global\n\n ``backward_pre`` hooks registered with\n\n :func:`register_module_full_backward_pre_hook` will fire before\n\n all hooks registered by this method.\n\nReturns:\n\n :class:`torch.utils.hooks.RemovableHandle`:\n\n a handle that can be used to remove the added hook by calling\n\n ``handle.remove()``\n\n##### 4.2.50 register\\_load\\_state\\_dict\\_post\\_hook\n\n[TOC](#table-of-contents)\n\n**Description**\n\nRegister a post-hook to be run after module's :meth:`~nn.Module.load_state_dict` is called.\n\nIt should have the following signature::\n\n hook(module, incompatible_keys) -> None\n\nThe ``module`` argument is the current module that this hook is registered\n\non, and the ``incompatible_keys`` argument is a ``NamedTuple`` consisting\n\nof attributes ``missing_keys`` and ``unexpected_keys``. ``missing_keys``\n\nis a ``list`` of ``str`` containing the missing keys and\n\n``unexpected_keys`` is a ``list`` of ``str`` containing the unexpected keys.\n\nThe given incompatible_keys can be modified inplace if needed.\n\nNote that the checks performed when calling :func:`load_state_dict` with\n\n``strict=True`` are affected by modifications the hook makes to\n\n``missing_keys`` or ``unexpected_keys``, as expected. Additions to either\n\nset of keys will result in an error being thrown when ``strict=True``, and\n\nclearing out both missing and unexpected keys will avoid an error.\n\nReturns:\n\n :class:`torch.utils.hooks.RemovableHandle`:\n\n a handle that can be used to remove the added hook by calling\n\n ``handle.remove()``\n\n##### 4.2.51 register\\_load\\_state\\_dict\\_pre\\_hook\n\n[TOC](#table-of-contents)\n\n**Description**\n\nRegister a pre-hook to be run before module's :meth:`~nn.Module.load_state_dict` is called.\n\nIt should have the following signature::\n\n hook(module, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -> None # noqa: B950\n\nArguments:\n\n hook (Callable): Callable hook that will be invoked before\n\n loading the state dict.\n\n##### 4.2.52 register\\_module\n\n[TOC](#table-of-contents)\n\n**Description**\n\nAlias for :func:`add_module`.\n\n##### 4.2.53 register\\_parameter\n\n[TOC](#table-of-contents)\n\n**Description**\n\nAdd a parameter to the module.\n\nThe parameter can be accessed as an attribute using given name.\n\nArgs:\n\n name (str): name of the parameter. The parameter can be accessed\n\n from this module using the given name\n\n param (Parameter or None): parameter to be added to the module. If\n\n ``None``, then operations that run on parameters, such as :attr:`cuda`,\n\n are ignored. If ``None``, the parameter is **not** included in the\n\n module's :attr:`state_dict`.\n\n##### 4.2.54 register\\_state\\_dict\\_post\\_hook\n\n[TOC](#table-of-contents)\n\n**Description**\n\nRegister a post-hook for the :meth:`~torch.nn.Module.state_dict` method.\n\nIt should have the following signature::\n\n hook(module, state_dict, prefix, local_metadata) -> None\n\nThe registered hooks can modify the ``state_dict`` inplace.\n\n##### 4.2.55 register\\_state\\_dict\\_pre\\_hook\n\n[TOC](#table-of-contents)\n\n**Description**\n\nRegister a pre-hook for the :meth:`~torch.nn.Module.state_dict` method.\n\nIt should have the following signature::\n\n hook(module, prefix, keep_vars) -> None\n\nThe registered hooks can be used to perform pre-processing before the ``state_dict``\n\ncall is made.\n\n##### 4.2.56 requires\\_grad\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nChange if autograd should record operations on parameters in this module.\n\nThis method sets the parameters' :attr:`requires_grad` attributes\n\nin-place.\n\nThis method is helpful for freezing part of the module for finetuning\n\nor training parts of a model individually (e.g., GAN training).\n\nSee :ref:`locally-disable-grad-doc` for a comparison between\n\n`.requires_grad_()` and several similar mechanisms that may be confused with it.\n\nArgs:\n\n requires_grad (bool): whether autograd should record operations on\n\n parameters in this module. Default: ``True``.\n\nReturns:\n\n Module: self\n\n##### 4.2.57 set\\_extra\\_state\n\n[TOC](#table-of-contents)\n\n**Description**\n\nSet extra state contained in the loaded `state_dict`.\n\nThis function is called from :func:`load_state_dict` to handle any extra state\n\nfound within the `state_dict`. Implement this function and a corresponding\n\n:func:`get_extra_state` for your module if you need to store extra state within its\n\n`state_dict`.\n\nArgs:\n\n state (dict): Extra state from the `state_dict`\n\n##### 4.2.58 set\\_submodule\n\n[TOC](#table-of-contents)\n\n**Description**\n\nSet the submodule given by ``target`` if it exists, otherwise throw an error.\n\nFor example, let's say you have an ``nn.Module`` ``A`` that\n\nlooks like this:\n\n.. code-block:: text\n\n A(\n\n (net_b): Module(\n\n (net_c): Module(\n\n (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2))\n\n )\n\n (linear): Linear(in_features=100, out_features=200, bias=True)\n\n )\n\n )\n\n(The diagram shows an ``nn.Module`` ``A``. ``A`` has a nested\n\nsubmodule ``net_b``, which itself has two submodules ``net_c``\n\nand ``linear``. ``net_c`` then has a submodule ``conv``.)\n\nTo overide the ``Conv2d`` with a new submodule ``Linear``, you\n\nwould call\n\n``set_submodule(\"net_b.net_c.conv\", nn.Linear(33, 16))``.\n\nArgs:\n\n target: The fully-qualified string name of the submodule\n\n to look for. (See above example for how to specify a\n\n fully-qualified string.)\n\n module: The module to set the submodule to.\n\nRaises:\n\n ValueError: If the target string is empty\n\n AttributeError: If the target string references an invalid\n\n path or resolves to something that is not an\n\n ``nn.Module``\n\n##### 4.2.59 share\\_memory\n\n[TOC](#table-of-contents)\n\n**Description**\n\nSee :meth:`torch.Tensor.share_memory_`.\n\n##### 4.2.60 state\\_dict\n\n[TOC](#table-of-contents)\n\n**Description**\n\nReturn a dictionary containing references to the whole state of the module.\n\nBoth parameters and persistent buffers (e.g. running averages) are\n\nincluded. Keys are corresponding parameter and buffer names.\n\nParameters and buffers set to ``None`` are not included.\n\n.. note::\n\n The returned object is a shallow copy. It contains references\n\n to the module's parameters and buffers.\n\n.. warning::\n\n Currently ``state_dict()`` also accepts positional arguments for\n\n ``destination``, ``prefix`` and ``keep_vars`` in order. However,\n\n this is being deprecated and keyword arguments will be enforced in\n\n future releases.\n\n.. warning::\n\n Please avoid the use of argument ``destination`` as it is not\n\n designed for end-users.\n\nArgs:\n\n destination (dict, optional): If provided, the state of module will\n\n be updated into the dict and the same object is returned.\n\n Otherwise, an ``OrderedDict`` will be created and returned.\n\n Default: ``None``.\n\n prefix (str, optional): a prefix added to parameter and buffer\n\n names to compose the keys in state_dict. Default: ``''``.\n\n keep_vars (bool, optional): by default the :class:`~torch.Tensor` s\n\n returned in the state dict are detached from autograd. If it's\n\n set to ``True``, detaching will not be performed.\n\n Default: ``False``.\n\nReturns:\n\n dict:\n\n a dictionary containing a whole state of the module\n\nExample::\n\n >>> # xdoctest: +SKIP(\"undefined vars\")\n\n >>> module.state_dict().keys()\n\n ['bias', 'weight']\n\n##### 4.2.61 to\n\n[TOC](#table-of-contents)\n\n**Description**\n\nMove and/or cast the parameters and buffers.\n\nThis can be called as\n\n.. function:: to(device=None, dtype=None, non_blocking=False)\n\n :noindex:\n\n.. function:: to(dtype, non_blocking=False)\n\n :noindex:\n\n.. function:: to(tensor, non_blocking=False)\n\n :noindex:\n\n.. function:: to(memory_format=torch.channels_last)\n\n :noindex:\n\nIts signature is similar to :meth:`torch.Tensor.to`, but only accepts\n\nfloating point or complex :attr:`dtype`\\ s. In addition, this method will\n\nonly cast the floating point or complex parameters and buffers to :attr:`dtype`\n\n(if given). The integral parameters and buffers will be moved\n\n:attr:`device`, if that is given, but with dtypes unchanged. When\n\n:attr:`non_blocking` is set, it tries to convert/move asynchronously\n\nwith respect to the host if possible, e.g., moving CPU Tensors with\n\npinned memory to CUDA devices.\n\nSee below for examples.\n\n.. note::\n\n This method modifies the module in-place.\n\nArgs:\n\n device (:class:`torch.device`): the desired device of the parameters\n\n and buffers in this module\n\n dtype (:class:`torch.dtype`): the desired floating point or complex dtype of\n\n the parameters and buffers in this module\n\n tensor (torch.Tensor): Tensor whose dtype and device are the desired\n\n dtype and device for all parameters and buffers in this module\n\n memory_format (:class:`torch.memory_format`): the desired memory\n\n format for 4D parameters and buffers in this module (keyword\n\n only argument)\n\nReturns:\n\n Module: self\n\nExamples::\n\n >>> # xdoctest: +IGNORE_WANT(\"non-deterministic\")\n\n >>> linear = nn.Linear(2, 2)\n\n >>> linear.weight\n\n Parameter containing:\n\n tensor([[ 0.1913, -0.3420],\n\n [-0.5113, -0.2325]])\n\n >>> linear.to(torch.double)\n\n Linear(in_features=2, out_features=2, bias=True)\n\n >>> linear.weight\n\n Parameter containing:\n\n tensor([[ 0.1913, -0.3420],\n\n [-0.5113, -0.2325]], dtype=torch.float64)\n\n >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1)\n\n >>> gpu1 = torch.device(\"cuda:1\")\n\n >>> linear.to(gpu1, dtype=torch.half, non_blocking=True)\n\n Linear(in_features=2, out_features=2, bias=True)\n\n >>> linear.weight\n\n Parameter containing:\n\n tensor([[ 0.1914, -0.3420],\n\n [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1')\n\n >>> cpu = torch.device(\"cpu\")\n\n >>> linear.to(cpu)\n\n Linear(in_features=2, out_features=2, bias=True)\n\n >>> linear.weight\n\n Parameter containing:\n\n tensor([[ 0.1914, -0.3420],\n\n [-0.5112, -0.2324]], dtype=torch.float16)\n\n >>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble)\n\n >>> linear.weight\n\n Parameter containing:\n\n tensor([[ 0.3741+0.j, 0.2382+0.j],\n\n [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128)\n\n >>> linear(torch.ones(3, 2, dtype=torch.cdouble))\n\n tensor([[0.6122+0.j, 0.1150+0.j],\n\n [0.6122+0.j, 0.1150+0.j],\n\n [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128)\n\n##### 4.2.62 to\\_empty\n\n[TOC](#table-of-contents)\n\n**Description**\n\nMove the parameters and buffers to the specified device without copying storage.\n\nArgs:\n\n device (:class:`torch.device`): The desired device of the parameters\n\n and buffers in this module.\n\n recurse (bool): Whether parameters and buffers of submodules should\n\n be recursively moved to the specified device.\n\nReturns:\n\n Module: self\n\n##### 4.2.63 train\n\n[TOC](#table-of-contents)\n\n**Description**\n\nSet the module in training mode.\n\nThis has an effect only on certain modules. See the documentation of\n\nparticular modules for details of their behaviors in training/evaluation\n\nmode, i.e., whether they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,\n\netc.\n\nArgs:\n\n mode (bool): whether to set training mode (``True``) or evaluation\n\n mode (``False``). Default: ``True``.\n\nReturns:\n\n Module: self\n\n##### 4.2.64 type\n\n[TOC](#table-of-contents)\n\n**Description**\n\nCasts all parameters and buffers to :attr:`dst_type`.\n\n.. note::\n\n This method modifies the module in-place.\n\nArgs:\n\n dst_type (type or string): the desired type\n\nReturns:\n\n Module: self\n\n##### 4.2.65 xpu\n\n[TOC](#table-of-contents)\n\n**Description**\n\nMove all model parameters and buffers to the XPU.\n\nThis also makes associated parameters and buffers different objects. So\n\nit should be called before constructing optimizer if the module will\n\nlive on XPU while being optimized.\n\n.. note::\n\n This method modifies the module in-place.\n\nArguments:\n\n device (int, optional): if specified, all parameters will be\n\n copied to that device\n\nReturns:\n\n Module: self\n\n##### 4.2.66 zero\\_grad\n\n[TOC](#table-of-contents)\n\n**Description**\n\nReset gradients of all model parameters.\n\nSee similar function under :class:`torch.optim.Optimizer` for more context.\n\nArgs:\n\n set_to_none (bool): instead of setting to zero, set the grads to None.\n\n See :meth:`torch.optim.Optimizer.zero_grad` for details.\n\n# 5 multi\\_transforms\n\n[TOC](#table-of-contents)\n\nThe module `rsp.ml.multi_transforms` is based on `torchvision.transforms`, which is made for single images. `rsp.ml.multi_transforms` extends this functionality by providing transformations for sequences of images, which could be usefull for video augmentation.\n\n## 5.1 BGR2GRAY : MultiTransform\n\n[TOC](#table-of-contents)\n\n**Description**\n\nConverts a sequence of BGR images to grayscale images.\n\n\n### 5.1.1 \\_\\_call\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nCall self as a function.\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| input | torch.Tensor<br>List[PIL.Image]<br>List[numpy.array] | Sequence of images |\n### 5.1.2 \\_\\_init\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nInitializes a new instance.\n\n## 5.2 BGR2RGB : MultiTransform\n\n[TOC](#table-of-contents)\n\n**Description**\n\nConverts sequence of BGR images to RGB images.\n\n\n### 5.2.1 \\_\\_call\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nCall self as a function.\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| input | torch.Tensor<br>List[PIL.Image]<br>List[numpy.array] | Sequence of images |\n### 5.2.2 \\_\\_init\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nInitializes a new instance.\n\n## 5.3 Brightness : MultiTransform\n\n[TOC](#table-of-contents)\n\n**Description**\n\nMultiTransform is an extension to keep the same transformation over a sequence of images instead of initializing a new transformation for every single image. It is inspired by `torchvision.transforms` and could be used for video augmentation. Use `rsp.ml.multi_transforms.Compose`to combine multiple image sequence transformations.\n\n> **Note** `rsp.ml.multi_transforms.MultiTransform` is a base class and should be inherited.\n\n\n### 5.3.1 \\_\\_call\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nCall self as a function.\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| input | torch.Tensor<br>List[PIL.Image]<br>List[numpy.array] | Sequence of images |\n### 5.3.2 \\_\\_init\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nInitializes a new instance.\n\n## 5.4 CenterCrop : MultiTransform\n\n[TOC](#table-of-contents)\n\n**Description**\n\nCrops Images at the center after upscaling them. Dimensions kept the same.\n\n\n\n\n### 5.4.1 \\_\\_call\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nCall self as a function.\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| input | torch.Tensor<br>List[PIL.Image]<br>List[numpy.array] | Sequence of images |\n### 5.4.2 \\_\\_init\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nInitializes a new instance.\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| max_scale | float | Images are scaled randomly between 1. and max_scale before cropping to original size. |\n## 5.5 Color : MultiTransform\n\n[TOC](#table-of-contents)\n\n**Description**\n\nMultiTransform is an extension to keep the same transformation over a sequence of images instead of initializing a new transformation for every single image. It is inspired by `torchvision.transforms` and could be used for video augmentation. Use `rsp.ml.multi_transforms.Compose`to combine multiple image sequence transformations.\n\n> **Note** `rsp.ml.multi_transforms.MultiTransform` is a base class and should be inherited.\n\n\n### 5.5.1 \\_\\_call\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nCall self as a function.\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| input | torch.Tensor<br>List[PIL.Image]<br>List[numpy.array] | Sequence of images |\n### 5.5.2 \\_\\_init\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nInitializes a new instance.\n\n## 5.6 Compose : builtins.object\n\n[TOC](#table-of-contents)\n\n**Description**\n\nComposes several MultiTransforms together.\n\n**Example**\n\n```python\nimport rsp.ml.multi_transforms as t\n\ntransforms = t.Compose([\n t.BGR2GRAY(),\n t.Scale(0.5)\n])\n```\n### 5.6.1 \\_\\_call\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nCall self as a function.\n\n### 5.6.2 \\_\\_init\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nInitializes a new instance.\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| children | List[MultiTransform] | List of MultiTransforms to compose. |\n## 5.7 GaussianNoise : MultiTransform\n\n[TOC](#table-of-contents)\n\n**Description**\n\nMultiTransform is an extension to keep the same transformation over a sequence of images instead of initializing a new transformation for every single image. It is inspired by `torchvision.transforms` and could be used for video augmentation. Use `rsp.ml.multi_transforms.Compose`to combine multiple image sequence transformations.\n\n> **Note** `rsp.ml.multi_transforms.MultiTransform` is a base class and should be inherited.\n\n\n### 5.7.1 \\_\\_call\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nCall self as a function.\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| input | torch.Tensor<br>List[PIL.Image]<br>List[numpy.array] | Sequence of images |\n### 5.7.2 \\_\\_init\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nInitializes a new instance.\n\n## 5.8 MultiTransform : builtins.object\n\n[TOC](#table-of-contents)\n\n**Description**\n\nMultiTransform is an extension to keep the same transformation over a sequence of images instead of initializing a new transformation for every single image. It is inspired by `torchvision.transforms` and could be used for video augmentation. Use `rsp.ml.multi_transforms.Compose`to combine multiple image sequence transformations.\n\n> **Note** `rsp.ml.multi_transforms.MultiTransform` is a base class and should be inherited.\n\n\n### 5.8.1 \\_\\_call\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nCall self as a function.\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| input | torch.Tensor<br>List[PIL.Image]<br>List[numpy.array] | Sequence of images |\n### 5.8.2 \\_\\_init\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nInitializes a new instance.\n\n## 5.9 Normalize : MultiTransform\n\n[TOC](#table-of-contents)\n\n**Description**\n\nNormalize images with mean and standard deviation. Given mean: (mean[1],...,mean[n]) and std: (std[1],..,std[n]) for n channels, this transform will normalize each channel of the input torch.*Tensor i.e., output[channel] = (input[channel] - mean[channel]) / std[channel]\n\n> Based on torchvision.transforms.Normalize\n\n\n### 5.9.1 \\_\\_call\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nCall self as a function.\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| input | torch.Tensor<br>List[PIL.Image]<br>List[numpy.array] | Sequence of images |\n### 5.9.2 \\_\\_init\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nInitializes a new instance.\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| mean | List[float] | Sequence of means for each channel. |\n| std | List[float] | Sequence of standard deviations for each channel. |\n| inplace | bool | Set to True make this operation in-place. |\n## 5.10 RGB2BGR : BGR2RGB\n\n[TOC](#table-of-contents)\n\n**Description**\n\nConverts sequence of RGB images to BGR images.\n\n\n### 5.10.1 \\_\\_call\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nCall self as a function.\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| input | torch.Tensor<br>List[PIL.Image]<br>List[numpy.array] | Sequence of images |\n### 5.10.2 \\_\\_init\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nInitializes a new instance.\n\n## 5.11 RandomCrop : MultiTransform\n\n[TOC](#table-of-contents)\n\n**Description**\n\nCrops Images at a random location after upscaling them. Dimensions kept the same.\n\n\n\n\n### 5.11.1 \\_\\_call\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nCall self as a function.\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| input | torch.Tensor<br>List[PIL.Image]<br>List[numpy.array] | Sequence of images |\n### 5.11.2 \\_\\_init\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nInitializes a new instance.\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| max_scale | float | Images are scaled randomly between 1. and max_scale before cropping to original size. |\n## 5.12 RandomHorizontalFlip : MultiTransform\n\n[TOC](#table-of-contents)\n\n**Description**\n\nMultiTransform is an extension to keep the same transformation over a sequence of images instead of initializing a new transformation for every single image. It is inspired by `torchvision.transforms` and could be used for video augmentation. Use `rsp.ml.multi_transforms.Compose`to combine multiple image sequence transformations.\n\n> **Note** `rsp.ml.multi_transforms.MultiTransform` is a base class and should be inherited.\n\n\n### 5.12.1 \\_\\_call\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nCall self as a function.\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| input | torch.Tensor<br>List[PIL.Image]<br>List[numpy.array] | Sequence of images |\n### 5.12.2 \\_\\_init\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nInitializes a new instance.\n\n## 5.13 RandomVerticalFlip : MultiTransform\n\n[TOC](#table-of-contents)\n\n**Description**\n\nMultiTransform is an extension to keep the same transformation over a sequence of images instead of initializing a new transformation for every single image. It is inspired by `torchvision.transforms` and could be used for video augmentation. Use `rsp.ml.multi_transforms.Compose`to combine multiple image sequence transformations.\n\n> **Note** `rsp.ml.multi_transforms.MultiTransform` is a base class and should be inherited.\n\n\n### 5.13.1 \\_\\_call\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nCall self as a function.\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| input | torch.Tensor<br>List[PIL.Image]<br>List[numpy.array] | Sequence of images |\n### 5.13.2 \\_\\_init\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nInitializes a new instance.\n\n## 5.14 RemoveBackgroundAI : MultiTransform\n\n[TOC](#table-of-contents)\n\n**Description**\n\nMultiTransform is an extension to keep the same transformation over a sequence of images instead of initializing a new transformation for every single image. It is inspired by `torchvision.transforms` and could be used for video augmentation. Use `rsp.ml.multi_transforms.Compose`to combine multiple image sequence transformations.\n\n> **Note** `rsp.ml.multi_transforms.MultiTransform` is a base class and should be inherited.\n\n\n### 5.14.1 \\_\\_call\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nCall self as a function.\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| input | torch.Tensor<br>List[PIL.Image]<br>List[numpy.array] | Sequence of images |\n### 5.14.2 \\_\\_init\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nInitializes a new instance.\n\n## 5.15 ReplaceBackground : MultiTransform\n\n[TOC](#table-of-contents)\n\n**Description**\n\nTransformation for background replacement based on HSV values. Supports depth background replacement. backgrounds have to be passed as list of tuples of rgb and depth images.\n\n**Example**\n\n```python\nfrom rsp.nl.dataset import TUCRID\nimport rsp.ml.multi_transforms as multi_transforms\n\nUSE_DEPTH_DATA = False\nbackgrounds = TUCRID.load_backgrounds(USE_DEPTH_DATA)\ntranforms_train = multi_transforms.Compose([\n multi_transforms.ReplaceBackground(\n backgrounds = backgrounds,\n hsv_filter=[(69, 87, 139, 255, 52, 255)],\n p = 0.8\n ),\n multi_transforms.Stack()\n])\ntucrid = TUCRID('train', load_depth_data=USE_DEPTH_DATA, transforms=tranforms_train)\n\nfor X, T in tucrid:\n for x in X:\n img = x.permute(1, 2, 0).numpy()\n\n cv.imshow('img', img)\n cv.waitKey(30)\n```\n### 5.15.1 \\_\\_call\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nCall self as a function.\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| input | torch.Tensor<br>List[PIL.Image]<br>List[numpy.array] | Sequence of images |\n### 5.15.2 \\_\\_init\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nTransformation for background replacement based on HSV values. Supports depth background replacement. backgrounds have to be passed as list of tuples of rgb and depth images.\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| backgrounds | List[np.array] | List of background images |\n| hsv_filter | List[tuple[int, int, int, int, int, int]] | List of HSV filters |\n| p | float, default = 1. | Probability of applying the transformation |\n| rotate | float, default = 5 | Maximum rotation angle |\n| max_scale | float, default = 2 | Maximum scaling factor |\n| max_noise | float, default = 0.002 | Maximum noise level |\n## 5.16 Resize : MultiTransform\n\n[TOC](#table-of-contents)\n\n**Description**\n\nMultiTransform is an extension to keep the same transformation over a sequence of images instead of initializing a new transformation for every single image. It is inspired by `torchvision.transforms` and could be used for video augmentation. Use `rsp.ml.multi_transforms.Compose`to combine multiple image sequence transformations.\n\n> **Note** `rsp.ml.multi_transforms.MultiTransform` is a base class and should be inherited.\n\n\n### 5.16.1 \\_\\_call\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nCall self as a function.\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| input | torch.Tensor<br>List[PIL.Image]<br>List[numpy.array] | Sequence of images |\n### 5.16.2 \\_\\_init\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nInitializes a new instance.\n\n## 5.17 Rotate : MultiTransform\n\n[TOC](#table-of-contents)\n\n**Description**\n\nRandomly rotates images.\n\n**Equations**\n\n$angle = -max\\_angle + 2 \\cdot random() \\cdot max\\_angle$\n\n\n\n\n### 5.17.1 \\_\\_call\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nCall self as a function.\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| input | torch.Tensor<br>List[PIL.Image]<br>List[numpy.array] | Sequence of images |\n### 5.17.2 \\_\\_init\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nIitializes a new instance.\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| max_angle | float | Maximal rotation in degrees | -max_angle <= rotate <= max_angle |\n| auto_scale | bool, default = True | Image will be resized when auto scale is activated to avoid black margins. |\n## 5.18 Satturation : MultiTransform\n\n[TOC](#table-of-contents)\n\n**Description**\n\nMultiTransform is an extension to keep the same transformation over a sequence of images instead of initializing a new transformation for every single image. It is inspired by `torchvision.transforms` and could be used for video augmentation. Use `rsp.ml.multi_transforms.Compose`to combine multiple image sequence transformations.\n\n> **Note** `rsp.ml.multi_transforms.MultiTransform` is a base class and should be inherited.\n\n\n### 5.18.1 \\_\\_call\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nCall self as a function.\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| input | torch.Tensor<br>List[PIL.Image]<br>List[numpy.array] | Sequence of images |\n### 5.18.2 \\_\\_init\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nInitializes a new instance.\n\n## 5.19 Scale : MultiTransform\n\n[TOC](#table-of-contents)\n\n**Description**\n\nMultiTransform is an extension to keep the same transformation over a sequence of images instead of initializing a new transformation for every single image. It is inspired by `torchvision.transforms` and could be used for video augmentation. Use `rsp.ml.multi_transforms.Compose`to combine multiple image sequence transformations.\n\n> **Note** `rsp.ml.multi_transforms.MultiTransform` is a base class and should be inherited.\n\n\n### 5.19.1 \\_\\_call\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nCall self as a function.\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| input | torch.Tensor<br>List[PIL.Image]<br>List[numpy.array] | Sequence of images |\n### 5.19.2 \\_\\_init\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nInitializes a new instance.\n\n## 5.20 Stack : MultiTransform\n\n[TOC](#table-of-contents)\n\n**Description**\n\nMultiTransform is an extension to keep the same transformation over a sequence of images instead of initializing a new transformation for every single image. It is inspired by `torchvision.transforms` and could be used for video augmentation. Use `rsp.ml.multi_transforms.Compose`to combine multiple image sequence transformations.\n\n> **Note** `rsp.ml.multi_transforms.MultiTransform` is a base class and should be inherited.\n\n\n### 5.20.1 \\_\\_call\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nCall self as a function.\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| input | torch.Tensor<br>List[PIL.Image]<br>List[numpy.array] | Sequence of images |\n### 5.20.2 \\_\\_init\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nInitializes a new instance.\n\n## 5.21 ToCVImage : MultiTransform\n\n[TOC](#table-of-contents)\n\n**Description**\n\nConverts a `torch.Tensor`to Open CV image by changing dimensions (d0, d1, d2) -> (d1, d2, d0) and converting `torch.Tensor` to `numpy`.\n\n\n### 5.21.1 \\_\\_call\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nCall self as a function.\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| input | torch.Tensor<br>List[PIL.Image]<br>List[numpy.array] | Sequence of images |\n### 5.21.2 \\_\\_init\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nInitializes a new instance.\n\n## 5.22 ToNumpy : MultiTransform\n\n[TOC](#table-of-contents)\n\n**Description**\n\nConverts a `torch.Tensor`to `numpy`\n\n\n### 5.22.1 \\_\\_call\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nCall self as a function.\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| input | torch.Tensor<br>List[PIL.Image]<br>List[numpy.array] | Sequence of images |\n### 5.22.2 \\_\\_init\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nInitializes a new instance.\n\n## 5.23 ToPILImage : MultiTransform\n\n[TOC](#table-of-contents)\n\n**Description**\n\nConverts sequence of images to sequence of `PIL.Image`.\n\n\n### 5.23.1 \\_\\_call\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nCall self as a function.\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| input | torch.Tensor<br>List[PIL.Image]<br>List[numpy.array] | Sequence of images |\n### 5.23.2 \\_\\_init\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nInitializes a new instance.\n\n## 5.24 ToTensor : MultiTransform\n\n[TOC](#table-of-contents)\n\n**Description**\n\nConverts a sequence of images to torch.Tensor.\n\n\n### 5.24.1 \\_\\_call\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nCall self as a function.\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| input | torch.Tensor<br>List[PIL.Image]<br>List[numpy.array] | Sequence of images |\n### 5.24.2 \\_\\_init\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nInitializes a new instance.\n\n# 6 run\n\n[TOC](#table-of-contents)\n\nThe module `rsp.ml.run` provides some tools for storing, loading and visualizing data during training of models using PyTorch. \n\n## 6.1 Run : builtins.object\n\n[TOC](#table-of-contents)\n\n**Description**\n\nRun class to store and manage training\n\n**Example**\n\n```python\nfrom rsp.ml.run import Run\nimport rsp.ml.metrics as m\n\nmetrics = [\n m.top_1_accuracy\n]\nconfig = {\n m.top_1_accuracy.__name__: {\n 'ymin': 0,\n 'ymax': 1\n }\n}\nrun = Run(id='run0001', metrics=metrics, config=config, ignore_outliers_in_chart_scaling=True)\n\nfor epoch in range(100):\n \"\"\"here goes some training code, giving us inputs, predictions and targets\"\"\"\n acc = m.top_1_accuracy(predictions, targets)\n run.append(m.top_1_accuracy.__name__, 'train', acc)\n```\n### 6.1.1 \\_\\_init\\_\\_\n\n[TOC](#table-of-contents)\n\n**Description**\n\nRun class to store and manage training\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| id | str, default = None | Id of the run. If None, a new id is generated |\n| moving_average_epochs | int, default = 1 | Number of epochs to average over |\n| metrics | list, default = None | List of metrics to compute. Each metric should be a function that takes Y and T as input. |\n| device | str, default = None | torch device to run on |\n| ignore_outliers_in_chart_scaling | bool, default = False | Ignore outliers when scaling charts |\n| config | dict, default = {} | Configuration dictionary. Keys are metric names and values are dictionaries with keys 'ymin' and 'ymax' |\n### 6.1.2 append\n\n[TOC](#table-of-contents)\n\n**Description**\n\nAppend value to key in phase.\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| key | str | Key to append to |\n| phase | str | Phase to append to |\n| value | float | Value to append |\n### 6.1.3 get\\_avg\n\n[TOC](#table-of-contents)\n\n**Description**\n\nGet last average value of key in phase\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| key | str | Key to get |\n| phase | str | Phase to get from |\n\n**Returns**\n\nLast average value of key in phase. If key is not in data, returns np.nan : value : float\n\n### 6.1.4 get\\_val\n\n[TOC](#table-of-contents)\n\n**Description**\n\nGet last value of key in phase\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| key | str | Key to get |\n| phase | str | Phase to get from |\n\n**Returns**\n\nLast value of key in phase. If key is not in data, returns np.nan : value : float\n\n### 6.1.5 len\n\n[TOC](#table-of-contents)\n\n**Description**\n\nGet length of longest phase\n\n### 6.1.6 load\\_best\\_state\\_dict\n\n[TOC](#table-of-contents)\n\n**Description**\n\nLoad best state_dict from runs/{id}/{fname}\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| model | torch.nn.Module | Model to load state_dict into |\n| fname | str, default = 'state_dict.pt' | Filename to load from |\n| verbose | bool, default = False | Print loaded file |\n### 6.1.7 load\\_state\\_dict\n\n[TOC](#table-of-contents)\n\n**Description**\n\nLoad state_dict from runs/{id}/{fname}\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| model | torch.nn.Module | Model to load state_dict into |\n| fname | str, default = None | Filename to load from |\n### 6.1.8 pickle\\_dump\n\n[TOC](#table-of-contents)\n\n**Description**\n\nPickle model to runs/{id}/{fname}\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| model | torch.nn.Module | Model to pickle |\n| fname | str, default = 'model.pkl' | Filename to save to |\n### 6.1.9 pickle\\_load\n\n[TOC](#table-of-contents)\n\n**Description**\n\nLoad model from runs/{id}/{fname}\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| fname | str, default = 'model.pkl' | Filename to load from |\n### 6.1.10 plot\n\n[TOC](#table-of-contents)\n\n**Description**\n\nPlot all keys to runs/{id}/plot/{key}.jpg\n\n### 6.1.11 recalculate\\_moving\\_average\n\n[TOC](#table-of-contents)\n\n**Description**\n\nRecalculate moving average\n\n### 6.1.12 save\n\n[TOC](#table-of-contents)\n\n**Description**\n\nSave data to runs/{id}/data.json\n\n### 6.1.13 save\\_best\\_state\\_dict\n\n[TOC](#table-of-contents)\n\n**Description**\n\nSave state_dict if new_acc is better than previous best\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| state_dict | dict | State dict to save |\n| new_acc | float | New accuracy |\n| epoch | int, default = None | Epoch to save |\n| fname | str, default = 'state_dict.pt' | Filename to save to |\n### 6.1.14 save\\_state\\_dict\n\n[TOC](#table-of-contents)\n\n**Description**\n\nSave state_dict to runs/{id}/{fname}\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| state_dict | dict | State dict to save |\n| fname | str, default = 'state_dict.pt' | Filename to save to |\n### 6.1.15 train\\_epoch\n\n[TOC](#table-of-contents)\n\n**Description**\n\nTrain one epoch.\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| dataloader | DataLoader | DataLoader to train on |\n| model | torch.nn.Module | Model to train |\n| optimizer | torch.optim.Optimizer | Optimizer to use |\n| criterion | torch.nn.Module | Criterion to use |\n| num_batches | int, default = None | Number of batches to train on. If None, train on all batches |\n| return_YT | bool, default = False | Append Y and T to results |\n\n**Returns**\n\nDictionary with results : results : dict\n\n### 6.1.16 validate\\_epoch\n\n[TOC](#table-of-contents)\n\n**Description**\n\nValidate one epoch.\n\n**Parameters**\n\n| Name | Type | Description |\n|------|------|-------------|\n| dataloader | DataLoader | DataLoader to validate on |\n| model | torch.nn.Module | Model to validate |\n| optimizer | torch.optim.Optimizer | Optimizer to use |\n| criterion | torch.nn.Module | Criterion to use |\n| num_batches | int, default = None | Number of batches to validate on. If None, validate on all batches |\n| return_YT | bool, default = False | Append Y and T to results |\n\n**Returns**\n\nDictionary with results : results : dict\n\n",
"bugtrack_url": null,
"license": "MIT",
"summary": "Machine Learning",
"version": "0.0.133",
"project_urls": {
"Homepage": "https://github.com/SchulzR97/rsp-ml"
},
"split_keywords": [
"python",
" machine learning"
],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "923d2d52d1b537f62604093856918b80079ea4ce9ca14675b7150d70aa625564",
"md5": "d1e5667be1021ad201bf752a82b54442",
"sha256": "f7b0c9c9264fc2280e971baa66bfa8a1ba77395b83b79d0b3a71491929273d3f"
},
"downloads": -1,
"filename": "rsp_ml-0.0.133-py3-none-any.whl",
"has_sig": false,
"md5_digest": "d1e5667be1021ad201bf752a82b54442",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": null,
"size": 59971,
"upload_time": "2025-03-18T12:03:53",
"upload_time_iso_8601": "2025-03-18T12:03:53.507956Z",
"url": "https://files.pythonhosted.org/packages/92/3d/2d52d1b537f62604093856918b80079ea4ce9ca14675b7150d70aa625564/rsp_ml-0.0.133-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": "",
"digests": {
"blake2b_256": "27517ccf46a8f2464d7cca4688e6c9bdf31b542a818b905524dd8c55201f29b4",
"md5": "2fbea2074db5d8b511f86e1337bfde6d",
"sha256": "2d96cb6bd2e9f64617ea1eba973d454899b41386219a5f7374f5affd8746fc43"
},
"downloads": -1,
"filename": "rsp_ml-0.0.133.tar.gz",
"has_sig": false,
"md5_digest": "2fbea2074db5d8b511f86e1337bfde6d",
"packagetype": "sdist",
"python_version": "source",
"requires_python": null,
"size": 113748,
"upload_time": "2025-03-18T12:03:59",
"upload_time_iso_8601": "2025-03-18T12:03:59.107578Z",
"url": "https://files.pythonhosted.org/packages/27/51/7ccf46a8f2464d7cca4688e6c9bdf31b542a818b905524dd8c55201f29b4/rsp_ml-0.0.133.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2025-03-18 12:03:59",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "SchulzR97",
"github_project": "rsp-ml",
"travis_ci": false,
"coveralls": false,
"github_actions": false,
"lcname": "rsp-ml"
}