<div align="center">
<picture>
<img alt="Heron Logo" src="logo.svg" height="25%" width="25%">
</picture>
<br>
<h2>Heron</h2>
[](https://github.com/habedi/heron/actions/workflows/tests.yml)
[](https://codecov.io/gh/habedi/heron)
[](https://www.codefactor.io/repository/github/habedi/heron)
[](https://pypi.org/project/heron-ssl/)
[](https://github.com/habedi/heron)
[](https://github.com/habedi/heron/blob/main/docs)
[](https://github.com/habedi/heron/blob/main/LICENSE)
</div>
---
Heron is a high-level self-supervised learning (SSL) library built on JAX, Flax, and Optax.
It provides high-level abstractions to simplify the process of training and experimenting with SSL algorithms,
allowing you to focus on the model and data, not the boilerplate.
Heron aims to be modular and extensible, with clear separation between models, loss functions, data augmentations,
and training strategies.
### Core Features
- **High-Level Trainer API**: A simple `Trainer` class that abstracts away the complexities of JAX's functional
paradigm, including `jit`, `pmap`, and state management.
- **Modular Design**: Easily mix and match backbones, heads, loss functions, and augmentation pipelines.
- **SSL Strategies**: Pre-packaged implementations of popular SSL algorithms.
- **Performance**: Built on JAX and Flax to leverage hardware acceleration on GPUs and TPUs.
---
### Feature Roadmap
- [x] Establish core abstractions (`Trainer`, `Backbone`, `ProjectionHead`).
- [x] Implement **SimCLR** as the first end-to-end strategy.
- [ ] Implement a robust data augmentation pipeline for contrastive learning.
- [ ] Initial PyPI release.
- [ ] Implement **BYOL** and **SimSiam** (non-contrastive methods).
- [ ] Add logic for momentum encoders (teacher-student models).
- [ ] Refine `TrainState` management and checkpointing.
- [ ] Implement **DINO** and **MoCo v3**.
- [ ] Add Vision Transformer (ViT) backbones.
- [ ] Implement teacher-student centering and sharpening.
- [ ] Masked Image Modeling strategies (e.g., **MAE**).
- [ ] Integration with Hugging Face models and datasets.
- [ ] Comprehensive documentation and tutorials.
---
### Installation
```shell
pip install heron-ssl
````
### Quick Start
Here is a conceptual example of how to use the `Trainer` API.
```python
import heron_ssl as ssl
import tensorflow_datasets as tfds
import optax
# 1. Load a dataset
dataset = tfds.load('cifar10', split='train')
# 2. Define the model and SSL strategy
strategy = ssl.strategies.SimCLR(
backbone=ssl.models.ResNet50(),
projector=ssl.models.ProjectionHead(hidden_dims=[2048], output_dim=128),
)
# 3. Configure and run the trainer
trainer = ssl.Trainer(
strategy=strategy,
optimizer=optax.adam(1e-3),
)
# 4. Start training
trained_backbone_params = trainer.fit(dataset)
```
### Contributing
Contributions are welcome! Please see [CONTRIBUTING.md](CONTRIBUTING.md) for details on how to get started.
### License
Heron is licensed under the MIT License ([LICENSE](LICENSE)).
### Acknowledgements
* Logo is from [SVG Repo](https://www.svgrepo.com/svg/452899/heron).
Raw data
{
"_id": null,
"home_page": null,
"name": "heron-ssl",
"maintainer": "Hassan Abedi",
"docs_url": null,
"requires_python": "<4.0,>=3.10",
"maintainer_email": "hassan.abedi.t+heron@gmail.com",
"keywords": "self-supervised-learning, pytorch, deep-learning, ssl",
"author": "Hassan Abedi",
"author_email": "hassan.abedi.t+heron@gmail.com",
"download_url": "https://files.pythonhosted.org/packages/33/76/e720e70cc26bd1c91d47480ac1579ec090488476e5052f3d323123724861/heron_ssl-0.1.0a1.tar.gz",
"platform": null,
"description": "<div align=\"center\">\n <picture>\n <img alt=\"Heron Logo\" src=\"logo.svg\" height=\"25%\" width=\"25%\">\n </picture>\n<br>\n\n<h2>Heron</h2>\n\n[](https://github.com/habedi/heron/actions/workflows/tests.yml)\n[](https://codecov.io/gh/habedi/heron)\n[](https://www.codefactor.io/repository/github/habedi/heron)\n[](https://pypi.org/project/heron-ssl/)\n[](https://github.com/habedi/heron)\n[](https://github.com/habedi/heron/blob/main/docs)\n[](https://github.com/habedi/heron/blob/main/LICENSE)\n\n</div>\n\n---\n\nHeron is a high-level self-supervised learning (SSL) library built on JAX, Flax, and Optax.\nIt provides high-level abstractions to simplify the process of training and experimenting with SSL algorithms,\nallowing you to focus on the model and data, not the boilerplate.\n\nHeron aims to be modular and extensible, with clear separation between models, loss functions, data augmentations,\nand training strategies.\n\n### Core Features\n\n- **High-Level Trainer API**: A simple `Trainer` class that abstracts away the complexities of JAX's functional\n paradigm, including `jit`, `pmap`, and state management.\n- **Modular Design**: Easily mix and match backbones, heads, loss functions, and augmentation pipelines.\n- **SSL Strategies**: Pre-packaged implementations of popular SSL algorithms.\n- **Performance**: Built on JAX and Flax to leverage hardware acceleration on GPUs and TPUs.\n\n---\n\n### Feature Roadmap\n\n - [x] Establish core abstractions (`Trainer`, `Backbone`, `ProjectionHead`).\n - [x] Implement **SimCLR** as the first end-to-end strategy.\n - [ ] Implement a robust data augmentation pipeline for contrastive learning.\n - [ ] Initial PyPI release.\n - [ ] Implement **BYOL** and **SimSiam** (non-contrastive methods).\n - [ ] Add logic for momentum encoders (teacher-student models).\n - [ ] Refine `TrainState` management and checkpointing.\n - [ ] Implement **DINO** and **MoCo v3**.\n - [ ] Add Vision Transformer (ViT) backbones.\n - [ ] Implement teacher-student centering and sharpening.\n - [ ] Masked Image Modeling strategies (e.g., **MAE**).\n - [ ] Integration with Hugging Face models and datasets.\n - [ ] Comprehensive documentation and tutorials.\n\n---\n\n### Installation\n\n```shell\npip install heron-ssl\n````\n\n### Quick Start\n\nHere is a conceptual example of how to use the `Trainer` API.\n\n```python\nimport heron_ssl as ssl\nimport tensorflow_datasets as tfds\nimport optax\n\n# 1. Load a dataset\ndataset = tfds.load('cifar10', split='train')\n\n# 2. Define the model and SSL strategy\nstrategy = ssl.strategies.SimCLR(\n backbone=ssl.models.ResNet50(),\n projector=ssl.models.ProjectionHead(hidden_dims=[2048], output_dim=128),\n)\n\n# 3. Configure and run the trainer\ntrainer = ssl.Trainer(\n strategy=strategy,\n optimizer=optax.adam(1e-3),\n)\n\n# 4. Start training\ntrained_backbone_params = trainer.fit(dataset)\n```\n\n### Contributing\n\nContributions are welcome! Please see [CONTRIBUTING.md](CONTRIBUTING.md) for details on how to get started.\n\n### License\n\nHeron is licensed under the MIT License ([LICENSE](LICENSE)).\n\n### Acknowledgements\n\n* Logo is from [SVG Repo](https://www.svgrepo.com/svg/452899/heron).\n\n",
"bugtrack_url": null,
"license": "MIT",
"summary": "A high-level self-supervised learning library on top of JAX",
"version": "0.1.0a1",
"project_urls": {
"Documentation": "https://github.com/habedi/heron/blob/main/docs/index.md",
"Repository": "https://github.com/habedi/heron"
},
"split_keywords": [
"self-supervised-learning",
" pytorch",
" deep-learning",
" ssl"
],
"urls": [
{
"comment_text": null,
"digests": {
"blake2b_256": "f535427bcf1f467a1fff8679c504b633d1c97516044c85ee31399bf578cfb037",
"md5": "fc7e0833796e5aa3e4679c365652c760",
"sha256": "b540864bbcdabdae8a49db8528b2463f925061c7159ec935a816a73bbed122ae"
},
"downloads": -1,
"filename": "heron_ssl-0.1.0a1-py3-none-any.whl",
"has_sig": false,
"md5_digest": "fc7e0833796e5aa3e4679c365652c760",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": "<4.0,>=3.10",
"size": 7545,
"upload_time": "2025-08-24T15:24:22",
"upload_time_iso_8601": "2025-08-24T15:24:22.666469Z",
"url": "https://files.pythonhosted.org/packages/f5/35/427bcf1f467a1fff8679c504b633d1c97516044c85ee31399bf578cfb037/heron_ssl-0.1.0a1-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": null,
"digests": {
"blake2b_256": "3376e720e70cc26bd1c91d47480ac1579ec090488476e5052f3d323123724861",
"md5": "34d13993de7bcf10c0370597b7af497e",
"sha256": "a13e40fcee3adac4deb3c6ea9d0a165651139ce7aaf17dc0fb0dd659337534bb"
},
"downloads": -1,
"filename": "heron_ssl-0.1.0a1.tar.gz",
"has_sig": false,
"md5_digest": "34d13993de7bcf10c0370597b7af497e",
"packagetype": "sdist",
"python_version": "source",
"requires_python": "<4.0,>=3.10",
"size": 5662,
"upload_time": "2025-08-24T15:24:24",
"upload_time_iso_8601": "2025-08-24T15:24:24.161705Z",
"url": "https://files.pythonhosted.org/packages/33/76/e720e70cc26bd1c91d47480ac1579ec090488476e5052f3d323123724861/heron_ssl-0.1.0a1.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2025-08-24 15:24:24",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "habedi",
"github_project": "heron",
"github_not_found": true,
"lcname": "heron-ssl"
}