maskinversion-torch


Namemaskinversion-torch JSON
Version 1.1 PyPI version JSON
download
home_pagehttps://github.com/WalBouss/MaskInversion
SummaryMaskInversion
upload_time2024-10-14 17:31:02
maintainerNone
docs_urlNone
authorWalid Bousselham, Sofian Chaybouti, Christian Rupprecht, Vittorio Ferrari, Hilde Kuehne
requires_python>=3.7
licenseNone
keywords localized embedding mask explainability vit vision-language models clip pretrained
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # MaskInversion
### [MaskInversion: Localized Embeddings via Optimization of Explainability Maps](https://arxiv.org/abs/2407.20034)
_[Walid Bousselham](http://walidbousselham.com/), [Sofian Chaybouti](https://scholar.google.com/citations?user=8tewdk4AAAAJ&hl),[Christian Rupprecht](https://chrirupp.github.io/), [Vittorio Ferrari](https://sites.google.com/view/vittoferrari), [Hilde Kuehne](https://hildekuehne.github.io/)_

The proposed method, coined as MaskInversion, aims to learn a localized embedding or feature vector that encapsulates an object’s characteristics within an image specified by a query mask. This embedding should not solely represent the object’s intrinsic properties but also capture the broader context of the entire image.

To achieve this, we utilize representations provided by foundation models, such as CLIP. Our approach learns a token that captures the foundation model’s feature representation on the image region specified by the mask. Hence, the foundation model remains fixed during our process.

The following is the code for a wrapper around the [OpenCLIP](https://github.com/mlfoundations/open_clip) library to equip VL models with the ability to compute "localized embeddings" via the MaskInversion process.

<div align="center">
<img src="./assets/maskinversion_teaser.png" width="100%"/>
</div>

## :hammer: Installation
`maskinversion` library can be simply installed via pip: 
```bash
$ pip install maskinversion_torch
```

## :firecracker: Usage

### Available models
MaskInversion uses the [LeGrad](https://github.com/WalBouss/LeGrad) library to compute the explainability maps, hence MaskInversion support all the models from that library.
To see which pretrained models is available use the following code snippet:
```python
import maskinversion
maskinversion.available_models()
```

### Example
Given an image and several masks covering different objects, you can run `python example_maskinversion.py` or use the following code snippet to compute the **localized embedding** for each mask:

**Note**: the wrapper does not affect the original model, hence all the functionalities of OpenCLIP models can be used seamlessly.
```python
import requests
from PIL import Image
import torch
import torch.nn.functional as F
from open_clip import get_tokenizer, create_model_and_transforms
from maskinversion import (
 MaskInversion, MaskInversionImagePreprocess, MaskInversionMaskPreprocess, overlay_image_mask)

# ------ init model ------
device = 'cuda' if torch.cuda.is_available() else 'cpu'
pretrained = 'openai'
model_name = 'ViT-B-16'
model, _, preprocess = create_model_and_transforms(model_name=model_name, pretrained=pretrained, device=device)
tokenizer = get_tokenizer(model_name=model_name)

# ------ use MaskInversion wrapper ------
model = MaskInversion(model)
preprocess = MaskInversionImagePreprocess(preprocess, image_size=448)
mask_preprocess = MaskInversionMaskPreprocess()

# ------ init inputs ------
# === image ===
url_img = "https://github.com/WalBouss/MaskInversion/blob/main/assests/dress_and_flower.png"
img_pil = Image.open(requests.get(url_img, stream=True).raw).convert('RGB')
image = preprocess(img_pil).unsqueeze(0).to(device)

# === masks ===
masks_urls = ['https://github.com/WalBouss/MaskInversion/blob/main/assests/dress_mask.png',
              'https://github.com/WalBouss/MaskInversion/blob/main/assests/flower_mask.png']
masks = [Image.open(requests.get(url, stream=True).raw) for url in masks_urls]
masks = torch.stack([mask_preprocess(msk) for msk in masks]).to(device)

# === text ===
prompts = ['a photo of a dress', 'a photo of a flower']
text_input = tokenizer(prompts).to(device)
text_embeddings = model.encode_text(text_input)  # [num_prompts, dim]
text_embeddings = F.normalize(text_embeddings, dim=-1)

# ------ Compute localized embedding for each mask ------
localized_embeddings = model.compute_maskinversion(image=image, masks_target=masks, verbose=True)  # [num_masks, dim]
localized_embeddings = F.normalize(localized_embeddings, dim=-1)

# ------ Region-Text matching ------
mask_text_matching = localized_embeddings @ text_embeddings.transpose(-1, -2) # [num_masks, num_prompt]
for i, mask in enumerate(masks.cpu().numpy()):
    print(f'{prompts[i]}: {mask_text_matching[i].softmax(dim=-1)}')
    matched_prompt_idx = mask_text_matching[i].argmax()

    # ___ (Optional): Visualize overlay of the image + mask ___
    overlay_image_mask(image=img_pil, mask=mask, show=True, title=prompts[matched_prompt_idx])
```
 
### Visualize the final Explainability Maps
To visualize the explainability map after the MaskInversion process you can run `python example_viz_expl_maps.py` or use the following code snippet:
```python
import requests
from PIL import Image
import torch
import torch.nn.functional as F
from open_clip import get_tokenizer, create_model_and_transforms
from maskinversion import (
 MaskInversion, MaskInversionImagePreprocess, MaskInversionMaskPreprocess, overlay_image_expl_map)

# ------ init model ------
device = 'cuda' if torch.cuda.is_available() else 'cpu'
pretrained = 'openai'
model_name = 'ViT-B-16'
model, _, preprocess = create_model_and_transforms(model_name=model_name, pretrained=pretrained, device=device)
tokenizer = get_tokenizer(model_name=model_name)

# ------ use MaskInversion wrapper ------
model = MaskInversion(model)
preprocess = MaskInversionImagePreprocess(preprocess, image_size=448)
mask_preprocess = MaskInversionMaskPreprocess()

# ------ init inputs ------
# === image ===
url_img = "https://github.com/WalBouss/MaskInversion/blob/main/assests/cats-and-dogs.jpg"
img_pil = Image.open(requests.get(url_img, stream=True).raw).convert('RGB')
image = preprocess(img_pil).unsqueeze(0).to(device)

# === masks ===
masks_urls = ['https://github.com/WalBouss/MaskInversion/blob/main/assests/dress_mask.png',
              'https://github.com/WalBouss/MaskInversion/blob/main/assests/flower_mask.png']
masks = [Image.open(requests.get(url, stream=True).raw) for url in masks_urls]
masks = torch.stack([mask_preprocess(msk) for msk in masks]).to(device)

# === text ===
prompts = ['a photo of a dress', 'a photo of a flower']
prompts = ['a photo of a cat', 'a photo of a dog']
text_input = tokenizer(prompts).to(device)
text_embeddings = model.encode_text(text_input)  # # [num_prompts, dim]
text_embeddings = F.normalize(text_embeddings, dim=-1)

# ------ Compute localized embedding for each mask ------
localized_embeddings, expl_map = model.compute_maskinversion(
 image=image, masks_target=masks, verbose=True, return_expl_map=True)  # [num_masks, dim]
localized_embeddings = F.normalize(localized_embeddings, dim=-1)

# ------ Region-Text matching ------
mask_text_matching = localized_embeddings @ text_embeddings.transpose(-1, -2) # [num_masks, num_prompt]
for i, mask in enumerate(masks.cpu().numpy()):
    print(f'{prompts[i]}: {mask_text_matching[i].softmax(dim=-1)}')
    matched_prompt_idx = mask_text_matching[i].argmax()

    # ___ (Optional): Visualize overlay of the image + heatmap ___
    overlay_image_expl_map(image=img_pil, expl_map=expl_map[0, i], title=prompts[matched_prompt_idx], show=True)
```
### MaskInversion Hyperparameters
You can manually set the different hyperparameters used for the MaskInversion process,
_e.g._ number of `iterations`, learning rate (`lr`), the optimizer use (`optimizer`), weight decay (`wd`) or the coefficient `alpha` for the regularization loss.
```python
iterations = 10
lr = 0.5
alpha = 0.
wd = 0.
optimizer = torch.optim.AdamW
model = MaskInversion(model=model, iterations=iterations, lr=lr, alpha=alpha, wd=wd, optimizer=optimizer)
```

# :star: Acknowledgement
This code is build as wrapper around [OpenCLIP](https://github.com/mlfoundations/open_clip) library from [LAION](https://laion.ai/) and the [LeGrad](https://github.com/WalBouss/LeGrad) library, visit their repo for more vision-language models.
This project also takes inspiration from [AlphaCLIP](https://github.com/SunzeY/AlphaCLIP) and the [timm library](https://github.com/huggingface/pytorch-image-models), please visit their repository.

# :books: Citation
If you find this repository useful, please consider citing our work :pencil: and giving a star :star2: :
```
@article{bousselham2024maskinversion,
  title={MaskInversion: Localized Embeddings via Optimization of Explainability Maps},
  author={Walid Bousselham, Sofian Chaybouti, Christian Rupprecht, Vittorio Ferrari, Hilde Kuehne},
  journal={arXiv preprint arXiv:2407.20034},
  year={2024}
}

            

Raw data

            {
    "_id": null,
    "home_page": "https://github.com/WalBouss/MaskInversion",
    "name": "maskinversion-torch",
    "maintainer": null,
    "docs_url": null,
    "requires_python": ">=3.7",
    "maintainer_email": null,
    "keywords": "Localized Embedding, Mask, Explainability, ViT, Vision-Language Models, CLIP pretrained",
    "author": "Walid Bousselham, Sofian Chaybouti, Christian Rupprecht, Vittorio Ferrari, Hilde Kuehne",
    "author_email": null,
    "download_url": "https://files.pythonhosted.org/packages/d4/9b/cc663d29c5e0e59cf73fd2004e1565ab54398c6d4e9756296bcc8bd7c7ac/maskinversion_torch-1.1.tar.gz",
    "platform": null,
    "description": "# MaskInversion\n### [MaskInversion: Localized Embeddings via Optimization of Explainability Maps](https://arxiv.org/abs/2407.20034)\n_[Walid Bousselham](http://walidbousselham.com/), [Sofian Chaybouti](https://scholar.google.com/citations?user=8tewdk4AAAAJ&hl),[Christian Rupprecht](https://chrirupp.github.io/), [Vittorio Ferrari](https://sites.google.com/view/vittoferrari), [Hilde Kuehne](https://hildekuehne.github.io/)_\n\nThe proposed method, coined as MaskInversion, aims to learn a localized embedding or feature vector that encapsulates an object\u2019s characteristics within an image specified by a query mask. This embedding should not solely represent the object\u2019s intrinsic properties but also capture the broader context of the entire image.\n\nTo achieve this, we utilize representations provided by foundation models, such as CLIP. Our approach learns a token that captures the foundation model\u2019s feature representation on the image region specified by the mask. Hence, the foundation model remains fixed during our process.\n\nThe following is the code for a wrapper around the [OpenCLIP](https://github.com/mlfoundations/open_clip) library to equip VL models with the ability to compute \"localized embeddings\" via the MaskInversion process.\n\n<div align=\"center\">\n<img src=\"./assets/maskinversion_teaser.png\" width=\"100%\"/>\n</div>\n\n## :hammer: Installation\n`maskinversion` library can be simply installed via pip: \n```bash\n$ pip install maskinversion_torch\n```\n\n## :firecracker: Usage\n\n### Available models\nMaskInversion uses the [LeGrad](https://github.com/WalBouss/LeGrad) library to compute the explainability maps, hence MaskInversion support all the models from that library.\nTo see which pretrained models is available use the following code snippet:\n```python\nimport maskinversion\nmaskinversion.available_models()\n```\n\n### Example\nGiven an image and several masks covering different objects, you can run `python example_maskinversion.py` or use the following code snippet to compute the **localized embedding** for each mask:\n\n**Note**: the wrapper does not affect the original model, hence all the functionalities of OpenCLIP models can be used seamlessly.\n```python\nimport requests\nfrom PIL import Image\nimport torch\nimport torch.nn.functional as F\nfrom open_clip import get_tokenizer, create_model_and_transforms\nfrom maskinversion import (\n MaskInversion, MaskInversionImagePreprocess, MaskInversionMaskPreprocess, overlay_image_mask)\n\n# ------ init model ------\ndevice = 'cuda' if torch.cuda.is_available() else 'cpu'\npretrained = 'openai'\nmodel_name = 'ViT-B-16'\nmodel, _, preprocess = create_model_and_transforms(model_name=model_name, pretrained=pretrained, device=device)\ntokenizer = get_tokenizer(model_name=model_name)\n\n# ------ use MaskInversion wrapper ------\nmodel = MaskInversion(model)\npreprocess = MaskInversionImagePreprocess(preprocess, image_size=448)\nmask_preprocess = MaskInversionMaskPreprocess()\n\n# ------ init inputs ------\n# === image ===\nurl_img = \"https://github.com/WalBouss/MaskInversion/blob/main/assests/dress_and_flower.png\"\nimg_pil = Image.open(requests.get(url_img, stream=True).raw).convert('RGB')\nimage = preprocess(img_pil).unsqueeze(0).to(device)\n\n# === masks ===\nmasks_urls = ['https://github.com/WalBouss/MaskInversion/blob/main/assests/dress_mask.png',\n              'https://github.com/WalBouss/MaskInversion/blob/main/assests/flower_mask.png']\nmasks = [Image.open(requests.get(url, stream=True).raw) for url in masks_urls]\nmasks = torch.stack([mask_preprocess(msk) for msk in masks]).to(device)\n\n# === text ===\nprompts = ['a photo of a dress', 'a photo of a flower']\ntext_input = tokenizer(prompts).to(device)\ntext_embeddings = model.encode_text(text_input)  # [num_prompts, dim]\ntext_embeddings = F.normalize(text_embeddings, dim=-1)\n\n# ------ Compute localized embedding for each mask ------\nlocalized_embeddings = model.compute_maskinversion(image=image, masks_target=masks, verbose=True)  # [num_masks, dim]\nlocalized_embeddings = F.normalize(localized_embeddings, dim=-1)\n\n# ------ Region-Text matching ------\nmask_text_matching = localized_embeddings @ text_embeddings.transpose(-1, -2) # [num_masks, num_prompt]\nfor i, mask in enumerate(masks.cpu().numpy()):\n    print(f'{prompts[i]}: {mask_text_matching[i].softmax(dim=-1)}')\n    matched_prompt_idx = mask_text_matching[i].argmax()\n\n    # ___ (Optional): Visualize overlay of the image + mask ___\n    overlay_image_mask(image=img_pil, mask=mask, show=True, title=prompts[matched_prompt_idx])\n```\n \n### Visualize the final Explainability Maps\nTo visualize the explainability map after the MaskInversion process you can run `python example_viz_expl_maps.py` or use the following code snippet:\n```python\nimport requests\nfrom PIL import Image\nimport torch\nimport torch.nn.functional as F\nfrom open_clip import get_tokenizer, create_model_and_transforms\nfrom maskinversion import (\n MaskInversion, MaskInversionImagePreprocess, MaskInversionMaskPreprocess, overlay_image_expl_map)\n\n# ------ init model ------\ndevice = 'cuda' if torch.cuda.is_available() else 'cpu'\npretrained = 'openai'\nmodel_name = 'ViT-B-16'\nmodel, _, preprocess = create_model_and_transforms(model_name=model_name, pretrained=pretrained, device=device)\ntokenizer = get_tokenizer(model_name=model_name)\n\n# ------ use MaskInversion wrapper ------\nmodel = MaskInversion(model)\npreprocess = MaskInversionImagePreprocess(preprocess, image_size=448)\nmask_preprocess = MaskInversionMaskPreprocess()\n\n# ------ init inputs ------\n# === image ===\nurl_img = \"https://github.com/WalBouss/MaskInversion/blob/main/assests/cats-and-dogs.jpg\"\nimg_pil = Image.open(requests.get(url_img, stream=True).raw).convert('RGB')\nimage = preprocess(img_pil).unsqueeze(0).to(device)\n\n# === masks ===\nmasks_urls = ['https://github.com/WalBouss/MaskInversion/blob/main/assests/dress_mask.png',\n              'https://github.com/WalBouss/MaskInversion/blob/main/assests/flower_mask.png']\nmasks = [Image.open(requests.get(url, stream=True).raw) for url in masks_urls]\nmasks = torch.stack([mask_preprocess(msk) for msk in masks]).to(device)\n\n# === text ===\nprompts = ['a photo of a dress', 'a photo of a flower']\nprompts = ['a photo of a cat', 'a photo of a dog']\ntext_input = tokenizer(prompts).to(device)\ntext_embeddings = model.encode_text(text_input)  # # [num_prompts, dim]\ntext_embeddings = F.normalize(text_embeddings, dim=-1)\n\n# ------ Compute localized embedding for each mask ------\nlocalized_embeddings, expl_map = model.compute_maskinversion(\n image=image, masks_target=masks, verbose=True, return_expl_map=True)  # [num_masks, dim]\nlocalized_embeddings = F.normalize(localized_embeddings, dim=-1)\n\n# ------ Region-Text matching ------\nmask_text_matching = localized_embeddings @ text_embeddings.transpose(-1, -2) # [num_masks, num_prompt]\nfor i, mask in enumerate(masks.cpu().numpy()):\n    print(f'{prompts[i]}: {mask_text_matching[i].softmax(dim=-1)}')\n    matched_prompt_idx = mask_text_matching[i].argmax()\n\n    # ___ (Optional): Visualize overlay of the image + heatmap ___\n    overlay_image_expl_map(image=img_pil, expl_map=expl_map[0, i], title=prompts[matched_prompt_idx], show=True)\n```\n### MaskInversion Hyperparameters\nYou can manually set the different hyperparameters used for the MaskInversion process,\n_e.g._ number of `iterations`, learning rate (`lr`), the optimizer use (`optimizer`), weight decay (`wd`) or the coefficient `alpha` for the regularization loss.\n```python\niterations = 10\nlr = 0.5\nalpha = 0.\nwd = 0.\noptimizer = torch.optim.AdamW\nmodel = MaskInversion(model=model, iterations=iterations, lr=lr, alpha=alpha, wd=wd, optimizer=optimizer)\n```\n\n# :star: Acknowledgement\nThis code is build as wrapper around [OpenCLIP](https://github.com/mlfoundations/open_clip) library from [LAION](https://laion.ai/) and the [LeGrad](https://github.com/WalBouss/LeGrad) library, visit their repo for more vision-language models.\nThis project also takes inspiration from [AlphaCLIP](https://github.com/SunzeY/AlphaCLIP) and the [timm library](https://github.com/huggingface/pytorch-image-models), please visit their repository.\n\n# :books: Citation\nIf you find this repository useful, please consider citing our work :pencil: and giving a star :star2: :\n```\n@article{bousselham2024maskinversion,\n  title={MaskInversion: Localized Embeddings via Optimization of Explainability Maps},\n  author={Walid Bousselham, Sofian Chaybouti, Christian Rupprecht, Vittorio Ferrari, Hilde Kuehne},\n  journal={arXiv preprint arXiv:2407.20034},\n  year={2024}\n}\n",
    "bugtrack_url": null,
    "license": null,
    "summary": "MaskInversion",
    "version": "1.1",
    "project_urls": {
        "Homepage": "https://github.com/WalBouss/MaskInversion"
    },
    "split_keywords": [
        "localized embedding",
        " mask",
        " explainability",
        " vit",
        " vision-language models",
        " clip pretrained"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "321cc7b6cddcb1cfb5f9ba807a4164a9f3f439a0331438f4384948402c8781f9",
                "md5": "8ec4da1cc8270fb392883e8786c129d4",
                "sha256": "461b33d2db3bb425b06009883fa426bf229ebe1a6688d4e1c32eae456a5b9bb6"
            },
            "downloads": -1,
            "filename": "maskinversion_torch-1.1-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "8ec4da1cc8270fb392883e8786c129d4",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.7",
            "size": 7102,
            "upload_time": "2024-10-14T17:31:00",
            "upload_time_iso_8601": "2024-10-14T17:31:00.715918Z",
            "url": "https://files.pythonhosted.org/packages/32/1c/c7b6cddcb1cfb5f9ba807a4164a9f3f439a0331438f4384948402c8781f9/maskinversion_torch-1.1-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "d49bcc663d29c5e0e59cf73fd2004e1565ab54398c6d4e9756296bcc8bd7c7ac",
                "md5": "5bcbf6239c476cb36209fc533f6990a0",
                "sha256": "11b5d6970d0a418c32ad50d831b178490d753ecf614db8a081534e0024990aa8"
            },
            "downloads": -1,
            "filename": "maskinversion_torch-1.1.tar.gz",
            "has_sig": false,
            "md5_digest": "5bcbf6239c476cb36209fc533f6990a0",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.7",
            "size": 7439,
            "upload_time": "2024-10-14T17:31:02",
            "upload_time_iso_8601": "2024-10-14T17:31:02.754621Z",
            "url": "https://files.pythonhosted.org/packages/d4/9b/cc663d29c5e0e59cf73fd2004e1565ab54398c6d4e9756296bcc8bd7c7ac/maskinversion_torch-1.1.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2024-10-14 17:31:02",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "WalBouss",
    "github_project": "MaskInversion",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": false,
    "requirements": [],
    "lcname": "maskinversion-torch"
}
        
Elapsed time: 2.13653s