# simtorch
[![Downloads](https://pepy.tech/badge/simtorch)](https://pepy.tech/project/simtorch)
A Pytorch library to measure the similarity between two neural network representations. The library currently supports the following (dis)similarity measures:
* Centered Kernel Alignment (CKA) - [Kornblith, et al, ICML 2019](http://proceedings.mlr.press/v97/kornblith19a.html)
* Deconfounded CKA - [Cui, et al, NeurIPS 2022](https://openreview.net/pdf?id=mMdRZipvld2)
* Procrustes [WIP]
* CCA [WIP]
## Design
The package consists of two components -
* `SimilarityModel` - which is a thin wrapper on `torch.nn.Module()` which adds forwards hooks to store the layer-wise activations (aka representations) in a dictionary.
* `BaseSimilarity` - which sets the interface for classes that compute similarity between network representations
## Installation
The package is indexed by pypi
```
pip install simtorch
```
## Usage
The torch model objects need to be wrapped with `SimilarityModel`. A list of names of the layers we wish to compute the representations is passed as an attribute to this class.
```
model1 = torchvision.models.densenet121()
model2 = torchvision.models.resnet101()
sim_model1 = SimilarityModel(
model1,
model_name="DenseNet 121",
layers_to_include=["conv", "classifier",]
)
sim_model2 = SimilarityModel(
model2,
model_name="ResNet 101",
layers_to_include=["conv", "fc",]
)
```
An instance of a similarity metric can then be initialized with these `SimilarityModel`s. The `compute()` method can be used to obtain a similarity matrix $S$ for these two models where $S[i, j]$ is the similarity metric for the $i^{th}$ layer of the first model and the $j^{th}$ layer of the second model.
```
sim_cka = CKA(sim_model1, sim_model2, device="cuda")
cka_matrix = sim_cka.compute(torch_dataloader)
```
The similarity matrix can be visualized using the `sim_cka.plot_similarity()` method to obtain the CKA similarity plot
<img title="Centered Kernel Alignment Matrix" alt="Centered Kernel Alignment Matrix" src="assets/img/cka_dense121_res101.png">
## Citations
If you use Deconfounded Centered Kernel Alignment (dCKA) for your research, please cite:
```
@article{cui2022deconfounded,
title={Deconfounded Representation Similarity for Comparison of Neural Networks},
author={Cui, Tianyu and Kumar, Yogesh and Marttinen, Pekka and Kaski, Samuel},
journal={Neural Information Processing Systems (NeurIPS)},
year={2022}
}
```
## Credits
This has been built by using the following awesome repos as reference:
* [anatome](https://github.com/moskomule/anatome), maintained by [@moskomule](https://github.com/moskomule)
* [Pytorch-Model-Compare](https://github.com/AntixK/PyTorch-Model-Compare), maintained by [@AntixK](https://github.com/AntixK)
* [centered-kernel-alignment](https://github.com/Kennethborup/centered_kernel_alignment), maintained by [@Kennethborup](https://github.com/Kennethborup)
Raw data
{
"_id": null,
"home_page": "https://github.com/ykumards/simtorch",
"name": "simtorch",
"maintainer": "",
"docs_url": null,
"requires_python": ">=3.7",
"maintainer_email": "",
"keywords": "",
"author": "Yogesh Kumar",
"author_email": "ykumards@gmail.com",
"download_url": "https://files.pythonhosted.org/packages/47/2a/d512a9cd232b4d417a5ef61e4503035a9c3a49446687e0e42063c2a00f5a/simtorch-0.2.1.tar.gz",
"platform": null,
"description": "# simtorch\n\n[![Downloads](https://pepy.tech/badge/simtorch)](https://pepy.tech/project/simtorch)\n\nA Pytorch library to measure the similarity between two neural network representations. The library currently supports the following (dis)similarity measures:\n\n* Centered Kernel Alignment (CKA) - [Kornblith, et al, ICML 2019](http://proceedings.mlr.press/v97/kornblith19a.html)\n* Deconfounded CKA - [Cui, et al, NeurIPS 2022](https://openreview.net/pdf?id=mMdRZipvld2)\n* Procrustes [WIP]\n* CCA [WIP]\n\n\n## Design\n\nThe package consists of two components -\n\n* `SimilarityModel` - which is a thin wrapper on `torch.nn.Module()` which adds forwards hooks to store the layer-wise activations (aka representations) in a dictionary.\n* `BaseSimilarity` - which sets the interface for classes that compute similarity between network representations\n\n## Installation\n\nThe package is indexed by pypi\n\n```\npip install simtorch\n```\n\n## Usage\n\nThe torch model objects need to be wrapped with `SimilarityModel`. A list of names of the layers we wish to compute the representations is passed as an attribute to this class.\n\n```\nmodel1 = torchvision.models.densenet121()\nmodel2 = torchvision.models.resnet101()\n\nsim_model1 = SimilarityModel(\n model1,\n model_name=\"DenseNet 121\",\n layers_to_include=[\"conv\", \"classifier\",]\n)\n\nsim_model2 = SimilarityModel(\n model2,\n model_name=\"ResNet 101\",\n layers_to_include=[\"conv\", \"fc\",]\n)\n```\n\nAn instance of a similarity metric can then be initialized with these `SimilarityModel`s. The `compute()` method can be used to obtain a similarity matrix $S$ for these two models where $S[i, j]$ is the similarity metric for the $i^{th}$ layer of the first model and the $j^{th}$ layer of the second model.\n\n```\nsim_cka = CKA(sim_model1, sim_model2, device=\"cuda\")\ncka_matrix = sim_cka.compute(torch_dataloader)\n```\n\nThe similarity matrix can be visualized using the `sim_cka.plot_similarity()` method to obtain the CKA similarity plot\n\n<img title=\"Centered Kernel Alignment Matrix\" alt=\"Centered Kernel Alignment Matrix\" src=\"assets/img/cka_dense121_res101.png\">\n\n\n## Citations\n\nIf you use Deconfounded Centered Kernel Alignment (dCKA) for your research, please cite:\n\n```\n@article{cui2022deconfounded,\n title={Deconfounded Representation Similarity for Comparison of Neural Networks},\n author={Cui, Tianyu and Kumar, Yogesh and Marttinen, Pekka and Kaski, Samuel},\n journal={Neural Information Processing Systems (NeurIPS)},\n year={2022}\n}\n```\n\n## Credits\n\nThis has been built by using the following awesome repos as reference:\n\n* [anatome](https://github.com/moskomule/anatome), maintained by [@moskomule](https://github.com/moskomule)\n* [Pytorch-Model-Compare](https://github.com/AntixK/PyTorch-Model-Compare), maintained by [@AntixK](https://github.com/AntixK)\n* [centered-kernel-alignment](https://github.com/Kennethborup/centered_kernel_alignment), maintained by [@Kennethborup](https://github.com/Kennethborup)\n\n",
"bugtrack_url": null,
"license": "Apache",
"summary": "",
"version": "0.2.1",
"split_keywords": [],
"urls": [
{
"comment_text": "",
"digests": {
"md5": "38eb57fd38fd56fb703f3a19094e7b9b",
"sha256": "ea870e1fea7ee23c0f86503c985c07417f94ff18074da2e56d7b65b2a50b94a1"
},
"downloads": -1,
"filename": "simtorch-0.2.1-py3-none-any.whl",
"has_sig": false,
"md5_digest": "38eb57fd38fd56fb703f3a19094e7b9b",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": ">=3.7",
"size": 13382,
"upload_time": "2022-12-04T11:34:17",
"upload_time_iso_8601": "2022-12-04T11:34:17.381182Z",
"url": "https://files.pythonhosted.org/packages/db/25/6d02095bcdd1b3483b8e9053114bbdd02f4e46ed40c5acb0e20408ff8db7/simtorch-0.2.1-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": "",
"digests": {
"md5": "179e4f0e0d198e3bd272b0ad435c53fc",
"sha256": "bc3846313d888e256ca78e166c602370d2648d1294acb881bf8736a1281c58c9"
},
"downloads": -1,
"filename": "simtorch-0.2.1.tar.gz",
"has_sig": false,
"md5_digest": "179e4f0e0d198e3bd272b0ad435c53fc",
"packagetype": "sdist",
"python_version": "source",
"requires_python": ">=3.7",
"size": 12411,
"upload_time": "2022-12-04T11:34:18",
"upload_time_iso_8601": "2022-12-04T11:34:18.738186Z",
"url": "https://files.pythonhosted.org/packages/47/2a/d512a9cd232b4d417a5ef61e4503035a9c3a49446687e0e42063c2a00f5a/simtorch-0.2.1.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2022-12-04 11:34:18",
"github": true,
"gitlab": false,
"bitbucket": false,
"github_user": "ykumards",
"github_project": "simtorch",
"travis_ci": false,
"coveralls": false,
"github_actions": true,
"requirements": [
{
"name": "torch",
"specs": [
[
">=",
"1.13"
]
]
},
{
"name": "torchvision",
"specs": []
},
{
"name": "matplotlib",
"specs": []
},
{
"name": "seaborn",
"specs": []
},
{
"name": "tqdm",
"specs": []
}
],
"lcname": "simtorch"
}