# fastmtl
> Multi-task learning utilities for fastai
## Install
`pip install fastmtl`
## Usage
### Loss
Apply a loss function per model output and get weighted sum of them. For instance, if the first model output is for classification and the second model output is for regression,
```py
from fastmtl.loss import CombinedLoss
loss_func = CombinedLoss(CrossEntropyLossFlat(), MSELossFlat(), weight=[1.0, 3.0])
```
### Metric
Apply metrics for each model output. For instance, if we have a model making classification and regression, we can evaluate each model output with relevant metrics. Assuming that model outputs a tuple of tensors for classification and regression, respectively:
```py
from fastai.metrics import F1Score, R2Score
from fastmtl.metric import mtl_metrics
clf_f1_macro = F1Score(average='macro')
clf_f1_macro.name = 'clf_f1(macro)'
clf_f1_micro = F1Score(average='micro')
clf_f1_micro.name = 'clf_f1(micro)'
reg_r2 = R2Score()
reg_r2.name = 'reg_r2'
# metrics for classification in the first list
# metrics for regression in the second list
metrics = mtl_metrics([clf_f1_macro, clf_f1_micro], [reg_r2])
learn = Learner(
...
metrics=metrics,
)
```
## Tutorials
[Video distortion detection](https://bdsaglam.github.io/fastmtl/tutorial.vqa)
## TODO
- [ ] Support tabular learner
- [ ] Support fastai>=2.7
Raw data
{
"_id": null,
"home_page": "https://github.com/bdsaglam/fastmtl/tree/master/",
"name": "fastmtl",
"maintainer": "",
"docs_url": null,
"requires_python": ">=3.7",
"maintainer_email": "",
"keywords": "deep learning,multi-task learning,fastai,pytorch",
"author": "Bar\u0131\u015f Deniz Sa\u011flam",
"author_email": "bdsaglam@gmail.com",
"download_url": "https://files.pythonhosted.org/packages/f3/86/8c7cbdef0c43e1e72af32f15ad22f83ac8e834627fec029c773f5d67ddd2/fastmtl-1.1.0.tar.gz",
"platform": null,
"description": "# fastmtl\n> Multi-task learning utilities for fastai\n\n\n## Install\n\n`pip install fastmtl`\n\n## Usage\n\n### Loss\n\nApply a loss function per model output and get weighted sum of them. For instance, if the first model output is for classification and the second model output is for regression,\n```py\nfrom fastmtl.loss import CombinedLoss\nloss_func = CombinedLoss(CrossEntropyLossFlat(), MSELossFlat(), weight=[1.0, 3.0])\n```\n\n### Metric\n\nApply metrics for each model output. For instance, if we have a model making classification and regression, we can evaluate each model output with relevant metrics. Assuming that model outputs a tuple of tensors for classification and regression, respectively:\n\n```py\nfrom fastai.metrics import F1Score, R2Score\nfrom fastmtl.metric import mtl_metrics\n\nclf_f1_macro = F1Score(average='macro')\nclf_f1_macro.name = 'clf_f1(macro)'\nclf_f1_micro = F1Score(average='micro')\nclf_f1_micro.name = 'clf_f1(micro)'\n\nreg_r2 = R2Score()\nreg_r2.name = 'reg_r2'\n\n# metrics for classification in the first list \n# metrics for regression in the second list \nmetrics = mtl_metrics([clf_f1_macro, clf_f1_micro], [reg_r2])\n\nlearn = Learner(\n ...\n metrics=metrics,\n)\n```\n\n## Tutorials\n\n[Video distortion detection](https://bdsaglam.github.io/fastmtl/tutorial.vqa)\n\n## TODO\n- [ ] Support tabular learner\n- [ ] Support fastai>=2.7\n\n\n",
"bugtrack_url": null,
"license": "Apache Software License 2.0",
"summary": "Multi-task learning utilities for fastai",
"version": "1.1.0",
"split_keywords": [
"deep learning",
"multi-task learning",
"fastai",
"pytorch"
],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "f3868c7cbdef0c43e1e72af32f15ad22f83ac8e834627fec029c773f5d67ddd2",
"md5": "26054e2728635c8334bd0a56226e5fab",
"sha256": "a223f1c57c89bd173e2e68a2af1acc98b0977843c12f46b1c5ea66d560fd783b"
},
"downloads": -1,
"filename": "fastmtl-1.1.0.tar.gz",
"has_sig": false,
"md5_digest": "26054e2728635c8334bd0a56226e5fab",
"packagetype": "sdist",
"python_version": "source",
"requires_python": ">=3.7",
"size": 10034,
"upload_time": "2023-01-22T20:15:06",
"upload_time_iso_8601": "2023-01-22T20:15:06.594228Z",
"url": "https://files.pythonhosted.org/packages/f3/86/8c7cbdef0c43e1e72af32f15ad22f83ac8e834627fec029c773f5d67ddd2/fastmtl-1.1.0.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2023-01-22 20:15:06",
"github": false,
"gitlab": false,
"bitbucket": false,
"lcname": "fastmtl"
}