# Flax Image Models
• <strong>[Introduction](#introduction)</strong><br>
• <strong>[Installation](#installation)</strong><br>
• <strong>[Usage](#usage)</strong><br>
• <strong>[Available Architectures](#available-architectures)</strong><br>
• <strong>[Contributing](#contributing)</strong><br>
• <strong>[Acknowledgements](#acknowledgements)</strong><br>
## Introduction
flaim is a library of state-of-the-art pre-trained vision models, plus common deep learning modules in computer vision, for Flax.
It exposes a host of diverse image models through a straightforward interface with an emphasis on simplicity, leanness, and readability,
and offers lower-level modules for designing custom architectures.
## Installation
flaim can be installed through ```pip install flaim```. Beware that pip installs the CPU version of JAX, and you must [manually install JAX](https://github.com/google/jax#installation) yourself to run your programs on the GPU or TPU.
## Usage
```flaim.get_model``` is the central function of flaim and manages model retrieval. It accepts a handful
of arguments:
* ```model_name``` (```str```): The name of the model. If it is not recognized, an exception is thrown.
* ```pretrained``` (```bool```): Determines if pre-trained parameters are to be returned in lieu of randomly-initialized ones.
* ```n_classes``` (```int```): The number of output classes. This argument's value can fall under three groups:
* 0: The model outputs the raw final feature maps. For instance, a ResNet is composed of a stem and four stages, followed
by a head constituted of global average pooling and a fully-connected layer for generating predictions. When ```n_classes = 0```, the output of
the fourth stage is returned, and the head is discarded.
* -1: Every part of the head, except for the linear layer, is applied and the output returned. In the ResNet example, the output of
the pooling layer is returned.
* Positive integers: ```n_classes``` is interpreted as the desired number of output categories.
* ```jit``` (```bool```): Whether to JIT the model's initialization function. The benefit of JITting the initialization function
is that no actual forward pass with real data is performed, unlike the default configuration. On the other hand, JIT compilation
is generally a lengthy process.
* ```prng``` (```T.Optional[jax.random.KeyArray]```): PRNG key used for initializing the model. When ```None```,
a PRNG key, with a seed of 0, is created. If ```pretrained``` is ```True``` and ```n_classes``` is 0 or -1, this argument has no effects
on the returned parameters.
* ```norm_stats``` (```bool```): Whether to also return normalization statistics used to normalize the input data when the model was trained. The statistics are returned as a dictionary, with key 'mean' containing the means and key 'std' the standard deviations for each channel.
The snippet below constructs a ResNet-50 with 10 output classes.
```python
import flaim
model, vars, norm_stats = flaim.get_model(
model_name='resnet50',
pretrained=True,
n_classes=10,
jit=True,
prng=None,
norm_stats=True,
)
```
Performing a forward pass with flaim is similar to any other Flax module. However, networks
that behave differently during training versus inference, e.g., due to batch normalization,
receive a ```training``` argument indicating whether the model should be in training mode or not.
```python
from jax import numpy as jnp
# input should be normalized using norm_stats beforehand
input = jnp.ones((2, 224, 224, 3))
# Training
output, batch_stats = model.apply(
vars,
input,
mutable=['batch_stats'],
training=True,
)
# Inference
output = model.apply(
vars,
input,
training=False,
)
```
Finally, intermediate activations can be captured by passing the string ```intermediates``` to ```mutable```.
```python
output, batch_stats, intermediates = model.apply(
vars,
input,
mutable=['batch_stats', 'intermediates'],
training=True,
)
```
If the model architecture is hierarchical, ```intermediates```'s items are the output of each stage and can be looked up through
```intermediates['intermediates']['stage_ind']```, where ```ind``` is the index of the stage, with 0 being reserved for the stem. For isotropic models, the output of every block is returned, accessible via ```intermediates['intermediates']['block_ind']```.
Note that Flax's ```sow``` API, which is used to store the intermediate activations, appends the data to a tuple; that is, if _n_ forward passes are performed, ```intermediates['intermediates']['stage_ind']``` or ```intermediates['intermediates']['block_ind']``` would be tuples of length _n_, with the *i*<sup>th</sup> item corresponding to the *i*<sup>th</sup> forward pass.
## Available Architectures
All available architectures, accompanied by short descriptions and references, are [here](https://github.com/bobmcdear/flaim/blob/main/ARCHITECTURES.md). ```flaim.list_models``` also returns a list of flaim models. Its only arugment, ```pattern```, is an optional regex pattern that, if not ```None```, ensures only model names containing this expression are returned, as demonstrated below.
```python
# Every model
print(flaim.list_models())
# ResNeXt-based networks
print(flaim.list_models(r'resnext'))
# ViTs of resolution 224 x 224
print(flaim.list_models(r'vit.+224'))
```
## Contributing
Code contributions are currently not accepted, however, there are three alternatives for those interested in contributing to flaim:
• Bugs: If you discover any bugs, please open an issue and include your system information, a description of the problem, and a reproducible example.<br>
• Feature request: flaim is under active development and many more models will be released in the near future. If there are particular architectures or modules you'd like to see added, please request them by opening an issue.<br>
• Questions: If you have questions regarding a model, the code, or anything else, please ask them by opening a discussion thread. Most likely somebody else has the same question, and asking it would help them too.<br>
## Acknowledgements
Many thanks to Ross Wightman for the amazing timm package, which was an inspiration for flaim and has been an indispensable guide during development. Additionally, the pre-trained parameters are stored on Hugging Face Hub; big thanks to Hugging Face for this gratis service.
References for ```flaim.models``` can be found [here](https://github.com/bobmcdear/flaim/blob/main/ARCHITECTURES.md), and ones for ```flaim.layers``` are in the source code.
Raw data
{
"_id": null,
"home_page": "https://github.com/bobmcdear/flaim",
"name": "flaim",
"maintainer": "",
"docs_url": null,
"requires_python": ">=3.8",
"maintainer_email": "",
"keywords": "computer vision,machine learning,deep learning,jax,flax",
"author": "Borna Ahmadzadeh",
"author_email": "borna.ahz@gmail.com",
"download_url": "https://files.pythonhosted.org/packages/09/5c/cdb06db73bfe644d5ac13dfb7e4dad762a5c616b31c4760380a2aedb608c/flaim-0.0.5.tar.gz",
"platform": null,
"description": "# Flax Image Models\n\n\u2022 <strong>[Introduction](#introduction)</strong><br>\n\u2022 <strong>[Installation](#installation)</strong><br>\n\u2022 <strong>[Usage](#usage)</strong><br>\n\u2022 <strong>[Available Architectures](#available-architectures)</strong><br>\n\u2022 <strong>[Contributing](#contributing)</strong><br>\n\u2022 <strong>[Acknowledgements](#acknowledgements)</strong><br>\n\n\n\n## Introduction\n\nflaim is a library of state-of-the-art pre-trained vision models, plus common deep learning modules in computer vision, for Flax.\nIt exposes a host of diverse image models through a straightforward interface with an emphasis on simplicity, leanness, and readability,\nand offers lower-level modules for designing custom architectures.\n\n## Installation\n\nflaim can be installed through ```pip install flaim```. Beware that pip installs the CPU version of JAX, and you must [manually install JAX](https://github.com/google/jax#installation) yourself to run your programs on the GPU or TPU.\n\n## Usage\n\n```flaim.get_model``` is the central function of flaim and manages model retrieval. It accepts a handful\nof arguments:\n* ```model_name``` (```str```): The name of the model. If it is not recognized, an exception is thrown.\n* ```pretrained``` (```bool```): Determines if pre-trained parameters are to be returned in lieu of randomly-initialized ones.\n* ```n_classes``` (```int```): The number of output classes. This argument's value can fall under three groups:\n * 0: The model outputs the raw final feature maps. For instance, a ResNet is composed of a stem and four stages, followed\n by a head constituted of global average pooling and a fully-connected layer for generating predictions. When ```n_classes = 0```, the output of\n the fourth stage is returned, and the head is discarded. \n * -1: Every part of the head, except for the linear layer, is applied and the output returned. In the ResNet example, the output of \n the pooling layer is returned.\n * Positive integers: ```n_classes``` is interpreted as the desired number of output categories.\n* ```jit``` (```bool```): Whether to JIT the model's initialization function. The benefit of JITting the initialization function \nis that no actual forward pass with real data is performed, unlike the default configuration. On the other hand, JIT compilation \nis generally a lengthy process.\n* ```prng``` (```T.Optional[jax.random.KeyArray]```): PRNG key used for initializing the model. When ```None```,\na PRNG key, with a seed of 0, is created. If ```pretrained``` is ```True``` and ```n_classes``` is 0 or -1, this argument has no effects\non the returned parameters.\n* ```norm_stats``` (```bool```): Whether to also return normalization statistics used to normalize the input data when the model was trained. The statistics are returned as a dictionary, with key 'mean' containing the means and key 'std' the standard deviations for each channel.\n\nThe snippet below constructs a ResNet-50 with 10 output classes.\n\n```python\nimport flaim\n\n\nmodel, vars, norm_stats = flaim.get_model(\n model_name='resnet50',\n pretrained=True,\n n_classes=10,\n jit=True,\n prng=None,\n norm_stats=True,\n )\n```\n\nPerforming a forward pass with flaim is similar to any other Flax module. However, networks\nthat behave differently during training versus inference, e.g., due to batch normalization, \nreceive a ```training``` argument indicating whether the model should be in training mode or not. \n\n```python\nfrom jax import numpy as jnp\n\n# input should be normalized using norm_stats beforehand\ninput = jnp.ones((2, 224, 224, 3))\n\n# Training\noutput, batch_stats = model.apply(\n vars,\n input,\n mutable=['batch_stats'],\n training=True,\n )\n\n# Inference\noutput = model.apply(\n vars,\n input,\n training=False,\n )\n```\n\nFinally, intermediate activations can be captured by passing the string ```intermediates``` to ```mutable```. \n\n```python\noutput, batch_stats, intermediates = model.apply(\n vars,\n input,\n mutable=['batch_stats', 'intermediates'],\n training=True,\n )\n```\n\nIf the model architecture is hierarchical, ```intermediates```'s items are the output of each stage and can be looked up through \n```intermediates['intermediates']['stage_ind']```, where ```ind``` is the index of the stage, with 0 being reserved for the stem. For isotropic models, the output of every block is returned, accessible via ```intermediates['intermediates']['block_ind']```.\n\nNote that Flax's ```sow``` API, which is used to store the intermediate activations, appends the data to a tuple; that is, if _n_ forward passes are performed, ```intermediates['intermediates']['stage_ind']``` or ```intermediates['intermediates']['block_ind']``` would be tuples of length _n_, with the *i*<sup>th</sup> item corresponding to the *i*<sup>th</sup> forward pass.\n\n## Available Architectures\n\nAll available architectures, accompanied by short descriptions and references, are [here](https://github.com/bobmcdear/flaim/blob/main/ARCHITECTURES.md). ```flaim.list_models``` also returns a list of flaim models. Its only arugment, ```pattern```, is an optional regex pattern that, if not ```None```, ensures only model names containing this expression are returned, as demonstrated below.\n\n```python\n# Every model\nprint(flaim.list_models())\n\n# ResNeXt-based networks\nprint(flaim.list_models(r'resnext'))\n\n# ViTs of resolution 224 x 224\nprint(flaim.list_models(r'vit.+224'))\n```\n\n## Contributing\n\nCode contributions are currently not accepted, however, there are three alternatives for those interested in contributing to flaim: \n\n\u2022 Bugs: If you discover any bugs, please open an issue and include your system information, a description of the problem, and a reproducible example.<br>\n\u2022 Feature request: flaim is under active development and many more models will be released in the near future. If there are particular architectures or modules you'd like to see added, please request them by opening an issue.<br>\n\u2022 Questions: If you have questions regarding a model, the code, or anything else, please ask them by opening a discussion thread. Most likely somebody else has the same question, and asking it would help them too.<br>\n\n\n## Acknowledgements\n\nMany thanks to Ross Wightman for the amazing timm package, which was an inspiration for flaim and has been an indispensable guide during development. Additionally, the pre-trained parameters are stored on Hugging Face Hub; big thanks to Hugging Face for this gratis service.\n\nReferences for ```flaim.models``` can be found [here](https://github.com/bobmcdear/flaim/blob/main/ARCHITECTURES.md), and ones for ```flaim.layers``` are in the source code.\n",
"bugtrack_url": null,
"license": "GNU",
"summary": "Flax Image Models",
"version": "0.0.5",
"split_keywords": [
"computer vision",
"machine learning",
"deep learning",
"jax",
"flax"
],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "88cb76bce7239dfb124a6b993be3c5fad9a673bca3fbfa0b15138b114227ced4",
"md5": "9836389703d379a01869050a0d7f5dbe",
"sha256": "85b601a5c71041ec0cc9f158898ba849756c9267325ffbc54497fbe8a2d325f1"
},
"downloads": -1,
"filename": "flaim-0.0.5-py3-none-any.whl",
"has_sig": false,
"md5_digest": "9836389703d379a01869050a0d7f5dbe",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": ">=3.8",
"size": 79409,
"upload_time": "2023-01-27T17:42:15",
"upload_time_iso_8601": "2023-01-27T17:42:15.797268Z",
"url": "https://files.pythonhosted.org/packages/88/cb/76bce7239dfb124a6b993be3c5fad9a673bca3fbfa0b15138b114227ced4/flaim-0.0.5-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": "",
"digests": {
"blake2b_256": "095ccdb06db73bfe644d5ac13dfb7e4dad762a5c616b31c4760380a2aedb608c",
"md5": "d7b2ee86f3db6c74a6ebfe5946db2791",
"sha256": "d224fe0b47f7d2e51e5d73cca5641a92f9a97951b1131b1613e12cd522eed504"
},
"downloads": -1,
"filename": "flaim-0.0.5.tar.gz",
"has_sig": false,
"md5_digest": "d7b2ee86f3db6c74a6ebfe5946db2791",
"packagetype": "sdist",
"python_version": "source",
"requires_python": ">=3.8",
"size": 55160,
"upload_time": "2023-01-27T17:42:17",
"upload_time_iso_8601": "2023-01-27T17:42:17.768918Z",
"url": "https://files.pythonhosted.org/packages/09/5c/cdb06db73bfe644d5ac13dfb7e4dad762a5c616b31c4760380a2aedb608c/flaim-0.0.5.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2023-01-27 17:42:17",
"github": true,
"gitlab": false,
"bitbucket": false,
"github_user": "bobmcdear",
"github_project": "flaim",
"travis_ci": false,
"coveralls": false,
"github_actions": false,
"lcname": "flaim"
}