atai


Nameatai JSON
Version 0.0.6 PyPI version JSON
download
home_pagehttps://github.com/frenio/atai
SummaryAtomic AI – An attempt at a minimalist, flexible deep learning framework for diverse models.
upload_time2024-07-22 02:51:54
maintainerNone
docs_urlNone
authorFrenio Redeker
requires_python>=3.7
licenseApache Software License 2.0
keywords nbdev jupyter notebook python
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # atai


<!-- WARNING: THIS FILE WAS AUTOGENERATED! DO NOT EDIT! -->

Atomic AI is a flexible, minimalist deep neural network training
framework based on Jeremy Howard’s
[miniai](https://github.com/fastai/course22p2/tree/master) from the
[fast.ai 2022](https://course.fast.ai/Lessons/part2.html) course.

## Install

``` sh
pip install atai
```

## How to use

``` python
from atai.core import *
```

The following example demonstrates how the Atomic AI training framework
can be used to train a custom model that predicts protein solubility.

### Imports

``` python
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.nn import init
from torch import optim

from torcheval.metrics import BinaryAccuracy, BinaryAUROC
from torcheval.metrics.functional import binary_auroc, binary_accuracy
from torchmetrics.classification import BinaryMatthewsCorrCoef
from torchmetrics.functional.classification import binary_matthews_corrcoef

import fastcore.all as fc
from functools import partial
```

### Load Protein Solubility

This example uses the dataset from the
[DeepSol](https://doi.org/10.1093/bioinformatics/bty166) paper by
Khurana *et al.* which was obtained at
<https://zenodo.org/records/1162886>. It consists of amino acid
sequences of peptides along with solubility labels that are `1` if the
peptide is soluble and `0` if the peptide is insoluble.

``` python
train_sqs = open('sol_data/train_src', 'r').read().splitlines()
train_tgs = list(map(int, open('sol_data/train_tgt', 'r').read().splitlines()))

valid_sqs = open('sol_data/val_src', 'r').read().splitlines()
valid_tgs = list(map(int, open('sol_data/val_tgt', 'r').read().splitlines()))

train_sqs[:2], train_tgs[:2]
```

    (['GMILKTNLFGHTYQFKSITDVLAKANEEKSGDRLAGVAAESAEERVAAKVVLSKMTLGDLRNNPVVPYETDEVTRIIQDQVNDRIHDSIKNWTVEELREWILDHKTTDADIKRVARGLTSEIIAAVTKLMSNLDLIYGAKKIRVIAHANTTIGLPGTFSARLQPNHPTDDPDGILASLMEGLTYGIGDAVIGLNPVDDSTDSVVRLLNKFEEFRSKWDVPTQTCVLAHVKTQMEAMRRGAPTGLVFQSIAGSEKGNTAFGFDGATIEEARQLALQSGAATGPNVMYFETGQGSELSSDAHFGVDQVTMEARCYGFAKKFDPFLVNTVVGFIGPEYLYDSKQVIRAGLEDHFMGKLTGISMGCDVCYTNHMKADQNDVENLSVLLTAAGCNFIMGIPHGDDVMLNYQTTGYHETATLRELFGLKPIKEFDQWMEKMGFSENGKLTSRAGDASIFLK',
      'MAHHHHHHMSFFRMKRRLNFVVKRGIEELWENSFLDNNVDMKKIEYSKTGDAWPCVLLRKKSFEDLHKLYYICLKEKNKLLGEQYFHLQNSTKMLQHGRLKKVKLTMKRILTVLSRRAIHDQCLRAKDMLKKQEEREFYEIQKFKLNEQLLCLKHKMNILKKYNSFSLEQISLTFSIKKIENKIQQIDIILNPLRKETMYLLIPHFKYQRKYSDLPGFISWKKQNIIALRNNMSKLHRLY'],
     [1, 0])

``` python
len(train_sqs), len(train_tgs), len(valid_sqs), len(valid_tgs)
```

    (62478, 62478, 6942, 6942)

### Data Preparation

Create a sorted list of amino acid sequences `aas` including an empty
string for padding and determine the size of the vocabulary.

``` python
aas = sorted(list(set("".join(train_sqs))) + [""])
vocab_size = len(aas)
aas, vocab_size
```

    (['',
      'A',
      'C',
      'D',
      'E',
      'F',
      'G',
      'H',
      'I',
      'K',
      'L',
      'M',
      'N',
      'P',
      'Q',
      'R',
      'S',
      'T',
      'V',
      'W',
      'Y'],
     21)

Create dictionaries that translate between string and integer
representations of amino acids and define the corresponding `encode` and
`decode` functions.

``` python
str2int = {aa:i for i, aa in enumerate(aas)}
int2str = {i:aa for i, aa in enumerate(aas)}
encode = lambda s: [str2int[aa] for aa in s]
decode = lambda l: ''.join([int2str[i] for i in l])

print(encode("AYWCCCGGGHH"))
print(decode(encode("AYWCCCGGGHH")))
```

    [1, 20, 19, 2, 2, 2, 6, 6, 6, 7, 7]
    AYWCCCGGGHH

Figure out what the range of lengths of amino acid sequences in the
dataset is.

``` python
train_lens = list(map(len, train_sqs))
min(train_lens), max(train_lens)
```

    (19, 1691)

Create a function that drops all sequences above a chosen threshold and
also returns a list of indices of the sequences that meet the threshold
that can be used to obtain the correct labels.

``` python
def drop_long_sqs(sqs, threshold=1200):
    new_sqs = []
    idx = []
    for i, sq in enumerate(sqs):
        if len(sq) <= threshold:
            new_sqs.append(sq)
            idx.append(i)
    return new_sqs, idx
```

Drop all sequences above your chosen threshold.

``` python
trnsqs, trnidx = drop_long_sqs(train_sqs, threshold=200)
vldsqs, vldidx = drop_long_sqs(valid_sqs, threshold=200)
```

``` python
len(trnidx), len(vldidx)
```

    (18066, 1971)

``` python
max(map(len, trnsqs))
```

    200

Create a function for zero padding all sequences.

``` python
def zero_pad(sq, length=1200):
    new_sq = sq.copy()
    if len(new_sq) < length:
        new_sq.extend([0] * (length-len(new_sq)))
    return new_sq
```

Now encode and zero pad all sequences and make sure that it worked out
correctly.

``` python
trn = list(map(encode, trnsqs))
vld = list(map(encode, vldsqs))
print(f"Length of the first two sequences before zero padding: {len(trn[0])}, {len(trn[1])}")
trn = list(map(partial(zero_pad, length=200), trn))
vld = list(map(partial(zero_pad, length=200), vld))
print(f"Length of the first two sequences after zero padding:  {len(trn[0])}, {len(trn[1])}");
```

    Length of the first two sequences before zero padding: 116, 135
    Length of the first two sequences after zero padding:  200, 200

Convert the data to `torch.tensor`s unsing `dtype=torch.int64` and check
for correctness.

``` python
trntns = torch.tensor(trn, dtype=torch.int64)
vldtns = torch.tensor(vld, dtype=torch.int64)
trntns.shape, trntns[0]
```

    (torch.Size([18066, 200]),
     tensor([11,  9,  1, 10,  2, 10, 10, 10, 10, 13, 18, 10,  6, 10, 10, 18, 16, 16,  9, 17, 10,  2, 16, 11,  4,  4,  1,  8, 12,  4, 15,  8, 14,
              4, 18,  1,  6, 16, 10,  8,  5, 15,  1,  8, 16, 16,  8,  6, 10,  4,  2, 14, 16, 18, 17, 16, 15,  6,  3, 10,  1, 17,  2, 13, 15,  6,
              5,  1, 18, 17,  6,  2, 17,  2,  6, 16,  1,  2,  6, 16, 19,  3, 18, 15,  1,  4, 17, 17,  2,  7,  2, 14,  2,  1,  6, 11,  3, 19, 17,
              6,  1, 15,  2,  2, 15, 18, 14, 13, 10,  4,  7,  7,  7,  7,  7,  7,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
              0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
              0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
              0,  0]))

``` python
trntns.shape, vldtns.shape
```

    (torch.Size([18066, 200]), torch.Size([1971, 200]))

Obtain the correct labels using the lists of indices obtained from the
`drop_long_sqs` function and convert the lists of labels to tensors in
`torch.float32` format.

``` python
trnlbs = torch.tensor(train_tgs, dtype=torch.float32)[trnidx]
vldlbs = torch.tensor(valid_tgs, dtype=torch.float32)[vldidx]
trnlbs.shape, vldlbs.shape
```

    (torch.Size([18066]), torch.Size([1971]))

Calculate the ratios of soluble peptides in the train and valid data.

``` python
trnlbs.sum().item()/trnlbs.shape[0], vldlbs.sum().item()/vldlbs.shape[0]
```

    (0.4722129967895494, 0.4657534246575342)

These ratios tell us that there are slightly less than half soluble
proteins in the training an validation data, and slightly more than half
in the test set.

### Dataset and DataLoaders

Turn train and valid data into datasets using the
[`Dataset`](https://frenio.github.io/atai/core.html#dataset) class.

``` python
trnds = Dataset(trntns, trnlbs)
vldds = Dataset(vldtns, vldlbs)
trnds[0]
```

    (tensor([11,  9,  1, 10,  2, 10, 10, 10, 10, 13, 18, 10,  6, 10, 10, 18, 16, 16,  9, 17, 10,  2, 16, 11,  4,  4,  1,  8, 12,  4, 15,  8, 14,
              4, 18,  1,  6, 16, 10,  8,  5, 15,  1,  8, 16, 16,  8,  6, 10,  4,  2, 14, 16, 18, 17, 16, 15,  6,  3, 10,  1, 17,  2, 13, 15,  6,
              5,  1, 18, 17,  6,  2, 17,  2,  6, 16,  1,  2,  6, 16, 19,  3, 18, 15,  1,  4, 17, 17,  2,  7,  2, 14,  2,  1,  6, 11,  3, 19, 17,
              6,  1, 15,  2,  2, 15, 18, 14, 13, 10,  4,  7,  7,  7,  7,  7,  7,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
              0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
              0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
              0,  0]),
     tensor(0.))

Use the [`get_dls`](https://frenio.github.io/atai/core.html#get_dls)
function to obtain the dataloaders from the train and valid datasets.

``` python
dls = get_dls(trnds, vldds, bs=32)
next(iter(dls.train))[0][:2], next(iter(dls.train))[1][:2]
```

    (tensor([[11,  1,  7,  7,  7,  7,  7,  7, 11, 11, 13, 15, 16,  4, 12, 14,  9,  4,  4,  4,  4, 19,  4, 10,  7, 13, 10,  5, 10, 16,  9,  8, 13,
              12,  9,  9,  3,  8,  3,  9, 12, 13,  1, 10, 16,  1, 10,  8, 17, 10,  8, 12,  4,  4,  4,  4,  9,  4, 18,  5, 16, 20,  4, 13, 15, 15,
               9, 12, 10,  9,  9,  9,  8, 17,  4,  9,  6, 14,  8,  8, 20,  9,  9,  3, 15, 15, 12, 10, 20,  4, 13, 20, 14, 14, 12,  9, 12,  3,  9,
               8, 12, 20,  5,  4,  9,  9, 12,  5,  6,  6, 12,  1,  3,  8, 16,  9,  4,  4,  4, 18, 10,  3, 18,  4, 11,  3,  4,  4,  6, 17, 17, 18,
              17, 17,  1,  4, 14,  6,  6,  3,  7, 16, 16, 14, 12, 18, 16, 12, 12, 14,  4,  1, 17,  3, 14, 17, 16,  8,  6,  4, 18, 10, 18,  2, 10,
              16, 11, 11, 12, 10,  9, 14,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
               0,  0],
             [11, 10,  1, 11, 18, 12,  9, 18, 10,  3, 19,  8, 15, 16, 10,  5, 19,  9,  4,  4, 11,  4, 10, 17, 10, 18,  6, 10, 14, 12, 16,  6,  9,
              17, 17,  5, 18, 12, 18,  8,  1, 16,  6, 14,  5, 17,  4,  3, 11,  8, 13, 17, 18,  6,  5, 12, 11, 15,  9,  8, 17,  9,  6, 12, 18, 17,
               8,  9, 10, 19,  3,  8,  6,  6, 14, 13, 15,  5, 15, 16, 11, 19,  4, 15, 20,  2, 15,  6, 18, 12,  1,  8, 18,  5, 11, 18,  3,  1,  1,
               3,  4,  4,  9, 10,  4,  1, 16, 15, 12,  4, 10, 11, 14, 10, 10,  3,  9, 13, 14, 10,  3,  1,  8, 13, 18, 10, 18, 10,  6, 12,  9,  9,
               3, 10, 13,  6,  1, 10,  3,  4, 15, 14, 10,  8,  4, 15, 11, 12, 10, 16, 16,  8, 14, 12, 15,  4,  8,  2,  2, 20, 16,  8, 16,  2,  9,
               4,  9,  4, 12,  8,  3,  8, 17, 10, 14, 19, 10,  8,  3,  7, 16,  9,  1, 14, 15,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
               0,  0]]),
     tensor([1., 0.]))

### Design Your Model

Let’s create a tiny model (~10k parameters) that uses a sequence of
1-dimensional convolutional layers with skip connections, kaiming he
initialization, leaky relus, batchnorm, and dropout.

First, obtain a single batch from `dls` to help design the model.

``` python
idx = next(iter(dls.train))[0] ## a single batch
idx, idx.shape
```

    (tensor([[11,  9, 17,  ...,  0,  0,  0],
             [16, 12,  1,  ...,  0,  0,  0],
             [16,  1, 14,  ...,  0,  0,  0],
             ...,
             [11,  1,  7,  ...,  0,  0,  0],
             [11,  3, 13,  ...,  0,  0,  0],
             [16, 17, 12,  ...,  0,  0,  0]]),
     torch.Size([32, 200]))

#### Custom Modules

``` python
def conv1d(ni, nf, ks=3, stride=2, act=nn.ReLU, norm=None, bias=None):
    if bias is None: bias = not isinstance(norm, (nn.BatchNorm1d,nn.BatchNorm2d,nn.BatchNorm3d))
    layers = [nn.Conv1d(ni, nf, stride=stride, kernel_size=ks, padding=ks//2, bias=bias)]
    if norm: layers.append(norm(nf))
    if act: layers.append(act())
    return nn.Sequential(*layers)

def _conv1d_block(ni, nf, stride, act=nn.ReLU, norm=None, ks=3):
    return nn.Sequential(conv1d(ni, nf, stride=1, act=act, norm=norm, ks=ks),
                         conv1d(nf, nf, stride=stride, act=None, norm=norm, ks=ks))

class ResBlock1d(nn.Module):
    def __init__(self, ni, nf, stride=1, ks=3, act=nn.ReLU, norm=None):
        super().__init__()
        self.convs = _conv1d_block(ni, nf, stride=stride, ks=ks, act=act, norm=norm)
        self.idconv = fc.noop if ni==nf else conv1d(ni, nf, stride=1, ks=1, act=None)
        self.pool = fc.noop if stride==1 else nn.AvgPool1d(stride, ceil_mode=True)
        self.act = act()

    def forward(self, x): return self.act(self.convs(x) + self.pool(self.idconv(x)))
```

The following module switches the rank order from BLC to BCL.

``` python
class Reshape(nn.Module):
    def forward(self, x): 
        B, L, C = x.shape
        return x.view(B, C, L)
```

#### Model Architecture

``` python
lr = 1e-2
epochs = 30
n_embd = 16
dls = get_dls(trnds, vldds, bs=32)
act_genrelu = partial(GeneralRelu, leak=0.1, sub=0.4)

model = nn.Sequential(nn.Embedding(vocab_size, n_embd, padding_idx=0), Reshape(),
                      ResBlock1d(n_embd, 2, ks=15, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
                      ResBlock1d(2, 4, ks=13, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
                      ResBlock1d(4, 4, ks=11, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
                      ResBlock1d(4, 4, ks=9, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
                      ResBlock1d(4, 8, ks=7, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
                      ResBlock1d(8, 8, ks=5, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
                      ResBlock1d(8, 16, ks=3, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
                      ResBlock1d(16, 32, ks=3, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
                      nn.Flatten(1, -1),
                      nn.Linear(32, 1),
                      nn.Flatten(0, -1),
                      nn.Sigmoid())
model(idx).shape
```

    torch.Size([32])

``` python
iw = partial(init_weights, leaky=0.1)
model = model.apply(iw)
metrics = MetricsCB(BinaryAccuracy(), BinaryMatthewsCorrCoef(), BinaryAUROC())
astats = ActivationStats(fc.risinstance(GeneralRelu))
cbs = [DeviceCB(), ProgressCB(plot=False), metrics, astats]
learn = TrainLearner(model, dls, F.binary_cross_entropy, lr=lr, cbs=cbs, opt_func=torch.optim.AdamW)
print(f"Parameters total: {sum(p.nelement() for p in model.parameters())}")
learn.lr_find(start_lr=1e-4, gamma=1.05, av_over=3, max_mult=5)
```

    Parameters total: 10175

<style>
    /* Turns off some styling */
    progress {
        /* gets rid of default border in Firefox and Opera. */
        border: none;
        /* Needs to be in here for Safari polyfill so background images work as expected. */
        background-size: auto;
    }
    progress:not([value]), progress:not([value])::-webkit-progress-bar {
        background: repeating-linear-gradient(45deg, #7e7e7e, #7e7e7e 10px, #5c5c5c 10px, #5c5c5c 20px);
    }
    .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {
        background: #F44336;
    }
</style>

    <div>
      <progress value='0' class='' max='10' style='width:300px; height:20px; vertical-align: middle;'></progress>
      0.00% [0/10 00:00&lt;?]
    </div>
    &#10;
&#10;    <div>
      <progress value='502' class='' max='565' style='width:300px; height:20px; vertical-align: middle;'></progress>
      88.85% [502/565 00:32&lt;00:04 0.919]
    </div>
    &#10;
![](index_files/figure-commonmark/cell-25-output-4.png)

This is a pretty noisy training set, so the learning rate finder does
not work very well. Yet it is possible to get a somewhat informative
result using the `av_over` keyword argument that tells
[`lr_find`](https://frenio.github.io/atai/core.html#lr_find) to average
over the specified number of batches for each learning rate tested. It
also helps to dial the `gamma` value down from its default value of
`1.3`.

### Training

``` python
learn.fit(epochs)
```

<style>
    /* Turns off some styling */
    progress {
        /* gets rid of default border in Firefox and Opera. */
        border: none;
        /* Needs to be in here for Safari polyfill so background images work as expected. */
        background-size: auto;
    }
    progress:not([value]), progress:not([value])::-webkit-progress-bar {
        background: repeating-linear-gradient(45deg, #7e7e7e, #7e7e7e 10px, #5c5c5c 10px, #5c5c5c 20px);
    }
    .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {
        background: #F44336;
    }
</style>

<div>

| BinaryAccuracy | BinaryMatthewsCorrCoef | BinaryAUROC | loss  | epoch | train |
|----------------|------------------------|-------------|-------|-------|-------|
| 0.510          | 0.008                  | 0.507       | 0.722 | 0     | train |
| 0.527          | -0.033                 | 0.494       | 0.691 | 0     | eval  |
| 0.515          | 0.004                  | 0.506       | 0.695 | 1     | train |
| 0.534          | 0.003                  | 0.520       | 0.691 | 1     | eval  |
| 0.537          | 0.054                  | 0.526       | 0.692 | 2     | train |
| 0.530          | 0.058                  | 0.550       | 0.693 | 2     | eval  |
| 0.553          | 0.090                  | 0.551       | 0.688 | 3     | train |
| 0.562          | 0.108                  | 0.575       | 0.684 | 3     | eval  |
| 0.589          | 0.170                  | 0.607       | 0.671 | 4     | train |
| 0.619          | 0.234                  | 0.663       | 0.647 | 4     | eval  |
| 0.622          | 0.247                  | 0.630       | 0.653 | 5     | train |
| 0.643          | 0.302                  | 0.676       | 0.637 | 5     | eval  |
| 0.627          | 0.260                  | 0.642       | 0.650 | 6     | train |
| 0.649          | 0.313                  | 0.675       | 0.628 | 6     | eval  |
| 0.633          | 0.272                  | 0.647       | 0.644 | 7     | train |
| 0.650          | 0.309                  | 0.648       | 0.637 | 7     | eval  |
| 0.634          | 0.276                  | 0.652       | 0.640 | 8     | train |
| 0.650          | 0.314                  | 0.683       | 0.620 | 8     | eval  |
| 0.637          | 0.284                  | 0.646       | 0.641 | 9     | train |
| 0.651          | 0.320                  | 0.685       | 0.617 | 9     | eval  |
| 0.639          | 0.287                  | 0.661       | 0.636 | 10    | train |
| 0.645          | 0.308                  | 0.671       | 0.647 | 10    | eval  |
| 0.638          | 0.281                  | 0.663       | 0.636 | 11    | train |
| 0.646          | 0.298                  | 0.690       | 0.620 | 11    | eval  |
| 0.641          | 0.286                  | 0.670       | 0.632 | 12    | train |
| 0.648          | 0.296                  | 0.685       | 0.619 | 12    | eval  |
| 0.641          | 0.290                  | 0.668       | 0.632 | 13    | train |
| 0.662          | 0.323                  | 0.691       | 0.618 | 13    | eval  |
| 0.643          | 0.290                  | 0.675       | 0.630 | 14    | train |
| 0.641          | 0.286                  | 0.677       | 0.627 | 14    | eval  |
| 0.643          | 0.289                  | 0.676       | 0.630 | 15    | train |
| 0.659          | 0.311                  | 0.699       | 0.616 | 15    | eval  |
| 0.644          | 0.291                  | 0.677       | 0.629 | 16    | train |
| 0.652          | 0.314                  | 0.690       | 0.613 | 16    | eval  |
| 0.646          | 0.296                  | 0.675       | 0.626 | 17    | train |
| 0.645          | 0.282                  | 0.694       | 0.622 | 17    | eval  |
| 0.642          | 0.288                  | 0.678       | 0.626 | 18    | train |
| 0.648          | 0.332                  | 0.677       | 0.639 | 18    | eval  |
| 0.645          | 0.292                  | 0.685       | 0.625 | 19    | train |
| 0.634          | 0.260                  | 0.698       | 0.615 | 19    | eval  |
| 0.649          | 0.302                  | 0.689       | 0.621 | 20    | train |
| 0.651          | 0.344                  | 0.710       | 0.617 | 20    | eval  |
| 0.648          | 0.299                  | 0.685       | 0.624 | 21    | train |
| 0.660          | 0.315                  | 0.700       | 0.614 | 21    | eval  |
| 0.648          | 0.297                  | 0.691       | 0.620 | 22    | train |
| 0.563          | 0.168                  | 0.679       | 0.672 | 22    | eval  |
| 0.651          | 0.303                  | 0.690       | 0.620 | 23    | train |
| 0.654          | 0.330                  | 0.710       | 0.611 | 23    | eval  |
| 0.650          | 0.304                  | 0.691       | 0.620 | 24    | train |
| 0.668          | 0.344                  | 0.711       | 0.599 | 24    | eval  |
| 0.654          | 0.311                  | 0.692       | 0.617 | 25    | train |
| 0.649          | 0.294                  | 0.698       | 0.620 | 25    | eval  |
| 0.650          | 0.301                  | 0.690       | 0.617 | 26    | train |
| 0.642          | 0.320                  | 0.697       | 0.611 | 26    | eval  |
| 0.650          | 0.303                  | 0.688       | 0.620 | 27    | train |
| 0.663          | 0.334                  | 0.708       | 0.625 | 27    | eval  |
| 0.652          | 0.308                  | 0.694       | 0.616 | 28    | train |
| 0.672          | 0.356                  | 0.711       | 0.597 | 28    | eval  |
| 0.651          | 0.304                  | 0.695       | 0.617 | 29    | train |
| 0.659          | 0.322                  | 0.702       | 0.603 | 29    | eval  |

</div>

### Inspect Activations

``` python
dls = get_dls(trnds, vldds, bs=256)

model = nn.Sequential(nn.Embedding(vocab_size, n_embd, padding_idx=0), Reshape(),
                      ResBlock1d(n_embd, 2, ks=15, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
                      ResBlock1d(2, 4, ks=13, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
                      ResBlock1d(4, 4, ks=11, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
                      ResBlock1d(4, 4, ks=9, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
                      ResBlock1d(4, 8, ks=7, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
                      ResBlock1d(8, 8, ks=5, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
                      ResBlock1d(8, 16, ks=3, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
                      ResBlock1d(16, 32, ks=3, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),
                      nn.Flatten(1, -1),
                      nn.Linear(32, 1),
                      nn.Flatten(0, -1),
                      nn.Sigmoid())

model = model.apply(iw)
astats = ActivationStats(fc.risinstance(GeneralRelu))
cbs = [DeviceCB(), astats]
learn = TrainLearner(model, dls, F.binary_cross_entropy, lr=lr, cbs=cbs, opt_func=torch.optim.AdamW)
print(f"Parameters total: {sum(p.nelement() for p in model.parameters())}")
learn.fit(1)
```

    Parameters total: 10175

``` python
astats.color_dim()
```

![](index_files/figure-commonmark/cell-28-output-1.png)

``` python
astats.plot_stats()
```

![](index_files/figure-commonmark/cell-29-output-1.png)

``` python
astats.dead_chart()
```

![](index_files/figure-commonmark/cell-30-output-1.png)

            

Raw data

            {
    "_id": null,
    "home_page": "https://github.com/frenio/atai",
    "name": "atai",
    "maintainer": null,
    "docs_url": null,
    "requires_python": ">=3.7",
    "maintainer_email": null,
    "keywords": "nbdev jupyter notebook python",
    "author": "Frenio Redeker",
    "author_email": "f.redeker@gmail.com",
    "download_url": "https://files.pythonhosted.org/packages/10/30/b6fb53713292c22a347dcc8a38d242f3a028f8c64ffe16f9e61c573cc46a/atai-0.0.6.tar.gz",
    "platform": null,
    "description": "# atai\n\n\n<!-- WARNING: THIS FILE WAS AUTOGENERATED! DO NOT EDIT! -->\n\nAtomic AI is a flexible, minimalist deep neural network training\nframework based on Jeremy Howard\u2019s\n[miniai](https://github.com/fastai/course22p2/tree/master) from the\n[fast.ai 2022](https://course.fast.ai/Lessons/part2.html) course.\n\n## Install\n\n``` sh\npip install atai\n```\n\n## How to use\n\n``` python\nfrom atai.core import *\n```\n\nThe following example demonstrates how the Atomic AI training framework\ncan be used to train a custom model that predicts protein solubility.\n\n### Imports\n\n``` python\nimport torch\nimport torch.nn.functional as F\nimport torch.nn as nn\nfrom torch.nn import init\nfrom torch import optim\n\nfrom torcheval.metrics import BinaryAccuracy, BinaryAUROC\nfrom torcheval.metrics.functional import binary_auroc, binary_accuracy\nfrom torchmetrics.classification import BinaryMatthewsCorrCoef\nfrom torchmetrics.functional.classification import binary_matthews_corrcoef\n\nimport fastcore.all as fc\nfrom functools import partial\n```\n\n### Load Protein Solubility\n\nThis example uses the dataset from the\n[DeepSol](https://doi.org/10.1093/bioinformatics/bty166) paper by\nKhurana *et al.* which was obtained at\n<https://zenodo.org/records/1162886>. It consists of amino acid\nsequences of peptides along with solubility labels that are `1` if the\npeptide is soluble and `0` if the peptide is insoluble.\n\n``` python\ntrain_sqs = open('sol_data/train_src', 'r').read().splitlines()\ntrain_tgs = list(map(int, open('sol_data/train_tgt', 'r').read().splitlines()))\n\nvalid_sqs = open('sol_data/val_src', 'r').read().splitlines()\nvalid_tgs = list(map(int, open('sol_data/val_tgt', 'r').read().splitlines()))\n\ntrain_sqs[:2], train_tgs[:2]\n```\n\n    (['GMILKTNLFGHTYQFKSITDVLAKANEEKSGDRLAGVAAESAEERVAAKVVLSKMTLGDLRNNPVVPYETDEVTRIIQDQVNDRIHDSIKNWTVEELREWILDHKTTDADIKRVARGLTSEIIAAVTKLMSNLDLIYGAKKIRVIAHANTTIGLPGTFSARLQPNHPTDDPDGILASLMEGLTYGIGDAVIGLNPVDDSTDSVVRLLNKFEEFRSKWDVPTQTCVLAHVKTQMEAMRRGAPTGLVFQSIAGSEKGNTAFGFDGATIEEARQLALQSGAATGPNVMYFETGQGSELSSDAHFGVDQVTMEARCYGFAKKFDPFLVNTVVGFIGPEYLYDSKQVIRAGLEDHFMGKLTGISMGCDVCYTNHMKADQNDVENLSVLLTAAGCNFIMGIPHGDDVMLNYQTTGYHETATLRELFGLKPIKEFDQWMEKMGFSENGKLTSRAGDASIFLK',\n      'MAHHHHHHMSFFRMKRRLNFVVKRGIEELWENSFLDNNVDMKKIEYSKTGDAWPCVLLRKKSFEDLHKLYYICLKEKNKLLGEQYFHLQNSTKMLQHGRLKKVKLTMKRILTVLSRRAIHDQCLRAKDMLKKQEEREFYEIQKFKLNEQLLCLKHKMNILKKYNSFSLEQISLTFSIKKIENKIQQIDIILNPLRKETMYLLIPHFKYQRKYSDLPGFISWKKQNIIALRNNMSKLHRLY'],\n     [1, 0])\n\n``` python\nlen(train_sqs), len(train_tgs), len(valid_sqs), len(valid_tgs)\n```\n\n    (62478, 62478, 6942, 6942)\n\n### Data Preparation\n\nCreate a sorted list of amino acid sequences `aas` including an empty\nstring for padding and determine the size of the vocabulary.\n\n``` python\naas = sorted(list(set(\"\".join(train_sqs))) + [\"\"])\nvocab_size = len(aas)\naas, vocab_size\n```\n\n    (['',\n      'A',\n      'C',\n      'D',\n      'E',\n      'F',\n      'G',\n      'H',\n      'I',\n      'K',\n      'L',\n      'M',\n      'N',\n      'P',\n      'Q',\n      'R',\n      'S',\n      'T',\n      'V',\n      'W',\n      'Y'],\n     21)\n\nCreate dictionaries that translate between string and integer\nrepresentations of amino acids and define the corresponding `encode` and\n`decode` functions.\n\n``` python\nstr2int = {aa:i for i, aa in enumerate(aas)}\nint2str = {i:aa for i, aa in enumerate(aas)}\nencode = lambda s: [str2int[aa] for aa in s]\ndecode = lambda l: ''.join([int2str[i] for i in l])\n\nprint(encode(\"AYWCCCGGGHH\"))\nprint(decode(encode(\"AYWCCCGGGHH\")))\n```\n\n    [1, 20, 19, 2, 2, 2, 6, 6, 6, 7, 7]\n    AYWCCCGGGHH\n\nFigure out what the range of lengths of amino acid sequences in the\ndataset is.\n\n``` python\ntrain_lens = list(map(len, train_sqs))\nmin(train_lens), max(train_lens)\n```\n\n    (19, 1691)\n\nCreate a function that drops all sequences above a chosen threshold and\nalso returns a list of indices of the sequences that meet the threshold\nthat can be used to obtain the correct labels.\n\n``` python\ndef drop_long_sqs(sqs, threshold=1200):\n    new_sqs = []\n    idx = []\n    for i, sq in enumerate(sqs):\n        if len(sq) <= threshold:\n            new_sqs.append(sq)\n            idx.append(i)\n    return new_sqs, idx\n```\n\nDrop all sequences above your chosen threshold.\n\n``` python\ntrnsqs, trnidx = drop_long_sqs(train_sqs, threshold=200)\nvldsqs, vldidx = drop_long_sqs(valid_sqs, threshold=200)\n```\n\n``` python\nlen(trnidx), len(vldidx)\n```\n\n    (18066, 1971)\n\n``` python\nmax(map(len, trnsqs))\n```\n\n    200\n\nCreate a function for zero padding all sequences.\n\n``` python\ndef zero_pad(sq, length=1200):\n    new_sq = sq.copy()\n    if len(new_sq) < length:\n        new_sq.extend([0] * (length-len(new_sq)))\n    return new_sq\n```\n\nNow encode and zero pad all sequences and make sure that it worked out\ncorrectly.\n\n``` python\ntrn = list(map(encode, trnsqs))\nvld = list(map(encode, vldsqs))\nprint(f\"Length of the first two sequences before zero padding: {len(trn[0])}, {len(trn[1])}\")\ntrn = list(map(partial(zero_pad, length=200), trn))\nvld = list(map(partial(zero_pad, length=200), vld))\nprint(f\"Length of the first two sequences after zero padding:  {len(trn[0])}, {len(trn[1])}\");\n```\n\n    Length of the first two sequences before zero padding: 116, 135\n    Length of the first two sequences after zero padding:  200, 200\n\nConvert the data to `torch.tensor`s unsing `dtype=torch.int64` and check\nfor correctness.\n\n``` python\ntrntns = torch.tensor(trn, dtype=torch.int64)\nvldtns = torch.tensor(vld, dtype=torch.int64)\ntrntns.shape, trntns[0]\n```\n\n    (torch.Size([18066, 200]),\n     tensor([11,  9,  1, 10,  2, 10, 10, 10, 10, 13, 18, 10,  6, 10, 10, 18, 16, 16,  9, 17, 10,  2, 16, 11,  4,  4,  1,  8, 12,  4, 15,  8, 14,\n              4, 18,  1,  6, 16, 10,  8,  5, 15,  1,  8, 16, 16,  8,  6, 10,  4,  2, 14, 16, 18, 17, 16, 15,  6,  3, 10,  1, 17,  2, 13, 15,  6,\n              5,  1, 18, 17,  6,  2, 17,  2,  6, 16,  1,  2,  6, 16, 19,  3, 18, 15,  1,  4, 17, 17,  2,  7,  2, 14,  2,  1,  6, 11,  3, 19, 17,\n              6,  1, 15,  2,  2, 15, 18, 14, 13, 10,  4,  7,  7,  7,  7,  7,  7,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n              0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n              0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n              0,  0]))\n\n``` python\ntrntns.shape, vldtns.shape\n```\n\n    (torch.Size([18066, 200]), torch.Size([1971, 200]))\n\nObtain the correct labels using the lists of indices obtained from the\n`drop_long_sqs` function and convert the lists of labels to tensors in\n`torch.float32` format.\n\n``` python\ntrnlbs = torch.tensor(train_tgs, dtype=torch.float32)[trnidx]\nvldlbs = torch.tensor(valid_tgs, dtype=torch.float32)[vldidx]\ntrnlbs.shape, vldlbs.shape\n```\n\n    (torch.Size([18066]), torch.Size([1971]))\n\nCalculate the ratios of soluble peptides in the train and valid data.\n\n``` python\ntrnlbs.sum().item()/trnlbs.shape[0], vldlbs.sum().item()/vldlbs.shape[0]\n```\n\n    (0.4722129967895494, 0.4657534246575342)\n\nThese ratios tell us that there are slightly less than half soluble\nproteins in the training an validation data, and slightly more than half\nin the test set.\n\n### Dataset and DataLoaders\n\nTurn train and valid data into datasets using the\n[`Dataset`](https://frenio.github.io/atai/core.html#dataset) class.\n\n``` python\ntrnds = Dataset(trntns, trnlbs)\nvldds = Dataset(vldtns, vldlbs)\ntrnds[0]\n```\n\n    (tensor([11,  9,  1, 10,  2, 10, 10, 10, 10, 13, 18, 10,  6, 10, 10, 18, 16, 16,  9, 17, 10,  2, 16, 11,  4,  4,  1,  8, 12,  4, 15,  8, 14,\n              4, 18,  1,  6, 16, 10,  8,  5, 15,  1,  8, 16, 16,  8,  6, 10,  4,  2, 14, 16, 18, 17, 16, 15,  6,  3, 10,  1, 17,  2, 13, 15,  6,\n              5,  1, 18, 17,  6,  2, 17,  2,  6, 16,  1,  2,  6, 16, 19,  3, 18, 15,  1,  4, 17, 17,  2,  7,  2, 14,  2,  1,  6, 11,  3, 19, 17,\n              6,  1, 15,  2,  2, 15, 18, 14, 13, 10,  4,  7,  7,  7,  7,  7,  7,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n              0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n              0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n              0,  0]),\n     tensor(0.))\n\nUse the [`get_dls`](https://frenio.github.io/atai/core.html#get_dls)\nfunction to obtain the dataloaders from the train and valid datasets.\n\n``` python\ndls = get_dls(trnds, vldds, bs=32)\nnext(iter(dls.train))[0][:2], next(iter(dls.train))[1][:2]\n```\n\n    (tensor([[11,  1,  7,  7,  7,  7,  7,  7, 11, 11, 13, 15, 16,  4, 12, 14,  9,  4,  4,  4,  4, 19,  4, 10,  7, 13, 10,  5, 10, 16,  9,  8, 13,\n              12,  9,  9,  3,  8,  3,  9, 12, 13,  1, 10, 16,  1, 10,  8, 17, 10,  8, 12,  4,  4,  4,  4,  9,  4, 18,  5, 16, 20,  4, 13, 15, 15,\n               9, 12, 10,  9,  9,  9,  8, 17,  4,  9,  6, 14,  8,  8, 20,  9,  9,  3, 15, 15, 12, 10, 20,  4, 13, 20, 14, 14, 12,  9, 12,  3,  9,\n               8, 12, 20,  5,  4,  9,  9, 12,  5,  6,  6, 12,  1,  3,  8, 16,  9,  4,  4,  4, 18, 10,  3, 18,  4, 11,  3,  4,  4,  6, 17, 17, 18,\n              17, 17,  1,  4, 14,  6,  6,  3,  7, 16, 16, 14, 12, 18, 16, 12, 12, 14,  4,  1, 17,  3, 14, 17, 16,  8,  6,  4, 18, 10, 18,  2, 10,\n              16, 11, 11, 12, 10,  9, 14,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n               0,  0],\n             [11, 10,  1, 11, 18, 12,  9, 18, 10,  3, 19,  8, 15, 16, 10,  5, 19,  9,  4,  4, 11,  4, 10, 17, 10, 18,  6, 10, 14, 12, 16,  6,  9,\n              17, 17,  5, 18, 12, 18,  8,  1, 16,  6, 14,  5, 17,  4,  3, 11,  8, 13, 17, 18,  6,  5, 12, 11, 15,  9,  8, 17,  9,  6, 12, 18, 17,\n               8,  9, 10, 19,  3,  8,  6,  6, 14, 13, 15,  5, 15, 16, 11, 19,  4, 15, 20,  2, 15,  6, 18, 12,  1,  8, 18,  5, 11, 18,  3,  1,  1,\n               3,  4,  4,  9, 10,  4,  1, 16, 15, 12,  4, 10, 11, 14, 10, 10,  3,  9, 13, 14, 10,  3,  1,  8, 13, 18, 10, 18, 10,  6, 12,  9,  9,\n               3, 10, 13,  6,  1, 10,  3,  4, 15, 14, 10,  8,  4, 15, 11, 12, 10, 16, 16,  8, 14, 12, 15,  4,  8,  2,  2, 20, 16,  8, 16,  2,  9,\n               4,  9,  4, 12,  8,  3,  8, 17, 10, 14, 19, 10,  8,  3,  7, 16,  9,  1, 14, 15,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n               0,  0]]),\n     tensor([1., 0.]))\n\n### Design Your Model\n\nLet\u2019s create a tiny model (~10k parameters) that uses a sequence of\n1-dimensional convolutional layers with skip connections, kaiming he\ninitialization, leaky relus, batchnorm, and dropout.\n\nFirst, obtain a single batch from `dls` to help design the model.\n\n``` python\nidx = next(iter(dls.train))[0] ## a single batch\nidx, idx.shape\n```\n\n    (tensor([[11,  9, 17,  ...,  0,  0,  0],\n             [16, 12,  1,  ...,  0,  0,  0],\n             [16,  1, 14,  ...,  0,  0,  0],\n             ...,\n             [11,  1,  7,  ...,  0,  0,  0],\n             [11,  3, 13,  ...,  0,  0,  0],\n             [16, 17, 12,  ...,  0,  0,  0]]),\n     torch.Size([32, 200]))\n\n#### Custom Modules\n\n``` python\ndef conv1d(ni, nf, ks=3, stride=2, act=nn.ReLU, norm=None, bias=None):\n    if bias is None: bias = not isinstance(norm, (nn.BatchNorm1d,nn.BatchNorm2d,nn.BatchNorm3d))\n    layers = [nn.Conv1d(ni, nf, stride=stride, kernel_size=ks, padding=ks//2, bias=bias)]\n    if norm: layers.append(norm(nf))\n    if act: layers.append(act())\n    return nn.Sequential(*layers)\n\ndef _conv1d_block(ni, nf, stride, act=nn.ReLU, norm=None, ks=3):\n    return nn.Sequential(conv1d(ni, nf, stride=1, act=act, norm=norm, ks=ks),\n                         conv1d(nf, nf, stride=stride, act=None, norm=norm, ks=ks))\n\nclass ResBlock1d(nn.Module):\n    def __init__(self, ni, nf, stride=1, ks=3, act=nn.ReLU, norm=None):\n        super().__init__()\n        self.convs = _conv1d_block(ni, nf, stride=stride, ks=ks, act=act, norm=norm)\n        self.idconv = fc.noop if ni==nf else conv1d(ni, nf, stride=1, ks=1, act=None)\n        self.pool = fc.noop if stride==1 else nn.AvgPool1d(stride, ceil_mode=True)\n        self.act = act()\n\n    def forward(self, x): return self.act(self.convs(x) + self.pool(self.idconv(x)))\n```\n\nThe following module switches the rank order from BLC to BCL.\n\n``` python\nclass Reshape(nn.Module):\n    def forward(self, x): \n        B, L, C = x.shape\n        return x.view(B, C, L)\n```\n\n#### Model Architecture\n\n``` python\nlr = 1e-2\nepochs = 30\nn_embd = 16\ndls = get_dls(trnds, vldds, bs=32)\nact_genrelu = partial(GeneralRelu, leak=0.1, sub=0.4)\n\nmodel = nn.Sequential(nn.Embedding(vocab_size, n_embd, padding_idx=0), Reshape(),\n                      ResBlock1d(n_embd, 2, ks=15, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),\n                      ResBlock1d(2, 4, ks=13, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),\n                      ResBlock1d(4, 4, ks=11, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),\n                      ResBlock1d(4, 4, ks=9, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),\n                      ResBlock1d(4, 8, ks=7, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),\n                      ResBlock1d(8, 8, ks=5, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),\n                      ResBlock1d(8, 16, ks=3, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),\n                      ResBlock1d(16, 32, ks=3, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),\n                      nn.Flatten(1, -1),\n                      nn.Linear(32, 1),\n                      nn.Flatten(0, -1),\n                      nn.Sigmoid())\nmodel(idx).shape\n```\n\n    torch.Size([32])\n\n``` python\niw = partial(init_weights, leaky=0.1)\nmodel = model.apply(iw)\nmetrics = MetricsCB(BinaryAccuracy(), BinaryMatthewsCorrCoef(), BinaryAUROC())\nastats = ActivationStats(fc.risinstance(GeneralRelu))\ncbs = [DeviceCB(), ProgressCB(plot=False), metrics, astats]\nlearn = TrainLearner(model, dls, F.binary_cross_entropy, lr=lr, cbs=cbs, opt_func=torch.optim.AdamW)\nprint(f\"Parameters total: {sum(p.nelement() for p in model.parameters())}\")\nlearn.lr_find(start_lr=1e-4, gamma=1.05, av_over=3, max_mult=5)\n```\n\n    Parameters total: 10175\n\n<style>\n    /* Turns off some styling */\n    progress {\n        /* gets rid of default border in Firefox and Opera. */\n        border: none;\n        /* Needs to be in here for Safari polyfill so background images work as expected. */\n        background-size: auto;\n    }\n    progress:not([value]), progress:not([value])::-webkit-progress-bar {\n        background: repeating-linear-gradient(45deg, #7e7e7e, #7e7e7e 10px, #5c5c5c 10px, #5c5c5c 20px);\n    }\n    .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n        background: #F44336;\n    }\n</style>\n\n    <div>\n      <progress value='0' class='' max='10' style='width:300px; height:20px; vertical-align: middle;'></progress>\n      0.00% [0/10 00:00&lt;?]\n    </div>\n    &#10;\n&#10;    <div>\n      <progress value='502' class='' max='565' style='width:300px; height:20px; vertical-align: middle;'></progress>\n      88.85% [502/565 00:32&lt;00:04 0.919]\n    </div>\n    &#10;\n![](index_files/figure-commonmark/cell-25-output-4.png)\n\nThis is a pretty noisy training set, so the learning rate finder does\nnot work very well. Yet it is possible to get a somewhat informative\nresult using the `av_over` keyword argument that tells\n[`lr_find`](https://frenio.github.io/atai/core.html#lr_find) to average\nover the specified number of batches for each learning rate tested. It\nalso helps to dial the `gamma` value down from its default value of\n`1.3`.\n\n### Training\n\n``` python\nlearn.fit(epochs)\n```\n\n<style>\n    /* Turns off some styling */\n    progress {\n        /* gets rid of default border in Firefox and Opera. */\n        border: none;\n        /* Needs to be in here for Safari polyfill so background images work as expected. */\n        background-size: auto;\n    }\n    progress:not([value]), progress:not([value])::-webkit-progress-bar {\n        background: repeating-linear-gradient(45deg, #7e7e7e, #7e7e7e 10px, #5c5c5c 10px, #5c5c5c 20px);\n    }\n    .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n        background: #F44336;\n    }\n</style>\n\n<div>\n\n| BinaryAccuracy | BinaryMatthewsCorrCoef | BinaryAUROC | loss  | epoch | train |\n|----------------|------------------------|-------------|-------|-------|-------|\n| 0.510          | 0.008                  | 0.507       | 0.722 | 0     | train |\n| 0.527          | -0.033                 | 0.494       | 0.691 | 0     | eval  |\n| 0.515          | 0.004                  | 0.506       | 0.695 | 1     | train |\n| 0.534          | 0.003                  | 0.520       | 0.691 | 1     | eval  |\n| 0.537          | 0.054                  | 0.526       | 0.692 | 2     | train |\n| 0.530          | 0.058                  | 0.550       | 0.693 | 2     | eval  |\n| 0.553          | 0.090                  | 0.551       | 0.688 | 3     | train |\n| 0.562          | 0.108                  | 0.575       | 0.684 | 3     | eval  |\n| 0.589          | 0.170                  | 0.607       | 0.671 | 4     | train |\n| 0.619          | 0.234                  | 0.663       | 0.647 | 4     | eval  |\n| 0.622          | 0.247                  | 0.630       | 0.653 | 5     | train |\n| 0.643          | 0.302                  | 0.676       | 0.637 | 5     | eval  |\n| 0.627          | 0.260                  | 0.642       | 0.650 | 6     | train |\n| 0.649          | 0.313                  | 0.675       | 0.628 | 6     | eval  |\n| 0.633          | 0.272                  | 0.647       | 0.644 | 7     | train |\n| 0.650          | 0.309                  | 0.648       | 0.637 | 7     | eval  |\n| 0.634          | 0.276                  | 0.652       | 0.640 | 8     | train |\n| 0.650          | 0.314                  | 0.683       | 0.620 | 8     | eval  |\n| 0.637          | 0.284                  | 0.646       | 0.641 | 9     | train |\n| 0.651          | 0.320                  | 0.685       | 0.617 | 9     | eval  |\n| 0.639          | 0.287                  | 0.661       | 0.636 | 10    | train |\n| 0.645          | 0.308                  | 0.671       | 0.647 | 10    | eval  |\n| 0.638          | 0.281                  | 0.663       | 0.636 | 11    | train |\n| 0.646          | 0.298                  | 0.690       | 0.620 | 11    | eval  |\n| 0.641          | 0.286                  | 0.670       | 0.632 | 12    | train |\n| 0.648          | 0.296                  | 0.685       | 0.619 | 12    | eval  |\n| 0.641          | 0.290                  | 0.668       | 0.632 | 13    | train |\n| 0.662          | 0.323                  | 0.691       | 0.618 | 13    | eval  |\n| 0.643          | 0.290                  | 0.675       | 0.630 | 14    | train |\n| 0.641          | 0.286                  | 0.677       | 0.627 | 14    | eval  |\n| 0.643          | 0.289                  | 0.676       | 0.630 | 15    | train |\n| 0.659          | 0.311                  | 0.699       | 0.616 | 15    | eval  |\n| 0.644          | 0.291                  | 0.677       | 0.629 | 16    | train |\n| 0.652          | 0.314                  | 0.690       | 0.613 | 16    | eval  |\n| 0.646          | 0.296                  | 0.675       | 0.626 | 17    | train |\n| 0.645          | 0.282                  | 0.694       | 0.622 | 17    | eval  |\n| 0.642          | 0.288                  | 0.678       | 0.626 | 18    | train |\n| 0.648          | 0.332                  | 0.677       | 0.639 | 18    | eval  |\n| 0.645          | 0.292                  | 0.685       | 0.625 | 19    | train |\n| 0.634          | 0.260                  | 0.698       | 0.615 | 19    | eval  |\n| 0.649          | 0.302                  | 0.689       | 0.621 | 20    | train |\n| 0.651          | 0.344                  | 0.710       | 0.617 | 20    | eval  |\n| 0.648          | 0.299                  | 0.685       | 0.624 | 21    | train |\n| 0.660          | 0.315                  | 0.700       | 0.614 | 21    | eval  |\n| 0.648          | 0.297                  | 0.691       | 0.620 | 22    | train |\n| 0.563          | 0.168                  | 0.679       | 0.672 | 22    | eval  |\n| 0.651          | 0.303                  | 0.690       | 0.620 | 23    | train |\n| 0.654          | 0.330                  | 0.710       | 0.611 | 23    | eval  |\n| 0.650          | 0.304                  | 0.691       | 0.620 | 24    | train |\n| 0.668          | 0.344                  | 0.711       | 0.599 | 24    | eval  |\n| 0.654          | 0.311                  | 0.692       | 0.617 | 25    | train |\n| 0.649          | 0.294                  | 0.698       | 0.620 | 25    | eval  |\n| 0.650          | 0.301                  | 0.690       | 0.617 | 26    | train |\n| 0.642          | 0.320                  | 0.697       | 0.611 | 26    | eval  |\n| 0.650          | 0.303                  | 0.688       | 0.620 | 27    | train |\n| 0.663          | 0.334                  | 0.708       | 0.625 | 27    | eval  |\n| 0.652          | 0.308                  | 0.694       | 0.616 | 28    | train |\n| 0.672          | 0.356                  | 0.711       | 0.597 | 28    | eval  |\n| 0.651          | 0.304                  | 0.695       | 0.617 | 29    | train |\n| 0.659          | 0.322                  | 0.702       | 0.603 | 29    | eval  |\n\n</div>\n\n### Inspect Activations\n\n``` python\ndls = get_dls(trnds, vldds, bs=256)\n\nmodel = nn.Sequential(nn.Embedding(vocab_size, n_embd, padding_idx=0), Reshape(),\n                      ResBlock1d(n_embd, 2, ks=15, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),\n                      ResBlock1d(2, 4, ks=13, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),\n                      ResBlock1d(4, 4, ks=11, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),\n                      ResBlock1d(4, 4, ks=9, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),\n                      ResBlock1d(4, 8, ks=7, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),\n                      ResBlock1d(8, 8, ks=5, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),\n                      ResBlock1d(8, 16, ks=3, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),\n                      ResBlock1d(16, 32, ks=3, stride=2, norm=nn.BatchNorm1d, act=act_genrelu), nn.Dropout(0.1),\n                      nn.Flatten(1, -1),\n                      nn.Linear(32, 1),\n                      nn.Flatten(0, -1),\n                      nn.Sigmoid())\n\nmodel = model.apply(iw)\nastats = ActivationStats(fc.risinstance(GeneralRelu))\ncbs = [DeviceCB(), astats]\nlearn = TrainLearner(model, dls, F.binary_cross_entropy, lr=lr, cbs=cbs, opt_func=torch.optim.AdamW)\nprint(f\"Parameters total: {sum(p.nelement() for p in model.parameters())}\")\nlearn.fit(1)\n```\n\n    Parameters total: 10175\n\n``` python\nastats.color_dim()\n```\n\n![](index_files/figure-commonmark/cell-28-output-1.png)\n\n``` python\nastats.plot_stats()\n```\n\n![](index_files/figure-commonmark/cell-29-output-1.png)\n\n``` python\nastats.dead_chart()\n```\n\n![](index_files/figure-commonmark/cell-30-output-1.png)\n",
    "bugtrack_url": null,
    "license": "Apache Software License 2.0",
    "summary": "Atomic AI \u2013 An attempt at a minimalist, flexible deep learning framework for diverse models.",
    "version": "0.0.6",
    "project_urls": {
        "Homepage": "https://github.com/frenio/atai"
    },
    "split_keywords": [
        "nbdev",
        "jupyter",
        "notebook",
        "python"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "765ac586200963bc35c8c41cdb61bf4de4c2961801ce7c21317958f7f0e4853f",
                "md5": "a98d2bb984de629204bc3869008b9d24",
                "sha256": "dd812447705a6b50462321b59f981d048703f10321bd9d5b431c4ce5bda546ae"
            },
            "downloads": -1,
            "filename": "atai-0.0.6-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "a98d2bb984de629204bc3869008b9d24",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.7",
            "size": 21244,
            "upload_time": "2024-07-22T02:51:52",
            "upload_time_iso_8601": "2024-07-22T02:51:52.548418Z",
            "url": "https://files.pythonhosted.org/packages/76/5a/c586200963bc35c8c41cdb61bf4de4c2961801ce7c21317958f7f0e4853f/atai-0.0.6-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "1030b6fb53713292c22a347dcc8a38d242f3a028f8c64ffe16f9e61c573cc46a",
                "md5": "36b35cae1d6e3bd971127ecfd357f51d",
                "sha256": "c748aae83d33540f8af51cba3db0216a76698402c3b36620b863e799df750410"
            },
            "downloads": -1,
            "filename": "atai-0.0.6.tar.gz",
            "has_sig": false,
            "md5_digest": "36b35cae1d6e3bd971127ecfd357f51d",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.7",
            "size": 28495,
            "upload_time": "2024-07-22T02:51:54",
            "upload_time_iso_8601": "2024-07-22T02:51:54.262543Z",
            "url": "https://files.pythonhosted.org/packages/10/30/b6fb53713292c22a347dcc8a38d242f3a028f8c64ffe16f9e61c573cc46a/atai-0.0.6.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2024-07-22 02:51:54",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "frenio",
    "github_project": "atai",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "lcname": "atai"
}
        
Elapsed time: 0.25968s