vit-keras


Namevit-keras JSON
Version 0.1.1 PyPI version JSON
download
home_pagehttps://github.com/faustomorales/vit-keras
SummaryKeras implementation of ViT (Vision Transformer)
upload_time2023-04-24 14:35:33
maintainer
docs_urlNone
authorFausto Morales
requires_python>=3.7,<3.12
licenseApache-2.0
keywords
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # vit-keras
This is a Keras implementation of the models described in [An Image is Worth 16x16 Words:
Transformes For Image Recognition at Scale](https://arxiv.org/pdf/2010.11929.pdf). It is based on an earlier implementation from [tuvovan](https://github.com/tuvovan/Vision_Transformer_Keras), modified to match the Flax implementation in the [official repository](https://github.com/google-research/vision_transformer).

The weights here are ported over from the weights provided in the official repository. See `utils.load_weights_numpy` to see how this is done (it's not pretty, but it does the job).

## Usage
Install this package using `pip install vit-keras`

You can use the model out-of-the-box with ImageNet 2012 classes using
something like the following. The weights will be downloaded automatically.

```python
from vit_keras import vit, utils

image_size = 384
classes = utils.get_imagenet_classes()
model = vit.vit_b16(
    image_size=image_size,
    activation='sigmoid',
    pretrained=True,
    include_top=True,
    pretrained_top=True
)
url = 'https://upload.wikimedia.org/wikipedia/commons/d/d7/Granny_smith_and_cross_section.jpg'
image = utils.read(url, image_size)
X = vit.preprocess_inputs(image).reshape(1, image_size, image_size, 3)
y = model.predict(X)
print(classes[y[0].argmax()]) # Granny smith
```

You can fine-tune using a model loaded as follows.

```python
image_size = 224
model = vit.vit_l32(
    image_size=image_size,
    activation='sigmoid',
    pretrained=True,
    include_top=True,
    pretrained_top=False,
    classes=200
)
# Train this model on your data as desired.
```

## Visualizing Attention Maps
There's some functionality for plotting attention maps for a given image and model. See example below. I'm not sure I'm doing this correctly (the official repository didn't have example code). Feedback /corrections welcome!

```python
import numpy as np
import matplotlib.pyplot as plt
from vit_keras import vit, utils, visualize

# Load a model
image_size = 384
classes = utils.get_imagenet_classes()
model = vit.vit_b16(
    image_size=image_size,
    activation='sigmoid',
    pretrained=True,
    include_top=True,
    pretrained_top=True
)
classes = utils.get_imagenet_classes()

# Get an image and compute the attention map
url = 'https://upload.wikimedia.org/wikipedia/commons/b/bc/Free%21_%283987584939%29.jpg'
image = utils.read(url, image_size)
attention_map = visualize.attention_map(model=model, image=image)
print('Prediction:', classes[
    model.predict(vit.preprocess_inputs(image)[np.newaxis])[0].argmax()]
)  # Prediction: Eskimo dog, husky

# Plot results
fig, (ax1, ax2) = plt.subplots(ncols=2)
ax1.axis('off')
ax2.axis('off')
ax1.set_title('Original')
ax2.set_title('Attention Map')
_ = ax1.imshow(image)
_ = ax2.imshow(attention_map)
```

![example of attention map](https://raw.githubusercontent.com/faustomorales/vit-keras/master/docs/attention_map_example.jpg)


            

Raw data

            {
    "_id": null,
    "home_page": "https://github.com/faustomorales/vit-keras",
    "name": "vit-keras",
    "maintainer": "",
    "docs_url": null,
    "requires_python": ">=3.7,<3.12",
    "maintainer_email": "",
    "keywords": "",
    "author": "Fausto Morales",
    "author_email": "faustomorales@gmail.com",
    "download_url": "",
    "platform": null,
    "description": "# vit-keras\nThis is a Keras implementation of the models described in [An Image is Worth 16x16 Words:\nTransformes For Image Recognition at Scale](https://arxiv.org/pdf/2010.11929.pdf). It is based on an earlier implementation from [tuvovan](https://github.com/tuvovan/Vision_Transformer_Keras), modified to match the Flax implementation in the [official repository](https://github.com/google-research/vision_transformer).\n\nThe weights here are ported over from the weights provided in the official repository. See `utils.load_weights_numpy` to see how this is done (it's not pretty, but it does the job).\n\n## Usage\nInstall this package using `pip install vit-keras`\n\nYou can use the model out-of-the-box with ImageNet 2012 classes using\nsomething like the following. The weights will be downloaded automatically.\n\n```python\nfrom vit_keras import vit, utils\n\nimage_size = 384\nclasses = utils.get_imagenet_classes()\nmodel = vit.vit_b16(\n    image_size=image_size,\n    activation='sigmoid',\n    pretrained=True,\n    include_top=True,\n    pretrained_top=True\n)\nurl = 'https://upload.wikimedia.org/wikipedia/commons/d/d7/Granny_smith_and_cross_section.jpg'\nimage = utils.read(url, image_size)\nX = vit.preprocess_inputs(image).reshape(1, image_size, image_size, 3)\ny = model.predict(X)\nprint(classes[y[0].argmax()]) # Granny smith\n```\n\nYou can fine-tune using a model loaded as follows.\n\n```python\nimage_size = 224\nmodel = vit.vit_l32(\n    image_size=image_size,\n    activation='sigmoid',\n    pretrained=True,\n    include_top=True,\n    pretrained_top=False,\n    classes=200\n)\n# Train this model on your data as desired.\n```\n\n## Visualizing Attention Maps\nThere's some functionality for plotting attention maps for a given image and model. See example below. I'm not sure I'm doing this correctly (the official repository didn't have example code). Feedback /corrections welcome!\n\n```python\nimport numpy as np\nimport matplotlib.pyplot as plt\nfrom vit_keras import vit, utils, visualize\n\n# Load a model\nimage_size = 384\nclasses = utils.get_imagenet_classes()\nmodel = vit.vit_b16(\n    image_size=image_size,\n    activation='sigmoid',\n    pretrained=True,\n    include_top=True,\n    pretrained_top=True\n)\nclasses = utils.get_imagenet_classes()\n\n# Get an image and compute the attention map\nurl = 'https://upload.wikimedia.org/wikipedia/commons/b/bc/Free%21_%283987584939%29.jpg'\nimage = utils.read(url, image_size)\nattention_map = visualize.attention_map(model=model, image=image)\nprint('Prediction:', classes[\n    model.predict(vit.preprocess_inputs(image)[np.newaxis])[0].argmax()]\n)  # Prediction: Eskimo dog, husky\n\n# Plot results\nfig, (ax1, ax2) = plt.subplots(ncols=2)\nax1.axis('off')\nax2.axis('off')\nax1.set_title('Original')\nax2.set_title('Attention Map')\n_ = ax1.imshow(image)\n_ = ax2.imshow(attention_map)\n```\n\n![example of attention map](https://raw.githubusercontent.com/faustomorales/vit-keras/master/docs/attention_map_example.jpg)\n\n",
    "bugtrack_url": null,
    "license": "Apache-2.0",
    "summary": "Keras implementation of ViT (Vision Transformer)",
    "version": "0.1.1",
    "split_keywords": [],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "a9f065606709ad8c220cc315ad9f4fee9079e4ab59f6fe8a1cccdae488b9ba2f",
                "md5": "3d1eab68413422798da7389a4e1dd039",
                "sha256": "e51f8b1e28c8797cb2a2b0d69b1de4f7bd4c582528affa123d935e20eff87c2e"
            },
            "downloads": -1,
            "filename": "vit_keras-0.1.1-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "3d1eab68413422798da7389a4e1dd039",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.7,<3.12",
            "size": 24503,
            "upload_time": "2023-04-24T14:35:33",
            "upload_time_iso_8601": "2023-04-24T14:35:33.919467Z",
            "url": "https://files.pythonhosted.org/packages/a9/f0/65606709ad8c220cc315ad9f4fee9079e4ab59f6fe8a1cccdae488b9ba2f/vit_keras-0.1.1-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2023-04-24 14:35:33",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "github_user": "faustomorales",
    "github_project": "vit-keras",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "lcname": "vit-keras"
}
        
Elapsed time: 0.92271s