cka-pytorch


Namecka-pytorch JSON
Version 1.1.2 PyPI version JSON
download
home_pageNone
SummaryA PyTorch implementation of Centered Kernel Alignment (CKA) with GPU support.
upload_time2025-09-17 13:38:21
maintainerNone
docs_urlNone
authorNone
requires_python>=3.11
licenseNone
keywords pytorch cka centered kernel alignment
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # Centered Kernel Alignment (CKA) - PyTorch Implementation

A PyTorch implementation of Centered Kernel Alignment (CKA) with GPU support for fast and efficient computation.

> [!WARNING]
> This project is for educational and academic purposes (and for fun 🤷🏻).

## Features

- **GPU Accelerated:** Leverages the power of GPUs for significantly faster CKA calculations compared to NumPy-based implementations.
- **On-the-Fly Calculation:** Computes CKA on-the-fly using mini-batches, avoiding the need to cache large intermediate feature representations.
- **Easy to Use:** Simple and intuitive API for calculating the CKA matrix between two models.
- **Flexible:** Can be used with any PyTorch models and dataloaders.

## Installation
```bash
pip install cka-pytorch
```

## Usage

```python
import torch

from torchvision.models import resnet18
from torch.utils.data import DataLoader

from cka_pytorch.cka import CKACalculator


# 1. Define your models and dataloader
model1 = resnet18(pretrained=True).cuda()
model2 = resnet18(pretrained=True).cuda() # Or a different model

# Create a dummy dataloader for demonstration
dummy_data = torch.randn(100, 3, 224, 224)
dummy_labels = torch.randint(0, 10, (100,))
dummy_dataset = torch.utils.data.TensorDataset(dummy_data, dummy_labels)
dataloader = DataLoader(dummy_dataset, batch_size=32)

# 2. Initialize the CKACalculator
# By default, we will calculate CKA across all layers of the two models
calculator = CKACalculator(
    model1=model1,
    model2=model2,
    model1_name="ResNet18",
    model2_name="ResNet18",
    batched_feature_size=256,
    verbose=True,
)

# 3. Calculate the CKA matrix
cka_matrix = calculator.calculate_cka_matrix(dataloader)

# 4. Plot the CKA Matrix as heatmap
calculator.plot_cka_matrix(title="CKA between ResNet18 and ResNet18")
```

## Contributing

- If you find this repository helpful, please give it a :star:.
- If you encounter any bugs or have suggestions for improvements, feel free to open an issue.
- This implementation has been primarily tested with ResNet architectures.

