<div align="center">
<a href="https://badge.fury.io/py/balanced-loss"><img src="https://badge.fury.io/py/balanced-loss.svg" alt="pypi version"></a>
</div>
<p align="center">
<img src="https://user-images.githubusercontent.com/34196005/180311379-1003da44-cdf9-46e8-af83-e65fbc3710cd.png" width="350">
</p>
<p align="center">
Easy-to-use, class-balanced, cross-entropy and focal loss implementation for Pytorch.
</p>
## Theory
When training dataset labels are imbalanced, one thing to do is to balance the loss across sample classes.
- First, the effective number of samples are calculated for all classes as:
![alt-text](https://user-images.githubusercontent.com/34196005/180266195-aa2e8696-cdeb-48ed-a85f-7ffb353942a4.png)
- Then the class balanced loss function is defined as:
![alt-text](https://user-images.githubusercontent.com/34196005/180266198-e27d8cba-f5e1-49ca-9f82-d8656333e3c4.png)
## Installation
```bash
pip install balanced-loss
```
## Usage
- Standard losses:
```python
import torch
from balanced_loss import Loss
# outputs and labels
logits = torch.tensor([[0.78, 0.1, 0.05]]) # 1 batch, 3 class
labels = torch.tensor([0]) # 1 batch
# focal loss
focal_loss = Loss(loss_type="focal_loss")
loss = focal_loss(logits, labels)
```
```python
# cross-entropy loss
ce_loss = Loss(loss_type="cross_entropy")
loss = ce_loss(logits, labels)
```
```python
# binary cross-entropy loss
bce_loss = Loss(loss_type="binary_cross_entropy")
loss = bce_loss(logits, labels)
```
- Class-balanced losses:
```python
import torch
from balanced_loss import Loss
# outputs and labels
logits = torch.tensor([[0.78, 0.1, 0.05]]) # 1 batch, 3 class
labels = torch.tensor([0]) # 1 batch
# number of samples per class in the training dataset
samples_per_class = [30, 100, 25] # 30, 100, 25 samples for labels 0, 1 and 2, respectively
# class-balanced focal loss
focal_loss = Loss(
loss_type="focal_loss",
samples_per_class=samples_per_class,
class_balanced=True
)
loss = focal_loss(logits, labels)
```
```python
# class-balanced cross-entropy loss
ce_loss = Loss(
loss_type="cross_entropy",
samples_per_class=samples_per_class,
class_balanced=True
)
loss = ce_loss(logits, labels)
```
```python
# class-balanced binary cross-entropy loss
bce_loss = Loss(
loss_type="binary_cross_entropy",
samples_per_class=samples_per_class,
class_balanced=True
)
loss = bce_loss(logits, labels)
```
- Customize parameters:
```python
import torch
from balanced_loss import Loss
# outputs and labels
logits = torch.tensor([[0.78, 0.1, 0.05]]) # 1 batch, 3 class
labels = torch.tensor([0])
# number of samples per class in the training dataset
samples_per_class = [30, 100, 25] # 30, 100, 25 samples for labels 0, 1 and 2, respectively
# class-balanced focal loss
focal_loss = Loss(
loss_type="focal_loss",
beta=0.999, # class-balanced loss beta
fl_gamma=2, # focal loss gamma
samples_per_class=samples_per_class,
class_balanced=True
)
loss = focal_loss(logits, labels)
```
## Improvements
What is the difference between this repo and vandit15's?
- This repo is a pypi installable package
- This repo implements loss functions as `torch.nn.Module`
- In addition to class balanced losses, this repo also supports the standard versions of the cross entropy/focal loss etc. over the same API
- All typos and errors in vandit15's source are fixed
## References
https://arxiv.org/abs/1901.05555
https://github.com/richardaecn/class-balanced-loss
https://github.com/vandit15/Class-balanced-loss-pytorch
Raw data
{
"_id": null,
"home_page": "https://github.com/fcakyon/balanced-loss",
"name": "balanced-loss",
"maintainer": "",
"docs_url": null,
"requires_python": ">=3.7",
"maintainer_email": "",
"keywords": "machine-learning,deep-learning,ml,pytorch,vision,loss,image-classification,video-classification",
"author": "",
"author_email": "",
"download_url": "https://files.pythonhosted.org/packages/26/4a/7fbab9ae35b9c490fbfe574c2247dfb1af32bba438c59443bb26ae983403/balanced-loss-0.1.0.tar.gz",
"platform": null,
"description": "<div align=\"center\">\n <a href=\"https://badge.fury.io/py/balanced-loss\"><img src=\"https://badge.fury.io/py/balanced-loss.svg\" alt=\"pypi version\"></a>\n</div>\n\n<p align=\"center\">\n<img src=\"https://user-images.githubusercontent.com/34196005/180311379-1003da44-cdf9-46e8-af83-e65fbc3710cd.png\" width=\"350\">\n</p>\n\n<p align=\"center\">\n Easy-to-use, class-balanced, cross-entropy and focal loss implementation for Pytorch.\n</p>\n\n## Theory\n\nWhen training dataset labels are imbalanced, one thing to do is to balance the loss across sample classes.\n\n- First, the effective number of samples are calculated for all classes as:\n\n![alt-text](https://user-images.githubusercontent.com/34196005/180266195-aa2e8696-cdeb-48ed-a85f-7ffb353942a4.png)\n\n- Then the class balanced loss function is defined as:\n\n![alt-text](https://user-images.githubusercontent.com/34196005/180266198-e27d8cba-f5e1-49ca-9f82-d8656333e3c4.png)\n\n\n## Installation\n\n```bash\npip install balanced-loss\n```\n\n## Usage\n\n- Standard losses:\n\n```python\nimport torch\nfrom balanced_loss import Loss\n\n# outputs and labels\nlogits = torch.tensor([[0.78, 0.1, 0.05]]) # 1 batch, 3 class\nlabels = torch.tensor([0]) # 1 batch\n\n# focal loss\nfocal_loss = Loss(loss_type=\"focal_loss\")\nloss = focal_loss(logits, labels)\n```\n\n```python\n# cross-entropy loss\nce_loss = Loss(loss_type=\"cross_entropy\")\nloss = ce_loss(logits, labels)\n```\n\n```python\n# binary cross-entropy loss\nbce_loss = Loss(loss_type=\"binary_cross_entropy\")\nloss = bce_loss(logits, labels)\n```\n\n- Class-balanced losses:\n\n```python\nimport torch\nfrom balanced_loss import Loss\n\n# outputs and labels\nlogits = torch.tensor([[0.78, 0.1, 0.05]]) # 1 batch, 3 class\nlabels = torch.tensor([0]) # 1 batch\n\n# number of samples per class in the training dataset\nsamples_per_class = [30, 100, 25] # 30, 100, 25 samples for labels 0, 1 and 2, respectively\n\n# class-balanced focal loss\nfocal_loss = Loss(\n loss_type=\"focal_loss\",\n samples_per_class=samples_per_class,\n class_balanced=True\n)\nloss = focal_loss(logits, labels)\n```\n\n```python\n# class-balanced cross-entropy loss\nce_loss = Loss(\n loss_type=\"cross_entropy\",\n samples_per_class=samples_per_class,\n class_balanced=True\n)\nloss = ce_loss(logits, labels)\n```\n\n```python\n# class-balanced binary cross-entropy loss\nbce_loss = Loss(\n loss_type=\"binary_cross_entropy\",\n samples_per_class=samples_per_class,\n class_balanced=True\n)\nloss = bce_loss(logits, labels)\n```\n\n- Customize parameters:\n\n```python\nimport torch\nfrom balanced_loss import Loss\n\n# outputs and labels\nlogits = torch.tensor([[0.78, 0.1, 0.05]]) # 1 batch, 3 class\nlabels = torch.tensor([0])\n\n# number of samples per class in the training dataset\nsamples_per_class = [30, 100, 25] # 30, 100, 25 samples for labels 0, 1 and 2, respectively\n\n# class-balanced focal loss\nfocal_loss = Loss(\n loss_type=\"focal_loss\",\n beta=0.999, # class-balanced loss beta\n fl_gamma=2, # focal loss gamma\n samples_per_class=samples_per_class,\n class_balanced=True\n)\nloss = focal_loss(logits, labels)\n```\n\n## Improvements\n\nWhat is the difference between this repo and vandit15's?\n\n- This repo is a pypi installable package\n- This repo implements loss functions as `torch.nn.Module`\n- In addition to class balanced losses, this repo also supports the standard versions of the cross entropy/focal loss etc. over the same API\n- All typos and errors in vandit15's source are fixed\n\n## References\n\nhttps://arxiv.org/abs/1901.05555\n\nhttps://github.com/richardaecn/class-balanced-loss\n\nhttps://github.com/vandit15/Class-balanced-loss-pytorch\n\n",
"bugtrack_url": null,
"license": "MIT",
"summary": "Easy to use class-balanced cross-entropy and focal loss implementation for Pytorch.",
"version": "0.1.0",
"split_keywords": [
"machine-learning",
"deep-learning",
"ml",
"pytorch",
"vision",
"loss",
"image-classification",
"video-classification"
],
"urls": [
{
"comment_text": "",
"digests": {
"md5": "1d191e256902fab2fadbe25a3962a7cb",
"sha256": "9504e5d52dc3773d701f0af07090e470b155eb77060fd00c1b0ac6fbff68f10c"
},
"downloads": -1,
"filename": "balanced_loss-0.1.0-py3-none-any.whl",
"has_sig": false,
"md5_digest": "1d191e256902fab2fadbe25a3962a7cb",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": ">=3.7",
"size": 5229,
"upload_time": "2022-07-21T21:04:10",
"upload_time_iso_8601": "2022-07-21T21:04:10.624004Z",
"url": "https://files.pythonhosted.org/packages/7e/a7/171d43fae753004d156b008d9db32458c487203df888841c5b2bc4f3f310/balanced_loss-0.1.0-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": "",
"digests": {
"md5": "9b030f64c65a8cf2e789768ac4890b9d",
"sha256": "e55957cda998ed84963b8aa9c4f32456a7edb4fd94a5938d17604bb7763dff07"
},
"downloads": -1,
"filename": "balanced-loss-0.1.0.tar.gz",
"has_sig": false,
"md5_digest": "9b030f64c65a8cf2e789768ac4890b9d",
"packagetype": "sdist",
"python_version": "source",
"requires_python": ">=3.7",
"size": 5402,
"upload_time": "2022-07-21T21:04:12",
"upload_time_iso_8601": "2022-07-21T21:04:12.854175Z",
"url": "https://files.pythonhosted.org/packages/26/4a/7fbab9ae35b9c490fbfe574c2247dfb1af32bba438c59443bb26ae983403/balanced-loss-0.1.0.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2022-07-21 21:04:12",
"github": true,
"gitlab": false,
"bitbucket": false,
"github_user": "fcakyon",
"github_project": "balanced-loss",
"travis_ci": false,
"coveralls": false,
"github_actions": true,
"requirements": [
{
"name": "torch",
"specs": []
},
{
"name": "numpy",
"specs": []
},
{
"name": "click",
"specs": [
[
"==",
"8.0.4"
]
]
}
],
"lcname": "balanced-loss"
}