# Poisson Identifiable VAE (pi-VAE)
This is a Pytorch implementation of [Poisson Identifiable VAE (pi-VAE)](https://arxiv.org/abs/2011.04798), used to construct latent variable models of neural activity while simultaneously modeling the relation between the latent and task variables (non-neural variables, e.g. sensory, motor, and other externally observable states).
The original implementation by [Dr. Ding Zhou](https://zhd96.github.io/) and [Dr. Xue-Xin Wei](https://sites.google.com/view/xxweineuraltheory/) in Tensorflow 1.13 is available [here](https://github.com/zhd96/pi-vae).
Another Pytorch implementation by [Dr. Lyndon Duong](http://lyndonduong.com/) is available [here](https://github.com/lyndond/lyndond.github.io/blob/0865902edb4648a8690ed8d449573d9236a72406/code/2021-11-25-pivae.ipynb).
## Install
```
pip install pi-vae-pytorch
```
## Usage
```
import torch
from pi_vae_pytorch import PiVAE
model = PiVAE(
x_dim = 100,
u_dim = 3,
z_dim = 2,
discrete_labels=False
)
x = torch.randn(1, 100) # Size([n_samples, x_dim])
u = torch.randn(1, 3) # Size([n_samples, u_dim])
outputs = model(x, u) # dict
```
### Parameters
- `x_dim`: int
Dimension of observation `x`
- `u_dim`: int
Dimension of label `u`
- `z_dim`: int
Dimension of latent `z`
- `discrete_labels`: bool
- Default: `True`
Flag denoting `u`'s label type - `True`: discrete or `False`: continuous.
- `encoder_n_hidden_layers`: int
- Default: `2`
Number of hidden layers in the MLP of the model's encoder.
- `encoder_hidden_layer_dim`: int
- Default: `120`
Dimensionality of each hidden layer in the MLP of the model's encoder.
- `encoder_hidden_layer_activation`: nn.Module
- Default: `nn.Tanh`
Activation function applied to the outputs of each hidden layer in the MLP of the model's encoder.
- `decoder_n_gin_blocks`: int
- Default: `2`
Number of GIN blocks used within the model's decoder.
- `decoder_gin_block_depth`: int
- Default: `2`
Number of AffineCouplingLayers which comprise each GIN block.
- `decoder_affine_input_layer_slice_dim`: int
- Default None (corresponds to `x_dim / 2`)
Index at which to split an n-dimensional input x.
- `decoder_affine_n_hidden_layers`: int
- Default: `2`
Number of hidden layers in the MLP of the model's encoder.
- `decoder_affine_hidden_layer_dim`: int
- Default: `None` (corresponds to `x_dim / 4`)
Dimensionality of each hidden layer in the MLP of each AffineCouplingLayer.
- `decoder_affine_hidden_layer_activation`: nn.Module
- Default: `nn.ReLU`
Activation function applied to the outputs of each hidden layer in the MLP of each AffineCouplingLayer.
- `decoder_nflow_n_hidden_layers`: int
- Default: `2`
Number of hidden layers in the MLP of the decoder's NFlowLayer.
- `decoder_nflow_hidden_layer_dim`: int
- Default: `None` (corresponds to `x_dim / 4`)
Dimensionality of each hidden layer in the MLP of the decoder's NFlowLayer.
- `decoder_nflow_hidden_layer_activation`: nn.Module
- Default: `nn.ReLU`
Activation function applied to the outputs of each hidden layer in the MLP of the decoder's NFlowLayer.
- `decoder_observation_model`: str
- Default: `poisson`
- One of `gaussian` or `poisson`
Observation model used by the model's decoder.
- `decoder_fr_clamp_min`: float
- Default: `1E-7`
- Only applied when `decoder_observation_model="poisson"`
Mininimum threshold used when clamping decoded firing rates.
- `decoder_fr_clamp_max`: float
- Default: `1E7`
- Only applied when `decoder_observation_model="poisson"`
Maximum threshold used when clamping decoded firing rates.
- `z_prior_n_hidden_layers`: int
- Default: `2`
- Only applied when `discrete_labels=False`
Number of hidden layers in the MLP of the ZPriorContinuous module.
- `z_prior_hidden_layer_dim`: int
- Default: `20`
- Only applied when `discrete_labels=False`
Dimensionality of each hidden layer in the MLP of the ZPriorContinuous module.
- `z_prior_hidden_layer_activation`: nn.Module
- Default: `nn.Tanh`
- Only applied when `discrete_labels=False`
Activation function applied to the outputs of each hidden layer in the MLP of the decoder's ZPriorContinuous module.
### Returns
A dicitonary with the following items.
- `firing_rate`: Tensor
- Size([n_samples, x_dim])
Predicted firing rates of `z_sample`.
- `lambda_mean`: Tensor
- Size([n_samples, z_dim])
Mean for each sample using label prior p(z \| u).
- `lambda_log_variance`: Tensor
- Size([n_samples, z_dim])
Log of variance for each sample using label prior p(z \| u).
- `posterior_mean`: Tensor
- Size([n_samples, z_dim])
Mean for each sample using full posterior of q(z \| x,u) ~ q(z \| x) × p(z \| u).
- `posterior_log_variance`: Tensor
- Size([n_samples, z_dim])
Log of variance for each sample using full posterior of q(z \| x,u) ~ q(z \| x) × p(z \| u).
- `z_mean`: Tensor
- Size([n_samples, z_dim])
Mean for each sample using approximation of q(z \| x).
- `z_log_variance`: Tensor
- Size([n_samples, z_dim])
Log of variance for each sample using approximation of q(z \| x).
- `z_sample`: Tensor
- Size([n_samples, z_dim])
Generated latents `z`.
## Loss Function
### Poisson observation model
```
from pi_vae_pytorch.utils import compute_loss
outputs = model(x, u) # Initialized with decoder_observation_model="poisson"
loss = compute_loss(
x=x,
firing_rate=outputs["firing_rate"],
lambda_mean=outputs["lambda_mean"],
lambda_log_variance=outputs["lambda_log_variance"],
posterior_mean=outputs["posterior_mean"],
posterior_log_variance=outputs["posterior_log_variance"],
observation_model=model.decoder_observation_model
)
loss.backward()
```
### Gaussian observation model
```
from pi_vae_pytorch.utils import compute_loss
outputs = model(x, u) # Initialized with decoder_observation_model="gaussian"
loss = compute_loss(
x=x,
firing_rate=outputs["firing_rate"],
lambda_mean=outputs["lambda_mean"],
lambda_log_variance=outputs["lambda_log_variance"],
posterior_mean=outputs["posterior_mean"],
posterior_log_variance=outputs["posterior_log_variance"],
observation_model=model.decoder_observation_model,
observation_noise_model=model.observation_noise_model
)
loss.backward()
```
### Parameters
- `x`: Tensor
- Size([n_samples, x_dim])
Observations `x`.
- `firing_rate`: Tensor
- Size([n_samples, x_dim])
Predicted firing rate of generated latent `z`.
- `lambda_mean`: Tensor
- Size([n_samples, z_dim])
Means from label prior p(z \| u).
- `lambda_log_variance`: Tensor
- Size([n_samples, z_dim])
Log of variances from label prior p(z \| u).
- `posterior_mean`: Tensor
- Size([n_samples. z_dim])
Means from full posterior of q(z \| x,u) ~ q(z \| x) × p(z \| u).
- `posterior_log_variance`: Tensor
- Size([n_samples. z_dim])
Log of variances from full posterior of q(z \| x,u) ~ q(z \| x) × p(z \| u).
- `observation_model`: str
- One of `poisson` or `gaussian`
- Should use the same value passed to `decoder_observation_model` when initializing `PiVAE`.
The observation model used by pi-VAE's decoder.
- `observation_noise_model`: nn.Module
- Default: None
- Only applied when `observation model="gaussian"`
The noise model used when pi-VAE's decoder utilizes a Gaussian observation model. When `PiVAE` is initialized with `decoder_observation_model="gaussian"`, the model's `observation_noise_model` attribute can be used.
## Citation
```
@misc{zhou2020learning,
title={Learning identifiable and interpretable latent models of high-dimensional neural activity using pi-VAE},
author={Ding Zhou and Xue-Xin Wei},
year={2020},
eprint={2011.04798},
archivePrefix={arXiv},
primaryClass={stat.ML}
}
```
Raw data
{
"_id": null,
"home_page": null,
"name": "pi-vae-pytorch",
"maintainer": null,
"docs_url": null,
"requires_python": ">=3.9",
"maintainer_email": "Marlan McInnes-Taylor <mmcinnestaylor@gmail.com>",
"keywords": "vae, pi-vae, poisson identifiable vae, poisson identifiable variational autoencoder, identifiable vae, identifiable variational autoencoder",
"author": null,
"author_email": "Marlan McInnes-Taylor <mmcinnestaylor@gmail.com>",
"download_url": "https://files.pythonhosted.org/packages/66/7e/b52aec397fcfdbd547279a1711b25d3e8f55925f47023f283224d663be1b/pi-vae-pytorch-1.0.0b3.tar.gz",
"platform": null,
"description": "# Poisson Identifiable VAE (pi-VAE)\n\nThis is a Pytorch implementation of [Poisson Identifiable VAE (pi-VAE)](https://arxiv.org/abs/2011.04798), used to construct latent variable models of neural activity while simultaneously modeling the relation between the latent and task variables (non-neural variables, e.g. sensory, motor, and other externally observable states).\n\nThe original implementation by [Dr. Ding Zhou](https://zhd96.github.io/) and [Dr. Xue-Xin Wei](https://sites.google.com/view/xxweineuraltheory/) in Tensorflow 1.13 is available [here](https://github.com/zhd96/pi-vae).\n\nAnother Pytorch implementation by [Dr. Lyndon Duong](http://lyndonduong.com/) is available [here](https://github.com/lyndond/lyndond.github.io/blob/0865902edb4648a8690ed8d449573d9236a72406/code/2021-11-25-pivae.ipynb).\n\n## Install\n\n```\npip install pi-vae-pytorch\n```\n\n## Usage\n\n```\nimport torch\nfrom pi_vae_pytorch import PiVAE\n\nmodel = PiVAE(\n x_dim = 100,\n u_dim = 3,\n z_dim = 2,\n discrete_labels=False\n)\n\nx = torch.randn(1, 100) # Size([n_samples, x_dim])\n\nu = torch.randn(1, 3) # Size([n_samples, u_dim])\n\noutputs = model(x, u) # dict\n```\n\n### Parameters\n\n- `x_dim`: int \n Dimension of observation `x`\n- `u_dim`: int \n Dimension of label `u`\n- `z_dim`: int \n Dimension of latent `z`\n- `discrete_labels`: bool \n - Default: `True` \n\n Flag denoting `u`'s label type - `True`: discrete or `False`: continuous.\n- `encoder_n_hidden_layers`: int \n - Default: `2` \n\n Number of hidden layers in the MLP of the model's encoder. \n- `encoder_hidden_layer_dim`: int \n - Default: `120` \n\n Dimensionality of each hidden layer in the MLP of the model's encoder. \n- `encoder_hidden_layer_activation`: nn.Module \n - Default: `nn.Tanh` \n\n Activation function applied to the outputs of each hidden layer in the MLP of the model's encoder. \n- `decoder_n_gin_blocks`: int \n - Default: `2` \n\n Number of GIN blocks used within the model's decoder. \n- `decoder_gin_block_depth`: int \n - Default: `2` \n\n Number of AffineCouplingLayers which comprise each GIN block.\n- `decoder_affine_input_layer_slice_dim`: int \n - Default None (corresponds to `x_dim / 2`) \n\n Index at which to split an n-dimensional input x. \n- `decoder_affine_n_hidden_layers`: int \n - Default: `2` \n\n Number of hidden layers in the MLP of the model's encoder. \n- `decoder_affine_hidden_layer_dim`: int \n - Default: `None` (corresponds to `x_dim / 4`) \n\n Dimensionality of each hidden layer in the MLP of each AffineCouplingLayer. \n- `decoder_affine_hidden_layer_activation`: nn.Module \n - Default: `nn.ReLU` \n\n Activation function applied to the outputs of each hidden layer in the MLP of each AffineCouplingLayer. \n- `decoder_nflow_n_hidden_layers`: int \n - Default: `2` \n\n Number of hidden layers in the MLP of the decoder's NFlowLayer. \n- `decoder_nflow_hidden_layer_dim`: int \n - Default: `None` (corresponds to `x_dim / 4`) \n\n Dimensionality of each hidden layer in the MLP of the decoder's NFlowLayer. \n- `decoder_nflow_hidden_layer_activation`: nn.Module \n - Default: `nn.ReLU` \n\n Activation function applied to the outputs of each hidden layer in the MLP of the decoder's NFlowLayer. \n- `decoder_observation_model`: str \n - Default: `poisson` \n - One of `gaussian` or `poisson`\n\n Observation model used by the model's decoder. \n- `decoder_fr_clamp_min`: float \n - Default: `1E-7` \n - Only applied when `decoder_observation_model=\"poisson\"`\n\n Mininimum threshold used when clamping decoded firing rates.\n- `decoder_fr_clamp_max`: float \n - Default: `1E7` \n - Only applied when `decoder_observation_model=\"poisson\"`\n\n Maximum threshold used when clamping decoded firing rates.\n- `z_prior_n_hidden_layers`: int \n - Default: `2` \n - Only applied when `discrete_labels=False` \n\n Number of hidden layers in the MLP of the ZPriorContinuous module. \n- `z_prior_hidden_layer_dim`: int \n - Default: `20` \n - Only applied when `discrete_labels=False`\n\n Dimensionality of each hidden layer in the MLP of the ZPriorContinuous module. \n- `z_prior_hidden_layer_activation`: nn.Module \n - Default: `nn.Tanh` \n - Only applied when `discrete_labels=False`\n\n Activation function applied to the outputs of each hidden layer in the MLP of the decoder's ZPriorContinuous module. \n\n### Returns\n\nA dicitonary with the following items. \n\n- `firing_rate`: Tensor \n - Size([n_samples, x_dim]) \n\n Predicted firing rates of `z_sample`. \n- `lambda_mean`: Tensor \n - Size([n_samples, z_dim]) \n\n Mean for each sample using label prior p(z \\| u). \n- `lambda_log_variance`: Tensor \n - Size([n_samples, z_dim]) \n \n Log of variance for each sample using label prior p(z \\| u). \n- `posterior_mean`: Tensor \n - Size([n_samples, z_dim]) \n\n Mean for each sample using full posterior of q(z \\| x,u) ~ q(z \\| x) × p(z \\| u). \n- `posterior_log_variance`: Tensor \n - Size([n_samples, z_dim]) \n\n Log of variance for each sample using full posterior of q(z \\| x,u) ~ q(z \\| x) × p(z \\| u). \n- `z_mean`: Tensor \n - Size([n_samples, z_dim]) \n\n Mean for each sample using approximation of q(z \\| x). \n- `z_log_variance`: Tensor \n - Size([n_samples, z_dim]) \n\n Log of variance for each sample using approximation of q(z \\| x). \n- `z_sample`: Tensor \n - Size([n_samples, z_dim]) \n \n Generated latents `z`. \n\n## Loss Function\n\n### Poisson observation model\n\n```\nfrom pi_vae_pytorch.utils import compute_loss\n\noutputs = model(x, u) # Initialized with decoder_observation_model=\"poisson\" \n\nloss = compute_loss(\n x=x,\n firing_rate=outputs[\"firing_rate\"],\n lambda_mean=outputs[\"lambda_mean\"],\n lambda_log_variance=outputs[\"lambda_log_variance\"],\n posterior_mean=outputs[\"posterior_mean\"],\n posterior_log_variance=outputs[\"posterior_log_variance\"],\n observation_model=model.decoder_observation_model\n)\n\nloss.backward()\n```\n\n### Gaussian observation model\n\n```\nfrom pi_vae_pytorch.utils import compute_loss\n\noutputs = model(x, u) # Initialized with decoder_observation_model=\"gaussian\" \n\nloss = compute_loss(\n x=x,\n firing_rate=outputs[\"firing_rate\"],\n lambda_mean=outputs[\"lambda_mean\"],\n lambda_log_variance=outputs[\"lambda_log_variance\"],\n posterior_mean=outputs[\"posterior_mean\"],\n posterior_log_variance=outputs[\"posterior_log_variance\"],\n observation_model=model.decoder_observation_model,\n observation_noise_model=model.observation_noise_model\n)\n\nloss.backward()\n```\n\n### Parameters\n\n- `x`: Tensor \n - Size([n_samples, x_dim]) \n\n Observations `x`. \n- `firing_rate`: Tensor \n - Size([n_samples, x_dim]) \n\n Predicted firing rate of generated latent `z`. \n- `lambda_mean`: Tensor \n - Size([n_samples, z_dim]) \n \n Means from label prior p(z \\| u). \n- `lambda_log_variance`: Tensor \n - Size([n_samples, z_dim]) \n \n Log of variances from label prior p(z \\| u). \n- `posterior_mean`: Tensor \n - Size([n_samples. z_dim]) \n \n Means from full posterior of q(z \\| x,u) ~ q(z \\| x) × p(z \\| u). \n- `posterior_log_variance`: Tensor \n - Size([n_samples. z_dim]) \n \n Log of variances from full posterior of q(z \\| x,u) ~ q(z \\| x) × p(z \\| u).\n- `observation_model`: str \n - One of `poisson` or `gaussian` \n - Should use the same value passed to `decoder_observation_model` when initializing `PiVAE`. \n\n The observation model used by pi-VAE's decoder.\n- `observation_noise_model`: nn.Module \n - Default: None \n - Only applied when `observation model=\"gaussian\"` \n \n The noise model used when pi-VAE's decoder utilizes a Gaussian observation model. When `PiVAE` is initialized with `decoder_observation_model=\"gaussian\"`, the model's `observation_noise_model` attribute can be used.\n\n## Citation\n\n```\n@misc{zhou2020learning,\n title={Learning identifiable and interpretable latent models of high-dimensional neural activity using pi-VAE}, \n author={Ding Zhou and Xue-Xin Wei},\n year={2020},\n eprint={2011.04798},\n archivePrefix={arXiv},\n primaryClass={stat.ML}\n}\n```\n",
"bugtrack_url": null,
"license": null,
"summary": "A Pytorch implementation of Poisson Identifiable VAE (pi-VAE), a variational auto encoder used to construct latent variable models of neural activity while simultaneously modeling the relation between the latent and task variables.",
"version": "1.0.0b3",
"project_urls": {
"Homepage": "https://mmcinnestaylor.github.io/pi-vae-pytorch/",
"Repository": "https://github.com/mmcinnestaylor/pi-vae-pytorch"
},
"split_keywords": [
"vae",
" pi-vae",
" poisson identifiable vae",
" poisson identifiable variational autoencoder",
" identifiable vae",
" identifiable variational autoencoder"
],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "72c73e02ca699782846364fb37c6bf2a3798dafeecd6ebf750b226ee33d8bbcd",
"md5": "5933f0e9cda136e77ebb07ef36dd40f4",
"sha256": "9cf0242b1ccbafb3b2adbf4425e8d3b132b5a97a7c58a15a214b6a241397e8ac"
},
"downloads": -1,
"filename": "pi_vae_pytorch-1.0.0b3-py3-none-any.whl",
"has_sig": false,
"md5_digest": "5933f0e9cda136e77ebb07ef36dd40f4",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": ">=3.9",
"size": 12207,
"upload_time": "2024-04-03T02:43:19",
"upload_time_iso_8601": "2024-04-03T02:43:19.953417Z",
"url": "https://files.pythonhosted.org/packages/72/c7/3e02ca699782846364fb37c6bf2a3798dafeecd6ebf750b226ee33d8bbcd/pi_vae_pytorch-1.0.0b3-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": "",
"digests": {
"blake2b_256": "667eb52aec397fcfdbd547279a1711b25d3e8f55925f47023f283224d663be1b",
"md5": "a8cd2c7f2fc852f16d4c10ffc98c3f50",
"sha256": "dfaa4a52733b6a006bda10c5490c49c16a78becc4629dc559c5301a496cf91bb"
},
"downloads": -1,
"filename": "pi-vae-pytorch-1.0.0b3.tar.gz",
"has_sig": false,
"md5_digest": "a8cd2c7f2fc852f16d4c10ffc98c3f50",
"packagetype": "sdist",
"python_version": "source",
"requires_python": ">=3.9",
"size": 12882,
"upload_time": "2024-04-03T02:43:21",
"upload_time_iso_8601": "2024-04-03T02:43:21.444904Z",
"url": "https://files.pythonhosted.org/packages/66/7e/b52aec397fcfdbd547279a1711b25d3e8f55925f47023f283224d663be1b/pi-vae-pytorch-1.0.0b3.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2024-04-03 02:43:21",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "mmcinnestaylor",
"github_project": "pi-vae-pytorch",
"travis_ci": false,
"coveralls": false,
"github_actions": true,
"lcname": "pi-vae-pytorch"
}