# **KalMax**: Kalman based neural decoding in Jax
**KalMax** = **Kal**man smoothing of **Max**imum likelihood estimates in Jax.
You provide $\mathbf{S} \in \mathbb{N}^{T \times N}$ (spike counts) and $\mathbf{X} \in \mathbb{R}^{T \times D}$ (a continuous variable, e.g. position) and `KalMax` provides jax-optimised functions and classes for:
1. **Fitting rate maps** using kernel density estimation (KDE)
2. **Calculating likelihood** maps $p(\mathbf{s}_t|\mathbf{x})$
3. **Kalman filter / smoother**
<img src="figures/display_figures/input_data.png" width=350>
#### Why are these functionalities combined into one package?...
Because Likelihood Estimation + Kalman filtering = Powerful neural decoding. By Kalman filtering/smoothing the maximum likelihood estimates (as opposed to the spikes themselves) we bypass the issues of naive Kalman filters (spikes are rarely linearly related to position) and maximum likelihood decoding (which does not account for temporal continuity in the trajectory), outperforming both for no extra computational cost.
<img src="figures/display_figures/filter_comparisons.gif" width=850>
Core `KalMax` functions are optimised and jit-compiled in jax making them **very fast**. For example `KalMax` kalman filtering is >13 times faster than an equivalent numpy implementation by the popular [`pykalman`](https://github.com/pykalman/pykalman/tree/master) library (see [demo](./kalmax_demo.ipynb)).
<img src="figures/display_figures/kalman_speed_comparison.png" width=150>
# Install
```
git clone https://github.com/TomGeorge1234/KalMax.git
cd KalMax
pip install -e .
```
(`-e`) is optional for developer install.
Alternatively
```
pip install git+https://github.com/TomGeorge1234/KalMax.git
```
# Usage
A full demo [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/TomGeorge1234/KalMax/blob/main/kalmax_demo.ipynb) is provided in the [`kalmax_demo.ipynb`](./kalmax_demo.ipynb). Sudo-code is provided below.
```python
import kalmax
import jax.numpy as jnp
```
```python
# 0. PREPARE DATA IN JAX ARRAYS
S_train = jnp.array(...) # (T, N_CELLS) train spike counts
Z_train = jnp.array(...) # (T, DIMS) train continuous variable
S_test = jnp.array(...) # (T_TEST, N_CELLS) test spike counts
bins = jnp.array(...) # (N_BINS, DIMS) coordinates at which to estimate receptive fields / likelihoods)
```
<img src="figures/display_figures/data.png" width=850>
```python
# 1. FIT RECEPTIVE FIELDS using kalmax.kde
firing_rate = kalmax.kde.kde(
bins = bins,
trajectory = Z_train,
spikes = S_train,
kernel = kalmax.kernels.gaussian_kernel,
kernel_kwargs = {'covariance':0.01**2*np.eye(DIMS)}, # kernel bandwidth
) # --> (N_CELLS, N_BINS)
```
<img src="figures/display_figures/receptive_fields.png" width=850>
```python
# 2.1 CALCULATE LIKELIHOODS using kalmax.poisson_log_likelihood
log_likelihoods = kalmax.kde.poisson_log_likelihood(
spikes = S_test,
mean_rate = firing_rate,
) # --> (T_TEST, N_CELLS)
# 2.2 FIT GAUSSIAN TO LIKELIHOODS using kalmax.utils.fit_gaussian
MLE_means, MLE_modes, MLE_covs = kalmax.utils.fit_gaussian_vmap(
x = bins,
likelihoods = jnp.exp(log_likelihoods),
) # --> (T_TEST, DIMS), (T_TEST, DIMS, DIMS)
```
<img src="figures/display_figures/likelihood_maps_fitted.png" width=850>
```python
# 3. KALMAN FILTER / SMOOTH using kalmax.KalmanFilter.KalmanFilter
kalman_filter = kalmax.kalman.KalmanFilter(
dim_Z = DIMS,
dim_Y = N_CELLS,
# SEE DEMO FOR HOW TO FIT/SET THESE
F=F, # state transition matrix
Q=Q, # state noise covariance
H=H, # observation matrix
R=R, # observation noise covariance
)
# [FILTER]
mus_f, sigmas_f = kalman_filter.filter(
Y = Y,
mu0 = mu0,
sigma0 = sigma0,
) # --> (T, DIMS), (T, DIMS, DIMS)
# [SMOOTH]
mus_s, sigmas_s = kalman_filter.smooth(
mus_f = mus_f,
sigmas_f = sigmas_f,
) # --> (T, DIMS), (T, DIMS, DIMS)
```
Raw data
{
"_id": null,
"home_page": "https://github.com/TomGeorge1234/KalMax",
"name": "kalmax",
"maintainer": null,
"docs_url": null,
"requires_python": ">=3.6",
"maintainer_email": null,
"keywords": null,
"author": "Tom George",
"author_email": null,
"download_url": "https://files.pythonhosted.org/packages/24/5b/03151ba22357fe06789672050952520cb78ef5f8884a8a49cefa83c23dbd/kalmax-0.0.0.tar.gz",
"platform": null,
"description": "# **KalMax**: Kalman based neural decoding in Jax \n**KalMax** = **Kal**man smoothing of **Max**imum likelihood estimates in Jax.\n\nYou provide $\\mathbf{S} \\in \\mathbb{N}^{T \\times N}$ (spike counts) and $\\mathbf{X} \\in \\mathbb{R}^{T \\times D}$ (a continuous variable, e.g. position) and `KalMax` provides jax-optimised functions and classes for:\n\n1. **Fitting rate maps** using kernel density estimation (KDE)\n2. **Calculating likelihood** maps $p(\\mathbf{s}_t|\\mathbf{x})$\n3. **Kalman filter / smoother**\n\n<img src=\"figures/display_figures/input_data.png\" width=350>\n\n\n\n\n#### Why are these functionalities combined into one package?...\n\nBecause Likelihood Estimation + Kalman filtering = Powerful neural decoding. By Kalman filtering/smoothing the maximum likelihood estimates (as opposed to the spikes themselves) we bypass the issues of naive Kalman filters (spikes are rarely linearly related to position) and maximum likelihood decoding (which does not account for temporal continuity in the trajectory), outperforming both for no extra computational cost.\n<img src=\"figures/display_figures/filter_comparisons.gif\" width=850>\n\n\nCore `KalMax` functions are optimised and jit-compiled in jax making them **very fast**. For example `KalMax` kalman filtering is >13 times faster than an equivalent numpy implementation by the popular [`pykalman`](https://github.com/pykalman/pykalman/tree/master) library (see [demo](./kalmax_demo.ipynb)).\n\n<img src=\"figures/display_figures/kalman_speed_comparison.png\" width=150>\n\n\n# Install\n```\ngit clone https://github.com/TomGeorge1234/KalMax.git\ncd KalMax\npip install -e .\n```\n(`-e`) is optional for developer install. \n\nAlternatively \n```\npip install git+https://github.com/TomGeorge1234/KalMax.git\n```\n\n# Usage \n\nA full demo [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/TomGeorge1234/KalMax/blob/main/kalmax_demo.ipynb) is provided in the [`kalmax_demo.ipynb`](./kalmax_demo.ipynb). Sudo-code is provided below. \n\n```python\nimport kalmax \nimport jax.numpy as jnp \n```\n\n```python\n# 0. PREPARE DATA IN JAX ARRAYS\nS_train = jnp.array(...) # (T, N_CELLS) train spike counts\nZ_train = jnp.array(...) # (T, DIMS) train continuous variable\nS_test = jnp.array(...) # (T_TEST, N_CELLS) test spike counts\nbins = jnp.array(...) # (N_BINS, DIMS) coordinates at which to estimate receptive fields / likelihoods)\n```\n<img src=\"figures/display_figures/data.png\" width=850>\n\n```python\n# 1. FIT RECEPTIVE FIELDS using kalmax.kde\nfiring_rate = kalmax.kde.kde(\n bins = bins,\n trajectory = Z_train,\n spikes = S_train,\n kernel = kalmax.kernels.gaussian_kernel,\n kernel_kwargs = {'covariance':0.01**2*np.eye(DIMS)}, # kernel bandwidth\n ) # --> (N_CELLS, N_BINS)\n```\n<img src=\"figures/display_figures/receptive_fields.png\" width=850>\n\n\n```python\n# 2.1 CALCULATE LIKELIHOODS using kalmax.poisson_log_likelihood\nlog_likelihoods = kalmax.kde.poisson_log_likelihood(\n spikes = S_test, \n mean_rate = firing_rate,\n ) # --> (T_TEST, N_CELLS)\n\n# 2.2 FIT GAUSSIAN TO LIKELIHOODS using kalmax.utils.fit_gaussian\nMLE_means, MLE_modes, MLE_covs = kalmax.utils.fit_gaussian_vmap(\n x = bins, \n likelihoods = jnp.exp(log_likelihoods),\n ) # --> (T_TEST, DIMS), (T_TEST, DIMS, DIMS)\n```\n<img src=\"figures/display_figures/likelihood_maps_fitted.png\" width=850>\n\n```python\n# 3. KALMAN FILTER / SMOOTH using kalmax.KalmanFilter.KalmanFilter\nkalman_filter = kalmax.kalman.KalmanFilter(\n dim_Z = DIMS, \n dim_Y = N_CELLS,\n # SEE DEMO FOR HOW TO FIT/SET THESE\n F=F, # state transition matrix\n Q=Q, # state noise covariance\n H=H, # observation matrix\n R=R, # observation noise covariance\n ) \n\n# [FILTER]\nmus_f, sigmas_f = kalman_filter.filter(\n Y = Y, \n mu0 = mu0,\n sigma0 = sigma0,\n ) # --> (T, DIMS), (T, DIMS, DIMS)\n\n# [SMOOTH]\nmus_s, sigmas_s = kalman_filter.smooth(\n mus_f = mus_f, \n sigmas_f = sigmas_f,\n ) # --> (T, DIMS), (T, DIMS, DIMS)\n```\n",
"bugtrack_url": null,
"license": null,
"summary": "Kalman based neural decoding in Jax",
"version": "0.0.0",
"project_urls": {
"Homepage": "https://github.com/TomGeorge1234/KalMax"
},
"split_keywords": [],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "4da4f861f94c0f1b82fce57fd666c7a835d4004e6f5ecfa4c260845f0c96166e",
"md5": "ccf8405bf56f78961daae9b0013a8c60",
"sha256": "94b994a341201d3308a94a287a2607096521fbcc51dc68e00614c5815be776e3"
},
"downloads": -1,
"filename": "kalmax-0.0.0-py3-none-any.whl",
"has_sig": false,
"md5_digest": "ccf8405bf56f78961daae9b0013a8c60",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": ">=3.6",
"size": 16021,
"upload_time": "2024-09-23T18:32:05",
"upload_time_iso_8601": "2024-09-23T18:32:05.297111Z",
"url": "https://files.pythonhosted.org/packages/4d/a4/f861f94c0f1b82fce57fd666c7a835d4004e6f5ecfa4c260845f0c96166e/kalmax-0.0.0-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": "",
"digests": {
"blake2b_256": "245b03151ba22357fe06789672050952520cb78ef5f8884a8a49cefa83c23dbd",
"md5": "45aa30fc3f5bf3218e8972dfab41ba7a",
"sha256": "1482927282d20762af50f70e2d9c86e75d9c2a0d2a22eddcd04cfad8e385e5a7"
},
"downloads": -1,
"filename": "kalmax-0.0.0.tar.gz",
"has_sig": false,
"md5_digest": "45aa30fc3f5bf3218e8972dfab41ba7a",
"packagetype": "sdist",
"python_version": "source",
"requires_python": ">=3.6",
"size": 16246,
"upload_time": "2024-09-23T18:32:07",
"upload_time_iso_8601": "2024-09-23T18:32:07.180676Z",
"url": "https://files.pythonhosted.org/packages/24/5b/03151ba22357fe06789672050952520cb78ef5f8884a8a49cefa83c23dbd/kalmax-0.0.0.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2024-09-23 18:32:07",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "TomGeorge1234",
"github_project": "KalMax",
"travis_ci": false,
"coveralls": false,
"github_actions": false,
"lcname": "kalmax"
}