essm-jax


Nameessm-jax JSON
Version 1.0.2 PyPI version JSON
download
home_pageNone
SummaryExtended State Space Modelling in JAX
upload_time2024-11-05 21:36:04
maintainerNone
docs_urlNone
authorNone
requires_python>3.9
licenseApache Software License
keywords kalman non-linear ekf modelling
VCS
bugtrack_url
requirements jax jaxlib numpy tensorflow_probability
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # Extended State Space Models in JAX

Given a potentially non-linear state space model this allows you to solve the forward and backward inference steps, for
linear space space models this is equivalent to the Kalman and Rauch-Tung-Striebel recursions.

Support for Python 3.10+.

## Example

All you need to do is define the transition and observation functions, and the initial state prior. These are all in
terms of `MultivariateNormalLinearOperator` distributions from `tensorflow_probability`.

```python
import jax
import numpy as np
import tensorflow_probability.substrates.jax as tfp
from jax import numpy as jnp

from essm_jax.essm import ExtendedStateSpaceModel

tfpd = tfp.distributions


def transition_fn(z, t, t_next, *args):
    mean = z + jnp.sin(2 * jnp.pi * t / 10 * z)
    cov = 0.1 * jnp.eye(np.size(z))
    return tfpd.MultivariateNormalTriL(mean, jnp.linalg.cholesky(cov))


def observation_fn(z, t, *args):
    mean = z
    cov = t * 0.01 * jnp.eye(np.size(z))
    return tfpd.MultivariateNormalTriL(mean, jnp.linalg.cholesky(cov))


n = 1

initial_state_prior = tfpd.MultivariateNormalTriL(jnp.zeros(n), jnp.eye(n))

essm = ExtendedStateSpaceModel(
    transition_fn=transition_fn,
    observation_fn=observation_fn,
    initial_state_prior=initial_state_prior,
    materialise_jacobians=False,  # Fast
    more_data_than_params=False  # if observation is bigger than latent we can speed it up.
)

T = 100
samples = essm.sample(jax.random.PRNGKey(0), num_time=T)

# Suppose we only observe every 3rd observation
mask = jnp.arange(T) % 3 != 0

# Marginal likelihood, p(x[:]) = prod_t p(x[t] | x[:t-1])
log_prob = essm.log_prob(samples.observation, mask=mask)
print(log_prob)

# Filtered latent distribution, p(z[t] | x[:t])
filter_result = essm.forward_filter(samples.observation, mask=mask)

# Smoothed latent distribution, p(z[t] | x[:]), i.e. past latents given all future observations
# Including new estimate for prior state p(z[0])
smooth_result, posterior_prior = essm.backward_smooth(filter_result, include_prior=True)
print(smooth_result)

# Forward simulate the model
forward_samples = essm.forward_simulate(
    key=jax.random.PRNGKey(0),
    num_time=25,
    filter_result=filter_result
)

import pylab as plt

plt.plot(samples.t, samples.latent[:, 0], label='latent')
plt.plot(filter_result.t, filter_result.filtered_mean[:, 0], label='filtered latent')
plt.plot(forward_samples.t, forward_samples.latent[:, 0], label='forward_simulated latent')
plt.legend()
plt.show()

plt.plot(samples.t, samples.observation[:, 0], label='observation')
plt.plot(filter_result.t, filter_result.observation_mean[:, 0], label='filtered obs')
plt.plot(forward_samples.t, forward_samples.observation[:, 0], label='forward_simulated obs')
plt.legend()
plt.show()
```

## Online Filtering

Take a look at [examples](./docs/examples) to learn how to do online filtering, for interactive application.

# Change Log

13 August 2024: Initial release 1.0.0.
14 August 2024: 1.0.1 released. Added sparse util. Add incremental API for online filtering. Arbitrary dt.

## Star History

