Name | mlx-image JSON |
Version |
0.1.8
JSON |
| download |
home_page | None |
Summary | Apple MLX image models library |
upload_time | 2024-11-24 18:08:28 |
maintainer | None |
docs_url | None |
author | Riccardo Musmeci |
requires_python | <4.0,>=3.10 |
license | None |
keywords |
|
VCS |
|
bugtrack_url |
|
requirements |
No requirements were recorded.
|
Travis-CI |
No Travis.
|
coveralls test coverage |
No coveralls.
|
# **mlx-image**
Image models based on [Apple MLX framework](https://github.com/ml-explore/mlx) for Apple Silicon machines.
## **Why?**
Apple MLX framework is a great tool to run machine learning models on Apple Silicon machines.
This repository is meant to convert image models from timm/torchvision to Apple MLX framework. The weights are just converted from .pth to .npz/.safetensors and the models **are not trained again**.
## How to install
```bash
pip install mlx-image
```
## Models
Model weights are available on the [`mlx-vision`](https://huggingface.co/mlx-vision) community on HuggingFace.
To load a model with pre-trained weights:
```python
from mlxim.model import create_model
# loading weights from HuggingFace (https://huggingface.co/mlx-vision/resnet18-mlxim)
model = create_model("resnet18") # pretrained weights loaded from HF
# loading weights from local file
model = create_model("resnet18", weights="path/to/resnet18/model.safetensors")
```
To list all available models:
```python
from mlxim.model import list_models
list_models()
```
### Supported models
List of all models available in `mlx-image`:
* **ResNet**: resnet18, resnet34, resnet50, resnet101, resnet152, wide_resnet50_2, wide_resnet101_2
* **ViT**:
* **supervised**: vit_base_patch16_224, vit_base_patch16_224.swag_lin, vit_base_patch16_384.swag_e2e, vit_base_patch32_224, vit_large_patch16_224, vit_large_patch16_224, vit_large_patch16_224.swag_lin, vit_large_patch16_512.swag_e2e, vit_huge_patch14_224.swag_lin, vit_huge_patch14_518.swag_e2e
* **DINO v1**: vit_base_patch16_224.dino, vit_small_patch16_224.dino, vit_small_patch8_224.dino, vit_base_patch8_224.dino
* **DINO v2**: vit_small_patch14_518.dinov2, vit_base_patch14_518.dinov2, vit_large_patch14_518.dinov2
* **Swin**: swin_tiny_patch4_window7_224, swin_small_patch4_window7_224, swin_base_patch4_window7_224, swin_v2_tiny_patch4_window8_256, swin_v2_small_patch4_window8_256, swin_v2_base_patch4_window8_256
* **RegNet**: regnet_x_400mf, regnet_x_800mf, regnet_x_1_6gf, regnet_x_3_2gf, regnet_x_8gf, regnet_x_16gf, regnet_x_32gf, regnet_y_400mf, regnet_y_800mf, regnet_y_1_6gf, regnet_y_3_2gf, regnet_y_8gf, regnet_y_16gf, regnet_y_32gf, regnet_y_128gf
> **Warning**: The `regnet_y_128gf` model couldn't be tested due to computational limitations.
## ImageNet-1K Results
Go to [results-imagenet-1k.csv](https://github.com/riccardomusmeci/mlx-image/blob/main/results/results-imagenet-1k.csv) to check every model converted to `mlx-image` and its performance on ImageNet-1K with different settings.
> **TL;DR** performance is comparable to the original models from PyTorch implementations.
## Similarity to PyTorch and other familiar tools
`mlx-image` tries to be as close as possible to PyTorch:
- `DataLoader` -> you can define your own `collate_fn` and also use `num_workers` to speed up data loading
- `Dataset` -> `mlx-image` already supports `LabelFolderDataset` (the good and old PyTorch `ImageFolder`) and `FolderDataset` (a generic folder with images in it)
- `ModelCheckpoint` -> keeps track of the best model and saves it to disk (similar to PyTorchLightning). It also suggests early stopping
## Training
Training is similar to PyTorch. Here's an example of how to train a model:
```python
import mlx.nn as nn
import mlx.optimizers as optim
from mlxim.model import create_model
from mlxim.data import LabelFolderDataset, DataLoader
train_dataset = LabelFolderDataset(
root_dir="path/to/train",
class_map={0: "class_0", 1: "class_1", 2: ["class_2", "class_3"]}
)
train_loader = DataLoader(
dataset=train_dataset,
batch_size=32,
shuffle=True,
num_workers=4
)
model = create_model("resnet18") # pretrained weights loaded from HF
optimizer = optim.Adam(learning_rate=1e-3)
def train_step(model, inputs, targets):
logits = model(inputs)
loss = mx.mean(nn.losses.cross_entropy(logits, target))
return loss
model.train()
for epoch in range(10):
for batch in train_loader:
x, target = batch
train_step_fn = nn.value_and_grad(model, train_step)
loss, grads = train_step_fn(x, target)
optimizer.update(model, grads)
mx.eval(model.state, optimizer.state)
```
## **Validation**
The `validation.py` script is run every time a pth model is converted to mlx and it's used to check if the model performs similarly to the original one on ImageNet-1K.
I use the configuration file `config/validation.yaml` to set the parameters for the validation script.
You can download the ImageNet-1K validation set from mlx-vision space on HuggingFace at this [link](https://huggingface.co/datasets/mlx-vision/imagenet-1k).
## **Contributing**
This is a work in progress, so any help is appreciated.
I am working on it in my spare time, so I can't guarantee frequent updates.
If you love coding and want to contribute, follow the instructions in [CONTRIBUTING.md](CONTRIBUTING.md).
## Additional Resources
* [mlx-vision community](https://huggingface.co/mlx-vision)
* [HuggingFace doc](https://huggingface.co/docs/hub/main/en/mlx-image)
## **To-Dos**
[ ] inference script (similar to train/validation)
[ ] DenseNet
[ ] MobileNet
## Contact
If you have any questions, please email `riccardomusmeci92@gmail.com`.
Raw data
{
"_id": null,
"home_page": null,
"name": "mlx-image",
"maintainer": null,
"docs_url": null,
"requires_python": "<4.0,>=3.10",
"maintainer_email": null,
"keywords": null,
"author": "Riccardo Musmeci",
"author_email": "riccardomusmeci92@gmail.com",
"download_url": "https://files.pythonhosted.org/packages/69/66/98dd8ffb515ad521cc9cc957dec2b49e2bf611ed9949723513548f3a8a90/mlx_image-0.1.8.tar.gz",
"platform": null,
"description": "# **mlx-image**\nImage models based on [Apple MLX framework](https://github.com/ml-explore/mlx) for Apple Silicon machines.\n\n## **Why?**\n\nApple MLX framework is a great tool to run machine learning models on Apple Silicon machines.\n\nThis repository is meant to convert image models from timm/torchvision to Apple MLX framework. The weights are just converted from .pth to .npz/.safetensors and the models **are not trained again**.\n\n## How to install\n\n```bash\npip install mlx-image\n```\n\n## Models\n\nModel weights are available on the [`mlx-vision`](https://huggingface.co/mlx-vision) community on HuggingFace.\n\nTo load a model with pre-trained weights:\n```python\nfrom mlxim.model import create_model\n\n# loading weights from HuggingFace (https://huggingface.co/mlx-vision/resnet18-mlxim)\nmodel = create_model(\"resnet18\") # pretrained weights loaded from HF\n\n# loading weights from local file\nmodel = create_model(\"resnet18\", weights=\"path/to/resnet18/model.safetensors\")\n```\n\nTo list all available models:\n\n```python\nfrom mlxim.model import list_models\nlist_models()\n```\n\n### Supported models\n\nList of all models available in `mlx-image`:\n\n* **ResNet**: resnet18, resnet34, resnet50, resnet101, resnet152, wide_resnet50_2, wide_resnet101_2\n* **ViT**:\n * **supervised**: vit_base_patch16_224, vit_base_patch16_224.swag_lin, vit_base_patch16_384.swag_e2e, vit_base_patch32_224, vit_large_patch16_224, vit_large_patch16_224, vit_large_patch16_224.swag_lin, vit_large_patch16_512.swag_e2e, vit_huge_patch14_224.swag_lin, vit_huge_patch14_518.swag_e2e\n \n * **DINO v1**: vit_base_patch16_224.dino, vit_small_patch16_224.dino, vit_small_patch8_224.dino, vit_base_patch8_224.dino\n\n * **DINO v2**: vit_small_patch14_518.dinov2, vit_base_patch14_518.dinov2, vit_large_patch14_518.dinov2\n* **Swin**: swin_tiny_patch4_window7_224, swin_small_patch4_window7_224, swin_base_patch4_window7_224, swin_v2_tiny_patch4_window8_256, swin_v2_small_patch4_window8_256, swin_v2_base_patch4_window8_256\n* **RegNet**: regnet_x_400mf, regnet_x_800mf, regnet_x_1_6gf, regnet_x_3_2gf, regnet_x_8gf, regnet_x_16gf, regnet_x_32gf, regnet_y_400mf, regnet_y_800mf, regnet_y_1_6gf, regnet_y_3_2gf, regnet_y_8gf, regnet_y_16gf, regnet_y_32gf, regnet_y_128gf\n\n> **Warning**: The `regnet_y_128gf` model couldn't be tested due to computational limitations.\n\n## ImageNet-1K Results\n\nGo to [results-imagenet-1k.csv](https://github.com/riccardomusmeci/mlx-image/blob/main/results/results-imagenet-1k.csv) to check every model converted to `mlx-image` and its performance on ImageNet-1K with different settings.\n\n> **TL;DR** performance is comparable to the original models from PyTorch implementations.\n\n\n## Similarity to PyTorch and other familiar tools\n\n`mlx-image` tries to be as close as possible to PyTorch:\n- `DataLoader` -> you can define your own `collate_fn` and also use `num_workers` to speed up data loading\n- `Dataset` -> `mlx-image` already supports `LabelFolderDataset` (the good and old PyTorch `ImageFolder`) and `FolderDataset` (a generic folder with images in it)\n- `ModelCheckpoint` -> keeps track of the best model and saves it to disk (similar to PyTorchLightning). It also suggests early stopping\n\n## Training\n\nTraining is similar to PyTorch. Here's an example of how to train a model:\n\n```python\nimport mlx.nn as nn\nimport mlx.optimizers as optim\nfrom mlxim.model import create_model\nfrom mlxim.data import LabelFolderDataset, DataLoader\n\ntrain_dataset = LabelFolderDataset(\n root_dir=\"path/to/train\",\n class_map={0: \"class_0\", 1: \"class_1\", 2: [\"class_2\", \"class_3\"]}\n)\ntrain_loader = DataLoader(\n dataset=train_dataset,\n batch_size=32,\n shuffle=True,\n num_workers=4\n)\nmodel = create_model(\"resnet18\") # pretrained weights loaded from HF\noptimizer = optim.Adam(learning_rate=1e-3)\n\ndef train_step(model, inputs, targets):\n logits = model(inputs)\n loss = mx.mean(nn.losses.cross_entropy(logits, target))\n return loss\n\nmodel.train()\nfor epoch in range(10):\n for batch in train_loader:\n x, target = batch\n train_step_fn = nn.value_and_grad(model, train_step)\n loss, grads = train_step_fn(x, target)\n optimizer.update(model, grads)\n mx.eval(model.state, optimizer.state)\n```\n\n## **Validation**\n\nThe `validation.py` script is run every time a pth model is converted to mlx and it's used to check if the model performs similarly to the original one on ImageNet-1K.\n\nI use the configuration file `config/validation.yaml` to set the parameters for the validation script.\n\nYou can download the ImageNet-1K validation set from mlx-vision space on HuggingFace at this [link](https://huggingface.co/datasets/mlx-vision/imagenet-1k).\n\n## **Contributing**\n\nThis is a work in progress, so any help is appreciated.\n\nI am working on it in my spare time, so I can't guarantee frequent updates.\n\nIf you love coding and want to contribute, follow the instructions in [CONTRIBUTING.md](CONTRIBUTING.md).\n\n## Additional Resources\n\n* [mlx-vision community](https://huggingface.co/mlx-vision)\n* [HuggingFace doc](https://huggingface.co/docs/hub/main/en/mlx-image)\n\n## **To-Dos**\n\n[ ] inference script (similar to train/validation)\n\n[ ] DenseNet\n\n[ ] MobileNet\n\n## Contact\n\nIf you have any questions, please email `riccardomusmeci92@gmail.com`.\n",
"bugtrack_url": null,
"license": null,
"summary": "Apple MLX image models library",
"version": "0.1.8",
"project_urls": null,
"split_keywords": [],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "fc5599f29393f25f4d74a794a78ece2131c5c5d5a2ccdc3917232f16b3bac346",
"md5": "07ac662fabf80920496e287cd78296c6",
"sha256": "64d4346e913b7c17a419c09d820f22ac9bd543d80409ba61e6e59c7ff00c9b56"
},
"downloads": -1,
"filename": "mlx_image-0.1.8-py3-none-any.whl",
"has_sig": false,
"md5_digest": "07ac662fabf80920496e287cd78296c6",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": "<4.0,>=3.10",
"size": 70619,
"upload_time": "2024-11-24T18:08:26",
"upload_time_iso_8601": "2024-11-24T18:08:26.322801Z",
"url": "https://files.pythonhosted.org/packages/fc/55/99f29393f25f4d74a794a78ece2131c5c5d5a2ccdc3917232f16b3bac346/mlx_image-0.1.8-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": "",
"digests": {
"blake2b_256": "696698dd8ffb515ad521cc9cc957dec2b49e2bf611ed9949723513548f3a8a90",
"md5": "678654f23ec55ccfc2bcb4a4a3423013",
"sha256": "fd64a9017801a6ae620c05f4f867987bb6b59a80ff88d0844af4d6b2cad8df22"
},
"downloads": -1,
"filename": "mlx_image-0.1.8.tar.gz",
"has_sig": false,
"md5_digest": "678654f23ec55ccfc2bcb4a4a3423013",
"packagetype": "sdist",
"python_version": "source",
"requires_python": "<4.0,>=3.10",
"size": 61092,
"upload_time": "2024-11-24T18:08:28",
"upload_time_iso_8601": "2024-11-24T18:08:28.141537Z",
"url": "https://files.pythonhosted.org/packages/69/66/98dd8ffb515ad521cc9cc957dec2b49e2bf611ed9949723513548f3a8a90/mlx_image-0.1.8.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2024-11-24 18:08:28",
"github": false,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"lcname": "mlx-image"
}