flax-pilot


Nameflax-pilot JSON
Version 0.1.11 PyPI version JSON
download
home_pagehttps://github.com/NITHISHM2410/flax-pilot
SummaryA Simplistic trainer for Flax
upload_time2024-08-05 11:36:53
maintainerNone
docs_urlNone
authorNithish M
requires_python>=3.6
licenseNone
keywords
VCS
bugtrack_url
requirements jax flax orbax-checkpoint mergedeep functools typing optax tqdm
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # Flax-Pilot

Flax-Pilot aims to simplify the process of writing training loops for Google's Flax framework. As someone new to Flax, I started this project to deepen my understanding. This module represents a beginner's exploration into building
efficient training workflows, emphasizing the need for further expertise to refine and expand its capabilities. Future plans include integrating multiple optimizer training, diverse metric modules, callbacks, and advancing towards more complex training
loops, aiming to enhance its functionality and versatility. Flax-Pilot supports distributed training, ensuring scalability and efficiency across multiple devices.

**As of 27-7-2024, the trainer is available as package [![PyPI version](https://img.shields.io/pypi/v/flax-pilot.svg)](https://pypi.org/project/flax-pilot/)**

## How to Use?

### 🛠️ Write a flax.linen Module

```python
import flax.linen as nn
class CNN(nn.Module):
    @nn.compact
    def __call__(self, x, deterministic):
        x = nn.Conv(features=32, kernel_size=(3, 3))(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))  
        x = nn.Dense(features=256)(x)
        x = nn.Dropout(rate=0.5, deterministic=deterministic)(x)
        x = nn.Dense(features=10)(x)
        return x
```

### 🔧 Define Optimizer, Input Shapes, and Dict of Loss & Metric Trackers
*Loss trackers (**lt**) takes in **scalar loss** value and averages it throughout training.*<br>
*Metric trackers (**mt**) takes in **y_true, y_pred** and computes metric score and averages throughout training.*<br>

```python
import optax as tx

opt = tx.adam(0.0001)
input_shape = {'x': (1, 28, 28, 1)}

from fpilot import BasicTrackers as tr

# Create tracker instances.
loss_metric_tracker_dict = {
    'lt': {'loss': tr.Mean()},
    'mt': {'F1': tr.F1Score(threshold=0.6, num_classes=10, average='macro')}
}
```

### 🧮 Create loss_fn
A function that takes these certain params as written below in the code and returns scalar loss, dict of loss & metrics values.<br>

Key names **lt**, **mt** shouldn't be changed anywhere, as training loops depend on those keys. Subkey names, **loss**, **F1** are free to be changed
but must match across **loss_metric_tracker_dict** and **loss_metric_value_dict**.<br>
```python
import optax as tx

# This fn's 1st return value is differentiated wrt the fn's first param.
def loss_fn(params, apply, sample, deterministic, det_key, step):
    x, y = sample
    yp = apply(params, x, deterministic=deterministic, rngs={'dropout': det_key})
    loss = tx.softmax_cross_entropy(y, yp).mean()
    loss_metric_value_dict = {'lt': {'loss': loss}, 'mt': {'F1': (y, yp)}}
    return loss, loss_metric_value_dict
```

### 🏋️ Create Trainer Instance

```python
from fpilot import Trainer

trainer = Trainer(CNN(), input_shape, optimizer, loss_fn, loss_metric_tracker_dict)
```

### 📈 Train the Model & Evaluate
```python
train_ds = ... # tf.data.Dataset as numpy iterator
val_ds = ... # tf.data.Dataset as numpy iterator
train_steps, val_steps = 10000, 1000 # steps per epoch
ckpt_path = "/saved/model/model_1"  # If set to None, no checkpoints will be saved during training.

trainer.train(epochs, train_ds, val_ds, train_steps, val_steps, ckpt_path)
```

## Demo
Review the 'examples' folder for training tutorials. The `vae-gan-cfg-using-pretrained` notebook demonstrates how to use 
the trainer as a Python package, while the other notebooks show how to use the trainer with git clone. 
Therefore, see the vae-gan-cfg-using-pretrained for a more simpler training.


            

Raw data

            {
    "_id": null,
    "home_page": "https://github.com/NITHISHM2410/flax-pilot",
    "name": "flax-pilot",
    "maintainer": null,
    "docs_url": null,
    "requires_python": ">=3.6",
    "maintainer_email": null,
    "keywords": null,
    "author": "Nithish M",
    "author_email": "nithishm2206@gmail.com",
    "download_url": "https://files.pythonhosted.org/packages/f7/30/39aee1fab9f9906dc44c4c3635537694ad6240ab7dc9d74ff005bb39032e/flax_pilot-0.1.11.tar.gz",
    "platform": null,
    "description": "# Flax-Pilot\r\n\r\nFlax-Pilot aims to simplify the process of writing training loops for Google's Flax framework. As someone new to Flax, I started this project to deepen my understanding. This module represents a beginner's exploration into building\r\nefficient training workflows, emphasizing the need for further expertise to refine and expand its capabilities. Future plans include integrating multiple optimizer training, diverse metric modules, callbacks, and advancing towards more complex training\r\nloops, aiming to enhance its functionality and versatility. Flax-Pilot supports distributed training, ensuring scalability and efficiency across multiple devices.\r\n\r\n**As of 27-7-2024, the trainer is available as package [![PyPI version](https://img.shields.io/pypi/v/flax-pilot.svg)](https://pypi.org/project/flax-pilot/)**\r\n\r\n## How to Use?\r\n\r\n### \ud83d\udee0\ufe0f Write a flax.linen Module\r\n\r\n```python\r\nimport flax.linen as nn\r\nclass CNN(nn.Module):\r\n    @nn.compact\r\n    def __call__(self, x, deterministic):\r\n        x = nn.Conv(features=32, kernel_size=(3, 3))(x)\r\n        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\r\n        x = x.reshape((x.shape[0], -1))  \r\n        x = nn.Dense(features=256)(x)\r\n        x = nn.Dropout(rate=0.5, deterministic=deterministic)(x)\r\n        x = nn.Dense(features=10)(x)\r\n        return x\r\n```\r\n\r\n### \ud83d\udd27 Define Optimizer, Input Shapes, and Dict of Loss & Metric Trackers\r\n*Loss trackers (**lt**) takes in **scalar loss** value and averages it throughout training.*<br>\r\n*Metric trackers (**mt**) takes in **y_true, y_pred** and computes metric score and averages throughout training.*<br>\r\n\r\n```python\r\nimport optax as tx\r\n\r\nopt = tx.adam(0.0001)\r\ninput_shape = {'x': (1, 28, 28, 1)}\r\n\r\nfrom fpilot import BasicTrackers as tr\r\n\r\n# Create tracker instances.\r\nloss_metric_tracker_dict = {\r\n    'lt': {'loss': tr.Mean()},\r\n    'mt': {'F1': tr.F1Score(threshold=0.6, num_classes=10, average='macro')}\r\n}\r\n```\r\n\r\n### \ud83e\uddee Create loss_fn\r\nA function that takes these certain params as written below in the code and returns scalar loss, dict of loss & metrics values.<br>\r\n\r\nKey names **lt**, **mt** shouldn't be changed anywhere, as training loops depend on those keys. Subkey names, **loss**, **F1** are free to be changed\r\nbut must match across **loss_metric_tracker_dict** and **loss_metric_value_dict**.<br>\r\n```python\r\nimport optax as tx\r\n\r\n# This fn's 1st return value is differentiated wrt the fn's first param.\r\ndef loss_fn(params, apply, sample, deterministic, det_key, step):\r\n    x, y = sample\r\n    yp = apply(params, x, deterministic=deterministic, rngs={'dropout': det_key})\r\n    loss = tx.softmax_cross_entropy(y, yp).mean()\r\n    loss_metric_value_dict = {'lt': {'loss': loss}, 'mt': {'F1': (y, yp)}}\r\n    return loss, loss_metric_value_dict\r\n```\r\n\r\n### \ud83c\udfcb\ufe0f Create Trainer Instance\r\n\r\n```python\r\nfrom fpilot import Trainer\r\n\r\ntrainer = Trainer(CNN(), input_shape, optimizer, loss_fn, loss_metric_tracker_dict)\r\n```\r\n\r\n### \ud83d\udcc8 Train the Model & Evaluate\r\n```python\r\ntrain_ds = ... # tf.data.Dataset as numpy iterator\r\nval_ds = ... # tf.data.Dataset as numpy iterator\r\ntrain_steps, val_steps = 10000, 1000 # steps per epoch\r\nckpt_path = \"/saved/model/model_1\"  # If set to None, no checkpoints will be saved during training.\r\n\r\ntrainer.train(epochs, train_ds, val_ds, train_steps, val_steps, ckpt_path)\r\n```\r\n\r\n## Demo\r\nReview the 'examples' folder for training tutorials. The `vae-gan-cfg-using-pretrained` notebook demonstrates how to use \r\nthe trainer as a Python package, while the other notebooks show how to use the trainer with git clone. \r\nTherefore, see the vae-gan-cfg-using-pretrained for a more simpler training.\r\n\r\n",
    "bugtrack_url": null,
    "license": null,
    "summary": "A Simplistic trainer for Flax",
    "version": "0.1.11",
    "project_urls": {
        "Homepage": "https://github.com/NITHISHM2410/flax-pilot"
    },
    "split_keywords": [],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "14edf919f48ccad55dd5708fc2e3e5fbce8f0d82192c426ba9014e382f067ff1",
                "md5": "1f836f4394d714208d3ec25c6af1c9e6",
                "sha256": "5208c583f84e297f004c9bf80f1cda3c24ebf4143f585b27f9cdd91066a384ef"
            },
            "downloads": -1,
            "filename": "flax_pilot-0.1.11-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "1f836f4394d714208d3ec25c6af1c9e6",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.6",
            "size": 12728,
            "upload_time": "2024-08-05T11:36:51",
            "upload_time_iso_8601": "2024-08-05T11:36:51.944805Z",
            "url": "https://files.pythonhosted.org/packages/14/ed/f919f48ccad55dd5708fc2e3e5fbce8f0d82192c426ba9014e382f067ff1/flax_pilot-0.1.11-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "f73039aee1fab9f9906dc44c4c3635537694ad6240ab7dc9d74ff005bb39032e",
                "md5": "989750612103649c8f4c81a6564112f3",
                "sha256": "2d4021cbd13919e54dfcaf7694184f9d4a90a10d8bef74f0a504f575ef95a89d"
            },
            "downloads": -1,
            "filename": "flax_pilot-0.1.11.tar.gz",
            "has_sig": false,
            "md5_digest": "989750612103649c8f4c81a6564112f3",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.6",
            "size": 10997,
            "upload_time": "2024-08-05T11:36:53",
            "upload_time_iso_8601": "2024-08-05T11:36:53.700589Z",
            "url": "https://files.pythonhosted.org/packages/f7/30/39aee1fab9f9906dc44c4c3635537694ad6240ab7dc9d74ff005bb39032e/flax_pilot-0.1.11.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2024-08-05 11:36:53",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "NITHISHM2410",
    "github_project": "flax-pilot",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": false,
    "requirements": [
        {
            "name": "jax",
            "specs": [
                [
                    ">=",
                    "0.4.26"
                ]
            ]
        },
        {
            "name": "flax",
            "specs": [
                [
                    ">=",
                    "0.8.4"
                ]
            ]
        },
        {
            "name": "orbax-checkpoint",
            "specs": [
                [
                    ">=",
                    "0.5.15"
                ]
            ]
        },
        {
            "name": "mergedeep",
            "specs": []
        },
        {
            "name": "functools",
            "specs": []
        },
        {
            "name": "typing",
            "specs": []
        },
        {
            "name": "optax",
            "specs": []
        },
        {
            "name": "tqdm",
            "specs": []
        }
    ],
    "lcname": "flax-pilot"
}
        
Elapsed time: 0.27186s