# FocalNet: Focal Modulation Networks for Tensorflow
This repository contains a TensorFlow implementation of the paper Focal Modulation Networks. The paper proposes an attention-free architecture called focal modulation, which can dynamically adjust the focus of convolutional neural networks on different regions of the input. Focal modulation can improve the performance of various vision tasks, such as image classification, object detection, semantic segmentation and face recognition.
Focal Modulation brings several merits:
* **Translation-Invariance**: It is performed for each target token with the context centered around it.
* **Explicit input-dependency**: The *modulator* is computed by aggregating the short- and long-rage context from the input and then applied to the target token.
* **Spatial- and channel-specific**: It first aggregates the context spatial-wise and then channel-wise, followed by an element-wise modulation.
* **Decoupled feature granularity**: Query token preserves the invidual information at finest level, while coarser context is extracted surrounding it. They two are decoupled but connected through the modulation operation.
* **Easy to implement**: We can implement both context aggregation and interaction in a very simple and light-weight way. It does not need softmax, multiple attention heads, feature map rolling or unfolding, etc.
<p align="center">
<img src="https://github.com/Shiro-LK/focalnet-tf/blob/main/figures/focalnet-model.png" width=80% height=80%
class="center">
</p>
This repository aims to reproduce the results of the paper using TensorFlow 2.4.1 and provide a modular and easy-to-use implementation of focal modulation networks. The code is based on the official PyTorch implementation of the paper, which can be found on the offical repository [here](https://github.com/microsoft/FocalNet) . Only the classification part is implemented. Pretrained checkpoints have been converted on Tensorflow.
<p align="center">
<img src="https://github.com/Shiro-LK/focalnet-tf/blob/main/figures/modulator.JPG" width=80% height=80%
class="center">
</p>
# Installation
> pip install focalnet-tf
# Example
```
import cv2
import sys
import numpy as np
import os
import tensorflow as tf
from focalnet import load_focalnet, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, imagenet1k, imagenet22k
def preprocess_image(image ):
image = image/255.0
image = (image - IMAGENET_DEFAULT_MEAN)/IMAGENET_DEFAULT_STD
return np.expand_dims(image, axis=0)
def center_crop(image, output_shape):
# Get the input shape
h, w, c = image.shape
# Get the output shape
h_desired, w_desired = output_shape
# Check if the output shape is valid
if h_desired > h or w_desired > w :
raise ValueError("Output shape must be smaller than or equal to input shape and have the same number of channels.")
# Compute the crop coordinates
h_start = (h - h_desired) // 2
h_end = h_start + h_desired
w_start = (w - w_desired) // 2
w_end = w_start + w_desired
# Crop the image and return it
return image[h_start:h_end, w_start:w_end, :]
image = cv2.cvtColor(cv2.imread("tests/dog.jpg"), cv2.COLOR_BGR2RGB)
image_crop = center_crop(image, (768, 768))
output_shape = (224, 224)
image_resized = cv2.resize(image_crop, output_shape)
inputs = preprocess_image(image_crop)
model = load_focalnet(model_name='focalnet_tiny_srf', pretrained=True, return_model=False, act_head="softmax")
output = model.predict(inputs)
print(output[0, np.argmax(output)])
print(imagenet22k[np.argmax(output)])
```
# Acknowledgement
- paper : https://arxiv.org/abs/2203.11926 from Jianwei Yang et al.
- pytorch implementation : https://github.com/microsoft/FocalNet
Raw data
{
"_id": null,
"home_page": "https://github.com/Shiro-LK/focalnet-tf",
"name": "focalnet-tf",
"maintainer": "",
"docs_url": null,
"requires_python": "",
"maintainer_email": "",
"keywords": "focalnet,tensorflow",
"author": "Shiro-LK",
"author_email": "shirosaki94@gmail.com",
"download_url": "https://files.pythonhosted.org/packages/3a/d8/3bbcfa936ae80913c36fd0065cf42b5e2bbfe500d6f87de9b1bacd5859c0/focalnet-tf-0.0.2.3.tar.gz",
"platform": null,
"description": "# FocalNet: Focal Modulation Networks for Tensorflow \r\n\r\nThis repository contains a TensorFlow implementation of the paper Focal Modulation Networks. The paper proposes an attention-free architecture called focal modulation, which can dynamically adjust the focus of convolutional neural networks on different regions of the input. Focal modulation can improve the performance of various vision tasks, such as image classification, object detection, semantic segmentation and face recognition.\r\n\r\nFocal Modulation brings several merits:\r\n\r\n* **Translation-Invariance**: It is performed for each target token with the context centered around it.\r\n* **Explicit input-dependency**: The *modulator* is computed by aggregating the short- and long-rage context from the input and then applied to the target token.\r\n* **Spatial- and channel-specific**: It first aggregates the context spatial-wise and then channel-wise, followed by an element-wise modulation.\r\n* **Decoupled feature granularity**: Query token preserves the invidual information at finest level, while coarser context is extracted surrounding it. They two are decoupled but connected through the modulation operation.\r\n* **Easy to implement**: We can implement both context aggregation and interaction in a very simple and light-weight way. It does not need softmax, multiple attention heads, feature map rolling or unfolding, etc.\r\n\r\n<p align=\"center\">\r\n<img src=\"https://github.com/Shiro-LK/focalnet-tf/blob/main/figures/focalnet-model.png\" width=80% height=80% \r\nclass=\"center\">\r\n</p>\r\n\r\nThis repository aims to reproduce the results of the paper using TensorFlow 2.4.1 and provide a modular and easy-to-use implementation of focal modulation networks. The code is based on the official PyTorch implementation of the paper, which can be found on the offical repository [here](https://github.com/microsoft/FocalNet) . Only the classification part is implemented. Pretrained checkpoints have been converted on Tensorflow.\r\n\r\n<p align=\"center\">\r\n<img src=\"https://github.com/Shiro-LK/focalnet-tf/blob/main/figures/modulator.JPG\" width=80% height=80% \r\nclass=\"center\">\r\n</p>\r\n\r\n# Installation\r\n\r\n> pip install focalnet-tf\r\n\r\n# Example \r\n\r\n```\r\n\r\nimport cv2\r\nimport sys\r\nimport numpy as np\r\nimport os \r\nimport tensorflow as tf\r\nfrom focalnet import load_focalnet, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, imagenet1k, imagenet22k\r\n\r\ndef preprocess_image(image ):\r\n image = image/255.0\r\n image = (image - IMAGENET_DEFAULT_MEAN)/IMAGENET_DEFAULT_STD\r\n return np.expand_dims(image, axis=0)\r\n\r\ndef center_crop(image, output_shape):\r\n # Get the input shape\r\n h, w, c = image.shape\r\n\r\n # Get the output shape\r\n h_desired, w_desired = output_shape\r\n\r\n # Check if the output shape is valid\r\n if h_desired > h or w_desired > w :\r\n raise ValueError(\"Output shape must be smaller than or equal to input shape and have the same number of channels.\")\r\n\r\n # Compute the crop coordinates\r\n h_start = (h - h_desired) // 2\r\n h_end = h_start + h_desired\r\n w_start = (w - w_desired) // 2\r\n w_end = w_start + w_desired\r\n\r\n # Crop the image and return it\r\n return image[h_start:h_end, w_start:w_end, :]\r\n\r\nimage = cv2.cvtColor(cv2.imread(\"tests/dog.jpg\"), cv2.COLOR_BGR2RGB)\r\nimage_crop = center_crop(image, (768, 768))\r\noutput_shape = (224, 224)\r\nimage_resized = cv2.resize(image_crop, output_shape)\r\ninputs = preprocess_image(image_crop)\r\n\r\nmodel = load_focalnet(model_name='focalnet_tiny_srf', pretrained=True, return_model=False, act_head=\"softmax\")\r\noutput = model.predict(inputs)\r\nprint(output[0, np.argmax(output)])\r\nprint(imagenet22k[np.argmax(output)])\r\n\r\n```\r\n\r\n# Acknowledgement\r\n\r\n- paper : https://arxiv.org/abs/2203.11926 from Jianwei Yang et al.\r\n- pytorch implementation : https://github.com/microsoft/FocalNet\r\n\r\n",
"bugtrack_url": null,
"license": "MIT License",
"summary": "Re-implementation of FocalNet for tensorflow 2.X",
"version": "0.0.2.3",
"project_urls": {
"Download": "https://github.com/Shiro-LK/focalnet-tf.git",
"Homepage": "https://github.com/Shiro-LK/focalnet-tf"
},
"split_keywords": [
"focalnet",
"tensorflow"
],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "68fb0c9b9779e81d058ae72269b1c44e0ab7bd19f7303d104f84ed5f540a61ee",
"md5": "c4bc9ff38c74c8e3d7dba8d17cb322a5",
"sha256": "871072703dca263dc20061ed38acbaeaebf90fcd6edeb10356cf0b3ba2ebcd67"
},
"downloads": -1,
"filename": "focalnet_tf-0.0.2.3-py3-none-any.whl",
"has_sig": false,
"md5_digest": "c4bc9ff38c74c8e3d7dba8d17cb322a5",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": null,
"size": 600329,
"upload_time": "2023-06-26T19:51:48",
"upload_time_iso_8601": "2023-06-26T19:51:48.324363Z",
"url": "https://files.pythonhosted.org/packages/68/fb/0c9b9779e81d058ae72269b1c44e0ab7bd19f7303d104f84ed5f540a61ee/focalnet_tf-0.0.2.3-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": "",
"digests": {
"blake2b_256": "3ad83bbcfa936ae80913c36fd0065cf42b5e2bbfe500d6f87de9b1bacd5859c0",
"md5": "8eff3d3e215df5507a2cc48ee724d24e",
"sha256": "831d4ab777a8840d86048e5bba35fd585f5336db8177199600aa37c8d5825e26"
},
"downloads": -1,
"filename": "focalnet-tf-0.0.2.3.tar.gz",
"has_sig": false,
"md5_digest": "8eff3d3e215df5507a2cc48ee724d24e",
"packagetype": "sdist",
"python_version": "source",
"requires_python": null,
"size": 600367,
"upload_time": "2023-06-26T19:51:51",
"upload_time_iso_8601": "2023-06-26T19:51:51.669243Z",
"url": "https://files.pythonhosted.org/packages/3a/d8/3bbcfa936ae80913c36fd0065cf42b5e2bbfe500d6f87de9b1bacd5859c0/focalnet-tf-0.0.2.3.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2023-06-26 19:51:51",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "Shiro-LK",
"github_project": "focalnet-tf",
"travis_ci": false,
"coveralls": false,
"github_actions": false,
"lcname": "focalnet-tf"
}