tyxe


Nametyxe JSON
Version 0.0.2 PyPI version JSON
download
home_pagehttps://github.com/emaballarin/TyXe
SummaryBNNs for PyTorch using Pyro
upload_time2023-01-30 23:04:42
maintainer
docs_urlNone
author['Hippolyt Ritter', 'Theofanis Karaletsos']
requires_python>=3.8
licenseMIT
keywords deep learning machine learning bayesian deep learning
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # TyXe: Pyro-based BNNs for Pytorch users

TyXe aims to simplify the process of turning [Pytorch](www.pytorch.org) neural networks into Bayesian neural networks by
leveraging the model definition and inference capabilities of [Pyro](www.pyro.ai).
Our core design principle is to cleanly separate the construction of neural architecture, prior, inference distribution
and likelihood, enabling a flexible workflow where each component can be exchanged independently.
Defining a BNN in TyXe takes as little as 5 lines of code:
```
net = nn.Sequential(nn.Linear(1, 50), nn.Tanh(), nn.Linear(50, 1))
prior = tyxe.priors.IIDPrior(dist.Normal(0, 1))
likelihood = tyxe.likelihoods.HomoskedasticGaussian(scale=0.1)
inference = tyxe.guides.AutoNormal
bnn = tyxe.VariationalBNN(net, prior, likelihood, inference)
```

In the following, we assume that you (roughly) know what a BNN is mathematically.


## Motivating example
Standard neural networks give us a single function that fits the data, but many different ones are typically plausible.
With only a single fit, we don't know for what inputs the model is 'certain' (because there is training data nearby) and
where it is uncertain.

| ![ML](blob/regression_ml.png) | ![Samples](blob/regression_samples.png) |
|:---:|:---:|
| Maximum likelihood fit | Posterior samples |

Implementing the former can be achieved easily in a few lines of Pytorch code, but training a BNN that gives a
distribution over different fits is typically more complicated and is specifically what we aim to simplify.

## Training

Constructing a BNN object has been shown in the example above. 
For fitting the posterior approximation, we provide a high-level `.fit` method similar to libraries such as scikit-learn
or keras:

```
optim = pyro.optim.Adam({"lr": 1e-3})
bnn.fit(data_loader, optim, num_epochs)
```

## Prediction & evaluation

Further we provide `.predict` and `.evaluation` methods, which make predictions based on multiple samples from the approximate posterior, average them based on the observation model, and return log likelihoods and an error measure:
```
predictions = bnn.predict(x_test, num_samples)
error, log_likelihood = bnn.evaluate(x_test, y_test, num_samples)
```

## Local reparameterization

