# image-classification-jax
Run image classification experiments in JAX with ViT, resnet, cifar10, cifar100, imagenette, and imagenet.
Meant to be simple but good quality. Includes:
- ViT with qk normalization, swiglu, empty registers
- Palm style z-loss (https://arxiv.org/pdf/2204.02311)
- ability to use schedule-free from `optax.contrib`
- datasets currently implemented include cifar10, cifar100, imagenette, and imagenet
Currently no model sharding, only data parallelism (automatically splits batch `batch_size/n_devices`).
## Installation
```bash
pip install image-classification-jax
```
## Usage
Set your wandb key either in your python script or through command line:
```bash
export WANDB_API_KEY=<your_key>
```
Use `run_experiment` to run an experiment. Here's how you could run an experiment with
PSGD affine optimizer wrapped with schedule-free:
```python
import optax
from image_classification_jax.run_experiment import run_experiment
from psgd_jax.affine import affine
base_lr = 0.001
warmup = 256
lr = optax.join_schedules(
schedules=[
optax.linear_schedule(0.0, base_lr, warmup),
optax.constant_schedule(base_lr),
],
boundaries=[warmup],
)
psgd_opt = optax.chain(
optax.clip_by_global_norm(1.0),
affine(
lr,
preconditioner_update_probability=1.0,
b1=0.0,
weight_decay=0.0,
max_size_triangular=0,
max_skew_triangular=0,
precond_init_scale=1.0,
),
)
optimizer = optax.contrib.schedule_free(psgd_opt, learning_rate=lr, b1=0.95)
run_experiment(
log_to_wandb=True,
wandb_entity="",
wandb_project="image_classification_jax",
wandb_config_update={ # extra logging info for wandb
"optimizer": "psgd_affine",
"lr": base_lr,
"warmup": warmup,
"b1": 0.95,
"schedule_free": True,
},
global_seed=100,
dataset="cifar10",
batch_size=64,
n_epochs=10,
optimizer=optimizer,
compute_in_bfloat16=False,
apply_z_loss=True,
model_type="vit",
n_layers=4,
enc_dim=64,
n_heads=4,
n_empty_registers=0,
dropout_rate=0.0,
using_schedule_free=True, # set to True if optimizer wrapped with schedule_free
)
```
Raw data
{
"_id": null,
"home_page": null,
"name": "image-classification-jax",
"maintainer": null,
"docs_url": null,
"requires_python": ">=3.9",
"maintainer_email": null,
"keywords": "python, machine learning, optimization, jax",
"author": "Evan Walters",
"author_email": null,
"download_url": "https://files.pythonhosted.org/packages/8a/e3/edc9fc4f3274e37e83aef6effa4393a9cb3948b0da5414df1e67c7eb6d44/image_classification_jax-0.1.4.tar.gz",
"platform": null,
"description": "# image-classification-jax\n\nRun image classification experiments in JAX with ViT, resnet, cifar10, cifar100, imagenette, and imagenet.\n\nMeant to be simple but good quality. Includes:\n- ViT with qk normalization, swiglu, empty registers\n- Palm style z-loss (https://arxiv.org/pdf/2204.02311)\n- ability to use schedule-free from `optax.contrib`\n- datasets currently implemented include cifar10, cifar100, imagenette, and imagenet\n\nCurrently no model sharding, only data parallelism (automatically splits batch `batch_size/n_devices`).\n\n\n## Installation\n\n```bash\npip install image-classification-jax\n```\n\n\n## Usage\n\nSet your wandb key either in your python script or through command line:\n```bash\nexport WANDB_API_KEY=<your_key>\n```\n\nUse `run_experiment` to run an experiment. Here's how you could run an experiment with\nPSGD affine optimizer wrapped with schedule-free:\n\n```python\nimport optax\nfrom image_classification_jax.run_experiment import run_experiment\nfrom psgd_jax.affine import affine\n\nbase_lr = 0.001\nwarmup = 256\nlr = optax.join_schedules(\n schedules=[\n optax.linear_schedule(0.0, base_lr, warmup),\n optax.constant_schedule(base_lr),\n ],\n boundaries=[warmup],\n)\n\npsgd_opt = optax.chain(\n optax.clip_by_global_norm(1.0),\n affine(\n lr,\n preconditioner_update_probability=1.0,\n b1=0.0,\n weight_decay=0.0,\n max_size_triangular=0,\n max_skew_triangular=0,\n precond_init_scale=1.0,\n ),\n)\n\noptimizer = optax.contrib.schedule_free(psgd_opt, learning_rate=lr, b1=0.95)\n\nrun_experiment(\n log_to_wandb=True,\n wandb_entity=\"\",\n wandb_project=\"image_classification_jax\",\n wandb_config_update={ # extra logging info for wandb\n \"optimizer\": \"psgd_affine\",\n \"lr\": base_lr,\n \"warmup\": warmup,\n \"b1\": 0.95,\n \"schedule_free\": True,\n },\n global_seed=100,\n dataset=\"cifar10\",\n batch_size=64,\n n_epochs=10,\n optimizer=optimizer,\n compute_in_bfloat16=False,\n apply_z_loss=True,\n model_type=\"vit\",\n n_layers=4,\n enc_dim=64,\n n_heads=4,\n n_empty_registers=0,\n dropout_rate=0.0,\n using_schedule_free=True, # set to True if optimizer wrapped with schedule_free\n)\n```\n\n",
"bugtrack_url": null,
"license": null,
"summary": "Run image classification experiments in JAX with ViT, resnet, cifar10, cifar100, imagenette, and imagenet.",
"version": "0.1.4",
"project_urls": {
"homepage": "https://github.com/evanatyourservice/image-classification-jax",
"repository": "https://github.com/evanatyourservice/image-classification-jax"
},
"split_keywords": [
"python",
" machine learning",
" optimization",
" jax"
],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "a94ce1d8f73faf937f9750607d0395bfdcf49f63ad1332c33b5b16b876b1c496",
"md5": "72656b7a7ec8a601f155f687466aacf4",
"sha256": "1e2e004ee4fa18ce385d5eb1eac8edde66ae8d23b1dd0ad758e8e0b0c326d899"
},
"downloads": -1,
"filename": "image_classification_jax-0.1.4-py3-none-any.whl",
"has_sig": false,
"md5_digest": "72656b7a7ec8a601f155f687466aacf4",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": ">=3.9",
"size": 21286,
"upload_time": "2024-12-18T23:59:47",
"upload_time_iso_8601": "2024-12-18T23:59:47.049346Z",
"url": "https://files.pythonhosted.org/packages/a9/4c/e1d8f73faf937f9750607d0395bfdcf49f63ad1332c33b5b16b876b1c496/image_classification_jax-0.1.4-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": "",
"digests": {
"blake2b_256": "8ae3edc9fc4f3274e37e83aef6effa4393a9cb3948b0da5414df1e67c7eb6d44",
"md5": "e980e0cca17ffc864537810e7b82e801",
"sha256": "35db2de3ffa15dd6b29304118ea3f0b34510bc8bd413233331252fe24458064f"
},
"downloads": -1,
"filename": "image_classification_jax-0.1.4.tar.gz",
"has_sig": false,
"md5_digest": "e980e0cca17ffc864537810e7b82e801",
"packagetype": "sdist",
"python_version": "source",
"requires_python": ">=3.9",
"size": 17890,
"upload_time": "2024-12-18T23:59:49",
"upload_time_iso_8601": "2024-12-18T23:59:49.165128Z",
"url": "https://files.pythonhosted.org/packages/8a/e3/edc9fc4f3274e37e83aef6effa4393a9cb3948b0da5414df1e67c7eb6d44/image_classification_jax-0.1.4.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2024-12-18 23:59:49",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "evanatyourservice",
"github_project": "image-classification-jax",
"travis_ci": false,
"coveralls": false,
"github_actions": false,
"requirements": [
{
"name": "jax",
"specs": []
},
{
"name": "flax",
"specs": []
},
{
"name": "einops",
"specs": []
},
{
"name": "optax",
"specs": []
},
{
"name": "numpy",
"specs": []
},
{
"name": "tensorflow-cpu",
"specs": []
},
{
"name": "tensorflow-datasets",
"specs": []
},
{
"name": "wandb",
"specs": []
},
{
"name": "psgd-jax",
"specs": []
}
],
"lcname": "image-classification-jax"
}