trainplotkit


Nametrainplotkit JSON
Version 0.1.0 PyPI version JSON
download
home_pageNone
SummaryCreate live subplots in your notebook that update while training a PyTorch model
upload_time2025-01-21 22:00:08
maintainerNone
docs_urlNone
authorNone
requires_python>=3.9
licenseMIT
keywords pytorch torch deep learning neural network training visualization interactive dashboard jupyter notebook plotly
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # trainplotkit
Create live subplots in your notebook that update while training a PyTorch model

# Features
* Extensible framework for adding subplots updated in real time in your notebook during training
* Interaction between subplots after training has completed
  * Click on one subplot to select an epoch / sample and update other subplots dynamically
* Supports custom training loops and high-level training libraries like [pytorch_lightning](https://github.com/Lightning-AI/pytorch-lightning) and [fastai](https://github.com/fastai/fastai)
  * Coming soon: adapters for even more seamless integration with high-level training libraries
* All graph interactions provided by [plotly](https://plotly.com/python/)
* Built-in subplots:
  * Training curves
  * Custom metric vs epoch
  * Validation loss for individual samples (scatter plot)
  * Input image corresponding to selected sample
  * Class probililities corresponding to selected sample
  * Coming soon: colourful dimension plot from [fastai course Lesson 16](https://course.fast.ai/Lessons/lesson16.html) 1:14:30 for visualizing activation stats

# Use cases
* Quickly identifying and explaining outlier samples in a dataset
* Quickly developing visualizations to improve your understanding of a model and/or training process

# Installation
```
pip install git+https://github.com/d112358/trainplotkit.git
```

# Usage example
```python
import torch
import torchvision.transforms as T
from torch.utils.data import DataLoader
from torch import nn, optim
from torchvision import datasets, models
from torcheval.metrics import MulticlassAccuracy
from trainplotkit.plotgrid import PlotGrid
from trainplotkit.subplots.basic import TrainingCurveSP, MetricSP, ValidLossSP, ImageSP, ClassProbsSP

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Data preparation
transform    = T.Compose([T.ToTensor(), T.Normalize((0.5,), (0.5,))])
train_data   = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
valid_data   = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=64, num_workers=15, shuffle=True)
valid_loader = DataLoader(valid_data, batch_size=64, num_workers=15, shuffle=False)
num_classes  = len(valid_data.classes)

# Model setup
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

# Plots
batch_loss_fn = nn.CrossEntropyLoss(reduction='none')
probs_fn = lambda preds: torch.softmax(preds, dim=1)
sps = [
    TrainingCurveSP(colspan=2), 
    ValidLossSP(batch_loss_fn, remember_past_epochs=True, colspan=2), 
    ImageSP(valid_data, class_names=valid_data.classes, rowspan=2),
    MetricSP("Accuracy", MulticlassAccuracy(), colspan=2), 
    ClassProbsSP(probs_fn, remember_past_epochs=True, class_names=valid_data.classes, colspan=2),
]
pg = PlotGrid(num_grid_cols=5, subplots=sps)
pg.show()

# Training and validation loop
for epoch in range(4):
    # Training
    model.train()
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        pg.after_batch(training=True, inputs=images, targets=labels, predictions=outputs, loss=loss)
    pg.after_epoch(training=True)

    # Validation
    model.eval()
    val_loss, correct = 0, 0
    with torch.no_grad():
        for images, labels in valid_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            val_loss += criterion(outputs, labels).item()
            correct += (outputs.argmax(1) == labels).sum().item()
            pg.after_batch(training=False, inputs=images, targets=labels, predictions=outputs, loss=loss)
    pg.after_epoch(training=False)
pg.after_fit()
```
![Usage example](https://github.com/d112358/trainplotkit/raw/main/images/usage_example.png)

# License
This repository is released under the MIT license. See [LICENSE](https://github.com/d112358/trainplotkit/blob/main/LICENSE) for additional details.

            

Raw data

            {
    "_id": null,
    "home_page": null,
    "name": "trainplotkit",
    "maintainer": null,
    "docs_url": null,
    "requires_python": ">=3.9",
    "maintainer_email": null,
    "keywords": "pytorch, torch, deep learning, neural network, training, visualization, interactive, dashboard, jupyter, notebook, plotly",
    "author": null,
    "author_email": "Dirk Oosthuizen <dirk.jj.oosthuizen@gmail.com>",
    "download_url": "https://files.pythonhosted.org/packages/bc/37/697ad3787b88180aa2e2d23b3d1b7ddca2ab0165aa1f725ba917de5e5db9/trainplotkit-0.1.0.tar.gz",
    "platform": null,
    "description": "# trainplotkit\nCreate live subplots in your notebook that update while training a PyTorch model\n\n# Features\n* Extensible framework for adding subplots updated in real time in your notebook during training\n* Interaction between subplots after training has completed\n  * Click on one subplot to select an epoch / sample and update other subplots dynamically\n* Supports custom training loops and high-level training libraries like [pytorch_lightning](https://github.com/Lightning-AI/pytorch-lightning) and [fastai](https://github.com/fastai/fastai)\n  * Coming soon: adapters for even more seamless integration with high-level training libraries\n* All graph interactions provided by [plotly](https://plotly.com/python/)\n* Built-in subplots:\n  * Training curves\n  * Custom metric vs epoch\n  * Validation loss for individual samples (scatter plot)\n  * Input image corresponding to selected sample\n  * Class probililities corresponding to selected sample\n  * Coming soon: colourful dimension plot from [fastai course Lesson 16](https://course.fast.ai/Lessons/lesson16.html) 1:14:30 for visualizing activation stats\n\n# Use cases\n* Quickly identifying and explaining outlier samples in a dataset\n* Quickly developing visualizations to improve your understanding of a model and/or training process\n\n# Installation\n```\npip install git+https://github.com/d112358/trainplotkit.git\n```\n\n# Usage example\n```python\nimport torch\nimport torchvision.transforms as T\nfrom torch.utils.data import DataLoader\nfrom torch import nn, optim\nfrom torchvision import datasets, models\nfrom torcheval.metrics import MulticlassAccuracy\nfrom trainplotkit.plotgrid import PlotGrid\nfrom trainplotkit.subplots.basic import TrainingCurveSP, MetricSP, ValidLossSP, ImageSP, ClassProbsSP\n\n# Set device\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\nprint(f\"Using device: {device}\")\n\n# Data preparation\ntransform    = T.Compose([T.ToTensor(), T.Normalize((0.5,), (0.5,))])\ntrain_data   = datasets.CIFAR10(root=\"./data\", train=True, download=True, transform=transform)\nvalid_data   = datasets.CIFAR10(root=\"./data\", train=False, download=True, transform=transform)\ntrain_loader = DataLoader(train_data, batch_size=64, num_workers=15, shuffle=True)\nvalid_loader = DataLoader(valid_data, batch_size=64, num_workers=15, shuffle=False)\nnum_classes  = len(valid_data.classes)\n\n# Model setup\nmodel = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)\nmodel.fc = nn.Linear(model.fc.in_features, num_classes)\nmodel = model.to(device)\ncriterion = nn.CrossEntropyLoss()\noptimizer = optim.Adam(model.parameters(), lr=0.0001)\n\n# Plots\nbatch_loss_fn = nn.CrossEntropyLoss(reduction='none')\nprobs_fn = lambda preds: torch.softmax(preds, dim=1)\nsps = [\n    TrainingCurveSP(colspan=2), \n    ValidLossSP(batch_loss_fn, remember_past_epochs=True, colspan=2), \n    ImageSP(valid_data, class_names=valid_data.classes, rowspan=2),\n    MetricSP(\"Accuracy\", MulticlassAccuracy(), colspan=2), \n    ClassProbsSP(probs_fn, remember_past_epochs=True, class_names=valid_data.classes, colspan=2),\n]\npg = PlotGrid(num_grid_cols=5, subplots=sps)\npg.show()\n\n# Training and validation loop\nfor epoch in range(4):\n    # Training\n    model.train()\n    for images, labels in train_loader:\n        images, labels = images.to(device), labels.to(device)\n        optimizer.zero_grad()\n        outputs = model(images)\n        loss = criterion(outputs, labels)\n        loss.backward()\n        optimizer.step()\n        pg.after_batch(training=True, inputs=images, targets=labels, predictions=outputs, loss=loss)\n    pg.after_epoch(training=True)\n\n    # Validation\n    model.eval()\n    val_loss, correct = 0, 0\n    with torch.no_grad():\n        for images, labels in valid_loader:\n            images, labels = images.to(device), labels.to(device)\n            outputs = model(images)\n            val_loss += criterion(outputs, labels).item()\n            correct += (outputs.argmax(1) == labels).sum().item()\n            pg.after_batch(training=False, inputs=images, targets=labels, predictions=outputs, loss=loss)\n    pg.after_epoch(training=False)\npg.after_fit()\n```\n![Usage example](https://github.com/d112358/trainplotkit/raw/main/images/usage_example.png)\n\n# License\nThis repository is released under the MIT license. See [LICENSE](https://github.com/d112358/trainplotkit/blob/main/LICENSE) for additional details.\n",
    "bugtrack_url": null,
    "license": "MIT",
    "summary": "Create live subplots in your notebook that update while training a PyTorch model",
    "version": "0.1.0",
    "project_urls": {
        "Bug Tracker": "https://github.com/d112358/trainplotkit/issues",
        "Repository": "https://github.com/d112358/trainplotkit.git"
    },
    "split_keywords": [
        "pytorch",
        " torch",
        " deep learning",
        " neural network",
        " training",
        " visualization",
        " interactive",
        " dashboard",
        " jupyter",
        " notebook",
        " plotly"
    ],
    "urls": [
        {
            "comment_text": null,
            "digests": {
                "blake2b_256": "9ff0d3319f004f497a6c914f31babf69354055db57afabbac87fcb26dcb4979e",
                "md5": "764b3c2f947b025a474244300385823a",
                "sha256": "32d8f4498fb0caa85134297c6effd2ccc9e6f764e9fa17233d9ebf64c46ce7e8"
            },
            "downloads": -1,
            "filename": "trainplotkit-0.1.0-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "764b3c2f947b025a474244300385823a",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.9",
            "size": 14139,
            "upload_time": "2025-01-21T22:00:06",
            "upload_time_iso_8601": "2025-01-21T22:00:06.203608Z",
            "url": "https://files.pythonhosted.org/packages/9f/f0/d3319f004f497a6c914f31babf69354055db57afabbac87fcb26dcb4979e/trainplotkit-0.1.0-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": null,
            "digests": {
                "blake2b_256": "bc37697ad3787b88180aa2e2d23b3d1b7ddca2ab0165aa1f725ba917de5e5db9",
                "md5": "0898a45d835e44b41fdfb12db92f4087",
                "sha256": "d395ea56062e5ef8c801bf76f98a4ea849b02a35359ae9d62f5daef6302033b5"
            },
            "downloads": -1,
            "filename": "trainplotkit-0.1.0.tar.gz",
            "has_sig": false,
            "md5_digest": "0898a45d835e44b41fdfb12db92f4087",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.9",
            "size": 13666,
            "upload_time": "2025-01-21T22:00:08",
            "upload_time_iso_8601": "2025-01-21T22:00:08.800584Z",
            "url": "https://files.pythonhosted.org/packages/bc/37/697ad3787b88180aa2e2d23b3d1b7ddca2ab0165aa1f725ba917de5e5db9/trainplotkit-0.1.0.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2025-01-21 22:00:08",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "d112358",
    "github_project": "trainplotkit",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": false,
    "lcname": "trainplotkit"
}
        
Elapsed time: 0.44108s