auraloss


Nameauraloss JSON
Version 0.4.0 PyPI version JSON
download
home_pagehttps://github.com/csteinmetz1/auraloss
SummaryCollection of audio-focused loss functions in PyTorch.
upload_time2023-04-21 09:21:46
maintainer
docs_urlNone
authorChristian Steinmetz
requires_python>=3.6.0
licenseApache License 2.0
keywords
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            
<div  align="center">

# auraloss

<img width="200px" src="docs/auraloss-logo.svg">

A collection of audio-focused loss functions in PyTorch. 

[[PDF](https://www.christiansteinmetz.com/s/DMRN15__auraloss__Audio_focused_loss_functions_in_PyTorch.pdf)]

</div>

## Setup

```
pip install auraloss
```

If you want to use `MelSTFTLoss()` or `FIRFilter()` you will need to specify the extra install (librosa and scipy).

```
pip install auraloss[all]
```

## Usage

```python
import torch
import auraloss

mrstft = auraloss.freq.MultiResolutionSTFTLoss()

input = torch.rand(8,1,44100)
target = torch.rand(8,1,44100)

loss = mrstft(input, target)
```

**NEW**: Perceptual weighting with mel scaled spectrograms.

```python

bs = 8
chs = 1
seq_len = 131072
sample_rate = 44100

# some audio you want to compare
target = torch.rand(bs, chs, seq_len)
pred = torch.rand(bs, chs, seq_len)

# define the loss function
loss_fn = auraloss.freq.MultiResolutionSTFTLoss(
    fft_sizes=[1024, 2048, 8192],
    hop_sizes=[256, 512, 2048],
    win_lengths=[1024, 2048, 8192],
    scale="mel",
    n_bins=128,
    sample_rate=sample_rate,
    perceptual_weighting=True,
)

# compute
loss = loss_fn(pred, target)

```

## Citation
If you use this code in your work please consider citing us.
```bibtex
@inproceedings{steinmetz2020auraloss,
    title={auraloss: {A}udio focused loss functions in {PyTorch}},
    author={Steinmetz, Christian J. and Reiss, Joshua D.},
    booktitle={Digital Music Research Network One-day Workshop (DMRN+15)},
    year={2020}
}
```


# Loss functions

We categorize the loss functions as either time-domain or frequency-domain approaches. 
Additionally, we include perceptual transforms.

<table>
    <tr>
        <th>Loss function</th>
        <th>Interface</th>
        <th>Reference</th>
    </tr>
    <tr>
        <td colspan="3" align="center"><b>Time domain</b></td>
    </tr>
    <tr>
        <td>Error-to-signal ratio (ESR)</td>
        <td><code>auraloss.time.ESRLoss()</code></td>
        <td><a href=https://arxiv.org/abs/1911.08922>Wright & Välimäki, 2019</a></td>
    </tr>
    <tr>
        <td>DC error (DC)</td>
        <td><code>auraloss.time.DCLoss()</code></td>
        <td><a href=https://arxiv.org/abs/1911.08922>Wright & Välimäki, 2019</a></td>
    </tr>
    <tr>
        <td>Log hyperbolic cosine (Log-cosh)</td>
        <td><code>auraloss.time.LogCoshLoss()</code></td>
        <td><a href=https://openreview.net/forum?id=rkglvsC9Ym>Chen et al., 2019</a></td>
    </tr>
    <tr>
        <td>Signal-to-noise ratio (SNR)</td>
        <td><code>auraloss.time.SNRLoss()</code></td>
        <td></td>
    </tr>
    <tr>
        <td>Scale-invariant signal-to-distortion <br>  ratio (SI-SDR)</td>
        <td><code>auraloss.time.SISDRLoss()</code></td>
        <td><a href=https://arxiv.org/abs/1811.02508>Le Roux et al., 2018</a></td>
    </tr>
    <tr>
        <td>Scale-dependent signal-to-distortion <br>  ratio (SD-SDR)</td>
        <td><code>auraloss.time.SDSDRLoss()</code></td>
        <td><a href=https://arxiv.org/abs/1811.02508>Le Roux et al., 2018</a></td>
    </tr>
    <tr>
        <td colspan="3" align="center"><b>Frequency domain</b></td>
    </tr>
    <tr>
        <td>Aggregate STFT</td>
        <td><code>auraloss.freq.STFTLoss()</code></td>
        <td><a href=https://arxiv.org/abs/1808.06719>Arik et al., 2018</a></td>
    </tr>
    <tr>
        <td>Aggregate Mel-scaled STFT</td>
        <td><code>auraloss.freq.MelSTFTLoss(sample_rate)</code></td>
        <td></td>
    </tr>
    <tr>
        <td>Multi-resolution STFT</td>
        <td><code>auraloss.freq.MultiResolutionSTFTLoss()</code></td>
        <td><a href=https://arxiv.org/abs/1910.11480>Yamamoto et al., 2019*</a></td>
    </tr>
    <tr>
        <td>Random-resolution STFT</td>
        <td><code>auraloss.freq.RandomResolutionSTFTLoss()</code></td>
        <td><a href=https://www.christiansteinmetz.com/s/DMRN15__auraloss__Audio_focused_loss_functions_in_PyTorch.pdf>Steinmetz & Reiss, 2020</a></td>
    </tr>
    <tr>
        <td>Sum and difference STFT loss</td>
        <td><code>auraloss.freq.SumAndDifferenceSTFTLoss()</code></td>
        <td><a href=https://arxiv.org/abs/2010.10291>Steinmetz et al., 2020</a></td>
    </tr>
    <tr>
        <td colspan="3" align="center"><b>Perceptual transforms</b></td>
    </tr>
    <tr>
        <td>Sum and difference signal transform</td>
        <td><code>auraloss.perceptual.SumAndDifference()</code></td>
        <td><a href=#></a></td>
    </tr>
    <tr>
        <td>FIR pre-emphasis filters</td>
        <td><code>auraloss.perceptual.FIRFilter()</code></td>
        <td><a href=https://arxiv.org/abs/1911.08922>Wright & Välimäki, 2019</a></td>
    </tr>
</table>

\* [Wang et al., 2019](https://arxiv.org/abs/1904.12088) also propose a multi-resolution spectral loss (that [Engel et al., 2020](https://arxiv.org/abs/2001.04643) follow), 
but they do not include both the log magnitude (L1 distance) and spectral convergence terms, introduced in [Arik et al., 2018](https://arxiv.org/abs/1808.0671), and then extended for the multi-resolution case in [Yamamoto et al., 2019](https://arxiv.org/abs/1910.11480).

## Examples

Currently we include an example using a set of the loss functions to train a TCN for modeling an analog dynamic range compressor. 
For details please refer to the details in [`examples/compressor`](examples/compressor). 
We provide pre-trained models, evaluation scripts to compute the metrics in the [paper](https://www.christiansteinmetz.com/s/DMRN15__auraloss__Audio_focused_loss_functions_in_PyTorch.pdf), as well as scripts to retrain models. 

There are some more advanced things you can do based upon the `STFTLoss` class. 
For example, you can compute both linear and log scaled STFT errors as in [Engel et al., 2020](https://arxiv.org/abs/2001.04643).
In this case we do not include the spectral convergence term. 
```python
stft_loss = auraloss.freq.STFTLoss(
    w_log_mag=1.0, 
    w_lin_mag=1.0, 
    w_sc=0.0,
)
```

There is also a Mel-scaled STFT loss, which has some special requirements. 
This loss requires you set the sample rate as well as specify the correct device. 
```python
sample_rate = 44100
melstft_loss = auraloss.freq.MelSTFTLoss(sample_rate, device="cuda")
```

You can also build a multi-resolution Mel-scaled STFT loss with 64 bins easily. 
Make sure you pass the correct device where the tensors you are comparing will be. 
```python
loss_fn = auraloss.freq.MultiResolutionSTFTLoss(
    scale="mel", 
    n_bins=64,
    sample_rate=sample_rate,
    device="cuda"
)
```

If you are computing a loss on stereo audio you may want to consider the sum and difference (mid/side) loss. 
Below we have shown an example of using this loss function with the perceptual weighting and mel scaling for 
further perceptual relevance. 

```python

target = torch.rand(8, 2, 44100)
pred = torch.rand(8, 2, 44100)

loss_fn = auraloss.freq.SumAndDifferenceSTFTLoss(
    fft_sizes=[1024, 2048, 8192],
    hop_sizes=[256, 512, 2048],
    win_lengths=[1024, 2048, 8192],
    perceptual_weighting=True,
    sample_rate=44100,
    scale="mel",
    n_bins=128,
)

loss = loss_fn(pred, target)
```

# Development

Run tests locally with pytest. 

```python -m pytest```

            

Raw data

            {
    "_id": null,
    "home_page": "https://github.com/csteinmetz1/auraloss",
    "name": "auraloss",
    "maintainer": "",
    "docs_url": null,
    "requires_python": ">=3.6.0",
    "maintainer_email": "",
    "keywords": "",
    "author": "Christian Steinmetz",
    "author_email": "c.j.steinmetz@qmul.ac.uk",
    "download_url": "https://files.pythonhosted.org/packages/ec/55/2a3bbafa3e947b972c81f852c6c31cc359a21848908f2aabd35b34b532e9/auraloss-0.4.0.tar.gz",
    "platform": null,
    "description": "\n<div  align=\"center\">\n\n# auraloss\n\n<img width=\"200px\" src=\"docs/auraloss-logo.svg\">\n\nA collection of audio-focused loss functions in PyTorch. \n\n[[PDF](https://www.christiansteinmetz.com/s/DMRN15__auraloss__Audio_focused_loss_functions_in_PyTorch.pdf)]\n\n</div>\n\n## Setup\n\n```\npip install auraloss\n```\n\nIf you want to use `MelSTFTLoss()` or `FIRFilter()` you will need to specify the extra install (librosa and scipy).\n\n```\npip install auraloss[all]\n```\n\n## Usage\n\n```python\nimport torch\nimport auraloss\n\nmrstft = auraloss.freq.MultiResolutionSTFTLoss()\n\ninput = torch.rand(8,1,44100)\ntarget = torch.rand(8,1,44100)\n\nloss = mrstft(input, target)\n```\n\n**NEW**: Perceptual weighting with mel scaled spectrograms.\n\n```python\n\nbs = 8\nchs = 1\nseq_len = 131072\nsample_rate = 44100\n\n# some audio you want to compare\ntarget = torch.rand(bs, chs, seq_len)\npred = torch.rand(bs, chs, seq_len)\n\n# define the loss function\nloss_fn = auraloss.freq.MultiResolutionSTFTLoss(\n    fft_sizes=[1024, 2048, 8192],\n    hop_sizes=[256, 512, 2048],\n    win_lengths=[1024, 2048, 8192],\n    scale=\"mel\",\n    n_bins=128,\n    sample_rate=sample_rate,\n    perceptual_weighting=True,\n)\n\n# compute\nloss = loss_fn(pred, target)\n\n```\n\n## Citation\nIf you use this code in your work please consider citing us.\n```bibtex\n@inproceedings{steinmetz2020auraloss,\n    title={auraloss: {A}udio focused loss functions in {PyTorch}},\n    author={Steinmetz, Christian J. and Reiss, Joshua D.},\n    booktitle={Digital Music Research Network One-day Workshop (DMRN+15)},\n    year={2020}\n}\n```\n\n\n# Loss functions\n\nWe categorize the loss functions as either time-domain or frequency-domain approaches. \nAdditionally, we include perceptual transforms.\n\n<table>\n    <tr>\n        <th>Loss function</th>\n        <th>Interface</th>\n        <th>Reference</th>\n    </tr>\n    <tr>\n        <td colspan=\"3\" align=\"center\"><b>Time domain</b></td>\n    </tr>\n    <tr>\n        <td>Error-to-signal ratio (ESR)</td>\n        <td><code>auraloss.time.ESRLoss()</code></td>\n        <td><a href=https://arxiv.org/abs/1911.08922>Wright & V\u00e4lim\u00e4ki, 2019</a></td>\n    </tr>\n    <tr>\n        <td>DC error (DC)</td>\n        <td><code>auraloss.time.DCLoss()</code></td>\n        <td><a href=https://arxiv.org/abs/1911.08922>Wright & V\u00e4lim\u00e4ki, 2019</a></td>\n    </tr>\n    <tr>\n        <td>Log hyperbolic cosine (Log-cosh)</td>\n        <td><code>auraloss.time.LogCoshLoss()</code></td>\n        <td><a href=https://openreview.net/forum?id=rkglvsC9Ym>Chen et al., 2019</a></td>\n    </tr>\n    <tr>\n        <td>Signal-to-noise ratio (SNR)</td>\n        <td><code>auraloss.time.SNRLoss()</code></td>\n        <td></td>\n    </tr>\n    <tr>\n        <td>Scale-invariant signal-to-distortion <br>  ratio (SI-SDR)</td>\n        <td><code>auraloss.time.SISDRLoss()</code></td>\n        <td><a href=https://arxiv.org/abs/1811.02508>Le Roux et al., 2018</a></td>\n    </tr>\n    <tr>\n        <td>Scale-dependent signal-to-distortion <br>  ratio (SD-SDR)</td>\n        <td><code>auraloss.time.SDSDRLoss()</code></td>\n        <td><a href=https://arxiv.org/abs/1811.02508>Le Roux et al., 2018</a></td>\n    </tr>\n    <tr>\n        <td colspan=\"3\" align=\"center\"><b>Frequency domain</b></td>\n    </tr>\n    <tr>\n        <td>Aggregate STFT</td>\n        <td><code>auraloss.freq.STFTLoss()</code></td>\n        <td><a href=https://arxiv.org/abs/1808.06719>Arik et al., 2018</a></td>\n    </tr>\n    <tr>\n        <td>Aggregate Mel-scaled STFT</td>\n        <td><code>auraloss.freq.MelSTFTLoss(sample_rate)</code></td>\n        <td></td>\n    </tr>\n    <tr>\n        <td>Multi-resolution STFT</td>\n        <td><code>auraloss.freq.MultiResolutionSTFTLoss()</code></td>\n        <td><a href=https://arxiv.org/abs/1910.11480>Yamamoto et al., 2019*</a></td>\n    </tr>\n    <tr>\n        <td>Random-resolution STFT</td>\n        <td><code>auraloss.freq.RandomResolutionSTFTLoss()</code></td>\n        <td><a href=https://www.christiansteinmetz.com/s/DMRN15__auraloss__Audio_focused_loss_functions_in_PyTorch.pdf>Steinmetz & Reiss, 2020</a></td>\n    </tr>\n    <tr>\n        <td>Sum and difference STFT loss</td>\n        <td><code>auraloss.freq.SumAndDifferenceSTFTLoss()</code></td>\n        <td><a href=https://arxiv.org/abs/2010.10291>Steinmetz et al., 2020</a></td>\n    </tr>\n    <tr>\n        <td colspan=\"3\" align=\"center\"><b>Perceptual transforms</b></td>\n    </tr>\n    <tr>\n        <td>Sum and difference signal transform</td>\n        <td><code>auraloss.perceptual.SumAndDifference()</code></td>\n        <td><a href=#></a></td>\n    </tr>\n    <tr>\n        <td>FIR pre-emphasis filters</td>\n        <td><code>auraloss.perceptual.FIRFilter()</code></td>\n        <td><a href=https://arxiv.org/abs/1911.08922>Wright & V\u00e4lim\u00e4ki, 2019</a></td>\n    </tr>\n</table>\n\n\\* [Wang et al., 2019](https://arxiv.org/abs/1904.12088) also propose a multi-resolution spectral loss (that [Engel et al., 2020](https://arxiv.org/abs/2001.04643) follow), \nbut they do not include both the log magnitude (L1 distance) and spectral convergence terms, introduced in [Arik et al., 2018](https://arxiv.org/abs/1808.0671), and then extended for the multi-resolution case in [Yamamoto et al., 2019](https://arxiv.org/abs/1910.11480).\n\n## Examples\n\nCurrently we include an example using a set of the loss functions to train a TCN for modeling an analog dynamic range compressor. \nFor details please refer to the details in [`examples/compressor`](examples/compressor). \nWe provide pre-trained models, evaluation scripts to compute the metrics in the [paper](https://www.christiansteinmetz.com/s/DMRN15__auraloss__Audio_focused_loss_functions_in_PyTorch.pdf), as well as scripts to retrain models. \n\nThere are some more advanced things you can do based upon the `STFTLoss` class. \nFor example, you can compute both linear and log scaled STFT errors as in [Engel et al., 2020](https://arxiv.org/abs/2001.04643).\nIn this case we do not include the spectral convergence term. \n```python\nstft_loss = auraloss.freq.STFTLoss(\n    w_log_mag=1.0, \n    w_lin_mag=1.0, \n    w_sc=0.0,\n)\n```\n\nThere is also a Mel-scaled STFT loss, which has some special requirements. \nThis loss requires you set the sample rate as well as specify the correct device. \n```python\nsample_rate = 44100\nmelstft_loss = auraloss.freq.MelSTFTLoss(sample_rate, device=\"cuda\")\n```\n\nYou can also build a multi-resolution Mel-scaled STFT loss with 64 bins easily. \nMake sure you pass the correct device where the tensors you are comparing will be. \n```python\nloss_fn = auraloss.freq.MultiResolutionSTFTLoss(\n    scale=\"mel\", \n    n_bins=64,\n    sample_rate=sample_rate,\n    device=\"cuda\"\n)\n```\n\nIf you are computing a loss on stereo audio you may want to consider the sum and difference (mid/side) loss. \nBelow we have shown an example of using this loss function with the perceptual weighting and mel scaling for \nfurther perceptual relevance. \n\n```python\n\ntarget = torch.rand(8, 2, 44100)\npred = torch.rand(8, 2, 44100)\n\nloss_fn = auraloss.freq.SumAndDifferenceSTFTLoss(\n    fft_sizes=[1024, 2048, 8192],\n    hop_sizes=[256, 512, 2048],\n    win_lengths=[1024, 2048, 8192],\n    perceptual_weighting=True,\n    sample_rate=44100,\n    scale=\"mel\",\n    n_bins=128,\n)\n\nloss = loss_fn(pred, target)\n```\n\n# Development\n\nRun tests locally with pytest. \n\n```python -m pytest```\n",
    "bugtrack_url": null,
    "license": "Apache License 2.0",
    "summary": "Collection of audio-focused loss functions in PyTorch.",
    "version": "0.4.0",
    "split_keywords": [],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "a6ab8df927d3f0951cf67ca5973d89b35bcbda1777a4c78bf90a853d02d91285",
                "md5": "f4bfb385f27ddbba9b1202419e057601",
                "sha256": "7ca1cfff0d04db9c1269038a1c527fc38bc4756dd33bfff115889a3461d97d37"
            },
            "downloads": -1,
            "filename": "auraloss-0.4.0-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "f4bfb385f27ddbba9b1202419e057601",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.6.0",
            "size": 16743,
            "upload_time": "2023-04-21T09:21:44",
            "upload_time_iso_8601": "2023-04-21T09:21:44.905480Z",
            "url": "https://files.pythonhosted.org/packages/a6/ab/8df927d3f0951cf67ca5973d89b35bcbda1777a4c78bf90a853d02d91285/auraloss-0.4.0-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "ec552a3bbafa3e947b972c81f852c6c31cc359a21848908f2aabd35b34b532e9",
                "md5": "1ce3f70345d0b08511aa5234f31ddb64",
                "sha256": "86eb3bce81aaf579d9e2df59b0cae7ab8de793902c344af72f255d53c3d5c954"
            },
            "downloads": -1,
            "filename": "auraloss-0.4.0.tar.gz",
            "has_sig": false,
            "md5_digest": "1ce3f70345d0b08511aa5234f31ddb64",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.6.0",
            "size": 17640,
            "upload_time": "2023-04-21T09:21:46",
            "upload_time_iso_8601": "2023-04-21T09:21:46.686529Z",
            "url": "https://files.pythonhosted.org/packages/ec/55/2a3bbafa3e947b972c81f852c6c31cc359a21848908f2aabd35b34b532e9/auraloss-0.4.0.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2023-04-21 09:21:46",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "github_user": "csteinmetz1",
    "github_project": "auraloss",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "lcname": "auraloss"
}
        
Elapsed time: 0.15406s