image-classification-jax


Nameimage-classification-jax JSON
Version 0.1.3 PyPI version JSON
download
home_pageNone
SummaryRun image classification experiments in JAX with ViT, resnet, cifar10, cifar100, imagenette, and imagenet.
upload_time2024-10-07 05:00:08
maintainerNone
docs_urlNone
authorEvan Walters
requires_python>=3.9
licenseNone
keywords python machine learning optimization jax
VCS
bugtrack_url
requirements jax flax einops optax numpy tensorflow-cpu tensorflow-datasets wandb
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # image-classification-jax

Run image classification experiments in JAX with ViT, resnet, cifar10, cifar100, imagenette, and imagenet.

Meant to be simple but good quality. Includes:
- ViT with qk normalization, swiglu, empty registers
- Palm style z-loss (https://arxiv.org/pdf/2204.02311)
- ability to use schedule-free from `optax.contrib`
- ability to use PSGD optimizers from `psgd-jax` with hessian calc
- datasets currently implemented include cifar10, cifar100, imagenette, and imagenet

Currently no model sharding, only data parallelism (automatically splits batch `batch_size/n_devices`).


## Installation

```bash
pip install image-classification-jax
```


## Usage

Set your wandb key either in your python script or through command line:
```bash
export WANDB_API_KEY=<your_key>
```

Use `run_experiment` to run an experiment. Here's how you could run an experiment with
PSGD affine optimizer wrapped with schedule-free:

```python
import optax
from image_classification_jax.run_experiment import run_experiment
from psgd_jax.affine import affine

base_lr = 0.001
warmup = 256
lr = optax.join_schedules(
    schedules=[
        optax.linear_schedule(0.0, base_lr, warmup),
        optax.constant_schedule(base_lr),
    ],
    boundaries=[warmup],
)

psgd_opt = optax.chain(
    optax.clip_by_global_norm(1.0),
    affine(
        lr,
        preconditioner_update_probability=1.0,
        b1=0.0,
        weight_decay=0.0,
        max_size_triangular=0,
        max_skew_triangular=0,
        precond_init_scale=1.0,
    ),
)

optimizer = optax.contrib.schedule_free(psgd_opt, learning_rate=lr, b1=0.95)

run_experiment(
    log_to_wandb=True,
    wandb_entity="",
    wandb_project="image_classification_jax",
    wandb_config_update={  # extra logging info for wandb
        "optimizer": "psgd_affine",
        "lr": base_lr,
        "warmup": warmup,
        "b1": 0.95,
        "schedule_free": True,
    },
    global_seed=100,
    dataset="cifar10",
    batch_size=64,
    n_epochs=10,
    optimizer=optimizer,
    compute_in_bfloat16=False,
    l2_regularization=0.0001,
    randomize_l2_reg=False,
    apply_z_loss=True,
    model_type="vit",
    n_layers=4,
    enc_dim=64,
    n_heads=4,
    n_empty_registers=0,
    dropout_rate=0.0,
    using_schedule_free=True,  # set to True if optimizer wrapped with schedule_free
    psgd_calc_hessian=False,  # set to True if using PSGD and want to calc and pass in hessian
    psgd_precond_update_prob=1.0,
)
```


### TODO:

- [ ] Add SAM, ASAM, Momentum-SAM
- [ ] Add loss landscape flatness logging
- [ ] Add logging for optimizer output norm, hessian norm

            

Raw data

            {
    "_id": null,
    "home_page": null,
    "name": "image-classification-jax",
    "maintainer": null,
    "docs_url": null,
    "requires_python": ">=3.9",
    "maintainer_email": null,
    "keywords": "python, machine learning, optimization, jax",
    "author": "Evan Walters",
    "author_email": null,
    "download_url": "https://files.pythonhosted.org/packages/72/9c/f29804d2da4c0ff236bc8cdc3a0951aaefaaccdec3e467c8697bc01e179b/image_classification_jax-0.1.3.tar.gz",
    "platform": null,
    "description": "# image-classification-jax\n\nRun image classification experiments in JAX with ViT, resnet, cifar10, cifar100, imagenette, and imagenet.\n\nMeant to be simple but good quality. Includes:\n- ViT with qk normalization, swiglu, empty registers\n- Palm style z-loss (https://arxiv.org/pdf/2204.02311)\n- ability to use schedule-free from `optax.contrib`\n- ability to use PSGD optimizers from `psgd-jax` with hessian calc\n- datasets currently implemented include cifar10, cifar100, imagenette, and imagenet\n\nCurrently no model sharding, only data parallelism (automatically splits batch `batch_size/n_devices`).\n\n\n## Installation\n\n```bash\npip install image-classification-jax\n```\n\n\n## Usage\n\nSet your wandb key either in your python script or through command line:\n```bash\nexport WANDB_API_KEY=<your_key>\n```\n\nUse `run_experiment` to run an experiment. Here's how you could run an experiment with\nPSGD affine optimizer wrapped with schedule-free:\n\n```python\nimport optax\nfrom image_classification_jax.run_experiment import run_experiment\nfrom psgd_jax.affine import affine\n\nbase_lr = 0.001\nwarmup = 256\nlr = optax.join_schedules(\n    schedules=[\n        optax.linear_schedule(0.0, base_lr, warmup),\n        optax.constant_schedule(base_lr),\n    ],\n    boundaries=[warmup],\n)\n\npsgd_opt = optax.chain(\n    optax.clip_by_global_norm(1.0),\n    affine(\n        lr,\n        preconditioner_update_probability=1.0,\n        b1=0.0,\n        weight_decay=0.0,\n        max_size_triangular=0,\n        max_skew_triangular=0,\n        precond_init_scale=1.0,\n    ),\n)\n\noptimizer = optax.contrib.schedule_free(psgd_opt, learning_rate=lr, b1=0.95)\n\nrun_experiment(\n    log_to_wandb=True,\n    wandb_entity=\"\",\n    wandb_project=\"image_classification_jax\",\n    wandb_config_update={  # extra logging info for wandb\n        \"optimizer\": \"psgd_affine\",\n        \"lr\": base_lr,\n        \"warmup\": warmup,\n        \"b1\": 0.95,\n        \"schedule_free\": True,\n    },\n    global_seed=100,\n    dataset=\"cifar10\",\n    batch_size=64,\n    n_epochs=10,\n    optimizer=optimizer,\n    compute_in_bfloat16=False,\n    l2_regularization=0.0001,\n    randomize_l2_reg=False,\n    apply_z_loss=True,\n    model_type=\"vit\",\n    n_layers=4,\n    enc_dim=64,\n    n_heads=4,\n    n_empty_registers=0,\n    dropout_rate=0.0,\n    using_schedule_free=True,  # set to True if optimizer wrapped with schedule_free\n    psgd_calc_hessian=False,  # set to True if using PSGD and want to calc and pass in hessian\n    psgd_precond_update_prob=1.0,\n)\n```\n\n\n### TODO:\n\n- [ ] Add SAM, ASAM, Momentum-SAM\n- [ ] Add loss landscape flatness logging\n- [ ] Add logging for optimizer output norm, hessian norm\n",
    "bugtrack_url": null,
    "license": null,
    "summary": "Run image classification experiments in JAX with ViT, resnet, cifar10, cifar100, imagenette, and imagenet.",
    "version": "0.1.3",
    "project_urls": {
        "homepage": "https://github.com/evanatyourservice/image-classification-jax",
        "repository": "https://github.com/evanatyourservice/image-classification-jax"
    },
    "split_keywords": [
        "python",
        " machine learning",
        " optimization",
        " jax"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "51fb9496467fe018d8a8e95e7a13deeb83dc176e18b6583a016e0b508459e085",
                "md5": "03f89b84907f3e45bbb4c0540ada7c97",
                "sha256": "bd1ad372838a71bf4ee41f109839b5720e6ea01a3644e6bf2d6fc154d47dd323"
            },
            "downloads": -1,
            "filename": "image_classification_jax-0.1.3-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "03f89b84907f3e45bbb4c0540ada7c97",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.9",
            "size": 20746,
            "upload_time": "2024-10-07T05:00:07",
            "upload_time_iso_8601": "2024-10-07T05:00:07.144725Z",
            "url": "https://files.pythonhosted.org/packages/51/fb/9496467fe018d8a8e95e7a13deeb83dc176e18b6583a016e0b508459e085/image_classification_jax-0.1.3-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "729cf29804d2da4c0ff236bc8cdc3a0951aaefaaccdec3e467c8697bc01e179b",
                "md5": "cd17dad2caeb5f92d245b055642b9f4e",
                "sha256": "db1e9ba6eb2b263d2c18a3e4d54ee1ee4b0c107e8189e537855691af0900c716"
            },
            "downloads": -1,
            "filename": "image_classification_jax-0.1.3.tar.gz",
            "has_sig": false,
            "md5_digest": "cd17dad2caeb5f92d245b055642b9f4e",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.9",
            "size": 17478,
            "upload_time": "2024-10-07T05:00:08",
            "upload_time_iso_8601": "2024-10-07T05:00:08.370883Z",
            "url": "https://files.pythonhosted.org/packages/72/9c/f29804d2da4c0ff236bc8cdc3a0951aaefaaccdec3e467c8697bc01e179b/image_classification_jax-0.1.3.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2024-10-07 05:00:08",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "evanatyourservice",
    "github_project": "image-classification-jax",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": false,
    "requirements": [
        {
            "name": "jax",
            "specs": []
        },
        {
            "name": "flax",
            "specs": []
        },
        {
            "name": "einops",
            "specs": []
        },
        {
            "name": "optax",
            "specs": []
        },
        {
            "name": "numpy",
            "specs": []
        },
        {
            "name": "tensorflow-cpu",
            "specs": []
        },
        {
            "name": "tensorflow-datasets",
            "specs": []
        },
        {
            "name": "wandb",
            "specs": []
        }
    ],
    "lcname": "image-classification-jax"
}
        
Elapsed time: 0.67337s