pi-vae-pytorch


Namepi-vae-pytorch JSON
Version 1.0.0b6 PyPI version JSON
download
home_pageNone
SummaryA Pytorch implementation of Poisson Identifiable VAE (pi-VAE), a variational auto encoder used to construct latent variable models of neural activity while simultaneously modeling the relation between the latent and task variables.
upload_time2025-01-28 22:03:29
maintainerNone
docs_urlNone
authorNone
requires_python>=3.9
licenseNone
keywords vae variational autoencoder pi-vae poisson identifiable vae poisson identifiable variational autoencoder identifiable vae identifiable variational autoencoder
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # Poisson Identifiable VAE (pi-VAE) 2.0

![GitHub Actions Workflow Status](https://img.shields.io/github/actions/workflow/status/mmcinnestaylor/pi-vae-pytorch/publish-to-pypi.yml?logo=github&logoColor=white&label=Publish%20to%20PyPI)
![Python Version](https://img.shields.io/badge/python-3.9+-blue.svg)
![License](https://img.shields.io/pypi/l/pi-vae-pytorch)  
![PyPI - Version](https://img.shields.io/pypi/v/pi-vae-pytorch?label=pypi%20package)
![PyPI - Downloads](https://img.shields.io/pypi/dm/pi-vae-pytorch?label=pypi%20downloads)

This is a Pytorch implementation of [Poisson Identifiable Variational Autoencoder (pi-VAE)](https://arxiv.org/abs/2011.04798), used to construct latent variable models of neural activity while simultaneously modeling the relation between the latent and task variables (non-neural variables, e.g. sensory, motor, and other externally observable states).  

A special thank you to [Zhongxuan Wu](https://github.com/ZhongxuanWu) who helped in the design and testing of this implementation.  

### Model Versions

pi-VAE 1.0 and 2.0 differ solely in their loss function, specifically how the [Kullback–Leibler divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence) component of the loss is computed. Additional information is available in the *Loss Function - ELBOLoss* section of this documentation.

#### Version 2.0

- This codebase is the only known publically available implementation, and also includes an implementation of the Version 1.0 loss function.  

#### Version 1.0

- The original implementation by [Ding Zhou](https://zhd96.github.io/) and [Xue-Xin Wei](https://sites.google.com/view/xxweineuraltheory/) in Tensorflow 1.13 is available [here](https://github.com/zhd96/pi-vae).

- Another Pytorch implementation by [Lyndon Duong](http://lyndonduong.com/) is available [here](https://github.com/lyndond/lyndond.github.io/blob/0865902edb4648a8690ed8d449573d9236a72406/code/2021-11-25-pivae.ipynb).  



## Installation

It is possible to install this project using `pip`:
```
pip install pi-vae-pytorch
```

It is also possible to clone this repo and install it using `pip`: 
```
git clone https://github.com/mmcinnestaylor/pi-vae-pytorch.git
cd pip-vae-pytorch
pip install -e .
```

## Model Architecture

pi-VAE is comprised of three main components: the encoder, the label prior estimator, and the decoder. 

### MLP Structure

The Multi Layer Perceptron (MLP) is the primary building block of the aforementioned components. Each MLP used in this implementation is configurable by specifying the appropriate parameters when `PiVAE` is initialized:  
- number of hidden layers  
- hidden layer dimension  
    - applied to all hidden layers within a given MLP
- hidden layer activation function  
    - applied to all non-output layer activations within a given MLP

### Encoder

The model's encoder is comprised of a single MLP, which learns to approximate the distribution q(z \| x). 

### Label Prior Estimator

The model's label prior estimator learns to approximate the distribution p(z \| u). In the discrete label regime this module is comprised of two `nn.Embedding` submodules, while in the continuous label regime the module is comprised of a single MLP.

### Decoder

The model's decoder learns to map a latent sample `z` to its predicted firing rate in the model's observation space. Inputs to the decoder are passed through the following submodules. 

#### NFlowLayer

This module is comprised of a MLP which maps `z` to the concatenation of `z` and `t(z)`.

#### GINBlock

Outputs from the `NFlowLayer` are passed to a series of `GINBlock` modules. Each `GINBlock` is comprised of a specified number of `AffineCouplingLayer` modules. Each `AffineCouplingLayer` is comprised of a MLP and performs an affine coupling transformation.

## Initialization

```
pi_vae_pytorch.PiVAE(
    x_dim,
    u_dim,
    z_dim,
    discrete_labels=True,
    encoder_n_hidden_layers=2,
    encoder_hidden_layer_dim=128,
    encoder_hidden_layer_activation=nn.Tanh,
    decoder_n_gin_blocks=2,
    decoder_gin_block_depth=2,
    decoder_affine_input_layer_slice_dim=None,
    decoder_affine_n_hidden_layers=2,
    decoder_affine_hidden_layer_dim=None,
    decoder_affine_hidden_layer_activation=nn.ReLU,
    decoder_nflow_n_hidden_layers=2,
    decoder_nflow_hidden_layer_dim=None,
    decoder_nflow_hidden_layer_activation=nn.ReLU,
    decoder_observation_model='poisson',
    decoder_fr_clamp_min=1E-7,
    decoder_fr_clamp_max=1E7,
    label_prior_n_hidden_layers=2,
    label_prior_hidden_layer_dim=32,
    label_prior_hidden_layer_activation=nn.Tanh)
```

- **x_dim:** *int*  
    Dimension of observation `x`  

- **u_dim:** *int*  
    Dimension of observation labels `u`. In the discrete regime, this corresponds to the number of unique classes/labels. In the continuous regime, this corresponds to the dimension of each label.  

- **z_dim:** *int*  
    Dimension of latent `z`  

- **discrete_labels:** *bool, default=*`True`  
    - `True`: discrete or `False`: continuous  

    Flag denoting the observation's label regime. 
- **encoder_n_hidden_layers:** *int, default=*`2`  
    Number of hidden layers in the MLP of the model's encoder.  

- **encoder_hidden_layer_dim:** *int, default=*`128`  
    Dimensionality of each hidden layer in the MLP of the model's encoder.  

- **encoder_hidden_layer_activation:** *nn.Module, default=*`nn.Tanh`  
    Activation function applied to the outputs of each hidden layer in the MLP of the model's encoder.  

- **decoder_n_gin_blocks:** *int, default=*`2`  
    Number of GIN blocks used within the model's decoder.  

- **decoder_gin_block_depth:** *int, default=*`2`  
    Number of AffineCouplingLayers which comprise each GIN block.  

- **decoder_affine_input_layer_slice_dim:** *int, default=*`None` *(equivalent to* `x_dim // 2`*)*  
    Index at which to split an n-dimensional input x.  

- **decoder_affine_n_hidden_layers:** *int, default=*`2`  
    Number of hidden layers in the MLP of the model's encoder.  

- **decoder_affine_hidden_layer_dim:** *int, default=*`None` *(equivalent to* `x_dim // 4`*)*  
    Dimensionality of each hidden layer in the MLP of each AffineCouplingLayer.  

- **decoder_affine_hidden_layer_activation:** *nn.Module, default=*`nn.ReLU`  
    Activation function applied to the outputs of each hidden layer in the MLP of each AffineCouplingLayer.  

- **decoder_nflow_n_hidden_layers:** *int, default=*`2`  
    Number of hidden layers in the MLP of the decoder's NFlowLayer.  

- **decoder_nflow_hidden_layer_dim:** *int, default=*`None` *(equivalent to* `x_dim // 4`*)*  
    Dimensionality of each hidden layer in the MLP of the decoder's NFlowLayer.  

- **decoder_nflow_hidden_layer_activation:** *nn.Module, default=*`nn.ReLU`  
    Activation function applied to the outputs of each hidden layer in the MLP of the decoder's NFlowLayer.  

- **decoder_observation_model:** *str, default=*`'poisson'`  
    - One of `gaussian` or `poisson`

    Observation model used by the model's decoder. 
- **decoder_fr_clamp_min:** *float, default=*`1E-7`  
    - Only applied when `decoder_observation_model='poisson'`

    Mininimum threshold used when clamping decoded firing rates.
- **decoder_fr_clamp_max:** *float, default=*`1E7`  
    - Only applied when `decoder_observation_model='poisson'`

    Maximum threshold used when clamping decoded firing rates.
- **label_prior_n_hidden_layers:** *int, default=*`2`  
    - Only applied when `discrete_labels=False`  

    Number of hidden layers in the MLP of the label prior estimator module. 
- **label_prior_hidden_layer_dim:** *int, default=*`32`  
    - Only applied when `discrete_labels=False`

    Dimensionality of each hidden layer in the MLP of the label prior estimator module. 
- **label_prior_hidden_layer_activation:** *nn.Module, default=*`nn.Tanh`  
    - Only applied when `discrete_labels=False`

    Activation function applied to the outputs of each hidden layer in the MLP of the label prior estimator module. 

## Attributes

- **decoder:** *nn.Module*  
    The model's decoder module which projects a latent space sample into the model's observation space.  

- **decoder_observation_model:** *str*  
    The distribution of the obervsation space samples. One of `poisson` or `gaussian`.

- **decoder_fr_clamp_min:** *float*  
    Mininimum threshold used when clamping decoded firing rates.  

- **decoder_fr_clamp_max:** *float*  
    Maximum threshold used when clamping decoded firing rates.  

- **encoder:** *nn.Module*  
    The model's encoder module which approximates q(z \| x).  

- **inference:** *bool, default=*`False`  
    Flag denoting the model inference mode. When `True` the model is in inference mode.  

- **observation_noise_model:** *nn.Module*  
    - Only used when `decoder_observation_model='gaussian'`  

    The noise model used when computing the pi-VAE's loss.  

- **label_prior:** *nn.Module*  
    The model's label prior module which approximates p(z \| u).  

## Basic operation

For every observation space sample `x` and associated label `u` provided to pi-VAE's `forward` method, the encoder and label statistics (mean & log of variance) are obtained from the encoder  and label prior modules. These values are used to obtain the same statistics from the posterior q(z \| x,u). 

The [reparameterization trick](https://en.wikipedia.org/wiki/Reparameterization_trick) is performed with the resulting mean & log of variance to obtain the sample's representation in the model's latent space. This latent representation is then passed to the model's decoder module, which generates the predicted firing rate in the model's observation space. 

### Inputs

- **x:** *Tensor of shape(n_samples, x_dim)*  
    Samples in the model's observation space.  
- **u:** *Tensor, default=*`None`  
    - *shape(n_samples)* when using discrete labels
    - *shape(n_samples, u_dim)* when using continuous labels  
    
    Label corresponding to each sample. This parameter is not used when the model is in inference mode.  

### Outputs

A `dict` with the following items: 

- **encoder_firing_rate:** *Tensor of shape(n_samples, x_dim)*  
    Predicted firing rate of `encoder_z_sample`.  

- **encoder_z_sample:** *Tensor of shape(n_samples, z_dim)*  
    Latent space representation of each input sample computed from the encoder module's approximation of q(z \| x).  

- **encoder_mean:** *Tensor of shape(n_samples, z_dim)*  
    Mean of each input sample using the encoder module's approximation of q(z \| x).  

- **encoder_log_variance:** *Tensor of shape(n_samples, z_dim)*  
    Log of variance of each input sample using the encoder module's approximation of q(z \| x).  

- **label_mean:** *Tensor of shape(n_samples, z_dim)*  
    Mean of each input sample using the label prior module's approximation of p(z \| u).  

- **label_log_variance:** *Tensor of shape(n_samples, z_dim)*  
    Log of variance of input each sample using the label prior module's approximation of p(z \| u).  

- **posterior_firing_rate:** *Tensor of shape(n_samples, x_dim)*  
    Predicted firing rate of `posterior_z_sample`.  

- **posterior_z_sample:** *Tensor of shape(n_samples, z_dim)*      
    Latent space representation of each input sample computed from the approximation of posterior q(z \| x,u) ~ q(z \| x) × p(z \| u).  

- **posterior_mean:** *Tensor of shape(n_samples, z_dim)*  
    Mean of each input sample using the approximation of posterior of q(z \| x,u) ~ q(z \| x) × p(z \| u).  

- **posterior_log_variance:** *Tensor of shape(n_samples, z_dim)*  
    Log of variance of each input sample using the approximation of posterior q(z \| x,u) ~ q(z \| x) × p(z \| u).  

#### Inference Mode

A `dict` with the following items:  

- **encoder_firing_rate:** *Tensor of shape(n_samples, x_dim)*  
- **encoder_z_sample:** *Tensor of shape(n_samples, z_dim)*  
- **encoder_mean:** *Tensor of shape(n_samples, z_dim)*  
- **encoder_log_variance**: *Tensor of shape(n_samples, z_dim)*  

### Examples

#### Continuous Labels

```
import torch
from pi_vae_pytorch import PiVAE

model = PiVAE(
    x_dim = 100,
    u_dim = 3,
    z_dim = 2,
    discrete_labels=False
)

x = torch.randn(1, 100) # Size([n_samples, x_dim])

u = torch.randn(1, 3) # Size([n_samples, u_dim])

outputs = model(x, u) # dict
```

#### Discrete Labels

```
import torch
from pi_vae_pytorch import PiVAE

model = PiVAE(
    x_dim = 100,
    u_dim = 3,
    z_dim = 2,
    discrete_labels=True
)

x = torch.randn(1, 100) # Size([n_samples, x_dim])

u = torch.randint(u_dim, (1,)) # Size([n_samples])

outputs = model(x, u) # dict
```

## Class Methods

- **decode(*x*)**  
    Projects samples in the model's latent space (`z_dim`) into the model's observation space (`x_dim`) by passing them through the model's decoder module.  

    > **Parameters:**  

    - **x**: *Tensor of shape(n_samples, z_dim)*  
        Samples to be projected into the model's observation space.  
    
    > **Returns:**  

    - **decoded**: *Tensor of shape(n_samples, x_dim)*  
        Samples projected into the model's observation space.  

- **encode(*x, return_stats=False*)**  
    Projects samples in the model's observation space (`x_dim`) into the model's latent space (`z_dim`) by passing them through the model's encoder module.  
    
    > **Parameters:**  

    - **x**: *Tensor of shape(n_samples, x_dim)*  
        Samples to be projected into the model's observation space.  

    - **return_stats**: *bool, default=False*  
        If `True`, the mean and log of the variance associated with the encoded sample are returned; otherwise only the encoded sample is returned.  
    
    > **Returns:**  

    - **encoded**: *Tensor of shape(n_samples, z_dim)*  
        Samples projected into the model's latent space.  
        
    - **encoded_mean**: *Tensor of shape(n_samples, z_dim), optional*  
        Mean associated with a projected sample.  

    - **encoded_log_variance**: *Tensor of shape(n_samples, z_dim), optional*  
        Log of the variances associated with a projected sample.  
        
- **get_label_statistics(*u, device=None*)**  
    Returns the mean and log of the variance associated with a label `u` using the label prior estimator of p(z \| u).  

    > **Parameters:**  

    - **u**: *int, float, list, tuple, or Tensor of shape(1, u_dim)*  
        Label whose statictics will be returned. An integer is expected in the discrete label regime, while a float, list, tuple or Pytorch Tensor is expected in the continuous label regime.    
    - **device**: *torch.device, default=*`None` *(uses the CPU*)  
        A [`torch.device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device) object representing the device on which operations will be performed. Should match the `torch.device` on which the model resides.  
    
    > **Returns:**  

    - **label_mean**: *Tensor of shape(1, z_dim)*  
        Mean of label `u`.  

    - **label_log_variance**: *Tensor of shape(1, z_dim)*  
        Log of the variance of label `u`.  
    
- **sample(*u, n_samples=1, device=None*)**  
    Generates random samples in the model's observation dimension (`x_dim`). Samples are initially drawn from a Gaussian distribution in the model's latent dimension (`z_dim`) corresponding to specified label `u`. Samples are subsequently lifted to the model's observation dimension (`x_dim`) by passing them through the model's decoder.  

    > **Parameters:**  

    - **u**: *int, float, list, tuple, or Tensor of shape(1, u_dim)*  
        Label of the samples to generate. An integer is expected in the discrete label regime, while a float, list, tuple or Pytorch Tensor is expected in the continuous label regime.  

    - **n_samples**: *int, default=*`1`  
        Number of samples to generate.  

    - **device**: *torch.device, default=*`None` *(uses the CPU)*  
        A [`torch.device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device) object representing the device on which operations will be performed. Should match the `torch.device` on which the model resides.  
    
    > **Returns:**  

    - **samples**: *Tensor of shape(n_samples, x_dim)*  
        Randomly generated sample(s).  
    
- **sample_z(*u, n_samples=1, device=None*)**  
    Generates random samples in the model's latent dimension (`z_dim`). Samples are drawn from a Gaussian distribution corresponding to specified label `u`.  

    > **Parameters:**  

    - **u**: *int, float, list, tuple, or Tensor of shape(1, u_dim)*  
        Label of the samples to generate. An integer is expected in the discrete label regime, while a float, list, tuple or Pytorch Tensor is expected in the continuous label regime.  

    - **n_samples**: *int, default=*`1`  
        Number of samples to generate.  

    - **device**: *torch.device, default=*`None` *(uses the CPU)*  
        A [`torch.device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device) object representing the device on which operations will be performed. Should match the `torch.device` on which the model resides.  
    
    > **Returns:**  

    - **samples**: *Tensor of shape(n_samples, z_dim)*  
        Randomly generated sample(s).  

- **set_inference_mode(*state*)**  
    Toggles the model's inference state flag. When `True`, the model's `forward` method does not utilize the `u` parameter. When `False`,  the `u` parameter is utilized. Useful for working with unlabeled data. *NOTE: Inference mode must be disabled during model training.*  

    > **Parameters:**  

    - **state**: *bool*  
        The desired inference state.  
    
    > **Returns:**  

    - None  

## Static Methods

- **compute_posterior(*mean_0, log_variance_0, mean_1, log_variance_1*)**  
    Computes the posterior of two distributions as a product of Gaussians.  

    > **Parameters:**  

    - **mean_0:** *Tensor of shape(n_samples, sample_dim)*  
        Mean of a distribution.  

    - **log_variance_0:** *Tensor of shape(n_samples, sample_dim)*  
        Log of variance of a distribution.  

    - **mean_1:** *Tensor of shape(n_samples, sample_dim)*  
        Mean of a distribution.  

    - **log_variance_1:** *Tensor of shape(n_samples, sample_dim)*  
        Log of variance of a distribution.  
    
    > **Returns:**  

    - **posterior_mean:** *Tensor of shape(n_samples, sample_dim)*  
        Mean of the posterior distribution.  

    - **posterior_log_variance:** *Tensor of shape(n_samples, sample_dim)*  
        Log of variance of the posterior distribution.  

## Loss Function - ELBOLoss

pi-VAE learns the deep generative model and the approximate posterior q(z \| x, u) of the true posterior p(z \| x, u) by maximizing the evidence lower bound (ELBO) of p(x \| u). This loss function is implemented in the included `ELBOLoss` class.  

### Initialization

```
pi_vae_pytorch.ELBOLoss(
    version=2,
    alpha=0.5,
    observation_model='poisson',
    device=None)
```  

- **version:** *int, default=*`2`  
    - Either `1` or `2`  

    The version of the loss function.  
    - **Version 1:** Computes the KL divergence between the posterior and the label prior.  
    - **Version 2:** Computes the KL divergence between the posterior and the label prior as well as between the encoder and label prior. These two values are then weighted by the `alpha` parameter.  

- **alpha:** *float, default=*`0.5`  
    - Only applied when `version=2`  
    - Must reside within [0, 1]  

    Weights the contribution of the encoder KL loss and posterior KL loss to the total KL loss. 

    ```
    kl_loss = (alpha * encoder_kl_loss) + ((1 - alpha) * posterior_kl_loss)
    ```   

- **observation_model:** *str, default=*`'poisson'`  
    - Either `poisson` or `gaussian`  
    - Should use the same value passed to `decoder_observation_model` when initializing pi-VAE.  

    The observation model used by pi-VAE's decoder.  

- **device:** *torch.device, default=*`None` *(uses the CPU)*  
    - Only applied when `observation_model='gaussian'`  

    A [`torch.device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device) object representing the device on which operations will be performed. Should match the `torch.device` on which the model resides.  

### Inputs

- **x:** *Tensor of shape(n_samples, x_dim)*  
    Sample(s) in the model's observation space.  

- **posterior_firing_rate:** *Tensor of shape(n_samples, x_dim)*  
    Predicted firing rate of latent(s) generated from posterior q(z \| x,u).  

- **posterior_mean:** *Tensor of shape(n_samples, z_dim)*  
    Mean from posterior q(z \| x,u) ~ q(z \| x) × p(z \| u).  

- **posterior_log_variance:** *Tensor of shape(n_samples, z_dim)*  
    Log of variance from posterior q(z \| x,u) ~ q(z \| x) × p(z \| u).  

- **label_mean:** *Tensor of shape(n_samples, z_dim)*  
    Mean from the label prior estimator which approximates p(z \| u).  

- **label_log_variance:** *Tensor of shape(n_samples, z_dim)*  
    Log of variance from the label prior estimator which approximates p(z \| u).  

- **encoder_mean:** *Tensor of shape(n_samples, z_dim), default=*`None`  
    - Only used when `version=2`  
    
    Mean from the encoder which approximates p(z \| x).  

- **encoder_log_variance:** *Tensor of shape(n_samples, z_dim), default=*`None`  
    - Only used when `version=2`  

    Log of variance from the encoder which approximates p(z \| x).  

- **observation_noise_model:** *nn.Module, default=*`None`  
    - Only used when `observation_model='gaussian'`  
    
    The noise model used when pi-VAE's decoder utilizes a Gaussian observation model. When pi-VAE is initialized with `decoder_observation_model='gaussian'`, the model's `observation_noise_model` attribute should be used.

### Outputs

- **loss:** *Tensor of shape(1)*  
    The total loss of the samples.

### Static Methods

- **compute_kl_loss(*mean_0, log_variance_0, mean_1, log_variance_1*)**  
    Computes the [Kullback–Leibler divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence) between two distributions.  

    > **Parameters:**  

    - **mean_0:** *Tensor of shape(n_samples, sample_dim)*  
        Mean of a distribution.  

    - **log_variance_0:** *Tensor of shape(n_samples, sample_dim)*  
        Log of variance of a distribution.  

    - **mean_1:** *Tensor of shape(n_samples, sample_dim)*  
        Mean of a distribution.  

    - **log_variance_1:** *Tensor of shape(n_samples, sample_dim)*  
        Log of variance of a distribution.  
    
    > **Returns:**  

    - **kl_loss**: *Tensor of shape(1)*  
        The Kullback-Leibler divergence loss.  

### Examples

#### Poisson observation model

```
from pi_vae_pytorch import ELBOLoss

loss_fn = ELBOLoss()

outputs = model(x, u) # Initialized with decoder_observation_model='poisson'

loss = loss_fn(
    x=x,
    posterior_firing_rate=outputs['posterior_firing_rate'],
    posterior_mean=outputs['posterior_mean'],
    posterior_log_variance=outputs['posterior_log_variance'],
    label_mean=outputs['label_mean'],
    label_log_variance=outputs['label_log_variance'],
    encoder_mean=outputs['encoder_mean'],
    encoder_log_variance=outputs['encoder_log_variance']
)

loss.backward()
```

#### Gaussian observation model

```
import torch
from pi_vae_pytorch import ELBOLoss

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = model.to(device) # Initialized with decoder_observation_model='gaussian'

loss_fn = ELBOLoss(observation_model='gaussian', device=device)

outputs = model(x, u) 

loss = loss_fn(
    x=x,
    posterior_firing_rate=outputs['posterior_firing_rate'],
    posterior_mean=outputs['posterior_mean'],
    posterior_log_variance=outputs['posterior_log_variance'],
    label_mean=outputs['label_mean'],
    label_log_variance=outputs['label_log_variance'],
    encoder_mean=outputs['encoder_mean'],
    encoder_log_variance=outputs['encoder_log_variance']
    observation_noise_model=model.observation_noise_model
)

loss.backward()
```

## Citation

```
@misc{zhou2020learning,
    title={Learning identifiable and interpretable latent models of high-dimensional neural activity using pi-VAE}, 
    author={Ding Zhou and Xue-Xin Wei},
    year={2020},
    eprint={2011.04798},
    archivePrefix={arXiv},
    primaryClass={stat.ML}
}
```

            

Raw data

            {
    "_id": null,
    "home_page": null,
    "name": "pi-vae-pytorch",
    "maintainer": null,
    "docs_url": null,
    "requires_python": ">=3.9",
    "maintainer_email": "Marlan McInnes-Taylor <mmcinnestaylor@gmail.com>",
    "keywords": "vae, variational autoencoder, pi-vae, poisson identifiable vae, poisson identifiable variational autoencoder, identifiable vae, identifiable variational autoencoder",
    "author": null,
    "author_email": "Marlan McInnes-Taylor <mmcinnestaylor@gmail.com>",
    "download_url": "https://files.pythonhosted.org/packages/50/bd/d903a056c8369e736ea76de4373eb8ef6b72bb584a2be6a6e26d770934b2/pi_vae_pytorch-1.0.0b6.tar.gz",
    "platform": null,
    "description": "# Poisson Identifiable VAE (pi-VAE) 2.0\n\n![GitHub Actions Workflow Status](https://img.shields.io/github/actions/workflow/status/mmcinnestaylor/pi-vae-pytorch/publish-to-pypi.yml?logo=github&logoColor=white&label=Publish%20to%20PyPI)\n![Python Version](https://img.shields.io/badge/python-3.9+-blue.svg)\n![License](https://img.shields.io/pypi/l/pi-vae-pytorch)  \n![PyPI - Version](https://img.shields.io/pypi/v/pi-vae-pytorch?label=pypi%20package)\n![PyPI - Downloads](https://img.shields.io/pypi/dm/pi-vae-pytorch?label=pypi%20downloads)\n\nThis is a Pytorch implementation of [Poisson Identifiable Variational Autoencoder (pi-VAE)](https://arxiv.org/abs/2011.04798), used to construct latent variable models of neural activity while simultaneously modeling the relation between the latent and task variables (non-neural variables, e.g. sensory, motor, and other externally observable states).  \n\nA special thank you to [Zhongxuan Wu](https://github.com/ZhongxuanWu) who helped in the design and testing of this implementation.  \n\n### Model Versions\n\npi-VAE 1.0 and 2.0 differ solely in their loss function, specifically how the [Kullback\u2013Leibler divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence) component of the loss is computed. Additional information is available in the *Loss Function - ELBOLoss* section of this documentation.\n\n#### Version 2.0\n\n- This codebase is the only known publically available implementation, and also includes an implementation of the Version 1.0 loss function.  \n\n#### Version 1.0\n\n- The original implementation by [Ding Zhou](https://zhd96.github.io/) and [Xue-Xin Wei](https://sites.google.com/view/xxweineuraltheory/) in Tensorflow 1.13 is available [here](https://github.com/zhd96/pi-vae).\n\n- Another Pytorch implementation by [Lyndon Duong](http://lyndonduong.com/) is available [here](https://github.com/lyndond/lyndond.github.io/blob/0865902edb4648a8690ed8d449573d9236a72406/code/2021-11-25-pivae.ipynb).  \n\n\n\n## Installation\n\nIt is possible to install this project using `pip`:\n```\npip install pi-vae-pytorch\n```\n\nIt is also possible to clone this repo and install it using `pip`: \n```\ngit clone https://github.com/mmcinnestaylor/pi-vae-pytorch.git\ncd pip-vae-pytorch\npip install -e .\n```\n\n## Model Architecture\n\npi-VAE is comprised of three main components: the encoder, the label prior estimator, and the decoder. \n\n### MLP Structure\n\nThe Multi Layer Perceptron (MLP) is the primary building block of the aforementioned components. Each MLP used in this implementation is configurable by specifying the appropriate parameters when `PiVAE` is initialized:  \n- number of hidden layers  \n- hidden layer dimension  \n    - applied to all hidden layers within a given MLP\n- hidden layer activation function  \n    - applied to all non-output layer activations within a given MLP\n\n### Encoder\n\nThe model's encoder is comprised of a single MLP, which learns to approximate the distribution q(z \\| x). \n\n### Label Prior Estimator\n\nThe model's label prior estimator learns to approximate the distribution p(z \\| u). In the discrete label regime this module is comprised of two `nn.Embedding` submodules, while in the continuous label regime the module is comprised of a single MLP.\n\n### Decoder\n\nThe model's decoder learns to map a latent sample `z` to its predicted firing rate in the model's observation space. Inputs to the decoder are passed through the following submodules. \n\n#### NFlowLayer\n\nThis module is comprised of a MLP which maps `z` to the concatenation of `z` and `t(z)`.\n\n#### GINBlock\n\nOutputs from the `NFlowLayer` are passed to a series of `GINBlock` modules. Each `GINBlock` is comprised of a specified number of `AffineCouplingLayer` modules. Each `AffineCouplingLayer` is comprised of a MLP and performs an affine coupling transformation.\n\n## Initialization\n\n```\npi_vae_pytorch.PiVAE(\n    x_dim,\n    u_dim,\n    z_dim,\n    discrete_labels=True,\n    encoder_n_hidden_layers=2,\n    encoder_hidden_layer_dim=128,\n    encoder_hidden_layer_activation=nn.Tanh,\n    decoder_n_gin_blocks=2,\n    decoder_gin_block_depth=2,\n    decoder_affine_input_layer_slice_dim=None,\n    decoder_affine_n_hidden_layers=2,\n    decoder_affine_hidden_layer_dim=None,\n    decoder_affine_hidden_layer_activation=nn.ReLU,\n    decoder_nflow_n_hidden_layers=2,\n    decoder_nflow_hidden_layer_dim=None,\n    decoder_nflow_hidden_layer_activation=nn.ReLU,\n    decoder_observation_model='poisson',\n    decoder_fr_clamp_min=1E-7,\n    decoder_fr_clamp_max=1E7,\n    label_prior_n_hidden_layers=2,\n    label_prior_hidden_layer_dim=32,\n    label_prior_hidden_layer_activation=nn.Tanh)\n```\n\n- **x_dim:** *int*  \n    Dimension of observation `x`  \n\n- **u_dim:** *int*  \n    Dimension of observation labels `u`. In the discrete regime, this corresponds to the number of unique classes/labels. In the continuous regime, this corresponds to the dimension of each label.  \n\n- **z_dim:** *int*  \n    Dimension of latent `z`  \n\n- **discrete_labels:** *bool, default=*`True`  \n    - `True`: discrete or `False`: continuous  \n\n    Flag denoting the observation's label regime. \n- **encoder_n_hidden_layers:** *int, default=*`2`  \n    Number of hidden layers in the MLP of the model's encoder.  \n\n- **encoder_hidden_layer_dim:** *int, default=*`128`  \n    Dimensionality of each hidden layer in the MLP of the model's encoder.  \n\n- **encoder_hidden_layer_activation:** *nn.Module, default=*`nn.Tanh`  \n    Activation function applied to the outputs of each hidden layer in the MLP of the model's encoder.  \n\n- **decoder_n_gin_blocks:** *int, default=*`2`  \n    Number of GIN blocks used within the model's decoder.  \n\n- **decoder_gin_block_depth:** *int, default=*`2`  \n    Number of AffineCouplingLayers which comprise each GIN block.  \n\n- **decoder_affine_input_layer_slice_dim:** *int, default=*`None` *(equivalent to* `x_dim // 2`*)*  \n    Index at which to split an n-dimensional input x.  \n\n- **decoder_affine_n_hidden_layers:** *int, default=*`2`  \n    Number of hidden layers in the MLP of the model's encoder.  \n\n- **decoder_affine_hidden_layer_dim:** *int, default=*`None` *(equivalent to* `x_dim // 4`*)*  \n    Dimensionality of each hidden layer in the MLP of each AffineCouplingLayer.  \n\n- **decoder_affine_hidden_layer_activation:** *nn.Module, default=*`nn.ReLU`  \n    Activation function applied to the outputs of each hidden layer in the MLP of each AffineCouplingLayer.  \n\n- **decoder_nflow_n_hidden_layers:** *int, default=*`2`  \n    Number of hidden layers in the MLP of the decoder's NFlowLayer.  \n\n- **decoder_nflow_hidden_layer_dim:** *int, default=*`None` *(equivalent to* `x_dim // 4`*)*  \n    Dimensionality of each hidden layer in the MLP of the decoder's NFlowLayer.  \n\n- **decoder_nflow_hidden_layer_activation:** *nn.Module, default=*`nn.ReLU`  \n    Activation function applied to the outputs of each hidden layer in the MLP of the decoder's NFlowLayer.  \n\n- **decoder_observation_model:** *str, default=*`'poisson'`  \n    - One of `gaussian` or `poisson`\n\n    Observation model used by the model's decoder. \n- **decoder_fr_clamp_min:** *float, default=*`1E-7`  \n    - Only applied when `decoder_observation_model='poisson'`\n\n    Mininimum threshold used when clamping decoded firing rates.\n- **decoder_fr_clamp_max:** *float, default=*`1E7`  \n    - Only applied when `decoder_observation_model='poisson'`\n\n    Maximum threshold used when clamping decoded firing rates.\n- **label_prior_n_hidden_layers:** *int, default=*`2`  \n    - Only applied when `discrete_labels=False`  \n\n    Number of hidden layers in the MLP of the label prior estimator module. \n- **label_prior_hidden_layer_dim:** *int, default=*`32`  \n    - Only applied when `discrete_labels=False`\n\n    Dimensionality of each hidden layer in the MLP of the label prior estimator module. \n- **label_prior_hidden_layer_activation:** *nn.Module, default=*`nn.Tanh`  \n    - Only applied when `discrete_labels=False`\n\n    Activation function applied to the outputs of each hidden layer in the MLP of the label prior estimator module. \n\n## Attributes\n\n- **decoder:** *nn.Module*  \n    The model's decoder module which projects a latent space sample into the model's observation space.  \n\n- **decoder_observation_model:** *str*  \n    The distribution of the obervsation space samples. One of `poisson` or `gaussian`.\n\n- **decoder_fr_clamp_min:** *float*  \n    Mininimum threshold used when clamping decoded firing rates.  \n\n- **decoder_fr_clamp_max:** *float*  \n    Maximum threshold used when clamping decoded firing rates.  \n\n- **encoder:** *nn.Module*  \n    The model's encoder module which approximates q(z \\| x).  \n\n- **inference:** *bool, default=*`False`  \n    Flag denoting the model inference mode. When `True` the model is in inference mode.  \n\n- **observation_noise_model:** *nn.Module*  \n    - Only used when `decoder_observation_model='gaussian'`  \n\n    The noise model used when computing the pi-VAE's loss.  \n\n- **label_prior:** *nn.Module*  \n    The model's label prior module which approximates p(z \\| u).  \n\n## Basic operation\n\nFor every observation space sample `x` and associated label `u` provided to pi-VAE's `forward` method, the encoder and label statistics (mean & log of variance) are obtained from the encoder  and label prior modules. These values are used to obtain the same statistics from the posterior q(z \\| x,u). \n\nThe [reparameterization trick](https://en.wikipedia.org/wiki/Reparameterization_trick) is performed with the resulting mean & log of variance to obtain the sample's representation in the model's latent space. This latent representation is then passed to the model's decoder module, which generates the predicted firing rate in the model's observation space. \n\n### Inputs\n\n- **x:** *Tensor of shape(n_samples, x_dim)*  \n    Samples in the model's observation space.  \n- **u:** *Tensor, default=*`None`  \n    - *shape(n_samples)* when using discrete labels\n    - *shape(n_samples, u_dim)* when using continuous labels  \n    \n    Label corresponding to each sample. This parameter is not used when the model is in inference mode.  \n\n### Outputs\n\nA `dict` with the following items: \n\n- **encoder_firing_rate:** *Tensor of shape(n_samples, x_dim)*  \n    Predicted firing rate of `encoder_z_sample`.  \n\n- **encoder_z_sample:** *Tensor of shape(n_samples, z_dim)*  \n    Latent space representation of each input sample computed from the encoder module's approximation of q(z \\| x).  \n\n- **encoder_mean:** *Tensor of shape(n_samples, z_dim)*  \n    Mean of each input sample using the encoder module's approximation of q(z \\| x).  \n\n- **encoder_log_variance:** *Tensor of shape(n_samples, z_dim)*  \n    Log of variance of each input sample using the encoder module's approximation of q(z \\| x).  \n\n- **label_mean:** *Tensor of shape(n_samples, z_dim)*  \n    Mean of each input sample using the label prior module's approximation of p(z \\| u).  \n\n- **label_log_variance:** *Tensor of shape(n_samples, z_dim)*  \n    Log of variance of input each sample using the label prior module's approximation of p(z \\| u).  \n\n- **posterior_firing_rate:** *Tensor of shape(n_samples, x_dim)*  \n    Predicted firing rate of `posterior_z_sample`.  \n\n- **posterior_z_sample:** *Tensor of shape(n_samples, z_dim)*      \n    Latent space representation of each input sample computed from the approximation of posterior q(z \\| x,u) ~ q(z \\| x) &times; p(z \\| u).  \n\n- **posterior_mean:** *Tensor of shape(n_samples, z_dim)*  \n    Mean of each input sample using the approximation of posterior of q(z \\| x,u) ~ q(z \\| x) &times; p(z \\| u).  \n\n- **posterior_log_variance:** *Tensor of shape(n_samples, z_dim)*  \n    Log of variance of each input sample using the approximation of posterior q(z \\| x,u) ~ q(z \\| x) &times; p(z \\| u).  \n\n#### Inference Mode\n\nA `dict` with the following items:  \n\n- **encoder_firing_rate:** *Tensor of shape(n_samples, x_dim)*  \n- **encoder_z_sample:** *Tensor of shape(n_samples, z_dim)*  \n- **encoder_mean:** *Tensor of shape(n_samples, z_dim)*  \n- **encoder_log_variance**: *Tensor of shape(n_samples, z_dim)*  \n\n### Examples\n\n#### Continuous Labels\n\n```\nimport torch\nfrom pi_vae_pytorch import PiVAE\n\nmodel = PiVAE(\n    x_dim = 100,\n    u_dim = 3,\n    z_dim = 2,\n    discrete_labels=False\n)\n\nx = torch.randn(1, 100) # Size([n_samples, x_dim])\n\nu = torch.randn(1, 3) # Size([n_samples, u_dim])\n\noutputs = model(x, u) # dict\n```\n\n#### Discrete Labels\n\n```\nimport torch\nfrom pi_vae_pytorch import PiVAE\n\nmodel = PiVAE(\n    x_dim = 100,\n    u_dim = 3,\n    z_dim = 2,\n    discrete_labels=True\n)\n\nx = torch.randn(1, 100) # Size([n_samples, x_dim])\n\nu = torch.randint(u_dim, (1,)) # Size([n_samples])\n\noutputs = model(x, u) # dict\n```\n\n## Class Methods\n\n- **decode(*x*)**  \n    Projects samples in the model's latent space (`z_dim`) into the model's observation space (`x_dim`) by passing them through the model's decoder module.  \n\n    > **Parameters:**  \n\n    - **x**: *Tensor of shape(n_samples, z_dim)*  \n        Samples to be projected into the model's observation space.  \n    \n    > **Returns:**  \n\n    - **decoded**: *Tensor of shape(n_samples, x_dim)*  \n        Samples projected into the model's observation space.  \n\n- **encode(*x, return_stats=False*)**  \n    Projects samples in the model's observation space (`x_dim`) into the model's latent space (`z_dim`) by passing them through the model's encoder module.  \n    \n    > **Parameters:**  \n\n    - **x**: *Tensor of shape(n_samples, x_dim)*  \n        Samples to be projected into the model's observation space.  \n\n    - **return_stats**: *bool, default=False*  \n        If `True`, the mean and log of the variance associated with the encoded sample are returned; otherwise only the encoded sample is returned.  \n    \n    > **Returns:**  \n\n    - **encoded**: *Tensor of shape(n_samples, z_dim)*  \n        Samples projected into the model's latent space.  \n        \n    - **encoded_mean**: *Tensor of shape(n_samples, z_dim), optional*  \n        Mean associated with a projected sample.  \n\n    - **encoded_log_variance**: *Tensor of shape(n_samples, z_dim), optional*  \n        Log of the variances associated with a projected sample.  \n        \n- **get_label_statistics(*u, device=None*)**  \n    Returns the mean and log of the variance associated with a label `u` using the label prior estimator of p(z \\| u).  \n\n    > **Parameters:**  \n\n    - **u**: *int, float, list, tuple, or Tensor of shape(1, u_dim)*  \n        Label whose statictics will be returned. An integer is expected in the discrete label regime, while a float, list, tuple or Pytorch Tensor is expected in the continuous label regime.    \n    - **device**: *torch.device, default=*`None` *(uses the CPU*)  \n        A [`torch.device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device) object representing the device on which operations will be performed. Should match the `torch.device` on which the model resides.  \n    \n    > **Returns:**  \n\n    - **label_mean**: *Tensor of shape(1, z_dim)*  \n        Mean of label `u`.  \n\n    - **label_log_variance**: *Tensor of shape(1, z_dim)*  \n        Log of the variance of label `u`.  \n    \n- **sample(*u, n_samples=1, device=None*)**  \n    Generates random samples in the model's observation dimension (`x_dim`). Samples are initially drawn from a Gaussian distribution in the model's latent dimension (`z_dim`) corresponding to specified label `u`. Samples are subsequently lifted to the model's observation dimension (`x_dim`) by passing them through the model's decoder.  \n\n    > **Parameters:**  \n\n    - **u**: *int, float, list, tuple, or Tensor of shape(1, u_dim)*  \n        Label of the samples to generate. An integer is expected in the discrete label regime, while a float, list, tuple or Pytorch Tensor is expected in the continuous label regime.  \n\n    - **n_samples**: *int, default=*`1`  \n        Number of samples to generate.  \n\n    - **device**: *torch.device, default=*`None` *(uses the CPU)*  \n        A [`torch.device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device) object representing the device on which operations will be performed. Should match the `torch.device` on which the model resides.  \n    \n    > **Returns:**  \n\n    - **samples**: *Tensor of shape(n_samples, x_dim)*  \n        Randomly generated sample(s).  \n    \n- **sample_z(*u, n_samples=1, device=None*)**  \n    Generates random samples in the model's latent dimension (`z_dim`). Samples are drawn from a Gaussian distribution corresponding to specified label `u`.  \n\n    > **Parameters:**  \n\n    - **u**: *int, float, list, tuple, or Tensor of shape(1, u_dim)*  \n        Label of the samples to generate. An integer is expected in the discrete label regime, while a float, list, tuple or Pytorch Tensor is expected in the continuous label regime.  \n\n    - **n_samples**: *int, default=*`1`  \n        Number of samples to generate.  \n\n    - **device**: *torch.device, default=*`None` *(uses the CPU)*  \n        A [`torch.device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device) object representing the device on which operations will be performed. Should match the `torch.device` on which the model resides.  \n    \n    > **Returns:**  \n\n    - **samples**: *Tensor of shape(n_samples, z_dim)*  \n        Randomly generated sample(s).  \n\n- **set_inference_mode(*state*)**  \n    Toggles the model's inference state flag. When `True`, the model's `forward` method does not utilize the `u` parameter. When `False`,  the `u` parameter is utilized. Useful for working with unlabeled data. *NOTE: Inference mode must be disabled during model training.*  \n\n    > **Parameters:**  \n\n    - **state**: *bool*  \n        The desired inference state.  \n    \n    > **Returns:**  \n\n    - None  \n\n## Static Methods\n\n- **compute_posterior(*mean_0, log_variance_0, mean_1, log_variance_1*)**  \n    Computes the posterior of two distributions as a product of Gaussians.  \n\n    > **Parameters:**  \n\n    - **mean_0:** *Tensor of shape(n_samples, sample_dim)*  \n        Mean of a distribution.  \n\n    - **log_variance_0:** *Tensor of shape(n_samples, sample_dim)*  \n        Log of variance of a distribution.  \n\n    - **mean_1:** *Tensor of shape(n_samples, sample_dim)*  \n        Mean of a distribution.  \n\n    - **log_variance_1:** *Tensor of shape(n_samples, sample_dim)*  \n        Log of variance of a distribution.  \n    \n    > **Returns:**  \n\n    - **posterior_mean:** *Tensor of shape(n_samples, sample_dim)*  \n        Mean of the posterior distribution.  \n\n    - **posterior_log_variance:** *Tensor of shape(n_samples, sample_dim)*  \n        Log of variance of the posterior distribution.  \n\n## Loss Function - ELBOLoss\n\npi-VAE learns the deep generative model and the approximate posterior q(z \\| x, u) of the true posterior p(z \\| x, u) by maximizing the evidence lower bound (ELBO) of p(x \\| u). This loss function is implemented in the included `ELBOLoss` class.  \n\n### Initialization\n\n```\npi_vae_pytorch.ELBOLoss(\n    version=2,\n    alpha=0.5,\n    observation_model='poisson',\n    device=None)\n```  \n\n- **version:** *int, default=*`2`  \n    - Either `1` or `2`  \n\n    The version of the loss function.  \n    - **Version 1:** Computes the KL divergence between the posterior and the label prior.  \n    - **Version 2:** Computes the KL divergence between the posterior and the label prior as well as between the encoder and label prior. These two values are then weighted by the `alpha` parameter.  \n\n- **alpha:** *float, default=*`0.5`  \n    - Only applied when `version=2`  \n    - Must reside within [0, 1]  \n\n    Weights the contribution of the encoder KL loss and posterior KL loss to the total KL loss. \n\n    ```\n    kl_loss = (alpha * encoder_kl_loss) + ((1 - alpha) * posterior_kl_loss)\n    ```   \n\n- **observation_model:** *str, default=*`'poisson'`  \n    - Either `poisson` or `gaussian`  \n    - Should use the same value passed to `decoder_observation_model` when initializing pi-VAE.  \n\n    The observation model used by pi-VAE's decoder.  \n\n- **device:** *torch.device, default=*`None` *(uses the CPU)*  \n    - Only applied when `observation_model='gaussian'`  \n\n    A [`torch.device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device) object representing the device on which operations will be performed. Should match the `torch.device` on which the model resides.  \n\n### Inputs\n\n- **x:** *Tensor of shape(n_samples, x_dim)*  \n    Sample(s) in the model's observation space.  \n\n- **posterior_firing_rate:** *Tensor of shape(n_samples, x_dim)*  \n    Predicted firing rate of latent(s) generated from posterior q(z \\| x,u).  \n\n- **posterior_mean:** *Tensor of shape(n_samples, z_dim)*  \n    Mean from posterior q(z \\| x,u) ~ q(z \\| x) &times; p(z \\| u).  \n\n- **posterior_log_variance:** *Tensor of shape(n_samples, z_dim)*  \n    Log of variance from posterior q(z \\| x,u) ~ q(z \\| x) &times; p(z \\| u).  \n\n- **label_mean:** *Tensor of shape(n_samples, z_dim)*  \n    Mean from the label prior estimator which approximates p(z \\| u).  \n\n- **label_log_variance:** *Tensor of shape(n_samples, z_dim)*  \n    Log of variance from the label prior estimator which approximates p(z \\| u).  \n\n- **encoder_mean:** *Tensor of shape(n_samples, z_dim), default=*`None`  \n    - Only used when `version=2`  \n    \n    Mean from the encoder which approximates p(z \\| x).  \n\n- **encoder_log_variance:** *Tensor of shape(n_samples, z_dim), default=*`None`  \n    - Only used when `version=2`  \n\n    Log of variance from the encoder which approximates p(z \\| x).  \n\n- **observation_noise_model:** *nn.Module, default=*`None`  \n    - Only used when `observation_model='gaussian'`  \n    \n    The noise model used when pi-VAE's decoder utilizes a Gaussian observation model. When pi-VAE is initialized with `decoder_observation_model='gaussian'`, the model's `observation_noise_model` attribute should be used.\n\n### Outputs\n\n- **loss:** *Tensor of shape(1)*  \n    The total loss of the samples.\n\n### Static Methods\n\n- **compute_kl_loss(*mean_0, log_variance_0, mean_1, log_variance_1*)**  \n    Computes the [Kullback\u2013Leibler divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence) between two distributions.  \n\n    > **Parameters:**  \n\n    - **mean_0:** *Tensor of shape(n_samples, sample_dim)*  \n        Mean of a distribution.  \n\n    - **log_variance_0:** *Tensor of shape(n_samples, sample_dim)*  \n        Log of variance of a distribution.  \n\n    - **mean_1:** *Tensor of shape(n_samples, sample_dim)*  \n        Mean of a distribution.  \n\n    - **log_variance_1:** *Tensor of shape(n_samples, sample_dim)*  \n        Log of variance of a distribution.  \n    \n    > **Returns:**  \n\n    - **kl_loss**: *Tensor of shape(1)*  \n        The Kullback-Leibler divergence loss.  \n\n### Examples\n\n#### Poisson observation model\n\n```\nfrom pi_vae_pytorch import ELBOLoss\n\nloss_fn = ELBOLoss()\n\noutputs = model(x, u) # Initialized with decoder_observation_model='poisson'\n\nloss = loss_fn(\n    x=x,\n    posterior_firing_rate=outputs['posterior_firing_rate'],\n    posterior_mean=outputs['posterior_mean'],\n    posterior_log_variance=outputs['posterior_log_variance'],\n    label_mean=outputs['label_mean'],\n    label_log_variance=outputs['label_log_variance'],\n    encoder_mean=outputs['encoder_mean'],\n    encoder_log_variance=outputs['encoder_log_variance']\n)\n\nloss.backward()\n```\n\n#### Gaussian observation model\n\n```\nimport torch\nfrom pi_vae_pytorch import ELBOLoss\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n\nmodel = model.to(device) # Initialized with decoder_observation_model='gaussian'\n\nloss_fn = ELBOLoss(observation_model='gaussian', device=device)\n\noutputs = model(x, u) \n\nloss = loss_fn(\n    x=x,\n    posterior_firing_rate=outputs['posterior_firing_rate'],\n    posterior_mean=outputs['posterior_mean'],\n    posterior_log_variance=outputs['posterior_log_variance'],\n    label_mean=outputs['label_mean'],\n    label_log_variance=outputs['label_log_variance'],\n    encoder_mean=outputs['encoder_mean'],\n    encoder_log_variance=outputs['encoder_log_variance']\n    observation_noise_model=model.observation_noise_model\n)\n\nloss.backward()\n```\n\n## Citation\n\n```\n@misc{zhou2020learning,\n    title={Learning identifiable and interpretable latent models of high-dimensional neural activity using pi-VAE}, \n    author={Ding Zhou and Xue-Xin Wei},\n    year={2020},\n    eprint={2011.04798},\n    archivePrefix={arXiv},\n    primaryClass={stat.ML}\n}\n```\n",
    "bugtrack_url": null,
    "license": null,
    "summary": "A Pytorch implementation of Poisson Identifiable VAE (pi-VAE), a variational auto encoder used to construct latent variable models of neural activity while simultaneously modeling the relation between the latent and task variables.",
    "version": "1.0.0b6",
    "project_urls": {
        "Homepage": "https://mmcinnestaylor.github.io/pi-vae-pytorch/",
        "Repository": "https://github.com/mmcinnestaylor/pi-vae-pytorch"
    },
    "split_keywords": [
        "vae",
        " variational autoencoder",
        " pi-vae",
        " poisson identifiable vae",
        " poisson identifiable variational autoencoder",
        " identifiable vae",
        " identifiable variational autoencoder"
    ],
    "urls": [
        {
            "comment_text": null,
            "digests": {
                "blake2b_256": "dfe9f0a67b4c2ed8ffc5aa543169e4aac034df94fae3fbb76ea2e65d39cf3cd8",
                "md5": "de16b6419db98cdc392d4939334afbe2",
                "sha256": "1404d5e43b145c35e46497755504ed63bb38cb11163c06d7de2601c2e3c2b968"
            },
            "downloads": -1,
            "filename": "pi_vae_pytorch-1.0.0b6-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "de16b6419db98cdc392d4939334afbe2",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.9",
            "size": 17650,
            "upload_time": "2025-01-28T22:03:27",
            "upload_time_iso_8601": "2025-01-28T22:03:27.494574Z",
            "url": "https://files.pythonhosted.org/packages/df/e9/f0a67b4c2ed8ffc5aa543169e4aac034df94fae3fbb76ea2e65d39cf3cd8/pi_vae_pytorch-1.0.0b6-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": null,
            "digests": {
                "blake2b_256": "50bdd903a056c8369e736ea76de4373eb8ef6b72bb584a2be6a6e26d770934b2",
                "md5": "e2bc78752cc9d5e68d64bcc3f309590e",
                "sha256": "c8dbc8ee04426fa7d3c6e42e381daf7f271e75728b00a6ac1d856b3e23406f4d"
            },
            "downloads": -1,
            "filename": "pi_vae_pytorch-1.0.0b6.tar.gz",
            "has_sig": false,
            "md5_digest": "e2bc78752cc9d5e68d64bcc3f309590e",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.9",
            "size": 21230,
            "upload_time": "2025-01-28T22:03:29",
            "upload_time_iso_8601": "2025-01-28T22:03:29.366025Z",
            "url": "https://files.pythonhosted.org/packages/50/bd/d903a056c8369e736ea76de4373eb8ef6b72bb584a2be6a6e26d770934b2/pi_vae_pytorch-1.0.0b6.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2025-01-28 22:03:29",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "mmcinnestaylor",
    "github_project": "pi-vae-pytorch",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "lcname": "pi-vae-pytorch"
}
        
Elapsed time: 1.44173s