mrl-pytorch


Namemrl-pytorch JSON
Version 0.1.1 PyPI version JSON
download
home_pagehttps://github.com/filipbasara0/matryoshka-representation-learning
SummaryAn unofficial implementation of Matryoshka Representation Learning for contrastive self-supervised learning
upload_time2024-02-06 21:37:21
maintainer
docs_urlNone
authorFilip Basara
requires_python
licenseMIT
keywords machine learning pytorch self-supervised learning representation learning contrastive learning
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # Matryoshka Representation Learning

An unofficial PyTorch implementation of [Matryoshka Representation Learning](https://arxiv.org/abs/2205.13147) for contrastive self-supervised learning, specifically the [ReLIC](https://arxiv.org/abs/2010.07922) method. MRL encodes information at different granularities to learn flexible representations (single feature vector / embedding) of different dimensions that can be adapted to multiple downstream tasks. MRL can be easily used with other tasks and modalities such as classification, retrieval or language modeling. For example, ResNet50 returns a 2048 feature vector, where we can use the subset of that vector (eg. 64) for retrieval and a larger subset of the original vector (eg. 1024) for reranking. This can substantially reduce the computational resources.

The implementation is minimal and easily extendable with custom datasets. It shows that MRL blends very well with the ReLIC framework and is capable of learning very good representations. This repo doesn't depend on a specific self-supervised approach and can be easily extended to approaches as [BYOL](https://arxiv.org/abs/2006.07733) or [SimCLR](https://arxiv.org/abs/2002.05709).

# Results

Models are pretrained on training subsets - for `CIFAR10` 50,000 and for `STL10` 100,000 images. For evaluation, I trained and tested LogisticRegression on frozen features from:
1. `CIFAR10` - 50,000 train images on ReLIC
2. `STL10` - features were learned on 100k unlabeled images. LogReg was trained on 5k train images and evaluated on 8k test images.

Linear probing was used for evaluating on features extracted from encoders using the scikit LogisticRegression model. The table below shows training configurations and results when using the full dimension. Plots below show results accross dimensions.

More detailed evaluation steps and results for [CIFAR10](https://github.com/filipbasara0/matryoshka-representation-learning/blob/main/notebooks/linear-probing-cifar.ipynb) and [STL10](https://github.com/filipbasara0/matryoshka-representation-learning/blob/main/notebooks/linear-probing-stl.ipynb) can be found in the notebooks directory. 

| Evaulation model    | Dataset | Feature Extractor| Encoder   | Feature dim | Projection Head dim | Epochs | Top1 % |
|---------------------|---------|------------------|-----------|-------------|---------------------|--------|--------|
| LogisticRegression  | CIFAR10 | ReLIC            | ResNet-18 | 512         | 64                  | 400    | 84.19  |
| LogisticRegression  | STL10   | ReLIC            | ResNet-18 | 512         | 64                  | 400    | 81.55  |
| LogisticRegression  | STL10   | ReLIC            | ResNet-50 | 2048        | 64                  | 100    | 77.10  |

Below is the performance accross dimension for the ResNet18 model on the CIFAR10 dataset:

![image](https://github.com/filipbasara0/matryoshka-representation-learning/assets/29043871/0a25f5f4-b474-48e1-8314-6eafa90a942a)

Below is the performance accross dimension for the ResNet18 model on the STL10 dataset:

![image](https://github.com/filipbasara0/matryoshka-representation-learning/assets/29043871/e6aa56cc-00df-4bf6-b3c4-7dd7327a63db)

# Usage

### Instalation

```bash
$ pip install mrl-pytorch
```

Code currently supports ResNet18 and ResNet50. Supported datasets are STL10 and CIFAR10.

All training is done from scratch.

### Running Examples
`CIFAR10` ResNet-18 model was trained with this command:

`mrl_train --dataset_name "cifar10" --encoder_model_name resnet18 --fp16_precision`

`STL10` ResNet-50 model was trained with this command:

`mrl_train --dataset_name "stl10" --encoder_model_name resnet50 --fp16_precision`

### Detailed options
Once the code is setup, run the following command with optinos listed below:
`mrl_train [args...]⬇️`

```
ReLIC

options:
  -h, --help            show this help message and exit
  --dataset_path DATASET_PATH
                        Path where datasets will be saved
  --dataset_name {stl10,cifar10}
                        Dataset name
  -m {resnet18,resnet50}, --encoder_model_name {resnet18,resnet50}
                        model architecture: resnet18, resnet50 (default: resnet18)
  -save_model_dir SAVE_MODEL_DIR
                        Path where models
  --num_epochs NUM_EPOCHS
                        Number of epochs for training
  -b BATCH_SIZE, --batch_size BATCH_SIZE
                        Batch size
  -lr LEARNING_RATE, --learning_rate LEARNING_RATE
  -wd WEIGHT_DECAY, --weight_decay WEIGHT_DECAY
  --fp16_precision      Whether to use 16-bit precision GPU training.
  --proj_out_dim PROJ_OUT_DIM
                        Projector MLP out dimension
  --log_every_n_steps LOG_EVERY_N_STEPS
                        Log every n steps
  --gamma GAMMA         Initial EMA coefficient
  --alpha ALPHA         Regularization loss factor
  --update_gamma_after_step UPDATE_GAMMA_AFTER_STEP
                        Update EMA gamma after this step
  --update_gamma_every_n_steps UPDATE_GAMMA_EVERY_N_STEPS
                        Update EMA gamma after this many steps
```

# Citation

```
@misc{kusupati2022matryoshka,
      title={Matryoshka Representation Learning}, 
      author={Aditya Kusupati and Gantavya Bhatt and Aniket Rege and Matthew Wallingford and Aditya Sinha and Vivek Ramanujan and William Howard-Snyder and Kaifeng Chen and Sham Kakade and Prateek Jain and Ali Farhadi},
      year={2022},
      eprint={2205.13147},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

@misc{mitrovic2020representation,
      title={Representation Learning via Invariant Causal Mechanisms}, 
      author={Jovana Mitrovic and Brian McWilliams and Jacob Walker and Lars Buesing and Charles Blundell},
      year={2020},
      eprint={2010.07922},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}
```



            

Raw data

            {
    "_id": null,
    "home_page": "https://github.com/filipbasara0/matryoshka-representation-learning",
    "name": "mrl-pytorch",
    "maintainer": "",
    "docs_url": null,
    "requires_python": "",
    "maintainer_email": "",
    "keywords": "machine learning,pytorch,self-supervised learning,representation learning,contrastive learning",
    "author": "Filip Basara",
    "author_email": "basarafilip@gmail.com",
    "download_url": "https://files.pythonhosted.org/packages/dc/bf/5b206b4ae72a275316539228ac491f48e5c9a41956821f0275f6db15dea1/mrl-pytorch-0.1.1.tar.gz",
    "platform": null,
    "description": "# Matryoshka Representation Learning\n\nAn unofficial PyTorch implementation of [Matryoshka Representation Learning](https://arxiv.org/abs/2205.13147) for contrastive self-supervised learning, specifically the [ReLIC](https://arxiv.org/abs/2010.07922) method. MRL encodes information at different granularities to learn flexible representations (single feature vector / embedding) of different dimensions that can be adapted to multiple downstream tasks. MRL can be easily used with other tasks and modalities such as classification, retrieval or language modeling. For example, ResNet50 returns a 2048 feature vector, where we can use the subset of that vector (eg. 64) for retrieval and a larger subset of the original vector (eg. 1024) for reranking. This can substantially reduce the computational resources.\n\nThe implementation is minimal and easily extendable with custom datasets. It shows that MRL blends very well with the ReLIC framework and is capable of learning very good representations. This repo doesn't depend on a specific self-supervised approach and can be easily extended to approaches as [BYOL](https://arxiv.org/abs/2006.07733) or [SimCLR](https://arxiv.org/abs/2002.05709).\n\n# Results\n\nModels are pretrained on training subsets - for `CIFAR10` 50,000 and for `STL10` 100,000 images. For evaluation, I trained and tested LogisticRegression on frozen features from:\n1. `CIFAR10` - 50,000 train images on ReLIC\n2. `STL10` - features were learned on 100k unlabeled images. LogReg was trained on 5k train images and evaluated on 8k test images.\n\nLinear probing was used for evaluating on features extracted from encoders using the scikit LogisticRegression model. The table below shows training configurations and results when using the full dimension. Plots below show results accross dimensions.\n\nMore detailed evaluation steps and results for [CIFAR10](https://github.com/filipbasara0/matryoshka-representation-learning/blob/main/notebooks/linear-probing-cifar.ipynb) and [STL10](https://github.com/filipbasara0/matryoshka-representation-learning/blob/main/notebooks/linear-probing-stl.ipynb) can be found in the notebooks directory. \n\n| Evaulation model    | Dataset | Feature Extractor| Encoder   | Feature dim | Projection Head dim | Epochs | Top1 % |\n|---------------------|---------|------------------|-----------|-------------|---------------------|--------|--------|\n| LogisticRegression  | CIFAR10 | ReLIC            | ResNet-18 | 512         | 64                  | 400    | 84.19  |\n| LogisticRegression  | STL10   | ReLIC            | ResNet-18 | 512         | 64                  | 400    | 81.55  |\n| LogisticRegression  | STL10   | ReLIC            | ResNet-50 | 2048        | 64                  | 100    | 77.10  |\n\nBelow is the performance accross dimension for the ResNet18 model on the CIFAR10 dataset:\n\n![image](https://github.com/filipbasara0/matryoshka-representation-learning/assets/29043871/0a25f5f4-b474-48e1-8314-6eafa90a942a)\n\nBelow is the performance accross dimension for the ResNet18 model on the STL10 dataset:\n\n![image](https://github.com/filipbasara0/matryoshka-representation-learning/assets/29043871/e6aa56cc-00df-4bf6-b3c4-7dd7327a63db)\n\n# Usage\n\n### Instalation\n\n```bash\n$ pip install mrl-pytorch\n```\n\nCode currently supports ResNet18 and ResNet50. Supported datasets are STL10 and CIFAR10.\n\nAll training is done from scratch.\n\n### Running Examples\n`CIFAR10` ResNet-18 model was trained with this command:\n\n`mrl_train --dataset_name \"cifar10\" --encoder_model_name resnet18 --fp16_precision`\n\n`STL10` ResNet-50 model was trained with this command:\n\n`mrl_train --dataset_name \"stl10\" --encoder_model_name resnet50 --fp16_precision`\n\n### Detailed options\nOnce the code is setup, run the following command with optinos listed below:\n`mrl_train [args...]\u2b07\ufe0f`\n\n```\nReLIC\n\noptions:\n  -h, --help            show this help message and exit\n  --dataset_path DATASET_PATH\n                        Path where datasets will be saved\n  --dataset_name {stl10,cifar10}\n                        Dataset name\n  -m {resnet18,resnet50}, --encoder_model_name {resnet18,resnet50}\n                        model architecture: resnet18, resnet50 (default: resnet18)\n  -save_model_dir SAVE_MODEL_DIR\n                        Path where models\n  --num_epochs NUM_EPOCHS\n                        Number of epochs for training\n  -b BATCH_SIZE, --batch_size BATCH_SIZE\n                        Batch size\n  -lr LEARNING_RATE, --learning_rate LEARNING_RATE\n  -wd WEIGHT_DECAY, --weight_decay WEIGHT_DECAY\n  --fp16_precision      Whether to use 16-bit precision GPU training.\n  --proj_out_dim PROJ_OUT_DIM\n                        Projector MLP out dimension\n  --log_every_n_steps LOG_EVERY_N_STEPS\n                        Log every n steps\n  --gamma GAMMA         Initial EMA coefficient\n  --alpha ALPHA         Regularization loss factor\n  --update_gamma_after_step UPDATE_GAMMA_AFTER_STEP\n                        Update EMA gamma after this step\n  --update_gamma_every_n_steps UPDATE_GAMMA_EVERY_N_STEPS\n                        Update EMA gamma after this many steps\n```\n\n# Citation\n\n```\n@misc{kusupati2022matryoshka,\n      title={Matryoshka Representation Learning}, \n      author={Aditya Kusupati and Gantavya Bhatt and Aniket Rege and Matthew Wallingford and Aditya Sinha and Vivek Ramanujan and William Howard-Snyder and Kaifeng Chen and Sham Kakade and Prateek Jain and Ali Farhadi},\n      year={2022},\n      eprint={2205.13147},\n      archivePrefix={arXiv},\n      primaryClass={cs.LG}\n}\n\n@misc{mitrovic2020representation,\n      title={Representation Learning via Invariant Causal Mechanisms}, \n      author={Jovana Mitrovic and Brian McWilliams and Jacob Walker and Lars Buesing and Charles Blundell},\n      year={2020},\n      eprint={2010.07922},\n      archivePrefix={arXiv},\n      primaryClass={cs.LG}\n}\n```\n\n\n",
    "bugtrack_url": null,
    "license": "MIT",
    "summary": "An unofficial implementation of Matryoshka Representation Learning for contrastive self-supervised learning",
    "version": "0.1.1",
    "project_urls": {
        "Homepage": "https://github.com/filipbasara0/matryoshka-representation-learning"
    },
    "split_keywords": [
        "machine learning",
        "pytorch",
        "self-supervised learning",
        "representation learning",
        "contrastive learning"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "ac0eeb81e83b1faa391aa24e0166f17fc9db9018c6ac6d2f5dece5cf3109a3bc",
                "md5": "213a537a6d34194bd9571f139610bfc7",
                "sha256": "4036ad359a33c101fb971629faf4c14778d708036520f19398d89b505a19f884"
            },
            "downloads": -1,
            "filename": "mrl_pytorch-0.1.1-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "213a537a6d34194bd9571f139610bfc7",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": null,
            "size": 11077,
            "upload_time": "2024-02-06T21:37:19",
            "upload_time_iso_8601": "2024-02-06T21:37:19.434053Z",
            "url": "https://files.pythonhosted.org/packages/ac/0e/eb81e83b1faa391aa24e0166f17fc9db9018c6ac6d2f5dece5cf3109a3bc/mrl_pytorch-0.1.1-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "dcbf5b206b4ae72a275316539228ac491f48e5c9a41956821f0275f6db15dea1",
                "md5": "f4bddea79b3594c823b7c05ba195dea4",
                "sha256": "b444982368cc115d6af538f68f4397b2ced1cc26121f9c27060b6e41649cb3c3"
            },
            "downloads": -1,
            "filename": "mrl-pytorch-0.1.1.tar.gz",
            "has_sig": false,
            "md5_digest": "f4bddea79b3594c823b7c05ba195dea4",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": null,
            "size": 11395,
            "upload_time": "2024-02-06T21:37:21",
            "upload_time_iso_8601": "2024-02-06T21:37:21.251342Z",
            "url": "https://files.pythonhosted.org/packages/dc/bf/5b206b4ae72a275316539228ac491f48e5c9a41956821f0275f6db15dea1/mrl-pytorch-0.1.1.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2024-02-06 21:37:21",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "filipbasara0",
    "github_project": "matryoshka-representation-learning",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": false,
    "lcname": "mrl-pytorch"
}
        
Elapsed time: 0.18013s