gcvit


Namegcvit JSON
Version 1.1.5 PyPI version JSON
download
home_pagehttps://github.com/awsaf49/gcvit-tf
SummaryTensorflow 2.0 Implementation of GCViT: Global Context Vision Transformer. https://github.com/awsaf49/gcvit-tf
upload_time2023-10-16 08:25:02
maintainer
docs_urlNone
authorAwsaf
requires_python>=3.6
licenseMIT
keywords tensorflow computer_vision image classification transformer
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            <h1 align="center">
<p><a href='https://arxiv.org/pdf/2206.09959v1.pdf'>GCViT: Global Context Vision Transformer</a></p>
</h1>
<div align=center><img src="https://raw.githubusercontent.com/awsaf49/gcvit-tf/main/image/lvg_arch.PNG" width=800></div>
<p align="center">
<a href="https://github.com/awsaf49/gcvit-tf/blob/main/LICENSE.md">
  <img src="https://img.shields.io/badge/License-MIT-yellow.svg">
</a>
<img alt="python" src="https://img.shields.io/badge/python-%3E%3D3.6-blue?logo=python">
<img alt="tensorflow" src="https://img.shields.io/badge/tensorflow-%3E%3D2.4.1-orange?logo=tensorflow">
<div align=center><p>
<a target="_blank" href="https://huggingface.co/spaces/awsaf49/gcvit-tf"><img src="https://img.shields.io/badge/🤗%20Hugging%20Face-Spaces-yellow.svg"></a>
<a href="https://colab.research.google.com/github/awsaf49/gcvit-tf/blob/main/notebooks/GCViT_Flower_Classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
<a href="https://www.kaggle.com/awsaf49/flower-classification-gcvit-global-context-vit"><img src="https://kaggle.com/static/images/open-in-kaggle.svg" alt="Open In Kaggle"></a>
</p></div>
<h2 align="center">
<p>Tensorflow 2.0 Implementation of GCViT</p>
</h2>
</p>
<p align="center">
This library implements <b>GCViT</b> using Tensorflow 2.0 specifically in <code>tf.keras.Model</code> manner to get PyTorch flavor.
</p>

