relic-pytorch


Namerelic-pytorch JSON
Version 0.4.1 PyPI version JSON
download
home_pagehttps://github.com/filipbasara0/relic
SummarySimple self-supervised contrastive based on based on the ReLIC method
upload_time2024-02-22 07:36:15
maintainer
docs_urlNone
authorFilip Basara
requires_python
licenseMIT
keywords machine learning pytorch self-supervised learning representation learning contrastive learning
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            ![image](https://github.com/filipbasara0/relic/assets/29043871/130c3459-08fc-49c1-a922-43576f2a255c)

# ReLIC

A PyTorch implementation of a computer vision self-supervised learning method based on [Representation Learning via Invariant Causal Mechanisms (ReLIC)](https://arxiv.org/abs/2010.07922).

This simple approach is very similar to [BYOL](https://arxiv.org/abs/2006.07733) and [SimCLR](https://arxiv.org/abs/2002.05709). The training technique uses a online and target encoder (EMA) with a simple critic MLP projector, while the instance discrimination loss function resembles the contrastive loss used in SimCLR. The other half of the loss function acts as a regularizer - it includes an invariance penalty, which forces the representations to stay invariant under data augmentations and amplifies intra-class distances.

![image](https://github.com/filipbasara0/relic/assets/29043871/70ccdb40-3343-4ea7-946b-80bdc1e7b85d)

Repo includes the multi-crop augmentation, which is used in the follow-up [Pushing the limits of self-supervised ResNets: Can we outperform supervised learning without labels on ImageNet? (ReLICv2)](https://arxiv.org/pdf/2201.05119.pdf) paper. Loss function is extended to support an arbitrary number of small (local) and large (global) views. Using this technique generally results in more robust and higher quality representations.

Also has an experimental support for the sigmoid pairwise loss, from the [SigLIP](https://arxiv.org/abs/2303.15343) paper. This loss is generally less stable and gives slightly worse metrics, but still yields very good representations.


# Results

Models are pretrained on training subsets - for `CIFAR10` 50,000 and for `STL10` 100,000 images. For evaluation, I trained and tested LogisticRegression on frozen features from:
1. `CIFAR10` - 50,000 train images on ReLIC
2. `STL10` - features were learned on 100k unlabeled images. LogReg was trained on 5k train images and evaluated on 8k test images.

Linear probing was used for evaluating on features extracted from encoders using the scikit LogisticRegression model.

More detailed evaluation steps and results for [CIFAR10](https://github.com/filipbasara0/relic/blob/main/notebooks/linear-probing-cifar.ipynb) and [STL10](https://github.com/filipbasara0/relic/blob/main/notebooks/linear-probing-stl.ipynb) can be found in the notebooks directory. 

| Evaulation model    | Dataset | Feature Extractor| Encoder   | Feature dim | Projection Head dim | Epochs | Top1 % |
|---------------------|---------|------------------|-----------|-------------|---------------------|--------|--------|
| LogisticRegression  | CIFAR10 | ReLIC            | ResNet-18 | 512         | 64                  | 100    | 82.53  |
| LogisticRegression  | STL10   | ReLIC            | ResNet-18 | 512         | 64                  | 100    | 77.12  |
| LogisticRegression  | STL10   | ReLIC            | ResNet-50 | 2048        | 64                  | 100    | 81.95  |

[Here](https://drive.google.com/file/d/1XaZBdvPGPh2nQzzHAJ_oL41c1f8Lc_FN/view?usp=sharing) is a link to a resnet18 encoder trained on the ImageNet-1k subset. This models pefroms better on both CIFAR10 and STL10.

# Usage

### Instalation

```bash
$ pip install relic-pytorch
```

Code currently supports ResNet18, ResNet50 and an experimental version of the EfficientNet model. Supported datasets are STL10 and CIFAR10.

All training is done from scratch.

### Examples
`CIFAR10` ResNet-18 model was trained with this command:

`relic_train --dataset_name "cifar10" --encoder_model_name resnet18 --fp16_precision --gamma 0.99 --alpha 1.0`

`STL10` ResNet-50 model was trained with this command:

`relic_train --dataset_name "stl10" --encoder_model_name resnet50 --fp16_precision`

### Detailed options
Once the code is setup, run the following command with optinos listed below:
`relic_train [args...]⬇️`

```
ReLIC

options:
  -h, --help            show this help message and exit
  --dataset_path DATASET_PATH
                        Path where datasets will be saved
  --dataset_name {stl10,cifar10}
                        Dataset name
  -m {resnet18,resnet50,efficientnet}, --encoder_model_name {resnet18,resnet50,efficientnet}
                        model architecture: resnet18, resnet50 or efficientnet (default: resnet18)
  -save_model_dir SAVE_MODEL_DIR
                        Path where models
  --num_epochs NUM_EPOCHS
                        Number of epochs for training
  -b BATCH_SIZE, --batch_size BATCH_SIZE
                        Batch size
  -lr LEARNING_RATE, --learning_rate LEARNING_RATE
  -wd WEIGHT_DECAY, --weight_decay WEIGHT_DECAY
  --fp16_precision      Whether to use 16-bit precision GPU training.
  --proj_out_dim PROJ_OUT_DIM
                        Projector MLP out dimension
  --proj_hidden_dim PROJ_HIDDEN_DIM
                        Projector MLP hidden dimension
  --log_every_n_steps LOG_EVERY_N_STEPS
                        Log every n steps
  --gamma GAMMA         Initial EMA coefficient
  --use_siglip          Whether to use siglip loss.
  --alpha ALPHA         Regularization loss factor
  --update_gamma_after_step UPDATE_GAMMA_AFTER_STEP
                        Update EMA gamma after this step
  --update_gamma_every_n_steps UPDATE_GAMMA_EVERY_N_STEPS
                        Update EMA gamma after this many steps
```

# Citation

```
@misc{mitrovic2020representation,
      title={Representation Learning via Invariant Causal Mechanisms}, 
      author={Jovana Mitrovic and Brian McWilliams and Jacob Walker and Lars Buesing and Charles Blundell},
      year={2020},
      eprint={2010.07922},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

@misc{zhai2023sigmoid,
      title={Sigmoid Loss for Language Image Pre-Training}, 
      author={Xiaohua Zhai and Basil Mustafa and Alexander Kolesnikov and Lucas Beyer},
      year={2023},
      eprint={2303.15343},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}
```



            

Raw data

            {
    "_id": null,
    "home_page": "https://github.com/filipbasara0/relic",
    "name": "relic-pytorch",
    "maintainer": "",
    "docs_url": null,
    "requires_python": "",
    "maintainer_email": "",
    "keywords": "machine learning,pytorch,self-supervised learning,representation learning,contrastive learning",
    "author": "Filip Basara",
    "author_email": "basarafilip@gmail.com",
    "download_url": "https://files.pythonhosted.org/packages/99/83/1b22f86190055c216370c16d1a91f9a50ce2bda52fabe2320c4e1be969f8/relic-pytorch-0.4.1.tar.gz",
    "platform": null,
    "description": "![image](https://github.com/filipbasara0/relic/assets/29043871/130c3459-08fc-49c1-a922-43576f2a255c)\n\n# ReLIC\n\nA PyTorch implementation of a computer vision self-supervised learning method based on [Representation Learning via Invariant Causal Mechanisms (ReLIC)](https://arxiv.org/abs/2010.07922).\n\nThis simple approach is very similar to [BYOL](https://arxiv.org/abs/2006.07733) and [SimCLR](https://arxiv.org/abs/2002.05709). The training technique uses a online and target encoder (EMA) with a simple critic MLP projector, while the instance discrimination loss function resembles the contrastive loss used in SimCLR. The other half of the loss function acts as a regularizer - it includes an invariance penalty, which forces the representations to stay invariant under data augmentations and amplifies intra-class distances.\n\n![image](https://github.com/filipbasara0/relic/assets/29043871/70ccdb40-3343-4ea7-946b-80bdc1e7b85d)\n\nRepo includes the multi-crop augmentation, which is used in the follow-up [Pushing the limits of self-supervised ResNets: Can we outperform supervised learning without labels on ImageNet? (ReLICv2)](https://arxiv.org/pdf/2201.05119.pdf) paper. Loss function is extended to support an arbitrary number of small (local) and large (global) views. Using this technique generally results in more robust and higher quality representations.\n\nAlso has an experimental support for the sigmoid pairwise loss, from the [SigLIP](https://arxiv.org/abs/2303.15343) paper. This loss is generally less stable and gives slightly worse metrics, but still yields very good representations.\n\n\n# Results\n\nModels are pretrained on training subsets - for `CIFAR10` 50,000 and for `STL10` 100,000 images. For evaluation, I trained and tested LogisticRegression on frozen features from:\n1. `CIFAR10` - 50,000 train images on ReLIC\n2. `STL10` - features were learned on 100k unlabeled images. LogReg was trained on 5k train images and evaluated on 8k test images.\n\nLinear probing was used for evaluating on features extracted from encoders using the scikit LogisticRegression model.\n\nMore detailed evaluation steps and results for [CIFAR10](https://github.com/filipbasara0/relic/blob/main/notebooks/linear-probing-cifar.ipynb) and [STL10](https://github.com/filipbasara0/relic/blob/main/notebooks/linear-probing-stl.ipynb) can be found in the notebooks directory. \n\n| Evaulation model    | Dataset | Feature Extractor| Encoder   | Feature dim | Projection Head dim | Epochs | Top1 % |\n|---------------------|---------|------------------|-----------|-------------|---------------------|--------|--------|\n| LogisticRegression  | CIFAR10 | ReLIC            | ResNet-18 | 512         | 64                  | 100    | 82.53  |\n| LogisticRegression  | STL10   | ReLIC            | ResNet-18 | 512         | 64                  | 100    | 77.12  |\n| LogisticRegression  | STL10   | ReLIC            | ResNet-50 | 2048        | 64                  | 100    | 81.95  |\n\n[Here](https://drive.google.com/file/d/1XaZBdvPGPh2nQzzHAJ_oL41c1f8Lc_FN/view?usp=sharing) is a link to a resnet18 encoder trained on the ImageNet-1k subset. This models pefroms better on both CIFAR10 and STL10.\n\n# Usage\n\n### Instalation\n\n```bash\n$ pip install relic-pytorch\n```\n\nCode currently supports ResNet18, ResNet50 and an experimental version of the EfficientNet model. Supported datasets are STL10 and CIFAR10.\n\nAll training is done from scratch.\n\n### Examples\n`CIFAR10` ResNet-18 model was trained with this command:\n\n`relic_train --dataset_name \"cifar10\" --encoder_model_name resnet18 --fp16_precision --gamma 0.99 --alpha 1.0`\n\n`STL10` ResNet-50 model was trained with this command:\n\n`relic_train --dataset_name \"stl10\" --encoder_model_name resnet50 --fp16_precision`\n\n### Detailed options\nOnce the code is setup, run the following command with optinos listed below:\n`relic_train [args...]\u2b07\ufe0f`\n\n```\nReLIC\n\noptions:\n  -h, --help            show this help message and exit\n  --dataset_path DATASET_PATH\n                        Path where datasets will be saved\n  --dataset_name {stl10,cifar10}\n                        Dataset name\n  -m {resnet18,resnet50,efficientnet}, --encoder_model_name {resnet18,resnet50,efficientnet}\n                        model architecture: resnet18, resnet50 or efficientnet (default: resnet18)\n  -save_model_dir SAVE_MODEL_DIR\n                        Path where models\n  --num_epochs NUM_EPOCHS\n                        Number of epochs for training\n  -b BATCH_SIZE, --batch_size BATCH_SIZE\n                        Batch size\n  -lr LEARNING_RATE, --learning_rate LEARNING_RATE\n  -wd WEIGHT_DECAY, --weight_decay WEIGHT_DECAY\n  --fp16_precision      Whether to use 16-bit precision GPU training.\n  --proj_out_dim PROJ_OUT_DIM\n                        Projector MLP out dimension\n  --proj_hidden_dim PROJ_HIDDEN_DIM\n                        Projector MLP hidden dimension\n  --log_every_n_steps LOG_EVERY_N_STEPS\n                        Log every n steps\n  --gamma GAMMA         Initial EMA coefficient\n  --use_siglip          Whether to use siglip loss.\n  --alpha ALPHA         Regularization loss factor\n  --update_gamma_after_step UPDATE_GAMMA_AFTER_STEP\n                        Update EMA gamma after this step\n  --update_gamma_every_n_steps UPDATE_GAMMA_EVERY_N_STEPS\n                        Update EMA gamma after this many steps\n```\n\n# Citation\n\n```\n@misc{mitrovic2020representation,\n      title={Representation Learning via Invariant Causal Mechanisms}, \n      author={Jovana Mitrovic and Brian McWilliams and Jacob Walker and Lars Buesing and Charles Blundell},\n      year={2020},\n      eprint={2010.07922},\n      archivePrefix={arXiv},\n      primaryClass={cs.LG}\n}\n\n@misc{zhai2023sigmoid,\n      title={Sigmoid Loss for Language Image Pre-Training}, \n      author={Xiaohua Zhai and Basil Mustafa and Alexander Kolesnikov and Lucas Beyer},\n      year={2023},\n      eprint={2303.15343},\n      archivePrefix={arXiv},\n      primaryClass={cs.CV}\n}\n```\n\n\n",
    "bugtrack_url": null,
    "license": "MIT",
    "summary": "Simple self-supervised contrastive based on based on the ReLIC method",
    "version": "0.4.1",
    "project_urls": {
        "Homepage": "https://github.com/filipbasara0/relic"
    },
    "split_keywords": [
        "machine learning",
        "pytorch",
        "self-supervised learning",
        "representation learning",
        "contrastive learning"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "319efadba7a3abfaf4c9ed4cf6eec8e257d80ae1bc1429668eae15498830954d",
                "md5": "6eed70dcaf55a84474e1cd9b34269544",
                "sha256": "073b4e166c3ea36fdb9cc38dfd14b61e7fdd20d241aefdf1e458f533ff49671c"
            },
            "downloads": -1,
            "filename": "relic_pytorch-0.4.1-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "6eed70dcaf55a84474e1cd9b34269544",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": null,
            "size": 11726,
            "upload_time": "2024-02-22T07:36:12",
            "upload_time_iso_8601": "2024-02-22T07:36:12.939576Z",
            "url": "https://files.pythonhosted.org/packages/31/9e/fadba7a3abfaf4c9ed4cf6eec8e257d80ae1bc1429668eae15498830954d/relic_pytorch-0.4.1-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "99831b22f86190055c216370c16d1a91f9a50ce2bda52fabe2320c4e1be969f8",
                "md5": "3052b3d78dc958314a10f43e1322de21",
                "sha256": "f7820a8450944cd92d4e77141074e76d86fe1056c4d2a0e1b1f7634bed60e0c1"
            },
            "downloads": -1,
            "filename": "relic-pytorch-0.4.1.tar.gz",
            "has_sig": false,
            "md5_digest": "3052b3d78dc958314a10f43e1322de21",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": null,
            "size": 14440,
            "upload_time": "2024-02-22T07:36:15",
            "upload_time_iso_8601": "2024-02-22T07:36:15.070240Z",
            "url": "https://files.pythonhosted.org/packages/99/83/1b22f86190055c216370c16d1a91f9a50ce2bda52fabe2320c4e1be969f8/relic-pytorch-0.4.1.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2024-02-22 07:36:15",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "filipbasara0",
    "github_project": "relic",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": false,
    "lcname": "relic-pytorch"
}
        
Elapsed time: 0.26737s