<a href="https://star-history.com/#joshuaalbert/jaxns&Date">
  <picture>
    <source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=joshuaalbert/essm_jax&type=Date&theme=dark" />
    <source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=joshuaalbert/essm_jax&type=Date" />
    <img alt="Star History Chart" src="https://api.star-history.com/svg?repos=joshuaalbert/essm_jax&type=Date" />
  </picture>
</a>

            

Raw data

            {
    "_id": null,
    "home_page": null,
    "name": "essm-jax",
    "maintainer": null,
    "docs_url": null,
    "requires_python": ">3.9",
    "maintainer_email": null,
    "keywords": "kalman, non-linear, EKF, modelling",
    "author": null,
    "author_email": "\"Joshua G. Albert\" <albert@strw.leidenuniv.nl>",
    "download_url": "https://files.pythonhosted.org/packages/1b/7c/ced965cf174bd9e926f8fc6ce53e0af37a59b57bbb401849dd789a0b763c/essm_jax-1.0.2.tar.gz",
    "platform": null,
    "description": "# Extended State Space Models in JAX\n\nGiven a potentially non-linear state space model this allows you to solve the forward and backward inference steps, for\nlinear space space models this is equivalent to the Kalman and Rauch-Tung-Striebel recursions.\n\nSupport for Python 3.10+.\n\n## Example\n\nAll you need to do is define the transition and observation functions, and the initial state prior. These are all in\nterms of `MultivariateNormalLinearOperator` distributions from `tensorflow_probability`.\n\n```python\nimport jax\nimport numpy as np\nimport tensorflow_probability.substrates.jax as tfp\nfrom jax import numpy as jnp\n\nfrom essm_jax.essm import ExtendedStateSpaceModel\n\ntfpd = tfp.distributions\n\n\ndef transition_fn(z, t, t_next, *args):\n    mean = z + jnp.sin(2 * jnp.pi * t / 10 * z)\n    cov = 0.1 * jnp.eye(np.size(z))\n    return tfpd.MultivariateNormalTriL(mean, jnp.linalg.cholesky(cov))\n\n\ndef observation_fn(z, t, *args):\n    mean = z\n    cov = t * 0.01 * jnp.eye(np.size(z))\n    return tfpd.MultivariateNormalTriL(mean, jnp.linalg.cholesky(cov))\n\n\nn = 1\n\ninitial_state_prior = tfpd.MultivariateNormalTriL(jnp.zeros(n), jnp.eye(n))\n\nessm = ExtendedStateSpaceModel(\n    transition_fn=transition_fn,\n    observation_fn=observation_fn,\n    initial_state_prior=initial_state_prior,\n    materialise_jacobians=False,  # Fast\n    more_data_than_params=False  # if observation is bigger than latent we can speed it up.\n)\n\nT = 100\nsamples = essm.sample(jax.random.PRNGKey(0), num_time=T)\n\n# Suppose we only observe every 3rd observation\nmask = jnp.arange(T) % 3 != 0\n\n# Marginal likelihood, p(x[:]) = prod_t p(x[t] | x[:t-1])\nlog_prob = essm.log_prob(samples.observation, mask=mask)\nprint(log_prob)\n\n# Filtered latent distribution, p(z[t] | x[:t])\nfilter_result = essm.forward_filter(samples.observation, mask=mask)\n\n# Smoothed latent distribution, p(z[t] | x[:]), i.e. past latents given all future observations\n# Including new estimate for prior state p(z[0])\nsmooth_result, posterior_prior = essm.backward_smooth(filter_result, include_prior=True)\nprint(smooth_result)\n\n# Forward simulate the model\nforward_samples = essm.forward_simulate(\n    key=jax.random.PRNGKey(0),\n    num_time=25,\n    filter_result=filter_result\n)\n\nimport pylab as plt\n\nplt.plot(samples.t, samples.latent[:, 0], label='latent')\nplt.plot(filter_result.t, filter_result.filtered_mean[:, 0], label='filtered latent')\nplt.plot(forward_samples.t, forward_samples.latent[:, 0], label='forward_simulated latent')\nplt.legend()\nplt.show()\n\nplt.plot(samples.t, samples.observation[:, 0], label='observation')\nplt.plot(filter_result.t, filter_result.observation_mean[:, 0], label='filtered obs')\nplt.plot(forward_samples.t, forward_samples.observation[:, 0], label='forward_simulated obs')\nplt.legend()\nplt.show()\n```\n\n## Online Filtering\n\nTake a look at [examples](./docs/examples) to learn how to do online filtering, for interactive application.\n\n# Change Log\n\n13 August 2024: Initial release 1.0.0.\n14 August 2024: 1.0.1 released. Added sparse util. Add incremental API for online filtering. Arbitrary dt.\n\n## Star History\n\n<a href=\"https://star-history.com/#joshuaalbert/jaxns&Date\">\n  <picture>\n    <source media=\"(prefers-color-scheme: dark)\" srcset=\"https://api.star-history.com/svg?repos=joshuaalbert/essm_jax&type=Date&theme=dark\" />\n    <source media=\"(prefers-color-scheme: light)\" srcset=\"https://api.star-history.com/svg?repos=joshuaalbert/essm_jax&type=Date\" />\n    <img alt=\"Star History Chart\" src=\"https://api.star-history.com/svg?repos=joshuaalbert/essm_jax&type=Date\" />\n  </picture>\n</a>\n",
    "bugtrack_url": null,
    "license": "Apache Software License",
    "summary": "Extended State Space Modelling in JAX",
    "version": "1.0.2",
    "project_urls": {
        "Homepage": "https://github.com/joshuaalbert/essm_jax"
    },
    "split_keywords": [
        "kalman",
        " non-linear",
        " ekf",
        " modelling"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "70a9b8a526540c76a74875acb0b5fb8aaa7c018ec1813176ebb1010588da5893",
                "md5": "0c266dd334472249d29ea412d0056799",
                "sha256": "94396db655800c26031376dc75e14881123cf626da74d68850093f6d0073cce6"
            },
            "downloads": -1,
            "filename": "essm_jax-1.0.2-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "0c266dd334472249d29ea412d0056799",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">3.9",
            "size": 30543,
            "upload_time": "2024-11-05T21:36:03",
            "upload_time_iso_8601": "2024-11-05T21:36:03.057877Z",
            "url": "https://files.pythonhosted.org/packages/70/a9/b8a526540c76a74875acb0b5fb8aaa7c018ec1813176ebb1010588da5893/essm_jax-1.0.2-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "1b7cced965cf174bd9e926f8fc6ce53e0af37a59b57bbb401849dd789a0b763c",
                "md5": "d9625e567812be2d4dcc5a4d046889c6",
                "sha256": "34056ec2f652c33b937b6e12af94daed469db8d7def0a3ada5db2126ddf0dc2a"
            },
            "downloads": -1,
            "filename": "essm_jax-1.0.2.tar.gz",
            "has_sig": false,
            "md5_digest": "d9625e567812be2d4dcc5a4d046889c6",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">3.9",
            "size": 29393,
            "upload_time": "2024-11-05T21:36:04",
            "upload_time_iso_8601": "2024-11-05T21:36:04.574492Z",
            "url": "https://files.pythonhosted.org/packages/1b/7c/ced965cf174bd9e926f8fc6ce53e0af37a59b57bbb401849dd789a0b763c/essm_jax-1.0.2.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2024-11-05 21:36:04",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "joshuaalbert",
    "github_project": "essm_jax",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "requirements": [
        {
            "name": "jax",
            "specs": []
        },
        {
            "name": "jaxlib",
            "specs": []
        },
        {
            "name": "numpy",
            "specs": [
                [
                    "<",
                    "2"
                ]
            ]
        },
        {
            "name": "tensorflow_probability",
            "specs": []
        }
    ],
    "lcname": "essm-jax"
}
        
Elapsed time: 0.36654s