We implement [local reparameterization](https://arxiv.org/abs/1506.02557) for factorized Gaussians as a poutine, which reduces gradient noise during training.
This means it can be enabled or disabled at both during training and prediction with a context manager:
```
with tyxe.poutine.local_reparameterization():
    bnn.fit(data_loader, optim, num_epochs)
    bnn.predict(x_test, num_predictions)
```
At the moment, this poutine does not work with the `AutoNormal` and `AutoDiagonalNormal` guides in pyro, since those draw the weights from a Delta distribution, so you need to use `tyxe.guides.ParameterwiseDiagonalNormal` as your guide. 

## MCMC

We provide a unified interface to pyro's MCMC implementations, simply use the `tyxe.MCMC_BNN` class instead and provide a kernel instead of the guide:
```
kernel = pyro.infer.mcmcm.NUTS
bnn = tyxe.MCMC_BNN(net, prior, likelihood, kernel)
```
Any parameters that pyro's `MCMC` class accepts can be passed through the keyword arguments of the `.fit` method.

## Continual learning

Due to our design that cleanly separates the prior from guide, architecture and likelihood, it is easy to update it in a continual setting.
For example, you can construct a `tyxe.priors.DictPrior` by extracting the distributions over all weights and biases from a `ParameterwiseDiagonalNormal` instance using the `get_detached_distributions` method and pass it to `bnn.update_prior` to implement [Variational Continual Learning](https://arxiv.org/abs/1710.10628) in a few lines of code.
See `examples/vcl.py` for a basic example on split-MNIST and split-CIFAR. 

## Network architectures

We don't implement any layer classes.
You construct your network in Pytorch and then turn it into a BNN, which makes it easy to apply the same prior and inference strategies to different neural networks.
  
## Inference

For inference, we mainly provide an equivalent to pyro's `AutoDiagonalNormal` that is compatible with local reparameterization in `tyxe.guides`.
This module also contains a few helper functions for initialization of Gaussian mean parameters, e.g. to the values of a pre-trained network.
It should be possible to use any of pyro's autoguides for variational inference.
See `examples/resnet.py` for a few options as well as initializing to pre-trained weights.   

## Priors

The priors can be found in `tyxe.priors`.
We currently only support placing priors on the parameters.
Through the expose and hide arguments in the init method you can specify layers, types of layers and specific parameters over which you want to place a prior.
This helps, for example in learning the parameters of BatchNorm layers deterministically.
 

## Likelihoods

`tyxe.observation_models` contains classes that wrap the most common `torch.distributions` for specifying noise models of data to   


# Installation

We recommend installing TyXe using conda with the provided `environment.yml`, which also installs all the dependencies for the examples except for Pytorch3d, which needs to be added manually.
The environment assumes that you are using CUDA11.0, if this is not the case, simply change the `cudatoolkit` and `dgl-cuda` versions before running:
```
conda env create -f environment.yml
conda activate tyxe
pip install -e .
```

## Citation
If you use TyXe, please consider citing:
```
@inproceedings{ritter2022tyxe,
  title={Ty{X}e: Pyro-based {B}ayesian neural nets for {P}ytorch},
  author={Ritter, Hippolyt and Karaletsos, Theofanis},
  booktitle={Proceedings of Machine Learning and Systems},
  year={2022}
}
```

            

Raw data

            {
    "_id": null,
    "home_page": "https://github.com/emaballarin/TyXe",
    "name": "tyxe",
    "maintainer": "",
    "docs_url": null,
    "requires_python": ">=3.8",
    "maintainer_email": "",
    "keywords": "Deep Learning,Machine Learning,Bayesian Deep Learning",
    "author": "['Hippolyt Ritter', 'Theofanis Karaletsos']",
    "author_email": "j.ritter@cs.ucl.ac.uk",
    "download_url": "https://files.pythonhosted.org/packages/a8/82/173ada67f6e25f3eef1713bb66a103216915bde4edefdd19098a7ffdb501/tyxe-0.0.2.tar.gz",
    "platform": null,
    "description": "# TyXe: Pyro-based BNNs for Pytorch users\n\nTyXe aims to simplify the process of turning [Pytorch](www.pytorch.org) neural networks into Bayesian neural networks by\nleveraging the model definition and inference capabilities of [Pyro](www.pyro.ai).\nOur core design principle is to cleanly separate the construction of neural architecture, prior, inference distribution\nand likelihood, enabling a flexible workflow where each component can be exchanged independently.\nDefining a BNN in TyXe takes as little as 5 lines of code:\n```\nnet = nn.Sequential(nn.Linear(1, 50), nn.Tanh(), nn.Linear(50, 1))\nprior = tyxe.priors.IIDPrior(dist.Normal(0, 1))\nlikelihood = tyxe.likelihoods.HomoskedasticGaussian(scale=0.1)\ninference = tyxe.guides.AutoNormal\nbnn = tyxe.VariationalBNN(net, prior, likelihood, inference)\n```\n\nIn the following, we assume that you (roughly) know what a BNN is mathematically.\n\n\n## Motivating example\nStandard neural networks give us a single function that fits the data, but many different ones are typically plausible.\nWith only a single fit, we don't know for what inputs the model is 'certain' (because there is training data nearby) and\nwhere it is uncertain.\n\n| ![ML](blob/regression_ml.png) | ![Samples](blob/regression_samples.png) |\n|:---:|:---:|\n| Maximum likelihood fit | Posterior samples |\n\nImplementing the former can be achieved easily in a few lines of Pytorch code, but training a BNN that gives a\ndistribution over different fits is typically more complicated and is specifically what we aim to simplify.\n\n## Training\n\nConstructing a BNN object has been shown in the example above. \nFor fitting the posterior approximation, we provide a high-level `.fit` method similar to libraries such as scikit-learn\nor keras:\n\n```\noptim = pyro.optim.Adam({\"lr\": 1e-3})\nbnn.fit(data_loader, optim, num_epochs)\n```\n\n## Prediction & evaluation\n\nFurther we provide `.predict` and `.evaluation` methods, which make predictions based on multiple samples from the approximate posterior, average them based on the observation model, and return log likelihoods and an error measure:\n```\npredictions = bnn.predict(x_test, num_samples)\nerror, log_likelihood = bnn.evaluate(x_test, y_test, num_samples)\n```\n\n## Local reparameterization\n\nWe implement [local reparameterization](https://arxiv.org/abs/1506.02557) for factorized Gaussians as a poutine, which reduces gradient noise during training.\nThis means it can be enabled or disabled at both during training and prediction with a context manager:\n```\nwith tyxe.poutine.local_reparameterization():\n    bnn.fit(data_loader, optim, num_epochs)\n    bnn.predict(x_test, num_predictions)\n```\nAt the moment, this poutine does not work with the `AutoNormal` and `AutoDiagonalNormal` guides in pyro, since those draw the weights from a Delta distribution, so you need to use `tyxe.guides.ParameterwiseDiagonalNormal` as your guide. \n\n## MCMC\n\nWe provide a unified interface to pyro's MCMC implementations, simply use the `tyxe.MCMC_BNN` class instead and provide a kernel instead of the guide:\n```\nkernel = pyro.infer.mcmcm.NUTS\nbnn = tyxe.MCMC_BNN(net, prior, likelihood, kernel)\n```\nAny parameters that pyro's `MCMC` class accepts can be passed through the keyword arguments of the `.fit` method.\n\n## Continual learning\n\nDue to our design that cleanly separates the prior from guide, architecture and likelihood, it is easy to update it in a continual setting.\nFor example, you can construct a `tyxe.priors.DictPrior` by extracting the distributions over all weights and biases from a `ParameterwiseDiagonalNormal` instance using the `get_detached_distributions` method and pass it to `bnn.update_prior` to implement [Variational Continual Learning](https://arxiv.org/abs/1710.10628) in a few lines of code.\nSee `examples/vcl.py` for a basic example on split-MNIST and split-CIFAR. \n\n## Network architectures\n\nWe don't implement any layer classes.\nYou construct your network in Pytorch and then turn it into a BNN, which makes it easy to apply the same prior and inference strategies to different neural networks.\n  \n## Inference\n\nFor inference, we mainly provide an equivalent to pyro's `AutoDiagonalNormal` that is compatible with local reparameterization in `tyxe.guides`.\nThis module also contains a few helper functions for initialization of Gaussian mean parameters, e.g. to the values of a pre-trained network.\nIt should be possible to use any of pyro's autoguides for variational inference.\nSee `examples/resnet.py` for a few options as well as initializing to pre-trained weights.   \n\n## Priors\n\nThe priors can be found in `tyxe.priors`.\nWe currently only support placing priors on the parameters.\nThrough the expose and hide arguments in the init method you can specify layers, types of layers and specific parameters over which you want to place a prior.\nThis helps, for example in learning the parameters of BatchNorm layers deterministically.\n \n\n## Likelihoods\n\n`tyxe.observation_models` contains classes that wrap the most common `torch.distributions` for specifying noise models of data to   \n\n\n# Installation\n\nWe recommend installing TyXe using conda with the provided `environment.yml`, which also installs all the dependencies for the examples except for Pytorch3d, which needs to be added manually.\nThe environment assumes that you are using CUDA11.0, if this is not the case, simply change the `cudatoolkit` and `dgl-cuda` versions before running:\n```\nconda env create -f environment.yml\nconda activate tyxe\npip install -e .\n```\n\n## Citation\nIf you use TyXe, please consider citing:\n```\n@inproceedings{ritter2022tyxe,\n  title={Ty{X}e: Pyro-based {B}ayesian neural nets for {P}ytorch},\n  author={Ritter, Hippolyt and Karaletsos, Theofanis},\n  booktitle={Proceedings of Machine Learning and Systems},\n  year={2022}\n}\n```\n",
    "bugtrack_url": null,
    "license": "MIT",
    "summary": "BNNs for PyTorch using Pyro",
    "version": "0.0.2",
    "split_keywords": [
        "deep learning",
        "machine learning",
        "bayesian deep learning"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "a0e046170ff5fb44fc46535355ee485ec9387cd319da2d840cde0b3b9113b7bf",
                "md5": "2dd2a367a9d96d2ded5955454cd85fb0",
                "sha256": "9b589c8a971bbcebbc983e08b206c5ea0375a7b2f8cb97c84137b64a81cadaf0"
            },
            "downloads": -1,
            "filename": "tyxe-0.0.2-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "2dd2a367a9d96d2ded5955454cd85fb0",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.8",
            "size": 22974,
            "upload_time": "2023-01-30T23:04:40",
            "upload_time_iso_8601": "2023-01-30T23:04:40.721675Z",
            "url": "https://files.pythonhosted.org/packages/a0/e0/46170ff5fb44fc46535355ee485ec9387cd319da2d840cde0b3b9113b7bf/tyxe-0.0.2-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "a882173ada67f6e25f3eef1713bb66a103216915bde4edefdd19098a7ffdb501",
                "md5": "1cc44c1443bce7bc08ea2862d7fb1929",
                "sha256": "73680e40c17f2a7010907a46e9fef7e90b74b4d2e57e6b3856f06acf655d4070"
            },
            "downloads": -1,
            "filename": "tyxe-0.0.2.tar.gz",
            "has_sig": false,
            "md5_digest": "1cc44c1443bce7bc08ea2862d7fb1929",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.8",
            "size": 22519,
            "upload_time": "2023-01-30T23:04:42",
            "upload_time_iso_8601": "2023-01-30T23:04:42.094362Z",
            "url": "https://files.pythonhosted.org/packages/a8/82/173ada67f6e25f3eef1713bb66a103216915bde4edefdd19098a7ffdb501/tyxe-0.0.2.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2023-01-30 23:04:42",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "github_user": "emaballarin",
    "github_project": "TyXe",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "requirements": [],
    "lcname": "tyxe"
}
        
Elapsed time: 0.06923s