image-classification-jax


Nameimage-classification-jax JSON
Version 0.1.4 PyPI version JSON
download
home_pageNone
SummaryRun image classification experiments in JAX with ViT, resnet, cifar10, cifar100, imagenette, and imagenet.
upload_time2024-12-18 23:59:49
maintainerNone
docs_urlNone
authorEvan Walters
requires_python>=3.9
licenseNone
keywords python machine learning optimization jax
VCS
bugtrack_url
requirements jax flax einops optax numpy tensorflow-cpu tensorflow-datasets wandb psgd-jax
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # image-classification-jax

Run image classification experiments in JAX with ViT, resnet, cifar10, cifar100, imagenette, and imagenet.

Meant to be simple but good quality. Includes:
- ViT with qk normalization, swiglu, empty registers
- Palm style z-loss (https://arxiv.org/pdf/2204.02311)
- ability to use schedule-free from `optax.contrib`
- datasets currently implemented include cifar10, cifar100, imagenette, and imagenet

Currently no model sharding, only data parallelism (automatically splits batch `batch_size/n_devices`).


## Installation

```bash
pip install image-classification-jax
```


## Usage

Set your wandb key either in your python script or through command line:
```bash
export WANDB_API_KEY=<your_key>
```

Use `run_experiment` to run an experiment. Here's how you could run an experiment with
PSGD affine optimizer wrapped with schedule-free:

```python
import optax
from image_classification_jax.run_experiment import run_experiment
from psgd_jax.affine import affine

base_lr = 0.001
warmup = 256
lr = optax.join_schedules(
    schedules=[
        optax.linear_schedule(0.0, base_lr, warmup),
        optax.constant_schedule(base_lr),
    ],
    boundaries=[warmup],
)

psgd_opt = optax.chain(
    optax.clip_by_global_norm(1.0),
    affine(
        lr,
        preconditioner_update_probability=1.0,
        b1=0.0,
        weight_decay=0.0,
        max_size_triangular=0,
        max_skew_triangular=0,
        precond_init_scale=1.0,
    ),
)

optimizer = optax.contrib.schedule_free(psgd_opt, learning_rate=lr, b1=0.95)

run_experiment(
    log_to_wandb=True,
    wandb_entity="",
    wandb_project="image_classification_jax",
    wandb_config_update={  # extra logging info for wandb
        "optimizer": "psgd_affine",
        "lr": base_lr,
        "warmup": warmup,
        "b1": 0.95,
        "schedule_free": True,
    },
    global_seed=100,
    dataset="cifar10",
    batch_size=64,
    n_epochs=10,
    optimizer=optimizer,
    compute_in_bfloat16=False,
    apply_z_loss=True,
    model_type="vit",
    n_layers=4,
    enc_dim=64,
    n_heads=4,
    n_empty_registers=0,
    dropout_rate=0.0,
    using_schedule_free=True,  # set to True if optimizer wrapped with schedule_free
)
```


            

Raw data

            {
    "_id": null,
    "home_page": null,
    "name": "image-classification-jax",
    "maintainer": null,
    "docs_url": null,
    "requires_python": ">=3.9",
    "maintainer_email": null,
    "keywords": "python, machine learning, optimization, jax",
    "author": "Evan Walters",
    "author_email": null,
    "download_url": "https://files.pythonhosted.org/packages/8a/e3/edc9fc4f3274e37e83aef6effa4393a9cb3948b0da5414df1e67c7eb6d44/image_classification_jax-0.1.4.tar.gz",
    "platform": null,
    "description": "# image-classification-jax\n\nRun image classification experiments in JAX with ViT, resnet, cifar10, cifar100, imagenette, and imagenet.\n\nMeant to be simple but good quality. Includes:\n- ViT with qk normalization, swiglu, empty registers\n- Palm style z-loss (https://arxiv.org/pdf/2204.02311)\n- ability to use schedule-free from `optax.contrib`\n- datasets currently implemented include cifar10, cifar100, imagenette, and imagenet\n\nCurrently no model sharding, only data parallelism (automatically splits batch `batch_size/n_devices`).\n\n\n## Installation\n\n```bash\npip install image-classification-jax\n```\n\n\n## Usage\n\nSet your wandb key either in your python script or through command line:\n```bash\nexport WANDB_API_KEY=<your_key>\n```\n\nUse `run_experiment` to run an experiment. Here's how you could run an experiment with\nPSGD affine optimizer wrapped with schedule-free:\n\n```python\nimport optax\nfrom image_classification_jax.run_experiment import run_experiment\nfrom psgd_jax.affine import affine\n\nbase_lr = 0.001\nwarmup = 256\nlr = optax.join_schedules(\n    schedules=[\n        optax.linear_schedule(0.0, base_lr, warmup),\n        optax.constant_schedule(base_lr),\n    ],\n    boundaries=[warmup],\n)\n\npsgd_opt = optax.chain(\n    optax.clip_by_global_norm(1.0),\n    affine(\n        lr,\n        preconditioner_update_probability=1.0,\n        b1=0.0,\n        weight_decay=0.0,\n        max_size_triangular=0,\n        max_skew_triangular=0,\n        precond_init_scale=1.0,\n    ),\n)\n\noptimizer = optax.contrib.schedule_free(psgd_opt, learning_rate=lr, b1=0.95)\n\nrun_experiment(\n    log_to_wandb=True,\n    wandb_entity=\"\",\n    wandb_project=\"image_classification_jax\",\n    wandb_config_update={  # extra logging info for wandb\n        \"optimizer\": \"psgd_affine\",\n        \"lr\": base_lr,\n        \"warmup\": warmup,\n        \"b1\": 0.95,\n        \"schedule_free\": True,\n    },\n    global_seed=100,\n    dataset=\"cifar10\",\n    batch_size=64,\n    n_epochs=10,\n    optimizer=optimizer,\n    compute_in_bfloat16=False,\n    apply_z_loss=True,\n    model_type=\"vit\",\n    n_layers=4,\n    enc_dim=64,\n    n_heads=4,\n    n_empty_registers=0,\n    dropout_rate=0.0,\n    using_schedule_free=True,  # set to True if optimizer wrapped with schedule_free\n)\n```\n\n",
    "bugtrack_url": null,
    "license": null,
    "summary": "Run image classification experiments in JAX with ViT, resnet, cifar10, cifar100, imagenette, and imagenet.",
    "version": "0.1.4",
    "project_urls": {
        "homepage": "https://github.com/evanatyourservice/image-classification-jax",
        "repository": "https://github.com/evanatyourservice/image-classification-jax"
    },
    "split_keywords": [
        "python",
        " machine learning",
        " optimization",
        " jax"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "a94ce1d8f73faf937f9750607d0395bfdcf49f63ad1332c33b5b16b876b1c496",
                "md5": "72656b7a7ec8a601f155f687466aacf4",
                "sha256": "1e2e004ee4fa18ce385d5eb1eac8edde66ae8d23b1dd0ad758e8e0b0c326d899"
            },
            "downloads": -1,
            "filename": "image_classification_jax-0.1.4-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "72656b7a7ec8a601f155f687466aacf4",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.9",
            "size": 21286,
            "upload_time": "2024-12-18T23:59:47",
            "upload_time_iso_8601": "2024-12-18T23:59:47.049346Z",
            "url": "https://files.pythonhosted.org/packages/a9/4c/e1d8f73faf937f9750607d0395bfdcf49f63ad1332c33b5b16b876b1c496/image_classification_jax-0.1.4-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "8ae3edc9fc4f3274e37e83aef6effa4393a9cb3948b0da5414df1e67c7eb6d44",
                "md5": "e980e0cca17ffc864537810e7b82e801",
                "sha256": "35db2de3ffa15dd6b29304118ea3f0b34510bc8bd413233331252fe24458064f"
            },
            "downloads": -1,
            "filename": "image_classification_jax-0.1.4.tar.gz",
            "has_sig": false,
            "md5_digest": "e980e0cca17ffc864537810e7b82e801",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.9",
            "size": 17890,
            "upload_time": "2024-12-18T23:59:49",
            "upload_time_iso_8601": "2024-12-18T23:59:49.165128Z",
            "url": "https://files.pythonhosted.org/packages/8a/e3/edc9fc4f3274e37e83aef6effa4393a9cb3948b0da5414df1e67c7eb6d44/image_classification_jax-0.1.4.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2024-12-18 23:59:49",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "evanatyourservice",
    "github_project": "image-classification-jax",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": false,
    "requirements": [
        {
            "name": "jax",
            "specs": []
        },
        {
            "name": "flax",
            "specs": []
        },
        {
            "name": "einops",
            "specs": []
        },
        {
            "name": "optax",
            "specs": []
        },
        {
            "name": "numpy",
            "specs": []
        },
        {
            "name": "tensorflow-cpu",
            "specs": []
        },
        {
            "name": "tensorflow-datasets",
            "specs": []
        },
        {
            "name": "wandb",
            "specs": []
        },
        {
            "name": "psgd-jax",
            "specs": []
        }
    ],
    "lcname": "image-classification-jax"
}
        
Elapsed time: 0.76555s