# Extended State Space Models in JAX
Given a potentially non-linear state space model this allows you to solve the forward and backward inference steps, for
linear space space models this is equivalent to the Kalman and Rauch-Tung-Striebel recursions.
Support for Python 3.10+.
## Example
All you need to do is define the transition and observation functions, and the initial state prior. These are all in
terms of `MultivariateNormalLinearOperator` distributions from `tensorflow_probability`.
```python
import jax
import numpy as np
import tensorflow_probability.substrates.jax as tfp
from jax import numpy as jnp
from essm_jax.essm import ExtendedStateSpaceModel
tfpd = tfp.distributions
def transition_fn(z, t, t_next, *args):
mean = z + jnp.sin(2 * jnp.pi * t / 10 * z)
cov = 0.1 * jnp.eye(np.size(z))
return tfpd.MultivariateNormalTriL(mean, jnp.linalg.cholesky(cov))
def observation_fn(z, t, *args):
mean = z
cov = t * 0.01 * jnp.eye(np.size(z))
return tfpd.MultivariateNormalTriL(mean, jnp.linalg.cholesky(cov))
n = 1
initial_state_prior = tfpd.MultivariateNormalTriL(jnp.zeros(n), jnp.eye(n))
essm = ExtendedStateSpaceModel(
transition_fn=transition_fn,
observation_fn=observation_fn,
initial_state_prior=initial_state_prior,
materialise_jacobians=False, # Fast
more_data_than_params=False # if observation is bigger than latent we can speed it up.
)
T = 100
samples = essm.sample(jax.random.PRNGKey(0), num_time=T)
# Suppose we only observe every 3rd observation
mask = jnp.arange(T) % 3 != 0
# Marginal likelihood, p(x[:]) = prod_t p(x[t] | x[:t-1])
log_prob = essm.log_prob(samples.observation, mask=mask)
print(log_prob)
# Filtered latent distribution, p(z[t] | x[:t])
filter_result = essm.forward_filter(samples.observation, mask=mask)
# Smoothed latent distribution, p(z[t] | x[:]), i.e. past latents given all future observations
# Including new estimate for prior state p(z[0])
smooth_result, posterior_prior = essm.backward_smooth(filter_result, include_prior=True)
print(smooth_result)
# Forward simulate the model
forward_samples = essm.forward_simulate(
key=jax.random.PRNGKey(0),
num_time=25,
filter_result=filter_result
)
import pylab as plt
plt.plot(samples.t, samples.latent[:, 0], label='latent')
plt.plot(filter_result.t, filter_result.filtered_mean[:, 0], label='filtered latent')
plt.plot(forward_samples.t, forward_samples.latent[:, 0], label='forward_simulated latent')
plt.legend()
plt.show()
plt.plot(samples.t, samples.observation[:, 0], label='observation')
plt.plot(filter_result.t, filter_result.observation_mean[:, 0], label='filtered obs')
plt.plot(forward_samples.t, forward_samples.observation[:, 0], label='forward_simulated obs')
plt.legend()
plt.show()
```
## Online Filtering
Take a look at [examples](./docs/examples) to learn how to do online filtering, for interactive application.
# Change Log
13 August 2024: Initial release 1.0.0.
14 August 2024: 1.0.1 released. Added sparse util. Add incremental API for online filtering. Arbitrary dt.
## Star History
<a href="https://star-history.com/#joshuaalbert/jaxns&Date">
<picture>
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=joshuaalbert/essm_jax&type=Date&theme=dark" />
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=joshuaalbert/essm_jax&type=Date" />
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=joshuaalbert/essm_jax&type=Date" />
</picture>
</a>
Raw data
{
"_id": null,
"home_page": null,
"name": "essm-jax",
"maintainer": null,
"docs_url": null,
"requires_python": ">3.9",
"maintainer_email": null,
"keywords": "kalman, non-linear, EKF, modelling",
"author": null,
"author_email": "\"Joshua G. Albert\" <albert@strw.leidenuniv.nl>",
"download_url": "https://files.pythonhosted.org/packages/1b/7c/ced965cf174bd9e926f8fc6ce53e0af37a59b57bbb401849dd789a0b763c/essm_jax-1.0.2.tar.gz",
"platform": null,
"description": "# Extended State Space Models in JAX\n\nGiven a potentially non-linear state space model this allows you to solve the forward and backward inference steps, for\nlinear space space models this is equivalent to the Kalman and Rauch-Tung-Striebel recursions.\n\nSupport for Python 3.10+.\n\n## Example\n\nAll you need to do is define the transition and observation functions, and the initial state prior. These are all in\nterms of `MultivariateNormalLinearOperator` distributions from `tensorflow_probability`.\n\n```python\nimport jax\nimport numpy as np\nimport tensorflow_probability.substrates.jax as tfp\nfrom jax import numpy as jnp\n\nfrom essm_jax.essm import ExtendedStateSpaceModel\n\ntfpd = tfp.distributions\n\n\ndef transition_fn(z, t, t_next, *args):\n mean = z + jnp.sin(2 * jnp.pi * t / 10 * z)\n cov = 0.1 * jnp.eye(np.size(z))\n return tfpd.MultivariateNormalTriL(mean, jnp.linalg.cholesky(cov))\n\n\ndef observation_fn(z, t, *args):\n mean = z\n cov = t * 0.01 * jnp.eye(np.size(z))\n return tfpd.MultivariateNormalTriL(mean, jnp.linalg.cholesky(cov))\n\n\nn = 1\n\ninitial_state_prior = tfpd.MultivariateNormalTriL(jnp.zeros(n), jnp.eye(n))\n\nessm = ExtendedStateSpaceModel(\n transition_fn=transition_fn,\n observation_fn=observation_fn,\n initial_state_prior=initial_state_prior,\n materialise_jacobians=False, # Fast\n more_data_than_params=False # if observation is bigger than latent we can speed it up.\n)\n\nT = 100\nsamples = essm.sample(jax.random.PRNGKey(0), num_time=T)\n\n# Suppose we only observe every 3rd observation\nmask = jnp.arange(T) % 3 != 0\n\n# Marginal likelihood, p(x[:]) = prod_t p(x[t] | x[:t-1])\nlog_prob = essm.log_prob(samples.observation, mask=mask)\nprint(log_prob)\n\n# Filtered latent distribution, p(z[t] | x[:t])\nfilter_result = essm.forward_filter(samples.observation, mask=mask)\n\n# Smoothed latent distribution, p(z[t] | x[:]), i.e. past latents given all future observations\n# Including new estimate for prior state p(z[0])\nsmooth_result, posterior_prior = essm.backward_smooth(filter_result, include_prior=True)\nprint(smooth_result)\n\n# Forward simulate the model\nforward_samples = essm.forward_simulate(\n key=jax.random.PRNGKey(0),\n num_time=25,\n filter_result=filter_result\n)\n\nimport pylab as plt\n\nplt.plot(samples.t, samples.latent[:, 0], label='latent')\nplt.plot(filter_result.t, filter_result.filtered_mean[:, 0], label='filtered latent')\nplt.plot(forward_samples.t, forward_samples.latent[:, 0], label='forward_simulated latent')\nplt.legend()\nplt.show()\n\nplt.plot(samples.t, samples.observation[:, 0], label='observation')\nplt.plot(filter_result.t, filter_result.observation_mean[:, 0], label='filtered obs')\nplt.plot(forward_samples.t, forward_samples.observation[:, 0], label='forward_simulated obs')\nplt.legend()\nplt.show()\n```\n\n## Online Filtering\n\nTake a look at [examples](./docs/examples) to learn how to do online filtering, for interactive application.\n\n# Change Log\n\n13 August 2024: Initial release 1.0.0.\n14 August 2024: 1.0.1 released. Added sparse util. Add incremental API for online filtering. Arbitrary dt.\n\n## Star History\n\n<a href=\"https://star-history.com/#joshuaalbert/jaxns&Date\">\n <picture>\n <source media=\"(prefers-color-scheme: dark)\" srcset=\"https://api.star-history.com/svg?repos=joshuaalbert/essm_jax&type=Date&theme=dark\" />\n <source media=\"(prefers-color-scheme: light)\" srcset=\"https://api.star-history.com/svg?repos=joshuaalbert/essm_jax&type=Date\" />\n <img alt=\"Star History Chart\" src=\"https://api.star-history.com/svg?repos=joshuaalbert/essm_jax&type=Date\" />\n </picture>\n</a>\n",
"bugtrack_url": null,
"license": "Apache Software License",
"summary": "Extended State Space Modelling in JAX",
"version": "1.0.2",
"project_urls": {
"Homepage": "https://github.com/joshuaalbert/essm_jax"
},
"split_keywords": [
"kalman",
" non-linear",
" ekf",
" modelling"
],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "70a9b8a526540c76a74875acb0b5fb8aaa7c018ec1813176ebb1010588da5893",
"md5": "0c266dd334472249d29ea412d0056799",
"sha256": "94396db655800c26031376dc75e14881123cf626da74d68850093f6d0073cce6"
},
"downloads": -1,
"filename": "essm_jax-1.0.2-py3-none-any.whl",
"has_sig": false,
"md5_digest": "0c266dd334472249d29ea412d0056799",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": ">3.9",
"size": 30543,
"upload_time": "2024-11-05T21:36:03",
"upload_time_iso_8601": "2024-11-05T21:36:03.057877Z",
"url": "https://files.pythonhosted.org/packages/70/a9/b8a526540c76a74875acb0b5fb8aaa7c018ec1813176ebb1010588da5893/essm_jax-1.0.2-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": "",
"digests": {
"blake2b_256": "1b7cced965cf174bd9e926f8fc6ce53e0af37a59b57bbb401849dd789a0b763c",
"md5": "d9625e567812be2d4dcc5a4d046889c6",
"sha256": "34056ec2f652c33b937b6e12af94daed469db8d7def0a3ada5db2126ddf0dc2a"
},
"downloads": -1,
"filename": "essm_jax-1.0.2.tar.gz",
"has_sig": false,
"md5_digest": "d9625e567812be2d4dcc5a4d046889c6",
"packagetype": "sdist",
"python_version": "source",
"requires_python": ">3.9",
"size": 29393,
"upload_time": "2024-11-05T21:36:04",
"upload_time_iso_8601": "2024-11-05T21:36:04.574492Z",
"url": "https://files.pythonhosted.org/packages/1b/7c/ced965cf174bd9e926f8fc6ce53e0af37a59b57bbb401849dd789a0b763c/essm_jax-1.0.2.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2024-11-05 21:36:04",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "joshuaalbert",
"github_project": "essm_jax",
"travis_ci": false,
"coveralls": false,
"github_actions": true,
"requirements": [
{
"name": "jax",
"specs": []
},
{
"name": "jaxlib",
"specs": []
},
{
"name": "numpy",
"specs": [
[
"<",
"2"
]
]
},
{
"name": "tensorflow_probability",
"specs": []
}
],
"lcname": "essm-jax"
}