mlx-image


Namemlx-image JSON
Version 0.1.8 PyPI version JSON
download
home_pageNone
SummaryApple MLX image models library
upload_time2024-11-24 18:08:28
maintainerNone
docs_urlNone
authorRiccardo Musmeci
requires_python<4.0,>=3.10
licenseNone
keywords
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # **mlx-image**
Image models based on [Apple MLX framework](https://github.com/ml-explore/mlx) for Apple Silicon machines.

## **Why?**

Apple MLX framework is a great tool to run machine learning models on Apple Silicon machines.

This repository is meant to convert image models from timm/torchvision to Apple MLX framework. The weights are just converted from .pth to .npz/.safetensors and the models **are not trained again**.

## How to install

```bash
pip install mlx-image
```

## Models

Model weights are available on the [`mlx-vision`](https://huggingface.co/mlx-vision) community on HuggingFace.

To load a model with pre-trained weights:
```python
from mlxim.model import create_model

# loading weights from HuggingFace (https://huggingface.co/mlx-vision/resnet18-mlxim)
model = create_model("resnet18") # pretrained weights loaded from HF

# loading weights from local file
model = create_model("resnet18", weights="path/to/resnet18/model.safetensors")
```

To list all available models:

```python
from mlxim.model import list_models
list_models()
```

### Supported models

List of all models available in `mlx-image`:

* **ResNet**: resnet18, resnet34, resnet50, resnet101, resnet152, wide_resnet50_2, wide_resnet101_2
* **ViT**:
    * **supervised**: vit_base_patch16_224, vit_base_patch16_224.swag_lin, vit_base_patch16_384.swag_e2e, vit_base_patch32_224, vit_large_patch16_224, vit_large_patch16_224, vit_large_patch16_224.swag_lin, vit_large_patch16_512.swag_e2e, vit_huge_patch14_224.swag_lin, vit_huge_patch14_518.swag_e2e
    
    * **DINO v1**: vit_base_patch16_224.dino, vit_small_patch16_224.dino, vit_small_patch8_224.dino, vit_base_patch8_224.dino

    * **DINO v2**: vit_small_patch14_518.dinov2, vit_base_patch14_518.dinov2, vit_large_patch14_518.dinov2
* **Swin**: swin_tiny_patch4_window7_224, swin_small_patch4_window7_224, swin_base_patch4_window7_224, swin_v2_tiny_patch4_window8_256, swin_v2_small_patch4_window8_256, swin_v2_base_patch4_window8_256
* **RegNet**: regnet_x_400mf, regnet_x_800mf, regnet_x_1_6gf, regnet_x_3_2gf, regnet_x_8gf, regnet_x_16gf, regnet_x_32gf, regnet_y_400mf, regnet_y_800mf, regnet_y_1_6gf, regnet_y_3_2gf, regnet_y_8gf, regnet_y_16gf, regnet_y_32gf, regnet_y_128gf

> **Warning**: The `regnet_y_128gf` model couldn't be tested due to computational limitations.

## ImageNet-1K Results

Go to [results-imagenet-1k.csv](https://github.com/riccardomusmeci/mlx-image/blob/main/results/results-imagenet-1k.csv) to check every model converted to `mlx-image` and its performance on ImageNet-1K with different settings.

> **TL;DR** performance is comparable to the original models from PyTorch implementations.


## Similarity to PyTorch and other familiar tools

`mlx-image` tries to be as close as possible to PyTorch:
- `DataLoader` -> you can define your own `collate_fn` and also use `num_workers` to speed up data loading
- `Dataset` -> `mlx-image` already supports `LabelFolderDataset` (the good and old PyTorch `ImageFolder`) and `FolderDataset` (a generic folder with images in it)
- `ModelCheckpoint` -> keeps track of the best model and saves it to disk (similar to PyTorchLightning). It also suggests early stopping

## Training

Training is similar to PyTorch. Here's an example of how to train a model:

```python
import mlx.nn as nn
import mlx.optimizers as optim
from mlxim.model import create_model
from mlxim.data import LabelFolderDataset, DataLoader

train_dataset = LabelFolderDataset(
    root_dir="path/to/train",
    class_map={0: "class_0", 1: "class_1", 2: ["class_2", "class_3"]}
)
train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=4
)
model = create_model("resnet18") # pretrained weights loaded from HF
optimizer = optim.Adam(learning_rate=1e-3)

def train_step(model, inputs, targets):
    logits = model(inputs)
    loss = mx.mean(nn.losses.cross_entropy(logits, target))
    return loss

model.train()
for epoch in range(10):
    for batch in train_loader:
        x, target = batch
        train_step_fn = nn.value_and_grad(model, train_step)
        loss, grads = train_step_fn(x, target)
        optimizer.update(model, grads)
        mx.eval(model.state, optimizer.state)
```

## **Validation**

The `validation.py` script is run every time a pth model is converted to mlx and it's used to check if the model performs similarly to the original one on ImageNet-1K.

I use the configuration file `config/validation.yaml` to set the parameters for the validation script.

You can download the ImageNet-1K validation set from mlx-vision space on HuggingFace at this [link](https://huggingface.co/datasets/mlx-vision/imagenet-1k).

## **Contributing**

This is a work in progress, so any help is appreciated.

I am working on it in my spare time, so I can't guarantee frequent updates.

If you love coding and want to contribute, follow the instructions in [CONTRIBUTING.md](CONTRIBUTING.md).

## Additional Resources

* [mlx-vision community](https://huggingface.co/mlx-vision)
* [HuggingFace doc](https://huggingface.co/docs/hub/main/en/mlx-image)

## **To-Dos**

[ ] inference script (similar to train/validation)

[ ] DenseNet

[ ] MobileNet

## Contact

If you have any questions, please email `riccardomusmeci92@gmail.com`.

            

Raw data

            {
    "_id": null,
    "home_page": null,
    "name": "mlx-image",
    "maintainer": null,
    "docs_url": null,
    "requires_python": "<4.0,>=3.10",
    "maintainer_email": null,
    "keywords": null,
    "author": "Riccardo Musmeci",
    "author_email": "riccardomusmeci92@gmail.com",
    "download_url": "https://files.pythonhosted.org/packages/69/66/98dd8ffb515ad521cc9cc957dec2b49e2bf611ed9949723513548f3a8a90/mlx_image-0.1.8.tar.gz",
    "platform": null,
    "description": "# **mlx-image**\nImage models based on [Apple MLX framework](https://github.com/ml-explore/mlx) for Apple Silicon machines.\n\n## **Why?**\n\nApple MLX framework is a great tool to run machine learning models on Apple Silicon machines.\n\nThis repository is meant to convert image models from timm/torchvision to Apple MLX framework. The weights are just converted from .pth to .npz/.safetensors and the models **are not trained again**.\n\n## How to install\n\n```bash\npip install mlx-image\n```\n\n## Models\n\nModel weights are available on the [`mlx-vision`](https://huggingface.co/mlx-vision) community on HuggingFace.\n\nTo load a model with pre-trained weights:\n```python\nfrom mlxim.model import create_model\n\n# loading weights from HuggingFace (https://huggingface.co/mlx-vision/resnet18-mlxim)\nmodel = create_model(\"resnet18\") # pretrained weights loaded from HF\n\n# loading weights from local file\nmodel = create_model(\"resnet18\", weights=\"path/to/resnet18/model.safetensors\")\n```\n\nTo list all available models:\n\n```python\nfrom mlxim.model import list_models\nlist_models()\n```\n\n### Supported models\n\nList of all models available in `mlx-image`:\n\n* **ResNet**: resnet18, resnet34, resnet50, resnet101, resnet152, wide_resnet50_2, wide_resnet101_2\n* **ViT**:\n    * **supervised**: vit_base_patch16_224, vit_base_patch16_224.swag_lin, vit_base_patch16_384.swag_e2e, vit_base_patch32_224, vit_large_patch16_224, vit_large_patch16_224, vit_large_patch16_224.swag_lin, vit_large_patch16_512.swag_e2e, vit_huge_patch14_224.swag_lin, vit_huge_patch14_518.swag_e2e\n    \n    * **DINO v1**: vit_base_patch16_224.dino, vit_small_patch16_224.dino, vit_small_patch8_224.dino, vit_base_patch8_224.dino\n\n    * **DINO v2**: vit_small_patch14_518.dinov2, vit_base_patch14_518.dinov2, vit_large_patch14_518.dinov2\n* **Swin**: swin_tiny_patch4_window7_224, swin_small_patch4_window7_224, swin_base_patch4_window7_224, swin_v2_tiny_patch4_window8_256, swin_v2_small_patch4_window8_256, swin_v2_base_patch4_window8_256\n* **RegNet**: regnet_x_400mf, regnet_x_800mf, regnet_x_1_6gf, regnet_x_3_2gf, regnet_x_8gf, regnet_x_16gf, regnet_x_32gf, regnet_y_400mf, regnet_y_800mf, regnet_y_1_6gf, regnet_y_3_2gf, regnet_y_8gf, regnet_y_16gf, regnet_y_32gf, regnet_y_128gf\n\n> **Warning**: The `regnet_y_128gf` model couldn't be tested due to computational limitations.\n\n## ImageNet-1K Results\n\nGo to [results-imagenet-1k.csv](https://github.com/riccardomusmeci/mlx-image/blob/main/results/results-imagenet-1k.csv) to check every model converted to `mlx-image` and its performance on ImageNet-1K with different settings.\n\n> **TL;DR** performance is comparable to the original models from PyTorch implementations.\n\n\n## Similarity to PyTorch and other familiar tools\n\n`mlx-image` tries to be as close as possible to PyTorch:\n- `DataLoader` -> you can define your own `collate_fn` and also use `num_workers` to speed up data loading\n- `Dataset` -> `mlx-image` already supports `LabelFolderDataset` (the good and old PyTorch `ImageFolder`) and `FolderDataset` (a generic folder with images in it)\n- `ModelCheckpoint` -> keeps track of the best model and saves it to disk (similar to PyTorchLightning). It also suggests early stopping\n\n## Training\n\nTraining is similar to PyTorch. Here's an example of how to train a model:\n\n```python\nimport mlx.nn as nn\nimport mlx.optimizers as optim\nfrom mlxim.model import create_model\nfrom mlxim.data import LabelFolderDataset, DataLoader\n\ntrain_dataset = LabelFolderDataset(\n    root_dir=\"path/to/train\",\n    class_map={0: \"class_0\", 1: \"class_1\", 2: [\"class_2\", \"class_3\"]}\n)\ntrain_loader = DataLoader(\n    dataset=train_dataset,\n    batch_size=32,\n    shuffle=True,\n    num_workers=4\n)\nmodel = create_model(\"resnet18\") # pretrained weights loaded from HF\noptimizer = optim.Adam(learning_rate=1e-3)\n\ndef train_step(model, inputs, targets):\n    logits = model(inputs)\n    loss = mx.mean(nn.losses.cross_entropy(logits, target))\n    return loss\n\nmodel.train()\nfor epoch in range(10):\n    for batch in train_loader:\n        x, target = batch\n        train_step_fn = nn.value_and_grad(model, train_step)\n        loss, grads = train_step_fn(x, target)\n        optimizer.update(model, grads)\n        mx.eval(model.state, optimizer.state)\n```\n\n## **Validation**\n\nThe `validation.py` script is run every time a pth model is converted to mlx and it's used to check if the model performs similarly to the original one on ImageNet-1K.\n\nI use the configuration file `config/validation.yaml` to set the parameters for the validation script.\n\nYou can download the ImageNet-1K validation set from mlx-vision space on HuggingFace at this [link](https://huggingface.co/datasets/mlx-vision/imagenet-1k).\n\n## **Contributing**\n\nThis is a work in progress, so any help is appreciated.\n\nI am working on it in my spare time, so I can't guarantee frequent updates.\n\nIf you love coding and want to contribute, follow the instructions in [CONTRIBUTING.md](CONTRIBUTING.md).\n\n## Additional Resources\n\n* [mlx-vision community](https://huggingface.co/mlx-vision)\n* [HuggingFace doc](https://huggingface.co/docs/hub/main/en/mlx-image)\n\n## **To-Dos**\n\n[ ] inference script (similar to train/validation)\n\n[ ] DenseNet\n\n[ ] MobileNet\n\n## Contact\n\nIf you have any questions, please email `riccardomusmeci92@gmail.com`.\n",
    "bugtrack_url": null,
    "license": null,
    "summary": "Apple MLX image models library",
    "version": "0.1.8",
    "project_urls": null,
    "split_keywords": [],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "fc5599f29393f25f4d74a794a78ece2131c5c5d5a2ccdc3917232f16b3bac346",
                "md5": "07ac662fabf80920496e287cd78296c6",
                "sha256": "64d4346e913b7c17a419c09d820f22ac9bd543d80409ba61e6e59c7ff00c9b56"
            },
            "downloads": -1,
            "filename": "mlx_image-0.1.8-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "07ac662fabf80920496e287cd78296c6",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": "<4.0,>=3.10",
            "size": 70619,
            "upload_time": "2024-11-24T18:08:26",
            "upload_time_iso_8601": "2024-11-24T18:08:26.322801Z",
            "url": "https://files.pythonhosted.org/packages/fc/55/99f29393f25f4d74a794a78ece2131c5c5d5a2ccdc3917232f16b3bac346/mlx_image-0.1.8-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "696698dd8ffb515ad521cc9cc957dec2b49e2bf611ed9949723513548f3a8a90",
                "md5": "678654f23ec55ccfc2bcb4a4a3423013",
                "sha256": "fd64a9017801a6ae620c05f4f867987bb6b59a80ff88d0844af4d6b2cad8df22"
            },
            "downloads": -1,
            "filename": "mlx_image-0.1.8.tar.gz",
            "has_sig": false,
            "md5_digest": "678654f23ec55ccfc2bcb4a4a3423013",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": "<4.0,>=3.10",
            "size": 61092,
            "upload_time": "2024-11-24T18:08:28",
            "upload_time_iso_8601": "2024-11-24T18:08:28.141537Z",
            "url": "https://files.pythonhosted.org/packages/69/66/98dd8ffb515ad521cc9cc957dec2b49e2bf611ed9949723513548f3a8a90/mlx_image-0.1.8.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2024-11-24 18:08:28",
    "github": false,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "lcname": "mlx-image"
}
        
Elapsed time: 0.54247s