## Update
* **15 Jan 2023** : `GCViTLarge` model added with ckpt.
* **3 Sept 2022** : Annotated [kaggle-notebook](https://www.kaggle.com/code/awsaf49/gcvit-global-context-vision-transformer) based on this project won [Kaggle ML Research Spotlight: August 2022](https://www.kaggle.com/discussions/general/349817).
* **19 Aug 2022** : This project got acknowledged by [Official](https://github.com/NVlabs/GCVit) repo [here](https://github.com/NVlabs/GCVit#third-party-implementations-and-resources)

## Paper Implementation & Explanation **
I have explained the GCViT paper in a Kaggle notebook **[GCViT: Global Context Vision Transformer](https://www.kaggle.com/code/awsaf49/gcvit-global-context-vision-transformer)**, which also includes a detailed implementation of the model from scratch. The notebook provides a comprehensive explanation of each part of the model, with intuition.

Do check it out, especially if you are interested in learning more about GCViT or implementing it yourself. Note that this notebook has won the **Kaggle ML Research Award 2022.**

## Model
* Architecture:

<img src="https://raw.githubusercontent.com/awsaf49/gcvit-tf/main/image/arch.PNG">

* Local Vs Global Attention:

<img src="https://raw.githubusercontent.com/awsaf49/gcvit-tf/main/image/lvg_msa.PNG">

## Result
<img src="https://raw.githubusercontent.com/awsaf49/gcvit-tf/main/image/result.PNG" width=900>

Official codebase had some issue which has been fixed recently (12 August 2022). Here's the result of ported weights on **ImageNetV2-Test** data,

| Model        | Acc@1 | Acc@5 | #Params |
|--------------|-------|-------|---------|
| GCViT-XXTiny | 0.663    | 0.873    | 12M     |
| GCViT-XTiny  | 0.685    | 0.885    | 20M     |
| GCViT-Tiny   | 0.708    | 0.899    | 28M     |
| GCViT-Small  | 0.720    | 0.901    | 51M     |
| GCViT-Base   | 0.731    | 0.907    | 90M     |
| GCViT-Large  | 0.734    | 0.913    | 202M    |

## Installation
```bash
pip install -U gcvit
# or
# pip install -U git+https://github.com/awsaf49/gcvit-tf
```

## Usage
Load model using following codes,
```py
from gcvit import GCViTTiny
model = GCViTTiny(pretrain=True)
```

Any input size other than **224x224**,
```py
from gcvit import GCViTTiny
model = GCViTTiny(input_shape=(512,512,3), pretrain=True, resize_query=True)
```
Simple code to check model's prediction,
```py
from skimage.data import chelsea
img = tf.keras.applications.imagenet_utils.preprocess_input(chelsea(), mode='torch') # Chelsea the cat
img = tf.image.resize(img, (224, 224))[None,] # resize & create batch
pred = model(img).numpy()
print(tf.keras.applications.imagenet_utils.decode_predictions(pred)[0])
```
Prediction:
```py
[('n02124075', 'Egyptian_cat', 0.9194835),
('n02123045', 'tabby', 0.009686623), 
('n02123159', 'tiger_cat', 0.0061576385),
('n02127052', 'lynx', 0.0011503297), 
('n02883205', 'bow_tie', 0.00042479983)]
```
For feature extraction:
```py
model = GCViTTiny(pretrain=True)  # when pretrain=True, num_classes must be 1000
model.reset_classifier(num_classes=0, head_act=None)
feature = model(img)
print(feature.shape)
```
Feature:
```py
(None, 512)
```
For feature map:
```py
model = GCViTTiny(pretrain=True)  # when pretrain=True, num_classes must be 1000
feature = model.forward_features(img)
print(feature.shape)
```
Feature map:
```py
(None, 7, 7, 512)
```

## Kaggle Models
These pre-trained models can also be loaded using [Kaggle Models](https://www.kaggle.com/models/awsaf49/gcvit-tf). Setting `from_kaggle=True` will enforce model to load weights from Kaggle Models without downloading, thus can be used without internet in Kaggle.
```py
from gcvit import GCViTTiny
model = GCViTTiny(pretrain=True, from_kaggle=True)
```

## Live-Demo
* For live demo on Image Classification & Grad-CAM, with **ImageNet** weights, click <a target="_blank" href="https://huggingface.co/spaces/awsaf49/gcvit-tf"><img src="https://img.shields.io/badge/Try%20on-Gradio-orange"></a> powered by 🤗 Space and Gradio. here's an example,

<a href="https://huggingface.co/spaces/awsaf49/gcvit-tf"><img src="image/gradio_demo.JPG" height=500></a>

## Example
For working training example checkout these notebooks on **Google Colab** <a href="https://colab.research.google.com/github/awsaf49/gcvit-tf/blob/main/notebooks/GCViT_Flower_Classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a> & **Kaggle** <a href="https://www.kaggle.com/awsaf49/flower-classification-gcvit-global-context-vit"><img src="https://kaggle.com/static/images/open-in-kaggle.svg" alt="Open In Kaggle"></a>.

Here is grad-cam result after training on Flower Classification Dataset,

<img src="https://raw.githubusercontent.com/awsaf49/gcvit-tf/main/image/flower_gradcam.PNG" height=500>



## To Do
- [ ] Segmentation Pipeline
- [x] Support for `Kaggle Models`
- [x] Remove `tensorflow_addons`
- [x] New updated weights have been added.
- [x] Working training example in Colab & Kaggle.
- [x] GradCAM showcase.
- [x] Gradio Demo.
- [x] Build model with `tf.keras.Model`.
- [x] Port weights from official repo.
- [x] Support for `TPU`.

## Acknowledgement
* [GCVit](https://github.com/NVlabs/GCVit) (Official)
* [Swin-Transformer-TF](https://github.com/rishigami/Swin-Transformer-TF)
* [tfgcvit](https://github.com/shkarupa-alex/tfgcvit/tree/develop/tfgcvit)
* [keras_cv_attention_models](https://github.com/leondgarse/keras_cv_attention_model)


## Citation
```bibtex
@article{hatamizadeh2022global,
  title={Global Context Vision Transformers},
  author={Hatamizadeh, Ali and Yin, Hongxu and Kautz, Jan and Molchanov, Pavlo},
  journal={arXiv preprint arXiv:2206.09959},
  year={2022}
}
```

            

Raw data

            {
    "_id": null,
    "home_page": "https://github.com/awsaf49/gcvit-tf",
    "name": "gcvit",
    "maintainer": "",
    "docs_url": null,
    "requires_python": ">=3.6",
    "maintainer_email": "",
    "keywords": "tensorflow computer_vision image classification transformer",
    "author": "Awsaf",
    "author_email": "awsaf49@gmail.com",
    "download_url": "https://files.pythonhosted.org/packages/77/37/908ed14fb861e2958823ed1916bff950c4f09ae1aafebfd57199cdf72dec/gcvit-1.1.5.tar.gz",
    "platform": null,
    "description": "<h1 align=\"center\">\n<p><a href='https://arxiv.org/pdf/2206.09959v1.pdf'>GCViT: Global Context Vision Transformer</a></p>\n</h1>\n<div align=center><img src=\"https://raw.githubusercontent.com/awsaf49/gcvit-tf/main/image/lvg_arch.PNG\" width=800></div>\n<p align=\"center\">\n<a href=\"https://github.com/awsaf49/gcvit-tf/blob/main/LICENSE.md\">\n  <img src=\"https://img.shields.io/badge/License-MIT-yellow.svg\">\n</a>\n<img alt=\"python\" src=\"https://img.shields.io/badge/python-%3E%3D3.6-blue?logo=python\">\n<img alt=\"tensorflow\" src=\"https://img.shields.io/badge/tensorflow-%3E%3D2.4.1-orange?logo=tensorflow\">\n<div align=center><p>\n<a target=\"_blank\" href=\"https://huggingface.co/spaces/awsaf49/gcvit-tf\"><img src=\"https://img.shields.io/badge/\ud83e\udd17%20Hugging%20Face-Spaces-yellow.svg\"></a>\n<a href=\"https://colab.research.google.com/github/awsaf49/gcvit-tf/blob/main/notebooks/GCViT_Flower_Classification.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n<a href=\"https://www.kaggle.com/awsaf49/flower-classification-gcvit-global-context-vit\"><img src=\"https://kaggle.com/static/images/open-in-kaggle.svg\" alt=\"Open In Kaggle\"></a>\n</p></div>\n<h2 align=\"center\">\n<p>Tensorflow 2.0 Implementation of GCViT</p>\n</h2>\n</p>\n<p align=\"center\">\nThis library implements <b>GCViT</b> using Tensorflow 2.0 specifically in <code>tf.keras.Model</code> manner to get PyTorch flavor.\n</p>\n\n## Update\n* **15 Jan 2023** : `GCViTLarge` model added with ckpt.\n* **3 Sept 2022** : Annotated [kaggle-notebook](https://www.kaggle.com/code/awsaf49/gcvit-global-context-vision-transformer) based on this project won [Kaggle ML Research Spotlight: August 2022](https://www.kaggle.com/discussions/general/349817).\n* **19 Aug 2022** : This project got acknowledged by [Official](https://github.com/NVlabs/GCVit) repo [here](https://github.com/NVlabs/GCVit#third-party-implementations-and-resources)\n\n## Paper Implementation & Explanation **\nI have explained the GCViT paper in a Kaggle notebook **[GCViT: Global Context Vision Transformer](https://www.kaggle.com/code/awsaf49/gcvit-global-context-vision-transformer)**, which also includes a detailed implementation of the model from scratch. The notebook provides a comprehensive explanation of each part of the model, with intuition.\n\nDo check it out, especially if you are interested in learning more about GCViT or implementing it yourself. Note that this notebook has won the **Kaggle ML Research Award 2022.**\n\n## Model\n* Architecture:\n\n<img src=\"https://raw.githubusercontent.com/awsaf49/gcvit-tf/main/image/arch.PNG\">\n\n* Local Vs Global Attention:\n\n<img src=\"https://raw.githubusercontent.com/awsaf49/gcvit-tf/main/image/lvg_msa.PNG\">\n\n## Result\n<img src=\"https://raw.githubusercontent.com/awsaf49/gcvit-tf/main/image/result.PNG\" width=900>\n\nOfficial codebase had some issue which has been fixed recently (12 August 2022). Here's the result of ported weights on **ImageNetV2-Test** data,\n\n| Model        | Acc@1 | Acc@5 | #Params |\n|--------------|-------|-------|---------|\n| GCViT-XXTiny | 0.663    | 0.873    | 12M     |\n| GCViT-XTiny  | 0.685    | 0.885    | 20M     |\n| GCViT-Tiny   | 0.708    | 0.899    | 28M     |\n| GCViT-Small  | 0.720    | 0.901    | 51M     |\n| GCViT-Base   | 0.731    | 0.907    | 90M     |\n| GCViT-Large  | 0.734    | 0.913    | 202M    |\n\n## Installation\n```bash\npip install -U gcvit\n# or\n# pip install -U git+https://github.com/awsaf49/gcvit-tf\n```\n\n## Usage\nLoad model using following codes,\n```py\nfrom gcvit import GCViTTiny\nmodel = GCViTTiny(pretrain=True)\n```\n\nAny input size other than **224x224**,\n```py\nfrom gcvit import GCViTTiny\nmodel = GCViTTiny(input_shape=(512,512,3), pretrain=True, resize_query=True)\n```\nSimple code to check model's prediction,\n```py\nfrom skimage.data import chelsea\nimg = tf.keras.applications.imagenet_utils.preprocess_input(chelsea(), mode='torch') # Chelsea the cat\nimg = tf.image.resize(img, (224, 224))[None,] # resize & create batch\npred = model(img).numpy()\nprint(tf.keras.applications.imagenet_utils.decode_predictions(pred)[0])\n```\nPrediction:\n```py\n[('n02124075', 'Egyptian_cat', 0.9194835),\n('n02123045', 'tabby', 0.009686623), \n('n02123159', 'tiger_cat', 0.0061576385),\n('n02127052', 'lynx', 0.0011503297), \n('n02883205', 'bow_tie', 0.00042479983)]\n```\nFor feature extraction:\n```py\nmodel = GCViTTiny(pretrain=True)  # when pretrain=True, num_classes must be 1000\nmodel.reset_classifier(num_classes=0, head_act=None)\nfeature = model(img)\nprint(feature.shape)\n```\nFeature:\n```py\n(None, 512)\n```\nFor feature map:\n```py\nmodel = GCViTTiny(pretrain=True)  # when pretrain=True, num_classes must be 1000\nfeature = model.forward_features(img)\nprint(feature.shape)\n```\nFeature map:\n```py\n(None, 7, 7, 512)\n```\n\n## Kaggle Models\nThese pre-trained models can also be loaded using [Kaggle Models](https://www.kaggle.com/models/awsaf49/gcvit-tf). Setting `from_kaggle=True` will enforce model to load weights from Kaggle Models without downloading, thus can be used without internet in Kaggle.\n```py\nfrom gcvit import GCViTTiny\nmodel = GCViTTiny(pretrain=True, from_kaggle=True)\n```\n\n## Live-Demo\n* For live demo on Image Classification & Grad-CAM, with **ImageNet** weights, click <a target=\"_blank\" href=\"https://huggingface.co/spaces/awsaf49/gcvit-tf\"><img src=\"https://img.shields.io/badge/Try%20on-Gradio-orange\"></a> powered by \ud83e\udd17 Space and Gradio. here's an example,\n\n<a href=\"https://huggingface.co/spaces/awsaf49/gcvit-tf\"><img src=\"image/gradio_demo.JPG\" height=500></a>\n\n## Example\nFor working training example checkout these notebooks on **Google Colab** <a href=\"https://colab.research.google.com/github/awsaf49/gcvit-tf/blob/main/notebooks/GCViT_Flower_Classification.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a> & **Kaggle** <a href=\"https://www.kaggle.com/awsaf49/flower-classification-gcvit-global-context-vit\"><img src=\"https://kaggle.com/static/images/open-in-kaggle.svg\" alt=\"Open In Kaggle\"></a>.\n\nHere is grad-cam result after training on Flower Classification Dataset,\n\n<img src=\"https://raw.githubusercontent.com/awsaf49/gcvit-tf/main/image/flower_gradcam.PNG\" height=500>\n\n\n\n## To Do\n- [ ] Segmentation Pipeline\n- [x] Support for `Kaggle Models`\n- [x] Remove `tensorflow_addons`\n- [x] New updated weights have been added.\n- [x] Working training example in Colab & Kaggle.\n- [x] GradCAM showcase.\n- [x] Gradio Demo.\n- [x] Build model with `tf.keras.Model`.\n- [x] Port weights from official repo.\n- [x] Support for `TPU`.\n\n## Acknowledgement\n* [GCVit](https://github.com/NVlabs/GCVit) (Official)\n* [Swin-Transformer-TF](https://github.com/rishigami/Swin-Transformer-TF)\n* [tfgcvit](https://github.com/shkarupa-alex/tfgcvit/tree/develop/tfgcvit)\n* [keras_cv_attention_models](https://github.com/leondgarse/keras_cv_attention_model)\n\n\n## Citation\n```bibtex\n@article{hatamizadeh2022global,\n  title={Global Context Vision Transformers},\n  author={Hatamizadeh, Ali and Yin, Hongxu and Kautz, Jan and Molchanov, Pavlo},\n  journal={arXiv preprint arXiv:2206.09959},\n  year={2022}\n}\n```\n",
    "bugtrack_url": null,
    "license": "MIT",
    "summary": "Tensorflow 2.0 Implementation of GCViT: Global Context Vision Transformer. https://github.com/awsaf49/gcvit-tf",
    "version": "1.1.5",
    "project_urls": {
        "Homepage": "https://github.com/awsaf49/gcvit-tf"
    },
    "split_keywords": [
        "tensorflow",
        "computer_vision",
        "image",
        "classification",
        "transformer"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "7003dfc872f14ca58ca344fa06ffb86446723da699670d69f09aef063c2a15cc",
                "md5": "87ac21bfbb2e4859e3b1ddc59d213f13",
                "sha256": "979ba00c392a61817333cf966c4711a72d83818da3da9f6ab9fc788ce3f69fc6"
            },
            "downloads": -1,
            "filename": "gcvit-1.1.5-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "87ac21bfbb2e4859e3b1ddc59d213f13",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.6",
            "size": 21147,
            "upload_time": "2023-10-16T08:25:00",
            "upload_time_iso_8601": "2023-10-16T08:25:00.807467Z",
            "url": "https://files.pythonhosted.org/packages/70/03/dfc872f14ca58ca344fa06ffb86446723da699670d69f09aef063c2a15cc/gcvit-1.1.5-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "7737908ed14fb861e2958823ed1916bff950c4f09ae1aafebfd57199cdf72dec",
                "md5": "3d13b95edc1e7b29c8a5b89ef996edbe",
                "sha256": "afa799b7abc20298986092637bbdf93ca5412f5df0d274ee8e5681ac8e55de35"
            },
            "downloads": -1,
            "filename": "gcvit-1.1.5.tar.gz",
            "has_sig": false,
            "md5_digest": "3d13b95edc1e7b29c8a5b89ef996edbe",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.6",
            "size": 19758,
            "upload_time": "2023-10-16T08:25:02",
            "upload_time_iso_8601": "2023-10-16T08:25:02.759475Z",
            "url": "https://files.pythonhosted.org/packages/77/37/908ed14fb861e2958823ed1916bff950c4f09ae1aafebfd57199cdf72dec/gcvit-1.1.5.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2023-10-16 08:25:02",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "awsaf49",
    "github_project": "gcvit-tf",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "requirements": [],
    "lcname": "gcvit"
}
        
Elapsed time: 0.44029s