nanodl


Namenanodl JSON
Version 1.2.1.dev1 PyPI version JSON
download
home_pagehttps://github.com/hmunachi/nanodl
SummaryA Jax-based library for designing and training transformer models from scratch.
upload_time2024-03-12 09:51:11
maintainer
docs_urlNone
authorHenry Ndubuaku
requires_python>=3.7
license
keywords transformers jax machine learning deep learning pytorch tensorflow
VCS
bugtrack_url
requirements jax jaxlib flax optax einops
Travis-CI No Travis.
coveralls test coverage No coveralls.
            <p align="center">
  <img src="assets/logo.jpg" alt="Alt text"/>
</p>

# A Jax-based library for designing and training transformer models from scratch.

![License](https://img.shields.io/github/license/hmunachi/nanodl?style=flat-square) ![Stars](https://img.shields.io/github/stars/hmunachi/nanodl?style=social) ![Forks](https://img.shields.io/github/forks/hmunachi/nanodl?style=social) ![Issues](https://img.shields.io/github/issues/hmunachi/nanodl?style=flat-square) [![LinkedIn](https://img.shields.io/badge/-LinkedIn-blue?style=flat-square&logo=linkedin&logoColor=white)](https://www.linkedin.com//company/80434055) [![Twitter](https://img.shields.io/twitter/follow/hmunachii?style=social)](https://twitter.com/hmunachii)




[**Overview**](#overview)
| [**Quick install**](#quick-install)
| [**What does NanoDL look like?**](#what-does-nanodl-look-like)
| [**Documentation**](https://nanodl.readthedocs.io/)

Author: [Henry Ndubuaku](https://www.linkedin.com/in/henry-ndubuaku-7b6350b8/)

## Overview
Developing and training transformer-based models is typically resource-intensive and time-consuming and AI/ML experts frequently need to build smaller-scale versions of these models for specific problems. Jax, a low-resource yet powerful framework, accelerates the development of neural networks, but existing resources for transformer development in Jax are limited. NanoDL addresses this challenge with the following features:

- A wide array of blocks and layers, facilitating the creation of customised transformer models from scratch.
- An extensive selection of models like Gemma, LlaMa2, Mistral, Mixtral, GPT3, GPT4 (inferred), T5, Whisper, ViT, Mixers, GAT, CLIP, and more, catering to a variety of tasks and applications.
- Data-parallel distributed trainers includding RLHF so developers can efficiently train large-scale models on multiple GPUs or TPUs, without the need for manual training loops.
- Dataloaders, making the process of data handling for Jax/Flax more straightforward and effective.
- Custom layers not found in Flax/Jax, such as RoPE, GQA, MQA, and SWin attention, allowing for more flexible model development.
- GPU/TPU-accelerated classical ML models like PCA, KMeans, Regression, Gaussian Processes etc., akin to SciKit Learn on GPU.
- Modular design so users can blend elements from various models, such as GPT, Mixtral, and LlaMa2, to craft unique hybrid transformer models.
- True random number generators in Jax which do not need the verbose code.
- A range of advanced algorithms for NLP and computer vision tasks, such as Gaussian Blur, BLEU etc.
- Each model is contained in a single file with no external dependencies, so the source code can also be easily used. 

Feedback on any of our discussion, issue and pull request threads are welcomed! Please report any feature requests, issues, questions or concerns in the [discussion forum](https://github.com/hmunachi/nanodl/discussions), or just let us know what you're working on! In case you want to reach out directly, we're at ndubuakuhenry@gmail.com.

## What's New in version 1.2.1.dev1

- Google's Gemma architecture.
- Reward model wrapper and data-parallel distributed reward trainer.
- True random number generators in Jax which do not need the verbose code (examples shown in next sections).

There are experimental features (like MAMBA architecture and RLHF) in the repo which is not available via the package, pending tests.

## Quick install

You will need Python 3.9 or later, and working [JAX](https://github.com/google/jax/blob/main/README.md)
installation, [FLAX](https://github.com/google/flax/blob/main/README.md)
installation, [OPTAX](https://github.com/google-deepmind/optax/blob/main/README.md)
installation (with GPU support for running training, without can only support creations).
Models can be designed and tested on CPUs but trainers are all Distributed Data-Parallel which would require a GPU with 1 to N GPUS/TPUS. For CPU-only version of JAX:

```
pip install --upgrade pip # To support manylinux2010 wheels.
pip install jax flax optax
```

Then, install nanodl from PyPi:

```
pip install nanodl
```

## What does nanodl look like?

We provide various example usages of the nanodl API.

```py
import jax
import jax.numpy as jnp
from nanodl import ArrayDataset, DataLoader
from nanodl import GPT4, GPTDataParallelTrainer

# Generate dummy data
batch_size = 8
max_length = 10

# Replace with actual tokenised data
data = jnp.ones((101, max_length), dtype=jnp.int32)

# Shift to create next-token prediction dataset
dummy_inputs = data[:, :-1]
dummy_targets = data[:, 1:]

# Create dataset and dataloader
dataset = ArrayDataset(dummy_inputs, dummy_targets)
dataloader = DataLoader(dataset, 
                        batch_size=batch_size, 
                        shuffle=True, 
                        drop_last=False)

# How to loop through dataloader
for batch in dataloader:
    x, y = batch
    print(x.shape, y.shape)
    break

# model parameters
hyperparams = {
    'num_layers': 1,
    'hidden_dim': 256,
    'num_heads': 2,
    'feedforward_dim': 256,
    'dropout': 0.1,
    'vocab_size': 1000,
    'embed_dim': 256,
    'max_length': max_length,
    'start_token': 0,
    'end_token': 50,
}

# Initialize model
model = GPT4(**hyperparams)
rngs = jax.random.PRNGKey(0)
rngs, dropout_rng = jax.random.split(rngs)
params = model.init({'params': rngs, 'dropout': dropout_rng}, dummy_inputs)['params']

# Call as you would a Jax/Flax model
outputs = model.apply({'params': params}, 
                      dummy_inputs, 
                      rngs={'dropout': dropout_rng})
print(outputs.shape)

# Training on data
trainer = GPTDataParallelTrainer(model, dummy_inputs.shape, 'params.pkl')
trainer.train(train_loader=dataloader, 
              num_epochs=2, 
              val_loader=dataloader)

print(trainer.evaluate(dataloader))

# Generating from a start token
start_tokens = jnp.array([[123, 456]])

# Remember to load the trained parameters 
params = trainer.load_params('params.pkl')
outputs = model.apply({'params': params},
                      start_tokens,
                      rngs={'dropout': jax.random.PRNGKey(2)}, 
                      method=model.generate)
print(outputs) 
```

Vision example

```py
import jax
import jax.numpy as jnp
from nanodl import ArrayDataset, DataLoader
from nanodl import DiffusionModel, DiffusionDataParallelTrainer

image_size = 32
block_depth = 2
batch_size = 8
widths = [32, 64, 128]
key = jax.random.PRNGKey(0)
input_shape = (101, image_size, image_size, 3)
images = jax.random.normal(key, input_shape)

# Use your own images
dataset = ArrayDataset(images) 
dataloader = DataLoader(dataset, 
                        batch_size=batch_size, 
                        shuffle=True, 
                        drop_last=False) 

# Create diffusion model
diffusion_model = DiffusionModel(image_size, widths, block_depth)
params = diffusion_model.init(key, images)
pred_noises, pred_images = diffusion_model.apply(params, images)
print(pred_noises.shape, pred_images.shape)

# Training on your data
# Note: saved params are often different from training weights, use the saved params for generation
trainer = DiffusionDataParallelTrainer(diffusion_model, 
                                       input_shape=images.shape, 
                                       weights_filename='params.pkl', 
                                       learning_rate=1e-4)
trainer.train(dataloader, 10, dataloader)
print(trainer.evaluate(dataloader))

# Generate some samples
params = trainer.load_params('params.pkl')
generated_images = diffusion_model.apply({'params': params}, 
                                         num_images=5, 
                                         diffusion_steps=5, 
                                         method=diffusion_model.generate)
print(generated_images.shape)
```

Audio example

```py
import jax
import jax.numpy as jnp
from nanodl import ArrayDataset, DataLoader
from nanodl import Whisper, WhisperDataParallelTrainer

# Dummy data parameters
batch_size = 8
max_length = 50
embed_dim = 256 
vocab_size = 1000 

# Generate data: replace with actual tokenised/quantised data
dummy_targets = jnp.ones((101, max_length), dtype=jnp.int32)
dummy_inputs = jnp.ones((101, max_length, embed_dim))

dataset = ArrayDataset(dummy_inputs, 
                       dummy_targets)

dataloader = DataLoader(dataset, 
                        batch_size=batch_size, 
                        shuffle=True, 
                        drop_last=False)

# model parameters
hyperparams = {
    'num_layers': 1,
    'hidden_dim': 256,
    'num_heads': 2,
    'feedforward_dim': 256,
    'dropout': 0.1,
    'vocab_size': 1000,
    'embed_dim': embed_dim,
    'max_length': max_length,
    'start_token': 0,
    'end_token': 50,
}

# Initialize model
model = Whisper(**hyperparams)
rngs = {'params': jax.random.key(0), 'dropout': jax.random.key(1)}
params = model.init(rngs, dummy_inputs, dummy_targets)['params']
outputs = model.apply({'params': params}, dummy_inputs, dummy_targets, rngs=rngs)
print(outputs.shape)

# Training on your data
trainer = WhisperDataParallelTrainer(model, 
                                     dummy_inputs.shape, 
                                     dummy_targets.shape, 
                                     'params.pkl')
trainer.train(dataloader, 2, dataloader)

# Sample inference
params = trainer.load_params('params.pkl')

# for more than one sample, use model.generate_batch
transcripts = model.apply({'params': params}, 
                          dummy_inputs[:1], 
                          rngs=rngs, 
                          method=model.generate)

print(transcripts)
```

Reward Model example for RLHF

```py
import jax
import jax.numpy as jnp
from nanodl import ArrayDataset, DataLoader
from nanodl import Mistral, RewardModel, RewardDataParallelTrainer

# Generate dummy data
batch_size = 8
max_length = 10

# Replace with actual tokenised data
dummy_chosen = jnp.ones((101, max_length), dtype=jnp.int32)
dummy_rejected = jnp.zeros((101, max_length), dtype=jnp.int32)

# Create dataset and dataloader
dataset = ArrayDataset(dummy_chosen, dummy_rejected)
dataloader = DataLoader(dataset, 
                        batch_size=batch_size, 
                        shuffle=True, 
                        drop_last=False)

 # model parameters
hyperparams = {
    'num_layers': 1,
    'hidden_dim': 256,
    'num_heads': 2,
    'feedforward_dim': 256,
    'dropout': 0.1,
    'vocab_size': 1000,
    'embed_dim': 256,
    'max_length': max_length,
    'start_token': 0,
    'end_token': 50,
    'num_groups': 2,
    'window_size': 5,
    'shift_size': 2
}

# Initialize reward model from Mistral
model = Mistral(**hyperparams)
reward_model = RewardModel(model, dim=hyperparams['hidden_dim'], dropout=0.1)

# Train the reward model
trainer = RewardDataParallelTrainer(reward_model, dummy_chosen.shape, 'reward_model_weights.pkl')
trainer.train(dataloader, 5, dataloader)
params = trainer.load_params('reward_model_weights.pkl')

# Call as you would a regular Flax model
rngs = jax.random.PRNGKey(0)
rngs, dropout_rng = jax.random.split(rngs)
rewards = reward_model.apply({'params': params}, 
                    dummy_chosen, 
                    rngs={'dropout': dropout_rng})

print(rewards.shape)
```

PCA example

```py
import jax
from nanodl import PCA

data = jax.random.normal(jax.random.key(0), (1000, 10))
pca = PCA(n_components=2)
pca.fit(data)
transformed_data = pca.transform(data)
original_data = pca.inverse_transform(transformed_data)
X_sampled = pca.sample(n_samples=1000, key=None)
print(X_sampled.shape, original_data.shape, transformed_data.shape)
```

NanoDL provides random module which abstracts away Jax's intricacies.
It generates truly random variables by using the current timestamp as seed.

```py
import jax 

# Jax example
key = jax.random.PRNGKey(0) 
jax_array = jax.random.uniform(key, shape=(3, 3))

# NanoDL example
jax_array = nanodl.uniform(shape=(3, 3))

# For reproducability, use seed
jax_array = nanodl.uniform(shape=(3, 3), seed=0)
```

This is the first iteration of this project, roughness is expected, contributions are therefore highly encouraged! Follow the recommended steps:

- Raise the issue/discussion to get second opinions
- Fork the repository
- Create a branch
- Make your changes without changing the design patterns
- Write tests for your changes if necessary
- Install locally with `pip install -e .`
- Run tests with `python -m unittest discover -s tests`
- Then submit a pull request from branch.

Contributions can be made in various forms:

- Writing documentation.
- Fixing bugs.
- Implementing papers.
- Writing high-coverage tests.
- Optimizing existing codes.
- Experimenting and submitting real-world examples to the examples section.
- Reporting bugs.
- Responding to reported issues.

To follow up or share thoughts, follow [here](https://forms.gle/vwveb9SKdPYywHx9A)

## Sponsorships

The name "NanoDL" stands for Nano Deep Learning. Models are exploding in size, therefore gate-keeping 
experts and companies with limited resources from building flexible models without prohibitive costs.
Following the success of Phi models, the long-term goal is to build and train nano versions of all available models,
while ensuring they compete with the original models in performance, with total 
number of parameters not exceeding 1B. Trained weights will be made available via this library.
Any form of sponsorship, funding, grants or contribution will help with training resources.
You can sponsor via the tag on the user profile, or reach out via ndubuakuhenry@gmail.com.

## Citing nanodl

To cite this repository:

```
@software{nanodl2024github,
  author = {Henry Ndubuaku},
  title = {NanoDL: A Jax-based library for designing and training transformer models from scratch.},
  url = {http://github.com/hmunachi/nanodl},
  version = {1.0.1dev},
  year = {2024},
}
```

            

Raw data

            {
    "_id": null,
    "home_page": "https://github.com/hmunachi/nanodl",
    "name": "nanodl",
    "maintainer": "",
    "docs_url": null,
    "requires_python": ">=3.7",
    "maintainer_email": "",
    "keywords": "transformers jax machine learning deep learning pytorch tensorflow",
    "author": "Henry Ndubuaku",
    "author_email": "ndubuakuhenry@gmail.com",
    "download_url": "https://files.pythonhosted.org/packages/27/91/516f0ede680e3ffd07e30d2b85cfe9901be4feb32d321cb5a4de0aa57da9/nanodl-1.2.1.dev1.tar.gz",
    "platform": null,
    "description": "<p align=\"center\">\n  <img src=\"assets/logo.jpg\" alt=\"Alt text\"/>\n</p>\n\n# A Jax-based library for designing and training transformer models from scratch.\n\n![License](https://img.shields.io/github/license/hmunachi/nanodl?style=flat-square) ![Stars](https://img.shields.io/github/stars/hmunachi/nanodl?style=social) ![Forks](https://img.shields.io/github/forks/hmunachi/nanodl?style=social) ![Issues](https://img.shields.io/github/issues/hmunachi/nanodl?style=flat-square) [![LinkedIn](https://img.shields.io/badge/-LinkedIn-blue?style=flat-square&logo=linkedin&logoColor=white)](https://www.linkedin.com//company/80434055) [![Twitter](https://img.shields.io/twitter/follow/hmunachii?style=social)](https://twitter.com/hmunachii)\n\n\n\n\n[**Overview**](#overview)\n| [**Quick install**](#quick-install)\n| [**What does NanoDL look like?**](#what-does-nanodl-look-like)\n| [**Documentation**](https://nanodl.readthedocs.io/)\n\nAuthor: [Henry Ndubuaku](https://www.linkedin.com/in/henry-ndubuaku-7b6350b8/)\n\n## Overview\nDeveloping and training transformer-based models is typically resource-intensive and time-consuming and AI/ML experts frequently need to build smaller-scale versions of these models for specific problems. Jax, a low-resource yet powerful framework, accelerates the development of neural networks, but existing resources for transformer development in Jax are limited. NanoDL addresses this challenge with the following features:\n\n- A wide array of blocks and layers, facilitating the creation of customised transformer models from scratch.\n- An extensive selection of models like Gemma, LlaMa2, Mistral, Mixtral, GPT3, GPT4 (inferred), T5, Whisper, ViT, Mixers, GAT, CLIP, and more, catering to a variety of tasks and applications.\n- Data-parallel distributed trainers includding RLHF so developers can efficiently train large-scale models on multiple GPUs or TPUs, without the need for manual training loops.\n- Dataloaders, making the process of data handling for Jax/Flax more straightforward and effective.\n- Custom layers not found in Flax/Jax, such as RoPE, GQA, MQA, and SWin attention, allowing for more flexible model development.\n- GPU/TPU-accelerated classical ML models like PCA, KMeans, Regression, Gaussian Processes etc., akin to SciKit Learn on GPU.\n- Modular design so users can blend elements from various models, such as GPT, Mixtral, and LlaMa2, to craft unique hybrid transformer models.\n- True random number generators in Jax which do not need the verbose code.\n- A range of advanced algorithms for NLP and computer vision tasks, such as Gaussian Blur, BLEU etc.\n- Each model is contained in a single file with no external dependencies, so the source code can also be easily used. \n\nFeedback on any of our discussion, issue and pull request threads are welcomed! Please report any feature requests, issues, questions or concerns in the [discussion forum](https://github.com/hmunachi/nanodl/discussions), or just let us know what you're working on! In case you want to reach out directly, we're at ndubuakuhenry@gmail.com.\n\n## What's New in version 1.2.1.dev1\n\n- Google's Gemma architecture.\n- Reward model wrapper and data-parallel distributed reward trainer.\n- True random number generators in Jax which do not need the verbose code (examples shown in next sections).\n\nThere are experimental features (like MAMBA architecture and RLHF) in the repo which is not available via the package, pending tests.\n\n## Quick install\n\nYou will need Python 3.9 or later, and working [JAX](https://github.com/google/jax/blob/main/README.md)\ninstallation, [FLAX](https://github.com/google/flax/blob/main/README.md)\ninstallation, [OPTAX](https://github.com/google-deepmind/optax/blob/main/README.md)\ninstallation (with GPU support for running training, without can only support creations).\nModels can be designed and tested on CPUs but trainers are all Distributed Data-Parallel which would require a GPU with 1 to N GPUS/TPUS. For CPU-only version of JAX:\n\n```\npip install --upgrade pip # To support manylinux2010 wheels.\npip install jax flax optax\n```\n\nThen, install nanodl from PyPi:\n\n```\npip install nanodl\n```\n\n## What does nanodl look like?\n\nWe provide various example usages of the nanodl API.\n\n```py\nimport jax\nimport jax.numpy as jnp\nfrom nanodl import ArrayDataset, DataLoader\nfrom nanodl import GPT4, GPTDataParallelTrainer\n\n# Generate dummy data\nbatch_size = 8\nmax_length = 10\n\n# Replace with actual tokenised data\ndata = jnp.ones((101, max_length), dtype=jnp.int32)\n\n# Shift to create next-token prediction dataset\ndummy_inputs = data[:, :-1]\ndummy_targets = data[:, 1:]\n\n# Create dataset and dataloader\ndataset = ArrayDataset(dummy_inputs, dummy_targets)\ndataloader = DataLoader(dataset, \n                        batch_size=batch_size, \n                        shuffle=True, \n                        drop_last=False)\n\n# How to loop through dataloader\nfor batch in dataloader:\n    x, y = batch\n    print(x.shape, y.shape)\n    break\n\n# model parameters\nhyperparams = {\n    'num_layers': 1,\n    'hidden_dim': 256,\n    'num_heads': 2,\n    'feedforward_dim': 256,\n    'dropout': 0.1,\n    'vocab_size': 1000,\n    'embed_dim': 256,\n    'max_length': max_length,\n    'start_token': 0,\n    'end_token': 50,\n}\n\n# Initialize model\nmodel = GPT4(**hyperparams)\nrngs = jax.random.PRNGKey(0)\nrngs, dropout_rng = jax.random.split(rngs)\nparams = model.init({'params': rngs, 'dropout': dropout_rng}, dummy_inputs)['params']\n\n# Call as you would a Jax/Flax model\noutputs = model.apply({'params': params}, \n                      dummy_inputs, \n                      rngs={'dropout': dropout_rng})\nprint(outputs.shape)\n\n# Training on data\ntrainer = GPTDataParallelTrainer(model, dummy_inputs.shape, 'params.pkl')\ntrainer.train(train_loader=dataloader, \n              num_epochs=2, \n              val_loader=dataloader)\n\nprint(trainer.evaluate(dataloader))\n\n# Generating from a start token\nstart_tokens = jnp.array([[123, 456]])\n\n# Remember to load the trained parameters \nparams = trainer.load_params('params.pkl')\noutputs = model.apply({'params': params},\n                      start_tokens,\n                      rngs={'dropout': jax.random.PRNGKey(2)}, \n                      method=model.generate)\nprint(outputs) \n```\n\nVision example\n\n```py\nimport jax\nimport jax.numpy as jnp\nfrom nanodl import ArrayDataset, DataLoader\nfrom nanodl import DiffusionModel, DiffusionDataParallelTrainer\n\nimage_size = 32\nblock_depth = 2\nbatch_size = 8\nwidths = [32, 64, 128]\nkey = jax.random.PRNGKey(0)\ninput_shape = (101, image_size, image_size, 3)\nimages = jax.random.normal(key, input_shape)\n\n# Use your own images\ndataset = ArrayDataset(images) \ndataloader = DataLoader(dataset, \n                        batch_size=batch_size, \n                        shuffle=True, \n                        drop_last=False) \n\n# Create diffusion model\ndiffusion_model = DiffusionModel(image_size, widths, block_depth)\nparams = diffusion_model.init(key, images)\npred_noises, pred_images = diffusion_model.apply(params, images)\nprint(pred_noises.shape, pred_images.shape)\n\n# Training on your data\n# Note: saved params are often different from training weights, use the saved params for generation\ntrainer = DiffusionDataParallelTrainer(diffusion_model, \n                                       input_shape=images.shape, \n                                       weights_filename='params.pkl', \n                                       learning_rate=1e-4)\ntrainer.train(dataloader, 10, dataloader)\nprint(trainer.evaluate(dataloader))\n\n# Generate some samples\nparams = trainer.load_params('params.pkl')\ngenerated_images = diffusion_model.apply({'params': params}, \n                                         num_images=5, \n                                         diffusion_steps=5, \n                                         method=diffusion_model.generate)\nprint(generated_images.shape)\n```\n\nAudio example\n\n```py\nimport jax\nimport jax.numpy as jnp\nfrom nanodl import ArrayDataset, DataLoader\nfrom nanodl import Whisper, WhisperDataParallelTrainer\n\n# Dummy data parameters\nbatch_size = 8\nmax_length = 50\nembed_dim = 256 \nvocab_size = 1000 \n\n# Generate data: replace with actual tokenised/quantised data\ndummy_targets = jnp.ones((101, max_length), dtype=jnp.int32)\ndummy_inputs = jnp.ones((101, max_length, embed_dim))\n\ndataset = ArrayDataset(dummy_inputs, \n                       dummy_targets)\n\ndataloader = DataLoader(dataset, \n                        batch_size=batch_size, \n                        shuffle=True, \n                        drop_last=False)\n\n# model parameters\nhyperparams = {\n    'num_layers': 1,\n    'hidden_dim': 256,\n    'num_heads': 2,\n    'feedforward_dim': 256,\n    'dropout': 0.1,\n    'vocab_size': 1000,\n    'embed_dim': embed_dim,\n    'max_length': max_length,\n    'start_token': 0,\n    'end_token': 50,\n}\n\n# Initialize model\nmodel = Whisper(**hyperparams)\nrngs = {'params': jax.random.key(0), 'dropout': jax.random.key(1)}\nparams = model.init(rngs, dummy_inputs, dummy_targets)['params']\noutputs = model.apply({'params': params}, dummy_inputs, dummy_targets, rngs=rngs)\nprint(outputs.shape)\n\n# Training on your data\ntrainer = WhisperDataParallelTrainer(model, \n                                     dummy_inputs.shape, \n                                     dummy_targets.shape, \n                                     'params.pkl')\ntrainer.train(dataloader, 2, dataloader)\n\n# Sample inference\nparams = trainer.load_params('params.pkl')\n\n# for more than one sample, use model.generate_batch\ntranscripts = model.apply({'params': params}, \n                          dummy_inputs[:1], \n                          rngs=rngs, \n                          method=model.generate)\n\nprint(transcripts)\n```\n\nReward Model example for RLHF\n\n```py\nimport jax\nimport jax.numpy as jnp\nfrom nanodl import ArrayDataset, DataLoader\nfrom nanodl import Mistral, RewardModel, RewardDataParallelTrainer\n\n# Generate dummy data\nbatch_size = 8\nmax_length = 10\n\n# Replace with actual tokenised data\ndummy_chosen = jnp.ones((101, max_length), dtype=jnp.int32)\ndummy_rejected = jnp.zeros((101, max_length), dtype=jnp.int32)\n\n# Create dataset and dataloader\ndataset = ArrayDataset(dummy_chosen, dummy_rejected)\ndataloader = DataLoader(dataset, \n                        batch_size=batch_size, \n                        shuffle=True, \n                        drop_last=False)\n\n # model parameters\nhyperparams = {\n    'num_layers': 1,\n    'hidden_dim': 256,\n    'num_heads': 2,\n    'feedforward_dim': 256,\n    'dropout': 0.1,\n    'vocab_size': 1000,\n    'embed_dim': 256,\n    'max_length': max_length,\n    'start_token': 0,\n    'end_token': 50,\n    'num_groups': 2,\n    'window_size': 5,\n    'shift_size': 2\n}\n\n# Initialize reward model from Mistral\nmodel = Mistral(**hyperparams)\nreward_model = RewardModel(model, dim=hyperparams['hidden_dim'], dropout=0.1)\n\n# Train the reward model\ntrainer = RewardDataParallelTrainer(reward_model, dummy_chosen.shape, 'reward_model_weights.pkl')\ntrainer.train(dataloader, 5, dataloader)\nparams = trainer.load_params('reward_model_weights.pkl')\n\n# Call as you would a regular Flax model\nrngs = jax.random.PRNGKey(0)\nrngs, dropout_rng = jax.random.split(rngs)\nrewards = reward_model.apply({'params': params}, \n                    dummy_chosen, \n                    rngs={'dropout': dropout_rng})\n\nprint(rewards.shape)\n```\n\nPCA example\n\n```py\nimport jax\nfrom nanodl import PCA\n\ndata = jax.random.normal(jax.random.key(0), (1000, 10))\npca = PCA(n_components=2)\npca.fit(data)\ntransformed_data = pca.transform(data)\noriginal_data = pca.inverse_transform(transformed_data)\nX_sampled = pca.sample(n_samples=1000, key=None)\nprint(X_sampled.shape, original_data.shape, transformed_data.shape)\n```\n\nNanoDL provides random module which abstracts away Jax's intricacies.\nIt generates truly random variables by using the current timestamp as seed.\n\n```py\nimport jax \n\n# Jax example\nkey = jax.random.PRNGKey(0) \njax_array = jax.random.uniform(key, shape=(3, 3))\n\n# NanoDL example\njax_array = nanodl.uniform(shape=(3, 3))\n\n# For reproducability, use seed\njax_array = nanodl.uniform(shape=(3, 3), seed=0)\n```\n\nThis is the first iteration of this project, roughness is expected, contributions are therefore highly encouraged! Follow the recommended steps:\n\n- Raise the issue/discussion to get second opinions\n- Fork the repository\n- Create a branch\n- Make your changes without changing the design patterns\n- Write tests for your changes if necessary\n- Install locally with `pip install -e .`\n- Run tests with `python -m unittest discover -s tests`\n- Then submit a pull request from branch.\n\nContributions can be made in various forms:\n\n- Writing documentation.\n- Fixing bugs.\n- Implementing papers.\n- Writing high-coverage tests.\n- Optimizing existing codes.\n- Experimenting and submitting real-world examples to the examples section.\n- Reporting bugs.\n- Responding to reported issues.\n\nTo follow up or share thoughts, follow [here](https://forms.gle/vwveb9SKdPYywHx9A)\n\n## Sponsorships\n\nThe name \"NanoDL\" stands for Nano Deep Learning. Models are exploding in size, therefore gate-keeping \nexperts and companies with limited resources from building flexible models without prohibitive costs.\nFollowing the success of Phi models, the long-term goal is to build and train nano versions of all available models,\nwhile ensuring they compete with the original models in performance, with total \nnumber of parameters not exceeding 1B. Trained weights will be made available via this library.\nAny form of sponsorship, funding, grants or contribution will help with training resources.\nYou can sponsor via the tag on the user profile, or reach out via ndubuakuhenry@gmail.com.\n\n## Citing nanodl\n\nTo cite this repository:\n\n```\n@software{nanodl2024github,\n  author = {Henry Ndubuaku},\n  title = {NanoDL: A Jax-based library for designing and training transformer models from scratch.},\n  url = {http://github.com/hmunachi/nanodl},\n  version = {1.0.1dev},\n  year = {2024},\n}\n```\n",
    "bugtrack_url": null,
    "license": "",
    "summary": "A Jax-based library for designing and training transformer models from scratch.",
    "version": "1.2.1.dev1",
    "project_urls": {
        "Homepage": "https://github.com/hmunachi/nanodl"
    },
    "split_keywords": [
        "transformers",
        "jax",
        "machine",
        "learning",
        "deep",
        "learning",
        "pytorch",
        "tensorflow"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "562fc8e1615144c02b1ad7bb644c777e73a9f029d503b7c3db362dda89126447",
                "md5": "f26d07cf5fbf7f9b83c2dc82dbfba15d",
                "sha256": "079cb0fb210b094ce85bec0afec63e366105fa2662d2c88d0a8b279676c7e179"
            },
            "downloads": -1,
            "filename": "nanodl-1.2.1.dev1-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "f26d07cf5fbf7f9b83c2dc82dbfba15d",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.7",
            "size": 130893,
            "upload_time": "2024-03-12T09:51:10",
            "upload_time_iso_8601": "2024-03-12T09:51:10.192899Z",
            "url": "https://files.pythonhosted.org/packages/56/2f/c8e1615144c02b1ad7bb644c777e73a9f029d503b7c3db362dda89126447/nanodl-1.2.1.dev1-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "2791516f0ede680e3ffd07e30d2b85cfe9901be4feb32d321cb5a4de0aa57da9",
                "md5": "5152d43de4f83408b11d55007a1c0e28",
                "sha256": "e1215445501842753afd0c37f7e8c6ae0a0e7b0d9d779c7438b6553abc808339"
            },
            "downloads": -1,
            "filename": "nanodl-1.2.1.dev1.tar.gz",
            "has_sig": false,
            "md5_digest": "5152d43de4f83408b11d55007a1c0e28",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.7",
            "size": 103525,
            "upload_time": "2024-03-12T09:51:11",
            "upload_time_iso_8601": "2024-03-12T09:51:11.391328Z",
            "url": "https://files.pythonhosted.org/packages/27/91/516f0ede680e3ffd07e30d2b85cfe9901be4feb32d321cb5a4de0aa57da9/nanodl-1.2.1.dev1.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2024-03-12 09:51:11",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "hmunachi",
    "github_project": "nanodl",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "requirements": [
        {
            "name": "jax",
            "specs": []
        },
        {
            "name": "jaxlib",
            "specs": []
        },
        {
            "name": "flax",
            "specs": []
        },
        {
            "name": "optax",
            "specs": []
        },
        {
            "name": "einops",
            "specs": []
        }
    ],
    "lcname": "nanodl"
}
        
Elapsed time: 0.19852s