## Acknowledgement
This project is based on:
- [CKA.pytorch](https://github.com/numpee/CKA.pytorch)
- [centered-kernel-alignment](https://github.com/RistoAle97/centered-kernel-alignment)

            

Raw data

            {
    "_id": null,
    "home_page": null,
    "name": "cka-pytorch",
    "maintainer": null,
    "docs_url": null,
    "requires_python": ">=3.11",
    "maintainer_email": null,
    "keywords": "pytorch, cka, centered kernel alignment",
    "author": null,
    "author_email": "Dat-Thinh Nguyen <datthinh1801@gmail.com>",
    "download_url": "https://files.pythonhosted.org/packages/94/fa/da73b1c93025a88c4779756a2967dad4270284fde1c22679e978cb5bf34b/cka_pytorch-1.1.2.tar.gz",
    "platform": null,
    "description": "# Centered Kernel Alignment (CKA) - PyTorch Implementation\n\nA PyTorch implementation of Centered Kernel Alignment (CKA) with GPU support for fast and efficient computation.\n\n> [!WARNING]\n> This project is for educational and academic purposes (and for fun \ud83e\udd37\ud83c\udffb).\n\n## Features\n\n- **GPU Accelerated:** Leverages the power of GPUs for significantly faster CKA calculations compared to NumPy-based implementations.\n- **On-the-Fly Calculation:** Computes CKA on-the-fly using mini-batches, avoiding the need to cache large intermediate feature representations.\n- **Easy to Use:** Simple and intuitive API for calculating the CKA matrix between two models.\n- **Flexible:** Can be used with any PyTorch models and dataloaders.\n\n## Installation\n```bash\npip install cka-pytorch\n```\n\n## Usage\n\n```python\nimport torch\n\nfrom torchvision.models import resnet18\nfrom torch.utils.data import DataLoader\n\nfrom cka_pytorch.cka import CKACalculator\n\n\n# 1. Define your models and dataloader\nmodel1 = resnet18(pretrained=True).cuda()\nmodel2 = resnet18(pretrained=True).cuda() # Or a different model\n\n# Create a dummy dataloader for demonstration\ndummy_data = torch.randn(100, 3, 224, 224)\ndummy_labels = torch.randint(0, 10, (100,))\ndummy_dataset = torch.utils.data.TensorDataset(dummy_data, dummy_labels)\ndataloader = DataLoader(dummy_dataset, batch_size=32)\n\n# 2. Initialize the CKACalculator\n# By default, we will calculate CKA across all layers of the two models\ncalculator = CKACalculator(\n    model1=model1,\n    model2=model2,\n    model1_name=\"ResNet18\",\n    model2_name=\"ResNet18\",\n    batched_feature_size=256,\n    verbose=True,\n)\n\n# 3. Calculate the CKA matrix\ncka_matrix = calculator.calculate_cka_matrix(dataloader)\n\n# 4. Plot the CKA Matrix as heatmap\ncalculator.plot_cka_matrix(title=\"CKA between ResNet18 and ResNet18\")\n```\n\n## Contributing\n\n- If you find this repository helpful, please give it a :star:.\n- If you encounter any bugs or have suggestions for improvements, feel free to open an issue.\n- This implementation has been primarily tested with ResNet architectures.\n\n## Acknowledgement\nThis project is based on:\n- [CKA.pytorch](https://github.com/numpee/CKA.pytorch)\n- [centered-kernel-alignment](https://github.com/RistoAle97/centered-kernel-alignment)\n",
    "bugtrack_url": null,
    "license": null,
    "summary": "A PyTorch implementation of Centered Kernel Alignment (CKA) with GPU support.",
    "version": "1.1.2",
    "project_urls": {
        "Homepage": "https://github.com/datthinh1801/CKA.pytorch",
        "Issues": "https://github.com/datthinh1801/CKA.pytorch/issues",
        "Repository": "https://github.com/datthinh1801/CKA.pytorch.git"
    },
    "split_keywords": [
        "pytorch",
        " cka",
        " centered kernel alignment"
    ],
    "urls": [
        {
            "comment_text": null,
            "digests": {
                "blake2b_256": "c50fcabf2db9a66e66ffffe050004bd49c450d3fb7f2e82462a5481a211e5f4b",
                "md5": "51bf039e2cc05ef8e5d48ab948dc6ae8",
                "sha256": "d2ad1fb911ae293bf9279e153417e9c7b9846e09376d38b7139a58566b284b19"
            },
            "downloads": -1,
            "filename": "cka_pytorch-1.1.2-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "51bf039e2cc05ef8e5d48ab948dc6ae8",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.11",
            "size": 18206,
            "upload_time": "2025-09-17T13:38:19",
            "upload_time_iso_8601": "2025-09-17T13:38:19.752372Z",
            "url": "https://files.pythonhosted.org/packages/c5/0f/cabf2db9a66e66ffffe050004bd49c450d3fb7f2e82462a5481a211e5f4b/cka_pytorch-1.1.2-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": null,
            "digests": {
                "blake2b_256": "94fada73b1c93025a88c4779756a2967dad4270284fde1c22679e978cb5bf34b",
                "md5": "329f6726861210ff7073498a6399771b",
                "sha256": "87e267b6d0fb417db3fb6261f0cfafe3cb7c75b370fecac84154efd4931e9af6"
            },
            "downloads": -1,
            "filename": "cka_pytorch-1.1.2.tar.gz",
            "has_sig": false,
            "md5_digest": "329f6726861210ff7073498a6399771b",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.11",
            "size": 16873,
            "upload_time": "2025-09-17T13:38:21",
            "upload_time_iso_8601": "2025-09-17T13:38:21.024242Z",
            "url": "https://files.pythonhosted.org/packages/94/fa/da73b1c93025a88c4779756a2967dad4270284fde1c22679e978cb5bf34b/cka_pytorch-1.1.2.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2025-09-17 13:38:21",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "datthinh1801",
    "github_project": "CKA.pytorch",
    "github_not_found": true,
    "lcname": "cka-pytorch"
}
        
Elapsed time: 5.00646s