# rebasin
![PyPI Version](https://img.shields.io/pypi/v/rebasin)
![Wheel](https://img.shields.io/pypi/wheel/rebasin)
[![Python 3.8+](https://img.shields.io/badge/Python-3.8+-blue.svg)](https://www.python.org/downloads/release/python-370/)
![License](https://img.shields.io/github/license/snimu/rebasin)
An implementation of methods described in
["Git Re-basin"-paper by Ainsworth et al.](https://arxiv.org/abs/2209.04836)
Can be applied to **arbitrary models**, without modification.
(Well, *almost* arbitrary models, see [Limitations](#limitations)).
---
**Table of Contents**
- [Installation](#installation)
- [Usage](#usage)
- [Limitations](#limitations)
- [Results](#results)
- [Acknowledgements](#acknowledgements)
## Installation
Requirements should be automatically installed, but one of them is graphviz,
which you might have to install per apt / brew / ... on your device.
The following install instructions are taken directly from
[torchview's installation instructions](https://github.com/mert-kurttutan/torchview#installation).
Debian-based Linux distro (e.g. Ubuntu):
```Bash
apt-get install graphviz
```
Windows:
```Bash
choco install graphviz
```
macOS
```Bash
brew install graphviz
```
see more details [here](https://graphviz.readthedocs.io/en/stable/manual.html).
Then, install rebasin via pip:
```bash
pip install rebasin
```
## Usage
Currently, only weight-matching is implemented as a method for rebasing,
and only a simplified form of linear interpolation is implemented.
The following is a minimal example. For now, the documentation lives in the docstrings,
though I intend to create a proper one.
`PermutationCoordinateDescent` and `interpolation.LerpSimple`
are the main classes, beside `MergeMany` (see below).
```python
from rebasin import PermutationCoordinateDescent
from rebasin import interpolation
model_a, model_b, train_dl= ...
input_data = next(iter(train_dl))[0]
# Rebasin
pcd = PermutationCoordinateDescent(model_a, model_b, input_data) # weight-matching
pcd.rebasin() # Rebasin model_b towards model_a. Automatically updates model_b
# Interpolate
lerp = interpolation.LerpSimple(
models=[model_a, model_b],
devices=["cuda:0", "cuda:1"], # Optional, defaults to cpu
device_interp="cuda:2", # Optional, defaults to cpu
savedir="/path/to/save/interpolation" # Optional, save all interpolated models
)
lerp.interpolate(steps=99) # Interpolate 99 models between model_a and model_b
```
The `MergeMany`-algorithm is also implemented
(though there will be interface-changes regarding the devices in the future):
```python
from rebasin import MergeMany
from torch import nn
class ExampleModel(nn.Module):
...
model_a, model_b, model_c = ExampleModel(), ExampleModel(), ExampleModel()
train_dl = ...
# Merge
merge = MergeMany(
models=[model_a, model_b, model_c],
working_model=ExampleModel(),
input_data=next(iter(train_dl))[0],
)
merged_model = merge.run()
# The merged model is also accessible through merge.working_model,
# but only after merge.run() has been called.
```
## Terminology
In this document, I will use the following terminology:
- **To rebasin**: To apply one of the methods described in the paper to a model,
permuting the rows and columns of its weights (and biases)
- `model_a`: The model that stays unchanged
- `model_b`: The model that is changed by rebasin it towards `model_a`
- `model_b (original)` for the unchanged, original `model_b`
- `model_b (rebasin)` for the changed, rebasined `model_b`
- **Path**: A sequence of modules in a model
## Limitations
### Only some methods are implemented
For rebasin, only weight-matching is implemented via `rebasin.PermutationCoordinateDescent`.
For interpolation, only a simplified method of linear interpolation is implemented
via `rebasin.interpolation.LerpSimple`.
### Limitations of the `PermutationCoordinateDescent`-class
The `PermutationCoordinateDescent`-class only permutes some Modules.
Most modules should work, but others may behave unexpectedly. In this case,
you need to add the module to [rebasin/modules.py](rebasin/modules.py);
make sure it is included in the `initialize_module`-function
(preferably by putting it into the `SPECIAL_MODULES`-dict).
Additionally, the `PermutationCoordinateDescent`-class only works with
`nn.Module`s, not functions. There is a requirement to have the permuted model
produce the same output as the un-permuted `Module`, which is a pretty
tight constraint. In some models, it isn't a problem at all, but especially in
models with lots of short residual blocks, it may (but doesn't have to) be a problem.
Where it is a problem, few to no parameters get permuted, which defeats the purpose of rebasin.
For example, @tysam-code's [hlb-gpt](https://github.com/tysam-code/hlb-gpt), a small but fast
language model implementation, isn't permuted at all.
Vision transformers like `torchvision.models.vit_b_16` have only very few permutations
applied to them. In general, **transformer models don't work well**, because they
reshape the input-tensor, and directly follow that up with residual blocks.
This means that almost nothing of the model can be permuted
(a single Linear layer between the reshaping and the first residual block would fix that,
but this isn't usually done...).
On the other hand, **CNNs usually work very well**.
If you are unsure, you can always print the model-graph! To do so, write:
```python
from rebasin import PermutationCoordinateDescent
pcd = PermutationCoordinateDescent(...)
print(pcd.pinit.model_graph) # pinit stands for "PermutationInitialization"
```
## Results
For the full results, see [rebasin-results](https://github.com/snimu/rebasin-results)
(I don't want to upload a bunch of images to this repo, so the results are in their own repo).
The clearest results were produces on [hlb-CIFAR10](https://github.com/tysam-code/hlb-CIFAR10).
For results on that model, see
[here](https://github.com/snimu/rebasin-results/blob/main/hlb-CIFAR10/RESULTS.md).
Here is a little taste of the results for that model:
<p align="center">
<img
src="https://github.com/snimu/rebasin-results/blob/main/hlb-CIFAR10/3x3-plot.png"
alt="hlb-CIFAR10: losses and accuracies of the model"
width="600"
/>
</p>
While `PermutationCoordinateDescent` doesn't fully eliminate the loss-barrier,
it does reduce it significantly, and, surprisingly, even moreso for the accuracy-barrier.
You can also find results for the `MergeMany`-algorithm there.
## Acknowledgements
**Git Re-Basin:**
```
Ainsworth, Samuel K., Jonathan Hayase, and Siddhartha Srinivasa.
"Git re-basin: Merging models modulo permutation symmetries."
arXiv preprint arXiv:2209.04836 (2022).
```
Link: https://arxiv.org/abs/2209.04836 (accessed on April 9th, 2023)
**ImageNet:**
I've used the ImageNet Data from the 2012 ILSVRC competition to evaluate
the algorithms from rebasin on the `torchvision.models`.
```
Olga Russakovsky*, Jia Deng*, Hao Su, Jonathan Krause, Sanjeev Satheesh,
Sean Ma, Zhiheng Huang, Andrej Karpathy, Aditya Khosla, Michael Bernstein,
Alexander C. Berg and Li Fei-Fei. (* = equal contribution)
ImageNet Large Scale Visual Recognition Challenge. arXiv:1409.0575, 2014
```
[Paper (link)](https://arxiv.org/abs/1409.0575) (Accessed on April 12th, 2023)
**Torchvision models**
For testing, I've used the torchvision models (v.015), of course (or I will):
https://pytorch.org/vision/0.15/models.html
**HLB-CIFAR10**
For testing, I forked [hlb-CIFAR10](https://github.com/tysam-code/hlb-CIFAR10)
by [@tysam-code](https://github.com/tysam-code):
authors:
- family-names: "Balsam"
given-names: "Tysam&"
title: "hlb-CIFAR10"
version: 0.4.0
date-released: 2023-02-12
url: "https://github.com/tysam-code/hlb-CIFAR10"
**HLB-GPT**
For testing, I also used [hlb-gpt](https://github.com/tysam-code/hlb-gpt) by @tysam-code:
authors:
- family-names: "Balsam"
given-names: "Tysam&"
title: "hlb-gpt"
version: 0.0.0
date-released: 2023-03-05
url: "https://github.com/tysam-code/hlb-gpt"
**Other**
My code took inspiration from the following sources:
- https://github.com/themrzmaster/git-re-basin-pytorch
I used the amazing library `torchview` to visualize the models:
- https://github.com/mert-kurttutan/torchview
Raw data
{
"_id": null,
"home_page": "https://github.com/snimu/rebasin",
"name": "rebasin",
"maintainer": "",
"docs_url": null,
"requires_python": ">=3.8",
"maintainer_email": "",
"keywords": "torch git-rebasin rebasin interpolate interpolation model deep-learning",
"author": "Sebastian M\u00fcller @snimu",
"author_email": "sebastian.nicolas.mueller@gmail.com",
"download_url": "https://files.pythonhosted.org/packages/35/db/1aa665d9ba9ab5b5fbd2e03069a25d2aa715ee54e282613b1dff989e70c6/rebasin-0.0.47.tar.gz",
"platform": null,
"description": "# rebasin\n\n![PyPI Version](https://img.shields.io/pypi/v/rebasin)\n![Wheel](https://img.shields.io/pypi/wheel/rebasin)\n[![Python 3.8+](https://img.shields.io/badge/Python-3.8+-blue.svg)](https://www.python.org/downloads/release/python-370/)\n![License](https://img.shields.io/github/license/snimu/rebasin)\n\nAn implementation of methods described in \n[\"Git Re-basin\"-paper by Ainsworth et al.](https://arxiv.org/abs/2209.04836)\n\nCan be applied to **arbitrary models**, without modification. \n\n(Well, *almost* arbitrary models, see [Limitations](#limitations)).\n\n---\n\n**Table of Contents**\n\n- [Installation](#installation)\n- [Usage](#usage)\n- [Limitations](#limitations)\n- [Results](#results)\n- [Acknowledgements](#acknowledgements)\n\n## Installation\n\nRequirements should be automatically installed, but one of them is graphviz, \nwhich you might have to install per apt / brew / ... on your device.\n\nThe following install instructions are taken directly from \n[torchview's installation instructions](https://github.com/mert-kurttutan/torchview#installation).\n\nDebian-based Linux distro (e.g. Ubuntu):\n\n```Bash\napt-get install graphviz\n```\n\nWindows:\n\n```Bash\nchoco install graphviz\n```\n\nmacOS\n\n```Bash\nbrew install graphviz\n```\n\nsee more details [here](https://graphviz.readthedocs.io/en/stable/manual.html).\n\n\nThen, install rebasin via pip:\n\n```bash\npip install rebasin\n```\n\n## Usage\n\nCurrently, only weight-matching is implemented as a method for rebasing, \nand only a simplified form of linear interpolation is implemented.\n\nThe following is a minimal example. For now, the documentation lives in the docstrings,\nthough I intend to create a proper one. \n`PermutationCoordinateDescent` and `interpolation.LerpSimple`\nare the main classes, beside `MergeMany` (see below).\n\n```python\nfrom rebasin import PermutationCoordinateDescent\nfrom rebasin import interpolation\n\nmodel_a, model_b, train_dl= ...\ninput_data = next(iter(train_dl))[0]\n\n# Rebasin\npcd = PermutationCoordinateDescent(model_a, model_b, input_data) # weight-matching\npcd.rebasin() # Rebasin model_b towards model_a. Automatically updates model_b\n\n# Interpolate\nlerp = interpolation.LerpSimple(\n models=[model_a, model_b],\n devices=[\"cuda:0\", \"cuda:1\"], # Optional, defaults to cpu\n device_interp=\"cuda:2\", # Optional, defaults to cpu\n savedir=\"/path/to/save/interpolation\" # Optional, save all interpolated models\n)\nlerp.interpolate(steps=99) # Interpolate 99 models between model_a and model_b\n```\n\nThe `MergeMany`-algorithm is also implemented \n(though there will be interface-changes regarding the devices in the future):\n\n```python\nfrom rebasin import MergeMany\nfrom torch import nn\n\nclass ExampleModel(nn.Module):\n ...\n\nmodel_a, model_b, model_c = ExampleModel(), ExampleModel(), ExampleModel()\ntrain_dl = ...\n\n# Merge\nmerge = MergeMany(\n models=[model_a, model_b, model_c],\n working_model=ExampleModel(),\n input_data=next(iter(train_dl))[0],\n)\nmerged_model = merge.run()\n# The merged model is also accessible through merge.working_model,\n# but only after merge.run() has been called.\n```\n\n## Terminology\n\nIn this document, I will use the following terminology:\n\n- **To rebasin**: To apply one of the methods described in the paper to a model,\n permuting the rows and columns of its weights (and biases)\n- `model_a`: The model that stays unchanged\n- `model_b`: The model that is changed by rebasin it towards `model_a`\n - `model_b (original)` for the unchanged, original `model_b`\n - `model_b (rebasin)` for the changed, rebasined `model_b`\n- **Path**: A sequence of modules in a model\n\n## Limitations\n\n### Only some methods are implemented\n\nFor rebasin, only weight-matching is implemented via `rebasin.PermutationCoordinateDescent`.\n\nFor interpolation, only a simplified method of linear interpolation is implemented \nvia `rebasin.interpolation.LerpSimple`.\n\n### Limitations of the `PermutationCoordinateDescent`-class\n\nThe `PermutationCoordinateDescent`-class only permutes some Modules. \nMost modules should work, but others may behave unexpectedly. In this case, \nyou need to add the module to [rebasin/modules.py](rebasin/modules.py);\nmake sure it is included in the `initialize_module`-function \n(preferably by putting it into the `SPECIAL_MODULES`-dict).\n\nAdditionally, the `PermutationCoordinateDescent`-class only works with\n`nn.Module`s, not functions. There is a requirement to have the permuted model\nproduce the same output as the un-permuted `Module`, which is a pretty \ntight constraint. In some models, it isn't a problem at all, but especially in \nmodels with lots of short residual blocks, it may (but doesn't have to) be a problem.\nWhere it is a problem, few to no parameters get permuted, which defeats the purpose of rebasin.\n\nFor example, @tysam-code's [hlb-gpt](https://github.com/tysam-code/hlb-gpt), a small but fast\nlanguage model implementation, isn't permuted at all. \nVision transformers like `torchvision.models.vit_b_16` have only very few permutations\napplied to them. In general, **transformer models don't work well**, because they \nreshape the input-tensor, and directly follow that up with residual blocks. \nThis means that almost nothing of the model can be permuted \n(a single Linear layer between the reshaping and the first residual block would fix that,\nbut this isn't usually done...).\n\nOn the other hand, **CNNs usually work very well**.\n\nIf you are unsure, you can always print the model-graph! To do so, write:\n\n```python\nfrom rebasin import PermutationCoordinateDescent\n\n\npcd = PermutationCoordinateDescent(...)\nprint(pcd.pinit.model_graph) # pinit stands for \"PermutationInitialization\"\n```\n\n## Results\n\nFor the full results, see [rebasin-results](https://github.com/snimu/rebasin-results)\n(I don't want to upload a bunch of images to this repo, so the results are in their own repo).\n\nThe clearest results were produces on [hlb-CIFAR10](https://github.com/tysam-code/hlb-CIFAR10).\nFor results on that model, see \n[here](https://github.com/snimu/rebasin-results/blob/main/hlb-CIFAR10/RESULTS.md).\n\nHere is a little taste of the results for that model:\n\n<p align=\"center\">\n <img\n src=\"https://github.com/snimu/rebasin-results/blob/main/hlb-CIFAR10/3x3-plot.png\"\n alt=\"hlb-CIFAR10: losses and accuracies of the model\"\n width=\"600\"\n />\n</p>\n\nWhile `PermutationCoordinateDescent` doesn't fully eliminate the loss-barrier, \nit does reduce it significantly, and, surprisingly, even moreso for the accuracy-barrier.\n\nYou can also find results for the `MergeMany`-algorithm there.\n\n## Acknowledgements\n\n**Git Re-Basin:**\n\n```\nAinsworth, Samuel K., Jonathan Hayase, and Siddhartha Srinivasa. \n\"Git re-basin: Merging models modulo permutation symmetries.\" \narXiv preprint arXiv:2209.04836 (2022).\n```\n\nLink: https://arxiv.org/abs/2209.04836 (accessed on April 9th, 2023)\n\n\n**ImageNet:**\n\nI've used the ImageNet Data from the 2012 ILSVRC competition to evaluate \nthe algorithms from rebasin on the `torchvision.models`.\n\n```\nOlga Russakovsky*, Jia Deng*, Hao Su, Jonathan Krause, Sanjeev Satheesh, \nSean Ma, Zhiheng Huang, Andrej Karpathy, Aditya Khosla, Michael Bernstein, \nAlexander C. Berg and Li Fei-Fei. (* = equal contribution) \nImageNet Large Scale Visual Recognition Challenge. arXiv:1409.0575, 2014\n```\n\n[Paper (link)](https://arxiv.org/abs/1409.0575) (Accessed on April 12th, 2023)\n\n\n**Torchvision models**\n\nFor testing, I've used the torchvision models (v.015), of course (or I will): \n\nhttps://pytorch.org/vision/0.15/models.html\n\n**HLB-CIFAR10**\nFor testing, I forked [hlb-CIFAR10](https://github.com/tysam-code/hlb-CIFAR10) \nby [@tysam-code](https://github.com/tysam-code):\n\n authors:\n - family-names: \"Balsam\"\n given-names: \"Tysam&\"\n title: \"hlb-CIFAR10\"\n version: 0.4.0\n date-released: 2023-02-12\n url: \"https://github.com/tysam-code/hlb-CIFAR10\"\n\n**HLB-GPT**\nFor testing, I also used [hlb-gpt](https://github.com/tysam-code/hlb-gpt) by @tysam-code: \n\n authors:\n - family-names: \"Balsam\"\n given-names: \"Tysam&\"\n title: \"hlb-gpt\"\n version: 0.0.0\n date-released: 2023-03-05\n url: \"https://github.com/tysam-code/hlb-gpt\"\n\n\n**Other**\n\nMy code took inspiration from the following sources:\n\n- https://github.com/themrzmaster/git-re-basin-pytorch\n\nI used the amazing library `torchview` to visualize the models:\n\n- https://github.com/mert-kurttutan/torchview\n\n",
"bugtrack_url": null,
"license": "MIT",
"summary": "An implementation of methods described in Git Re-basin-paper by Ainsworth et al.",
"version": "0.0.47",
"project_urls": {
"Homepage": "https://github.com/snimu/rebasin"
},
"split_keywords": [
"torch",
"git-rebasin",
"rebasin",
"interpolate",
"interpolation",
"model",
"deep-learning"
],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "20d12d4f5600f41f655efc13f623a990f78c263cb049df1fa494ef5522558078",
"md5": "3aa49d04472cc6e469b6a2e1fce95c92",
"sha256": "ca1c5b8439acd200c310291e29d6e512edb5d4ff5d125151510b259ee5b9de52"
},
"downloads": -1,
"filename": "rebasin-0.0.47-py3-none-any.whl",
"has_sig": false,
"md5_digest": "3aa49d04472cc6e469b6a2e1fce95c92",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": ">=3.8",
"size": 30255,
"upload_time": "2023-06-22T14:17:40",
"upload_time_iso_8601": "2023-06-22T14:17:40.441270Z",
"url": "https://files.pythonhosted.org/packages/20/d1/2d4f5600f41f655efc13f623a990f78c263cb049df1fa494ef5522558078/rebasin-0.0.47-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": "",
"digests": {
"blake2b_256": "35db1aa665d9ba9ab5b5fbd2e03069a25d2aa715ee54e282613b1dff989e70c6",
"md5": "26f3f8484bcd39a47fd9b9f935c5a6a6",
"sha256": "3bce8ce07db0758c51b7f6b5f21213aff34e72c7e1f0e7c4ceba66a1fa8b7761"
},
"downloads": -1,
"filename": "rebasin-0.0.47.tar.gz",
"has_sig": false,
"md5_digest": "26f3f8484bcd39a47fd9b9f935c5a6a6",
"packagetype": "sdist",
"python_version": "source",
"requires_python": ">=3.8",
"size": 41693,
"upload_time": "2023-06-22T14:17:42",
"upload_time_iso_8601": "2023-06-22T14:17:42.136168Z",
"url": "https://files.pythonhosted.org/packages/35/db/1aa665d9ba9ab5b5fbd2e03069a25d2aa715ee54e282613b1dff989e70c6/rebasin-0.0.47.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2023-06-22 14:17:42",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "snimu",
"github_project": "rebasin",
"travis_ci": false,
"coveralls": false,
"github_actions": false,
"requirements": [],
"lcname": "rebasin"
}