unified-focal-loss-pytorch


Nameunified-focal-loss-pytorch JSON
Version 0.1.2 PyPI version JSON
download
home_page
SummaryAn implementation of loss functions from "Unified Focal loss: Generalising Dice and cross entropy-based losses to handle class imbalanced medical image segmentation"
upload_time2023-08-18 23:11:37
maintainer
docs_urlNone
authorTaylor Denouden
requires_python>=3.9,<4.0
licenseMIT
keywords
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # Unified Focal Loss PyTorch

An implementation of loss functions
from [“Unified Focal loss: Generalising Dice and cross entropy-based losses to handle class imbalanced medical image segmentation”][1]

Extended for multiclass classification and to allow passing an ignore index.

*Note: This implementation is not tested against the original implementation. It varies
from the original implementation based on my own interpretation of the paper.*

[1]: https://github.com/mlyg/unified-focal-loss

## Installation

```bash
pip install unified-focal-loss-pytorch
```

## Usage

```python
import torch
import torch.nn.functional as F
from unified_focal_loss import AsymmetricUnifiedFocalLoss

loss_fn = AsymmetricUnifiedFocalLoss(
    delta=0.7,
    gamma=0.5,
    ignore_index=2,
)

logits = torch.tensor([
    [[0.1000, 0.4000],
     [0.2000, 0.5000],
     [0.3000, 0.6000]],

    [[0.7000, 0.0000],
     [0.8000, 0.1000],
     [0.9000, 0.2000]]
])

# Shape should be (batch_size, num_classes, ...)
probs = F.softmax(logits, dim=1)
# Shape should be (batch_size, ...). Not one-hot encoded.
targets = torch.tensor([
    [0, 1],
    [2, 0],
])

loss = loss_fn(probs, targets)
print(loss)
# >>> tensor(0.6737)
```

## Detailed API Reference
See [API docs](docs/api.md).

## License
See [LICENSE](LICENSE).

            

Raw data

            {
    "_id": null,
    "home_page": "",
    "name": "unified-focal-loss-pytorch",
    "maintainer": "",
    "docs_url": null,
    "requires_python": ">=3.9,<4.0",
    "maintainer_email": "",
    "keywords": "",
    "author": "Taylor Denouden",
    "author_email": "taylordenouden@gmail.com",
    "download_url": "https://files.pythonhosted.org/packages/6a/dc/581e25141b9068933bf5edad4c790ebe34868fd82fb19ff2a81173b6a5a8/unified_focal_loss_pytorch-0.1.2.tar.gz",
    "platform": null,
    "description": "# Unified Focal Loss PyTorch\n\nAn implementation of loss functions\nfrom [\u201cUnified Focal loss: Generalising Dice and cross entropy-based losses to handle class imbalanced medical image segmentation\u201d][1]\n\nExtended for multiclass classification and to allow passing an ignore index.\n\n*Note: This implementation is not tested against the original implementation. It varies\nfrom the original implementation based on my own interpretation of the paper.*\n\n[1]: https://github.com/mlyg/unified-focal-loss\n\n## Installation\n\n```bash\npip install unified-focal-loss-pytorch\n```\n\n## Usage\n\n```python\nimport torch\nimport torch.nn.functional as F\nfrom unified_focal_loss import AsymmetricUnifiedFocalLoss\n\nloss_fn = AsymmetricUnifiedFocalLoss(\n    delta=0.7,\n    gamma=0.5,\n    ignore_index=2,\n)\n\nlogits = torch.tensor([\n    [[0.1000, 0.4000],\n     [0.2000, 0.5000],\n     [0.3000, 0.6000]],\n\n    [[0.7000, 0.0000],\n     [0.8000, 0.1000],\n     [0.9000, 0.2000]]\n])\n\n# Shape should be (batch_size, num_classes, ...)\nprobs = F.softmax(logits, dim=1)\n# Shape should be (batch_size, ...). Not one-hot encoded.\ntargets = torch.tensor([\n    [0, 1],\n    [2, 0],\n])\n\nloss = loss_fn(probs, targets)\nprint(loss)\n# >>> tensor(0.6737)\n```\n\n## Detailed API Reference\nSee [API docs](docs/api.md).\n\n## License\nSee [LICENSE](LICENSE).\n",
    "bugtrack_url": null,
    "license": "MIT",
    "summary": "An implementation of loss functions from \"Unified Focal loss: Generalising Dice and cross entropy-based losses to handle class imbalanced medical image segmentation\"",
    "version": "0.1.2",
    "project_urls": null,
    "split_keywords": [],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "fe15756875ed41bd147d7e7a9eee5bb76429b83ce3867c0cfff4d1b0cda04766",
                "md5": "0892020a484a22a59ab6094dcc147cbc",
                "sha256": "0631b618d10c9537a0385f618c78ad6f1ecec8d68489f9812d2f76a54c5b73b6"
            },
            "downloads": -1,
            "filename": "unified_focal_loss_pytorch-0.1.2-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "0892020a484a22a59ab6094dcc147cbc",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.9,<4.0",
            "size": 6368,
            "upload_time": "2023-08-18T23:11:35",
            "upload_time_iso_8601": "2023-08-18T23:11:35.595132Z",
            "url": "https://files.pythonhosted.org/packages/fe/15/756875ed41bd147d7e7a9eee5bb76429b83ce3867c0cfff4d1b0cda04766/unified_focal_loss_pytorch-0.1.2-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "6adc581e25141b9068933bf5edad4c790ebe34868fd82fb19ff2a81173b6a5a8",
                "md5": "d8ba46ff13474867cddd6f4cbd2a159f",
                "sha256": "d43be8e91943bd951ee7353e6a64b07a296373827de2374cbb0126ed8d30ab01"
            },
            "downloads": -1,
            "filename": "unified_focal_loss_pytorch-0.1.2.tar.gz",
            "has_sig": false,
            "md5_digest": "d8ba46ff13474867cddd6f4cbd2a159f",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.9,<4.0",
            "size": 6024,
            "upload_time": "2023-08-18T23:11:37",
            "upload_time_iso_8601": "2023-08-18T23:11:37.501713Z",
            "url": "https://files.pythonhosted.org/packages/6a/dc/581e25141b9068933bf5edad4c790ebe34868fd82fb19ff2a81173b6a5a8/unified_focal_loss_pytorch-0.1.2.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2023-08-18 23:11:37",
    "github": false,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "lcname": "unified-focal-loss-pytorch"
}
        
Elapsed time: 0.10972s