# spatial-grouping-attention
## PyTorch Implementation of Spatial Grouping Attention

[](https://github.com/rhoadesScholar/spatial-grouping-attention/actions/workflows/ci-cd.yml)
[](https://codecov.io/github/rhoadesScholar/spatial-grouping-attention)


Inspired by the spatial grouping layer in Native Segmentation Vision Transformers (https://arxiv.org/abs/2505.16993), implemented in PyTorch with a modified rotary position embedding generalized to N-dimensions and incorporating real-world pixel spacing.
## Installation
### From PyPI
You will first need to install PyTorch separately, as it is required for building one of our dependencies (natten). We recommend installing within a virtual environment, such as `venv` or `mamba`:
```bash
# create a virtual environment
mamba create -n spatial-attention -y python=3.11 pytorch ninja cmake
# activate the virtual environment
mamba activate spatial-attention
# install the package(s)
pip install spatial-grouping-attention
pip install natten==0.17.5 # requires python 3.11
```
### From source
To install the latest development version directly from GitHub, follow the creation and activation of a virtual environment, as above, then run:
```bash
pip install git+https://github.com/rhoadesScholar/spatial-grouping-attention.git
```
## Usage
The spatial grouping attention mechanism automatically computes query grid parameters (`q_spacing` and `q_grid_shape`) from the input key grid and convolution parameters. This makes it easy to use - you only need to specify the input resolution and the algorithm handles the spatial downsampling.
### Basic 2D Dense Attention
```python
import torch
from spatial_grouping_attention import DenseSpatialGroupingAttention
# Create attention module with 7x7 grouping kernel, stride=1 (no downsampling)
attention = DenseSpatialGroupingAttention(
feature_dims=128,
spatial_dims=2,
kernel_size=7, # 7x7 spatial grouping
stride=1, # No spatial downsampling
num_heads=8,
mlp_ratio=4
)
# Input: 32x32 image with 128 features per pixel
batch_size, height, width = 2, 32, 32
x = torch.randn(batch_size, height * width, 128)
# Only specify input (key) grid - query grid computed automatically
input_spacing = (0.5, 0.5) # 0.5 microns per pixel
input_grid_shape = (32, 32) # Input resolution
output = attention(x=x, input_spacing=input_spacing, input_grid_shape=input_grid_shape)
# Auto-computed: q_spacing = (0.5, 0.5) * 1 = (0.5, 0.5)
# Auto-computed: q_grid_shape = (32, 32) (no downsampling with stride=1)
print(f"Output shape: {output['x_out'].shape}") # (2, 1024, 128)
```
### 2D Attention with Spatial Downsampling
```python
# Create attention with spatial downsampling for efficiency
downsampling_attention = DenseSpatialGroupingAttention(
feature_dims=256,
spatial_dims=2,
kernel_size=5, # 5x5 grouping kernel
stride=2, # 2x spatial downsampling
padding=2, # Maintain spatial coverage
num_heads=16
)
# High resolution input: 128x128 image
x_hires = torch.randn(1, 128*128, 256)
input_spacing_hires = (0.1, 0.1) # 0.1 mm per pixel (input)
input_grid_shape_hires = (128, 128) # High-res input grid
output_hires = downsampling_attention(
x=x_hires,
input_spacing=input_spacing_hires,
input_grid_shape=input_grid_shape_hires
)
# Auto-computed: q_spacing = (0.1, 0.1) * 2 = (0.2, 0.2)
# Auto-computed: q_grid_shape = (128+2*2-5)//2+1 = (64, 64)
print(f"Downsampled output: {output_hires['x_out'].shape}") # (1, 4096, 256)
print(f"Compression ratio: {128*128 / (64*64)}x") # 4x fewer points
```
### 3D Sparse Attention (GPU Required)
```python
# 3D sparse attention for volumetric data (requires CUDA + natten)
try:
from spatial_grouping_attention import SparseSpatialGroupingAttention
sparse_3d = SparseSpatialGroupingAttention(
feature_dims=128,
spatial_dims=3,
kernel_size=(3, 5, 5), # Anisotropic: 3x5x5 grouping
stride=(1, 2, 2), # Downsample only in x,y
num_heads=8,
neighborhood_kernel=9 # Local attention window
)
# 3D volume: 16x64x64 voxels
depth, height, width = 16, 64, 64
x_3d = torch.randn(1, depth*height*width, 128).cuda()
# Anisotropic spacing (e.g., confocal microscopy)
input_spacing_3d = (0.5, 0.1, 0.1) # z, y, x spacing in microns
input_grid_shape_3d = (16, 64, 64)
output_3d = sparse_3d(
x=x_3d,
input_spacing=input_spacing_3d,
input_grid_shape=input_grid_shape_3d
)
# Auto-computed: q_spacing = (0.5*1, 0.1*2, 0.1*2) = (0.5, 0.2, 0.2)
# Auto-computed: q_grid_shape = (16, 32, 32) - downsampled in x,y only
print(f"3D sparse output: {output_3d['x_out'].shape}") # (1, 16384, 128)
except ImportError:
print("SparseSpatialGroupingAttention requires CUDA and natten package")
```
### Multi-Scale Processing
```python
# Process same input at multiple scales efficiently
multiscale_attention = DenseSpatialGroupingAttention(
feature_dims=64,
spatial_dims=2,
kernel_size=9,
stride=4, # 4x downsampling for global context
num_heads=4,
iters=3 # Multiple attention iterations
)
# Input image
x_input = torch.randn(1, 64*64, 64)
input_spacing = (1.0, 1.0) # 1 micron per pixel
input_grid_shape = (64, 64)
# Global context via 4x downsampling
global_output = multiscale_attention(
x=x_input,
input_spacing=input_spacing,
input_grid_shape=input_grid_shape
)
# Auto-computed: q_spacing = (4.0, 4.0), q_grid_shape = (16, 16)
print(f"Global context: {global_output['x_out'].shape}") # (1, 256, 64)
# Fine-scale processing with stride=1
fine_attention = DenseSpatialGroupingAttention(
feature_dims=64,
spatial_dims=2,
kernel_size=5,
stride=1, # Full resolution
num_heads=4
)
fine_output = fine_attention(
x=x_input,
input_spacing=input_spacing,
input_grid_shape=input_grid_shape
)
# Auto-computed: q_spacing = (1.0, 1.0), q_grid_shape = (64, 64)
print(f"Fine details: {fine_output['x_out'].shape}") # (1, 4096, 64)
```
### Integration with Neural Networks
```python
class HierarchicalSpatialNet(torch.nn.Module):
"""Multi-scale spatial processing network"""
def __init__(self, input_channels=3, num_classes=10):
super().__init__()
# Input embedding
self.embed = torch.nn.Linear(input_channels, 128)
# Coarse-scale attention (4x downsampling)
self.coarse_attention = DenseSpatialGroupingAttention(
feature_dims=128,
spatial_dims=2,
kernel_size=7,
stride=4, # 4x spatial compression
num_heads=8
)
# Fine-scale attention (2x downsampling)
self.fine_attention = DenseSpatialGroupingAttention(
feature_dims=128,
spatial_dims=2,
kernel_size=5,
stride=2, # 2x spatial compression
num_heads=8
)
# Cross-scale fusion
self.fusion = torch.nn.Linear(256, 128)
self.classifier = torch.nn.Linear(128, num_classes)
def forward(self, images, pixel_spacing=(1.0, 1.0)):
B, C, H, W = images.shape
# Flatten and embed
x = images.permute(0, 2, 3, 1).reshape(B, H*W, C)
x = self.embed(x)
# Multi-scale attention
coarse_out = self.coarse_attention(
x=x,
input_spacing=pixel_spacing,
input_grid_shape=(H, W)
)['x_out'] # (B, H*W/16, 128) - 4x downsampling
fine_out = self.fine_attention(
x=x,
input_spacing=pixel_spacing,
input_grid_shape=(H, W)
)['x_out'] # (B, H*W/4, 128) - 2x downsampling
# Upsample coarse to match fine resolution for fusion
coarse_upsampled = torch.nn.functional.interpolate(
coarse_out.transpose(1, 2).reshape(B, 128, H//4, W//4),
size=(H//2, W//2),
mode='bilinear',
align_corners=False
).reshape(B, 128, -1).transpose(1, 2)
# Fuse multi-scale features
fused = self.fusion(torch.cat([fine_out, coarse_upsampled], dim=-1))
# Global pooling and classification
global_features = fused.mean(dim=1)
return self.classifier(global_features)
# Usage example
net = HierarchicalSpatialNet(input_channels=3, num_classes=1000)
sample_images = torch.randn(4, 3, 128, 128) # ImageNet-style input
pixel_spacing = (0.1, 0.1) # 0.1 mm per pixel
predictions = net(sample_images, pixel_spacing)
print(f"Predictions: {predictions.shape}") # (4, 1000)
```
### Key Principles
1. **Automatic Grid Calculation**: You only specify input (`input_spacing`, `input_grid_shape`) - the query grid is computed as:
- `q_spacing = input_spacing * stride`
- `q_grid_shape = (k_grid + 2*padding - kernel) // stride + 1`
2. **Spatial Grouping**: The `kernel_size` determines how many neighboring points are grouped together for attention computation.
3. **Multi-Scale Processing**: Use different `stride` values to process the same input at multiple spatial scales efficiently.
4. **Memory Efficiency**: Larger strides reduce the number of query points, making attention computation more efficient for large inputs.
## Contributing
1. Fork the repository
2. Create a feature branch (`git checkout -b feature/amazing-feature`)
3. Make your changes
4. Run the test suite (`make test`)
5. Commit your changes (`git commit -m 'Add some amazing feature'`)
6. Push to the branch (`git push origin feature/amazing-feature`)
7. Open a Pull Request
## License
BSD 3-Clause License. See [LICENSE](LICENSE) for details.
## Citation
If you use this software in your research, please cite it using the information in [CITATION.cff](CITATION.cff).
Raw data
{
"_id": null,
"home_page": null,
"name": "spatial-grouping-attention",
"maintainer": null,
"docs_url": null,
"requires_python": ">=3.10",
"maintainer_email": "Jeff Rhoades <rhoadesj@hhmi.org>",
"keywords": "attention, rotary spatial embedding, segmentation, spatial grouping, transformer",
"author": null,
"author_email": "Jeff Rhoades <rhoadesj@hhmi.org>",
"download_url": "https://files.pythonhosted.org/packages/0a/15/c61ad99b27394642064587c2dc49627a54978997f6a41cb454d46810acc5/spatial_grouping_attention-2025.8.21.1739.tar.gz",
"platform": null,
"description": "# spatial-grouping-attention\n\n## PyTorch Implementation of Spatial Grouping Attention\n\n\n[](https://github.com/rhoadesScholar/spatial-grouping-attention/actions/workflows/ci-cd.yml)\n[](https://codecov.io/github/rhoadesScholar/spatial-grouping-attention)\n\n\n\nInspired by the spatial grouping layer in Native Segmentation Vision Transformers (https://arxiv.org/abs/2505.16993), implemented in PyTorch with a modified rotary position embedding generalized to N-dimensions and incorporating real-world pixel spacing.\n\n## Installation\n\n### From PyPI\n\nYou will first need to install PyTorch separately, as it is required for building one of our dependencies (natten). We recommend installing within a virtual environment, such as `venv` or `mamba`:\n\n```bash\n# create a virtual environment\nmamba create -n spatial-attention -y python=3.11 pytorch ninja cmake\n\n# activate the virtual environment\nmamba activate spatial-attention\n\n# install the package(s)\npip install spatial-grouping-attention\npip install natten==0.17.5 # requires python 3.11\n```\n\n### From source\n\nTo install the latest development version directly from GitHub, follow the creation and activation of a virtual environment, as above, then run:\n\n```bash\npip install git+https://github.com/rhoadesScholar/spatial-grouping-attention.git\n```\n\n## Usage\n\nThe spatial grouping attention mechanism automatically computes query grid parameters (`q_spacing` and `q_grid_shape`) from the input key grid and convolution parameters. This makes it easy to use - you only need to specify the input resolution and the algorithm handles the spatial downsampling.\n\n### Basic 2D Dense Attention\n\n```python\nimport torch\nfrom spatial_grouping_attention import DenseSpatialGroupingAttention\n\n# Create attention module with 7x7 grouping kernel, stride=1 (no downsampling)\nattention = DenseSpatialGroupingAttention(\n feature_dims=128,\n spatial_dims=2,\n kernel_size=7, # 7x7 spatial grouping\n stride=1, # No spatial downsampling\n num_heads=8,\n mlp_ratio=4\n)\n\n# Input: 32x32 image with 128 features per pixel\nbatch_size, height, width = 2, 32, 32\nx = torch.randn(batch_size, height * width, 128)\n\n# Only specify input (key) grid - query grid computed automatically\ninput_spacing = (0.5, 0.5) # 0.5 microns per pixel\ninput_grid_shape = (32, 32) # Input resolution\n\noutput = attention(x=x, input_spacing=input_spacing, input_grid_shape=input_grid_shape)\n# Auto-computed: q_spacing = (0.5, 0.5) * 1 = (0.5, 0.5)\n# Auto-computed: q_grid_shape = (32, 32) (no downsampling with stride=1)\nprint(f\"Output shape: {output['x_out'].shape}\") # (2, 1024, 128)\n```\n\n### 2D Attention with Spatial Downsampling\n\n```python\n# Create attention with spatial downsampling for efficiency\ndownsampling_attention = DenseSpatialGroupingAttention(\n feature_dims=256,\n spatial_dims=2,\n kernel_size=5, # 5x5 grouping kernel\n stride=2, # 2x spatial downsampling\n padding=2, # Maintain spatial coverage\n num_heads=16\n)\n\n# High resolution input: 128x128 image\nx_hires = torch.randn(1, 128*128, 256)\ninput_spacing_hires = (0.1, 0.1) # 0.1 mm per pixel (input)\ninput_grid_shape_hires = (128, 128) # High-res input grid\n\noutput_hires = downsampling_attention(\n x=x_hires,\n input_spacing=input_spacing_hires,\n input_grid_shape=input_grid_shape_hires\n)\n# Auto-computed: q_spacing = (0.1, 0.1) * 2 = (0.2, 0.2)\n# Auto-computed: q_grid_shape = (128+2*2-5)//2+1 = (64, 64)\nprint(f\"Downsampled output: {output_hires['x_out'].shape}\") # (1, 4096, 256)\nprint(f\"Compression ratio: {128*128 / (64*64)}x\") # 4x fewer points\n```\n\n### 3D Sparse Attention (GPU Required)\n\n```python\n# 3D sparse attention for volumetric data (requires CUDA + natten)\ntry:\n from spatial_grouping_attention import SparseSpatialGroupingAttention\n\n sparse_3d = SparseSpatialGroupingAttention(\n feature_dims=128,\n spatial_dims=3,\n kernel_size=(3, 5, 5), # Anisotropic: 3x5x5 grouping\n stride=(1, 2, 2), # Downsample only in x,y\n num_heads=8,\n neighborhood_kernel=9 # Local attention window\n )\n\n # 3D volume: 16x64x64 voxels\n depth, height, width = 16, 64, 64\n x_3d = torch.randn(1, depth*height*width, 128).cuda()\n\n # Anisotropic spacing (e.g., confocal microscopy)\n input_spacing_3d = (0.5, 0.1, 0.1) # z, y, x spacing in microns\n input_grid_shape_3d = (16, 64, 64)\n\n output_3d = sparse_3d(\n x=x_3d,\n input_spacing=input_spacing_3d,\n input_grid_shape=input_grid_shape_3d\n )\n # Auto-computed: q_spacing = (0.5*1, 0.1*2, 0.1*2) = (0.5, 0.2, 0.2)\n # Auto-computed: q_grid_shape = (16, 32, 32) - downsampled in x,y only\n print(f\"3D sparse output: {output_3d['x_out'].shape}\") # (1, 16384, 128)\n\nexcept ImportError:\n print(\"SparseSpatialGroupingAttention requires CUDA and natten package\")\n```\n\n### Multi-Scale Processing\n\n```python\n# Process same input at multiple scales efficiently\nmultiscale_attention = DenseSpatialGroupingAttention(\n feature_dims=64,\n spatial_dims=2,\n kernel_size=9,\n stride=4, # 4x downsampling for global context\n num_heads=4,\n iters=3 # Multiple attention iterations\n)\n\n# Input image\nx_input = torch.randn(1, 64*64, 64)\ninput_spacing = (1.0, 1.0) # 1 micron per pixel\ninput_grid_shape = (64, 64)\n\n# Global context via 4x downsampling\nglobal_output = multiscale_attention(\n x=x_input,\n input_spacing=input_spacing,\n input_grid_shape=input_grid_shape\n)\n# Auto-computed: q_spacing = (4.0, 4.0), q_grid_shape = (16, 16)\nprint(f\"Global context: {global_output['x_out'].shape}\") # (1, 256, 64)\n\n# Fine-scale processing with stride=1\nfine_attention = DenseSpatialGroupingAttention(\n feature_dims=64,\n spatial_dims=2,\n kernel_size=5,\n stride=1, # Full resolution\n num_heads=4\n)\n\nfine_output = fine_attention(\n x=x_input,\n input_spacing=input_spacing,\n input_grid_shape=input_grid_shape\n)\n# Auto-computed: q_spacing = (1.0, 1.0), q_grid_shape = (64, 64)\nprint(f\"Fine details: {fine_output['x_out'].shape}\") # (1, 4096, 64)\n```\n\n### Integration with Neural Networks\n\n```python\nclass HierarchicalSpatialNet(torch.nn.Module):\n \"\"\"Multi-scale spatial processing network\"\"\"\n\n def __init__(self, input_channels=3, num_classes=10):\n super().__init__()\n\n # Input embedding\n self.embed = torch.nn.Linear(input_channels, 128)\n\n # Coarse-scale attention (4x downsampling)\n self.coarse_attention = DenseSpatialGroupingAttention(\n feature_dims=128,\n spatial_dims=2,\n kernel_size=7,\n stride=4, # 4x spatial compression\n num_heads=8\n )\n\n # Fine-scale attention (2x downsampling)\n self.fine_attention = DenseSpatialGroupingAttention(\n feature_dims=128,\n spatial_dims=2,\n kernel_size=5,\n stride=2, # 2x spatial compression\n num_heads=8\n )\n\n # Cross-scale fusion\n self.fusion = torch.nn.Linear(256, 128)\n self.classifier = torch.nn.Linear(128, num_classes)\n\n def forward(self, images, pixel_spacing=(1.0, 1.0)):\n B, C, H, W = images.shape\n\n # Flatten and embed\n x = images.permute(0, 2, 3, 1).reshape(B, H*W, C)\n x = self.embed(x)\n\n # Multi-scale attention\n coarse_out = self.coarse_attention(\n x=x,\n input_spacing=pixel_spacing,\n input_grid_shape=(H, W)\n )['x_out'] # (B, H*W/16, 128) - 4x downsampling\n\n fine_out = self.fine_attention(\n x=x,\n input_spacing=pixel_spacing,\n input_grid_shape=(H, W)\n )['x_out'] # (B, H*W/4, 128) - 2x downsampling\n\n # Upsample coarse to match fine resolution for fusion\n coarse_upsampled = torch.nn.functional.interpolate(\n coarse_out.transpose(1, 2).reshape(B, 128, H//4, W//4),\n size=(H//2, W//2),\n mode='bilinear',\n align_corners=False\n ).reshape(B, 128, -1).transpose(1, 2)\n\n # Fuse multi-scale features\n fused = self.fusion(torch.cat([fine_out, coarse_upsampled], dim=-1))\n\n # Global pooling and classification\n global_features = fused.mean(dim=1)\n return self.classifier(global_features)\n\n# Usage example\nnet = HierarchicalSpatialNet(input_channels=3, num_classes=1000)\nsample_images = torch.randn(4, 3, 128, 128) # ImageNet-style input\npixel_spacing = (0.1, 0.1) # 0.1 mm per pixel\n\npredictions = net(sample_images, pixel_spacing)\nprint(f\"Predictions: {predictions.shape}\") # (4, 1000)\n```\n\n### Key Principles\n\n1. **Automatic Grid Calculation**: You only specify input (`input_spacing`, `input_grid_shape`) - the query grid is computed as:\n - `q_spacing = input_spacing * stride`\n - `q_grid_shape = (k_grid + 2*padding - kernel) // stride + 1`\n\n2. **Spatial Grouping**: The `kernel_size` determines how many neighboring points are grouped together for attention computation.\n\n3. **Multi-Scale Processing**: Use different `stride` values to process the same input at multiple spatial scales efficiently.\n\n4. **Memory Efficiency**: Larger strides reduce the number of query points, making attention computation more efficient for large inputs.\n\n## Contributing\n\n1. Fork the repository\n2. Create a feature branch (`git checkout -b feature/amazing-feature`)\n3. Make your changes\n4. Run the test suite (`make test`)\n5. Commit your changes (`git commit -m 'Add some amazing feature'`)\n6. Push to the branch (`git push origin feature/amazing-feature`)\n7. Open a Pull Request\n\n## License\n\nBSD 3-Clause License. See [LICENSE](LICENSE) for details.\n\n## Citation\n\nIf you use this software in your research, please cite it using the information in [CITATION.cff](CITATION.cff).\n",
"bugtrack_url": null,
"license": "BSD 3-Clause License",
"summary": "PyTorch Implementation of Spatial Grouping Attention Layer",
"version": "2025.8.21.1739",
"project_urls": {
"Documentation": "https://github.com/rhoadesScholar/spatial-grouping-attention",
"Homepage": "https://github.com/rhoadesScholar/spatial-grouping-attention",
"Issues": "https://github.com/rhoadesScholar/spatial-grouping-attention/issues",
"Repository": "https://github.com/rhoadesScholar/spatial-grouping-attention"
},
"split_keywords": [
"attention",
" rotary spatial embedding",
" segmentation",
" spatial grouping",
" transformer"
],
"urls": [
{
"comment_text": null,
"digests": {
"blake2b_256": "ce0535d6dc30a4c898c18ca6941bd5f34d10d77193ae911b6111a5915a8adadf",
"md5": "10a915842ae564bdfec8b4680a22a888",
"sha256": "9e7350dd44ba41ade4f2a5997d8f1c8257904890a2c803acccb61be87c5f8ca6"
},
"downloads": -1,
"filename": "spatial_grouping_attention-2025.8.21.1739-py3-none-any.whl",
"has_sig": false,
"md5_digest": "10a915842ae564bdfec8b4680a22a888",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": ">=3.10",
"size": 12352,
"upload_time": "2025-08-21T17:52:27",
"upload_time_iso_8601": "2025-08-21T17:52:27.596068Z",
"url": "https://files.pythonhosted.org/packages/ce/05/35d6dc30a4c898c18ca6941bd5f34d10d77193ae911b6111a5915a8adadf/spatial_grouping_attention-2025.8.21.1739-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": null,
"digests": {
"blake2b_256": "0a15c61ad99b27394642064587c2dc49627a54978997f6a41cb454d46810acc5",
"md5": "3168d99de51a38022f2fa1784116a27f",
"sha256": "9a2d3c75dabf4ff1b95961afb221cf4e3437ec31bc409e60bd3fd9ec334c27a4"
},
"downloads": -1,
"filename": "spatial_grouping_attention-2025.8.21.1739.tar.gz",
"has_sig": false,
"md5_digest": "3168d99de51a38022f2fa1784116a27f",
"packagetype": "sdist",
"python_version": "source",
"requires_python": ">=3.10",
"size": 16481,
"upload_time": "2025-08-21T17:52:29",
"upload_time_iso_8601": "2025-08-21T17:52:29.711305Z",
"url": "https://files.pythonhosted.org/packages/0a/15/c61ad99b27394642064587c2dc49627a54978997f6a41cb454d46810acc5/spatial_grouping_attention-2025.8.21.1739.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2025-08-21 17:52:29",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "rhoadesScholar",
"github_project": "spatial-grouping-attention",
"travis_ci": false,
"coveralls": false,
"github_actions": true,
"lcname": "spatial-grouping-attention"
}