# trainplotkit
Create live subplots in your notebook that update while training a PyTorch model
# Features
* Extensible framework for adding subplots updated in real time in your notebook during training
* Interaction between subplots after training has completed
* Click on one subplot to select an epoch / sample and update other subplots dynamically
* Supports custom training loops and high-level training libraries like [pytorch_lightning](https://github.com/Lightning-AI/pytorch-lightning) and [fastai](https://github.com/fastai/fastai)
* Coming soon: adapters for even more seamless integration with high-level training libraries
* All graph interactions provided by [plotly](https://plotly.com/python/)
* Built-in subplots:
* Training curves
* Custom metric vs epoch
* Validation loss for individual samples (scatter plot)
* Input image corresponding to selected sample
* Class probililities corresponding to selected sample
* Coming soon: colourful dimension plot from [fastai course Lesson 16](https://course.fast.ai/Lessons/lesson16.html) 1:14:30 for visualizing activation stats
# Use cases
* Quickly identifying and explaining outlier samples in a dataset
* Quickly developing visualizations to improve your understanding of a model and/or training process
# Installation
```
pip install git+https://github.com/d112358/trainplotkit.git
```
# Usage example
```python
import torch
import torchvision.transforms as T
from torch.utils.data import DataLoader
from torch import nn, optim
from torchvision import datasets, models
from torcheval.metrics import MulticlassAccuracy
from trainplotkit.plotgrid import PlotGrid
from trainplotkit.subplots.basic import TrainingCurveSP, MetricSP, ValidLossSP, ImageSP, ClassProbsSP
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Data preparation
transform = T.Compose([T.ToTensor(), T.Normalize((0.5,), (0.5,))])
train_data = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
valid_data = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=64, num_workers=15, shuffle=True)
valid_loader = DataLoader(valid_data, batch_size=64, num_workers=15, shuffle=False)
num_classes = len(valid_data.classes)
# Model setup
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)
# Plots
batch_loss_fn = nn.CrossEntropyLoss(reduction='none')
probs_fn = lambda preds: torch.softmax(preds, dim=1)
sps = [
TrainingCurveSP(colspan=2),
ValidLossSP(batch_loss_fn, remember_past_epochs=True, colspan=2),
ImageSP(valid_data, class_names=valid_data.classes, rowspan=2),
MetricSP("Accuracy", MulticlassAccuracy(), colspan=2),
ClassProbsSP(probs_fn, remember_past_epochs=True, class_names=valid_data.classes, colspan=2),
]
pg = PlotGrid(num_grid_cols=5, subplots=sps)
pg.show()
# Training and validation loop
for epoch in range(4):
# Training
model.train()
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
pg.after_batch(training=True, inputs=images, targets=labels, predictions=outputs, loss=loss)
pg.after_epoch(training=True)
# Validation
model.eval()
val_loss, correct = 0, 0
with torch.no_grad():
for images, labels in valid_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
val_loss += criterion(outputs, labels).item()
correct += (outputs.argmax(1) == labels).sum().item()
pg.after_batch(training=False, inputs=images, targets=labels, predictions=outputs, loss=loss)
pg.after_epoch(training=False)
pg.after_fit()
```
![Usage example](https://github.com/d112358/trainplotkit/raw/main/images/usage_example.png)
# License
This repository is released under the MIT license. See [LICENSE](https://github.com/d112358/trainplotkit/blob/main/LICENSE) for additional details.
Raw data
{
"_id": null,
"home_page": null,
"name": "trainplotkit",
"maintainer": null,
"docs_url": null,
"requires_python": ">=3.9",
"maintainer_email": null,
"keywords": "pytorch, torch, deep learning, neural network, training, visualization, interactive, dashboard, jupyter, notebook, plotly",
"author": null,
"author_email": "Dirk Oosthuizen <dirk.jj.oosthuizen@gmail.com>",
"download_url": "https://files.pythonhosted.org/packages/bc/37/697ad3787b88180aa2e2d23b3d1b7ddca2ab0165aa1f725ba917de5e5db9/trainplotkit-0.1.0.tar.gz",
"platform": null,
"description": "# trainplotkit\nCreate live subplots in your notebook that update while training a PyTorch model\n\n# Features\n* Extensible framework for adding subplots updated in real time in your notebook during training\n* Interaction between subplots after training has completed\n * Click on one subplot to select an epoch / sample and update other subplots dynamically\n* Supports custom training loops and high-level training libraries like [pytorch_lightning](https://github.com/Lightning-AI/pytorch-lightning) and [fastai](https://github.com/fastai/fastai)\n * Coming soon: adapters for even more seamless integration with high-level training libraries\n* All graph interactions provided by [plotly](https://plotly.com/python/)\n* Built-in subplots:\n * Training curves\n * Custom metric vs epoch\n * Validation loss for individual samples (scatter plot)\n * Input image corresponding to selected sample\n * Class probililities corresponding to selected sample\n * Coming soon: colourful dimension plot from [fastai course Lesson 16](https://course.fast.ai/Lessons/lesson16.html) 1:14:30 for visualizing activation stats\n\n# Use cases\n* Quickly identifying and explaining outlier samples in a dataset\n* Quickly developing visualizations to improve your understanding of a model and/or training process\n\n# Installation\n```\npip install git+https://github.com/d112358/trainplotkit.git\n```\n\n# Usage example\n```python\nimport torch\nimport torchvision.transforms as T\nfrom torch.utils.data import DataLoader\nfrom torch import nn, optim\nfrom torchvision import datasets, models\nfrom torcheval.metrics import MulticlassAccuracy\nfrom trainplotkit.plotgrid import PlotGrid\nfrom trainplotkit.subplots.basic import TrainingCurveSP, MetricSP, ValidLossSP, ImageSP, ClassProbsSP\n\n# Set device\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\nprint(f\"Using device: {device}\")\n\n# Data preparation\ntransform = T.Compose([T.ToTensor(), T.Normalize((0.5,), (0.5,))])\ntrain_data = datasets.CIFAR10(root=\"./data\", train=True, download=True, transform=transform)\nvalid_data = datasets.CIFAR10(root=\"./data\", train=False, download=True, transform=transform)\ntrain_loader = DataLoader(train_data, batch_size=64, num_workers=15, shuffle=True)\nvalid_loader = DataLoader(valid_data, batch_size=64, num_workers=15, shuffle=False)\nnum_classes = len(valid_data.classes)\n\n# Model setup\nmodel = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)\nmodel.fc = nn.Linear(model.fc.in_features, num_classes)\nmodel = model.to(device)\ncriterion = nn.CrossEntropyLoss()\noptimizer = optim.Adam(model.parameters(), lr=0.0001)\n\n# Plots\nbatch_loss_fn = nn.CrossEntropyLoss(reduction='none')\nprobs_fn = lambda preds: torch.softmax(preds, dim=1)\nsps = [\n TrainingCurveSP(colspan=2), \n ValidLossSP(batch_loss_fn, remember_past_epochs=True, colspan=2), \n ImageSP(valid_data, class_names=valid_data.classes, rowspan=2),\n MetricSP(\"Accuracy\", MulticlassAccuracy(), colspan=2), \n ClassProbsSP(probs_fn, remember_past_epochs=True, class_names=valid_data.classes, colspan=2),\n]\npg = PlotGrid(num_grid_cols=5, subplots=sps)\npg.show()\n\n# Training and validation loop\nfor epoch in range(4):\n # Training\n model.train()\n for images, labels in train_loader:\n images, labels = images.to(device), labels.to(device)\n optimizer.zero_grad()\n outputs = model(images)\n loss = criterion(outputs, labels)\n loss.backward()\n optimizer.step()\n pg.after_batch(training=True, inputs=images, targets=labels, predictions=outputs, loss=loss)\n pg.after_epoch(training=True)\n\n # Validation\n model.eval()\n val_loss, correct = 0, 0\n with torch.no_grad():\n for images, labels in valid_loader:\n images, labels = images.to(device), labels.to(device)\n outputs = model(images)\n val_loss += criterion(outputs, labels).item()\n correct += (outputs.argmax(1) == labels).sum().item()\n pg.after_batch(training=False, inputs=images, targets=labels, predictions=outputs, loss=loss)\n pg.after_epoch(training=False)\npg.after_fit()\n```\n![Usage example](https://github.com/d112358/trainplotkit/raw/main/images/usage_example.png)\n\n# License\nThis repository is released under the MIT license. See [LICENSE](https://github.com/d112358/trainplotkit/blob/main/LICENSE) for additional details.\n",
"bugtrack_url": null,
"license": "MIT",
"summary": "Create live subplots in your notebook that update while training a PyTorch model",
"version": "0.1.0",
"project_urls": {
"Bug Tracker": "https://github.com/d112358/trainplotkit/issues",
"Repository": "https://github.com/d112358/trainplotkit.git"
},
"split_keywords": [
"pytorch",
" torch",
" deep learning",
" neural network",
" training",
" visualization",
" interactive",
" dashboard",
" jupyter",
" notebook",
" plotly"
],
"urls": [
{
"comment_text": null,
"digests": {
"blake2b_256": "9ff0d3319f004f497a6c914f31babf69354055db57afabbac87fcb26dcb4979e",
"md5": "764b3c2f947b025a474244300385823a",
"sha256": "32d8f4498fb0caa85134297c6effd2ccc9e6f764e9fa17233d9ebf64c46ce7e8"
},
"downloads": -1,
"filename": "trainplotkit-0.1.0-py3-none-any.whl",
"has_sig": false,
"md5_digest": "764b3c2f947b025a474244300385823a",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": ">=3.9",
"size": 14139,
"upload_time": "2025-01-21T22:00:06",
"upload_time_iso_8601": "2025-01-21T22:00:06.203608Z",
"url": "https://files.pythonhosted.org/packages/9f/f0/d3319f004f497a6c914f31babf69354055db57afabbac87fcb26dcb4979e/trainplotkit-0.1.0-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": null,
"digests": {
"blake2b_256": "bc37697ad3787b88180aa2e2d23b3d1b7ddca2ab0165aa1f725ba917de5e5db9",
"md5": "0898a45d835e44b41fdfb12db92f4087",
"sha256": "d395ea56062e5ef8c801bf76f98a4ea849b02a35359ae9d62f5daef6302033b5"
},
"downloads": -1,
"filename": "trainplotkit-0.1.0.tar.gz",
"has_sig": false,
"md5_digest": "0898a45d835e44b41fdfb12db92f4087",
"packagetype": "sdist",
"python_version": "source",
"requires_python": ">=3.9",
"size": 13666,
"upload_time": "2025-01-21T22:00:08",
"upload_time_iso_8601": "2025-01-21T22:00:08.800584Z",
"url": "https://files.pythonhosted.org/packages/bc/37/697ad3787b88180aa2e2d23b3d1b7ddca2ab0165aa1f725ba917de5e5db9/trainplotkit-0.1.0.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2025-01-21 22:00:08",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "d112358",
"github_project": "trainplotkit",
"travis_ci": false,
"coveralls": false,
"github_actions": false,
"lcname": "trainplotkit"
}