# image-classification-jax
Run image classification experiments in JAX with ViT, resnet, cifar10, cifar100, imagenette, and imagenet.
Meant to be simple but good quality. Includes:
- ViT with qk normalization, swiglu, empty registers
- Palm style z-loss (https://arxiv.org/pdf/2204.02311)
- ability to use schedule-free from `optax.contrib`
- ability to use PSGD optimizers from `psgd-jax` with hessian calc
- datasets currently implemented include cifar10, cifar100, imagenette, and imagenet
Currently no model sharding, only data parallelism (automatically splits batch `batch_size/n_devices`).
## Installation
```bash
pip install image-classification-jax
```
## Usage
Set your wandb key either in your python script or through command line:
```bash
export WANDB_API_KEY=<your_key>
```
Use `run_experiment` to run an experiment. Here's how you could run an experiment with
PSGD affine optimizer wrapped with schedule-free:
```python
import optax
from image_classification_jax.run_experiment import run_experiment
from psgd_jax.affine import affine
base_lr = 0.001
warmup = 256
lr = optax.join_schedules(
schedules=[
optax.linear_schedule(0.0, base_lr, warmup),
optax.constant_schedule(base_lr),
],
boundaries=[warmup],
)
psgd_opt = optax.chain(
optax.clip_by_global_norm(1.0),
affine(
lr,
preconditioner_update_probability=1.0,
b1=0.0,
weight_decay=0.0,
max_size_triangular=0,
max_skew_triangular=0,
precond_init_scale=1.0,
),
)
optimizer = optax.contrib.schedule_free(psgd_opt, learning_rate=lr, b1=0.95)
run_experiment(
log_to_wandb=True,
wandb_entity="",
wandb_project="image_classification_jax",
wandb_config_update={ # extra logging info for wandb
"optimizer": "psgd_affine",
"lr": base_lr,
"warmup": warmup,
"b1": 0.95,
"schedule_free": True,
},
global_seed=100,
dataset="cifar10",
batch_size=64,
n_epochs=10,
optimizer=optimizer,
compute_in_bfloat16=False,
l2_regularization=0.0001,
randomize_l2_reg=False,
apply_z_loss=True,
model_type="vit",
n_layers=4,
enc_dim=64,
n_heads=4,
n_empty_registers=0,
dropout_rate=0.0,
using_schedule_free=True, # set to True if optimizer wrapped with schedule_free
psgd_calc_hessian=False, # set to True if using PSGD and want to calc and pass in hessian
psgd_precond_update_prob=1.0,
)
```
### TODO:
- [ ] Add SAM, ASAM, Momentum-SAM
- [ ] Add loss landscape flatness logging
- [ ] Add logging for optimizer output norm, hessian norm
Raw data
{
"_id": null,
"home_page": null,
"name": "image-classification-jax",
"maintainer": null,
"docs_url": null,
"requires_python": ">=3.9",
"maintainer_email": null,
"keywords": "python, machine learning, optimization, jax",
"author": "Evan Walters",
"author_email": null,
"download_url": "https://files.pythonhosted.org/packages/72/9c/f29804d2da4c0ff236bc8cdc3a0951aaefaaccdec3e467c8697bc01e179b/image_classification_jax-0.1.3.tar.gz",
"platform": null,
"description": "# image-classification-jax\n\nRun image classification experiments in JAX with ViT, resnet, cifar10, cifar100, imagenette, and imagenet.\n\nMeant to be simple but good quality. Includes:\n- ViT with qk normalization, swiglu, empty registers\n- Palm style z-loss (https://arxiv.org/pdf/2204.02311)\n- ability to use schedule-free from `optax.contrib`\n- ability to use PSGD optimizers from `psgd-jax` with hessian calc\n- datasets currently implemented include cifar10, cifar100, imagenette, and imagenet\n\nCurrently no model sharding, only data parallelism (automatically splits batch `batch_size/n_devices`).\n\n\n## Installation\n\n```bash\npip install image-classification-jax\n```\n\n\n## Usage\n\nSet your wandb key either in your python script or through command line:\n```bash\nexport WANDB_API_KEY=<your_key>\n```\n\nUse `run_experiment` to run an experiment. Here's how you could run an experiment with\nPSGD affine optimizer wrapped with schedule-free:\n\n```python\nimport optax\nfrom image_classification_jax.run_experiment import run_experiment\nfrom psgd_jax.affine import affine\n\nbase_lr = 0.001\nwarmup = 256\nlr = optax.join_schedules(\n schedules=[\n optax.linear_schedule(0.0, base_lr, warmup),\n optax.constant_schedule(base_lr),\n ],\n boundaries=[warmup],\n)\n\npsgd_opt = optax.chain(\n optax.clip_by_global_norm(1.0),\n affine(\n lr,\n preconditioner_update_probability=1.0,\n b1=0.0,\n weight_decay=0.0,\n max_size_triangular=0,\n max_skew_triangular=0,\n precond_init_scale=1.0,\n ),\n)\n\noptimizer = optax.contrib.schedule_free(psgd_opt, learning_rate=lr, b1=0.95)\n\nrun_experiment(\n log_to_wandb=True,\n wandb_entity=\"\",\n wandb_project=\"image_classification_jax\",\n wandb_config_update={ # extra logging info for wandb\n \"optimizer\": \"psgd_affine\",\n \"lr\": base_lr,\n \"warmup\": warmup,\n \"b1\": 0.95,\n \"schedule_free\": True,\n },\n global_seed=100,\n dataset=\"cifar10\",\n batch_size=64,\n n_epochs=10,\n optimizer=optimizer,\n compute_in_bfloat16=False,\n l2_regularization=0.0001,\n randomize_l2_reg=False,\n apply_z_loss=True,\n model_type=\"vit\",\n n_layers=4,\n enc_dim=64,\n n_heads=4,\n n_empty_registers=0,\n dropout_rate=0.0,\n using_schedule_free=True, # set to True if optimizer wrapped with schedule_free\n psgd_calc_hessian=False, # set to True if using PSGD and want to calc and pass in hessian\n psgd_precond_update_prob=1.0,\n)\n```\n\n\n### TODO:\n\n- [ ] Add SAM, ASAM, Momentum-SAM\n- [ ] Add loss landscape flatness logging\n- [ ] Add logging for optimizer output norm, hessian norm\n",
"bugtrack_url": null,
"license": null,
"summary": "Run image classification experiments in JAX with ViT, resnet, cifar10, cifar100, imagenette, and imagenet.",
"version": "0.1.3",
"project_urls": {
"homepage": "https://github.com/evanatyourservice/image-classification-jax",
"repository": "https://github.com/evanatyourservice/image-classification-jax"
},
"split_keywords": [
"python",
" machine learning",
" optimization",
" jax"
],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "51fb9496467fe018d8a8e95e7a13deeb83dc176e18b6583a016e0b508459e085",
"md5": "03f89b84907f3e45bbb4c0540ada7c97",
"sha256": "bd1ad372838a71bf4ee41f109839b5720e6ea01a3644e6bf2d6fc154d47dd323"
},
"downloads": -1,
"filename": "image_classification_jax-0.1.3-py3-none-any.whl",
"has_sig": false,
"md5_digest": "03f89b84907f3e45bbb4c0540ada7c97",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": ">=3.9",
"size": 20746,
"upload_time": "2024-10-07T05:00:07",
"upload_time_iso_8601": "2024-10-07T05:00:07.144725Z",
"url": "https://files.pythonhosted.org/packages/51/fb/9496467fe018d8a8e95e7a13deeb83dc176e18b6583a016e0b508459e085/image_classification_jax-0.1.3-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": "",
"digests": {
"blake2b_256": "729cf29804d2da4c0ff236bc8cdc3a0951aaefaaccdec3e467c8697bc01e179b",
"md5": "cd17dad2caeb5f92d245b055642b9f4e",
"sha256": "db1e9ba6eb2b263d2c18a3e4d54ee1ee4b0c107e8189e537855691af0900c716"
},
"downloads": -1,
"filename": "image_classification_jax-0.1.3.tar.gz",
"has_sig": false,
"md5_digest": "cd17dad2caeb5f92d245b055642b9f4e",
"packagetype": "sdist",
"python_version": "source",
"requires_python": ">=3.9",
"size": 17478,
"upload_time": "2024-10-07T05:00:08",
"upload_time_iso_8601": "2024-10-07T05:00:08.370883Z",
"url": "https://files.pythonhosted.org/packages/72/9c/f29804d2da4c0ff236bc8cdc3a0951aaefaaccdec3e467c8697bc01e179b/image_classification_jax-0.1.3.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2024-10-07 05:00:08",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "evanatyourservice",
"github_project": "image-classification-jax",
"travis_ci": false,
"coveralls": false,
"github_actions": false,
"requirements": [
{
"name": "jax",
"specs": []
},
{
"name": "flax",
"specs": []
},
{
"name": "einops",
"specs": []
},
{
"name": "optax",
"specs": []
},
{
"name": "numpy",
"specs": []
},
{
"name": "tensorflow-cpu",
"specs": []
},
{
"name": "tensorflow-datasets",
"specs": []
},
{
"name": "wandb",
"specs": []
}
],
"lcname": "image-classification-jax"
}