# Ramsey
[](https://www.repostatus.org/#active)
[](https://github.com/ramsey-devs/ramsey/actions/workflows/ci.yaml)
[](https://codecov.io/gh/ramsey-devs/ramsey)
[](https://ramsey.readthedocs.io/en/latest/?badge=latest)
[](https://pypi.org/project/ramsey/)
> Probabilistic deep learning using JAX
## About
Ramsey is a library for probabilistic deep learning using [JAX](https://github.com/google/jax),
[Flax](https://github.com/google/flax) and [NumPyro](https://github.com/pyro-ppl/numpyro). Its scope covers
- neural processes (vanilla, attentive, Markovian, convolutional, ...),
- neural Laplace and Fourier operator models,
- etc.
## Example usage
You can, for instance, construct a simple neural process like this:
```python
from flax import nnx
from ramsey import NP
from ramsey.nn import MLP # just a flax.nnx module
def get_neural_process(in_features, out_features):
dim = 128
np = NP(
latent_encoder=(
MLP(in_features, [dim, dim], rngs=nnx.Rngs(0)),
MLP(dim, [dim, dim * 2], rngs=nnx.Rngs(1))
),
decoder=MLP(in_features, [dim, dim, out_features * 2], rngs=nnx.Rngs(2))
)
return np
neural_process = get_neural_process(1, 1)
```
The neural process above takes a decoder and a set of two latent encoders as arguments. All of these are typically `flax.nnx` MLPs, but
Ramsey is flexible enough that you can change them, for instance, to CNNs or RNNs.
Ramsey provides a unified interface where each method implements (at least) `__call__` and `loss`
functions to transform a set of inputs and compute a training loss, respectively:
```python
from jax import random as jr
from ramsey.data import sample_from_sine_function
data = sample_from_sine_function(jr.key(0))
x_context, y_context = data.x[:, :20, :], data.y[:, :20, :]
x_target, y_target = data.x, data.y
# make a prediction
pred = neural_process(
x_context=x_context,
y_context=y_context,
x_target=x_target,
)
# compute the loss
loss = neural_process.loss(
x_context=x_context,
y_context=y_context,
x_target=x_target,
y_target=y_target
)
```
## Installation
To install from PyPI, call:
```bash
pip install ramsey
```
To install the latest GitHub <RELEASE>, just call the following on the
command line:
```bash
pip install git+https://github.com/ramsey-devs/ramsey@<RELEASE>
```
See also the installation instructions for [JAX](https://github.com/google/jax), if you plan to use Ramsey on GPU/TPU.
## Contributing
Contributions in the form of pull requests are more than welcome. A good way to start is to check out issues labelled
["good first issue"](https://github.com/ramsey-devs/ramsey/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22).
In order to contribute:
1) Clone Ramsey and install `uv` from [here](https://github.com/astral-sh/uv),
2) create a new branch locally `git checkout -b feature/my-new-feature` or `git checkout -b issue/fixes-bug`,
3) install all dependencies via `uv sync --all-extras`,
4) implement your contribution and ideally a test case,
5) test it by calling `make format`, `make lints` and `make tests` on the (Unix) command line,
6) submit a PR 🙂
## Why Ramsey
Just as the names of other probabilistic languages are inspired by researchers in the field
(e.g., Stan, Edward, Turing), Ramsey takes its name from one of my favourite philosophers/mathematicians, [Frank Ramsey](https://plato.stanford.edu/entries/ramsey/).
Raw data
{
"_id": null,
"home_page": null,
"name": "ramsey",
"maintainer": null,
"docs_url": null,
"requires_python": ">=3.10",
"maintainer_email": null,
"keywords": "Bayes, jax, probabilistic deep learning, probabilistic models, neural processes",
"author": null,
"author_email": "Simon Dirmeier <sfyrbnd@pm.me>",
"download_url": "https://files.pythonhosted.org/packages/70/71/d697d75e83f2c9fd506c2f3c812e765d0db43cd5798240bf3e8534e74862/ramsey-0.3.0.tar.gz",
"platform": null,
"description": "# Ramsey\n\n[](https://www.repostatus.org/#active)\n[](https://github.com/ramsey-devs/ramsey/actions/workflows/ci.yaml)\n[](https://codecov.io/gh/ramsey-devs/ramsey)\n[](https://ramsey.readthedocs.io/en/latest/?badge=latest)\n[](https://pypi.org/project/ramsey/)\n\n> Probabilistic deep learning using JAX\n\n## About\n\nRamsey is a library for probabilistic deep learning using [JAX](https://github.com/google/jax),\n[Flax](https://github.com/google/flax) and [NumPyro](https://github.com/pyro-ppl/numpyro). Its scope covers\n\n- neural processes (vanilla, attentive, Markovian, convolutional, ...),\n- neural Laplace and Fourier operator models,\n- etc.\n\n## Example usage\n\nYou can, for instance, construct a simple neural process like this:\n\n```python\nfrom flax import nnx\n\nfrom ramsey import NP\nfrom ramsey.nn import MLP # just a flax.nnx module\n\ndef get_neural_process(in_features, out_features):\n dim = 128\n np = NP(\n latent_encoder=(\n MLP(in_features, [dim, dim], rngs=nnx.Rngs(0)),\n MLP(dim, [dim, dim * 2], rngs=nnx.Rngs(1))\n ),\n decoder=MLP(in_features, [dim, dim, out_features * 2], rngs=nnx.Rngs(2))\n )\n return np\n\nneural_process = get_neural_process(1, 1)\n```\n\nThe neural process above takes a decoder and a set of two latent encoders as arguments. All of these are typically `flax.nnx` MLPs, but\nRamsey is flexible enough that you can change them, for instance, to CNNs or RNNs.\n\nRamsey provides a unified interface where each method implements (at least) `__call__` and `loss`\nfunctions to transform a set of inputs and compute a training loss, respectively:\n\n```python\nfrom jax import random as jr\nfrom ramsey.data import sample_from_sine_function\n\ndata = sample_from_sine_function(jr.key(0))\nx_context, y_context = data.x[:, :20, :], data.y[:, :20, :]\nx_target, y_target = data.x, data.y\n\n# make a prediction\npred = neural_process(\n x_context=x_context,\n y_context=y_context,\n x_target=x_target,\n)\n\n# compute the loss\nloss = neural_process.loss(\n x_context=x_context,\n y_context=y_context,\n x_target=x_target,\n y_target=y_target\n)\n```\n\n## Installation\n\nTo install from PyPI, call:\n\n```bash\npip install ramsey\n```\n\nTo install the latest GitHub <RELEASE>, just call the following on the\ncommand line:\n\n```bash\npip install git+https://github.com/ramsey-devs/ramsey@<RELEASE>\n```\n\nSee also the installation instructions for [JAX](https://github.com/google/jax), if you plan to use Ramsey on GPU/TPU.\n\n## Contributing\n\nContributions in the form of pull requests are more than welcome. A good way to start is to check out issues labelled\n[\"good first issue\"](https://github.com/ramsey-devs/ramsey/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22).\n\nIn order to contribute:\n\n1) Clone Ramsey and install `uv` from [here](https://github.com/astral-sh/uv),\n2) create a new branch locally `git checkout -b feature/my-new-feature` or `git checkout -b issue/fixes-bug`,\n3) install all dependencies via `uv sync --all-extras`,\n4) implement your contribution and ideally a test case,\n5) test it by calling `make format`, `make lints` and `make tests` on the (Unix) command line,\n6) submit a PR \ud83d\ude42\n\n## Why Ramsey\n\nJust as the names of other probabilistic languages are inspired by researchers in the field\n(e.g., Stan, Edward, Turing), Ramsey takes its name from one of my favourite philosophers/mathematicians, [Frank Ramsey](https://plato.stanford.edu/entries/ramsey/).\n",
"bugtrack_url": null,
"license": null,
"summary": "Probabilistic deep learning using JAX",
"version": "0.3.0",
"project_urls": {
"Documentation": "https://ramsey.rtfd.io",
"Homepage": "https://github.com/ramsey-devs/ramsey"
},
"split_keywords": [
"bayes",
" jax",
" probabilistic deep learning",
" probabilistic models",
" neural processes"
],
"urls": [
{
"comment_text": null,
"digests": {
"blake2b_256": "03f3f58d4e5a0935a37249d5371b9a09c96a92b1f63d823f9e809696ed65cb91",
"md5": "cd45fc6779909a124e110f4197c1f322",
"sha256": "d875813c5339be37f010907d822b6b6439012e0ddfee342353edad60d8d436ac"
},
"downloads": -1,
"filename": "ramsey-0.3.0-py3-none-any.whl",
"has_sig": false,
"md5_digest": "cd45fc6779909a124e110f4197c1f322",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": ">=3.10",
"size": 27058,
"upload_time": "2025-02-08T12:23:03",
"upload_time_iso_8601": "2025-02-08T12:23:03.789828Z",
"url": "https://files.pythonhosted.org/packages/03/f3/f58d4e5a0935a37249d5371b9a09c96a92b1f63d823f9e809696ed65cb91/ramsey-0.3.0-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": null,
"digests": {
"blake2b_256": "7071d697d75e83f2c9fd506c2f3c812e765d0db43cd5798240bf3e8534e74862",
"md5": "6fc05472d71efc407bcf107da29a6a34",
"sha256": "c5bb024c60498ec397fa960fd73f9afa3f4d34c5030087fd1f4c71b343dedfc8"
},
"downloads": -1,
"filename": "ramsey-0.3.0.tar.gz",
"has_sig": false,
"md5_digest": "6fc05472d71efc407bcf107da29a6a34",
"packagetype": "sdist",
"python_version": "source",
"requires_python": ">=3.10",
"size": 509565,
"upload_time": "2025-02-08T12:23:06",
"upload_time_iso_8601": "2025-02-08T12:23:06.407136Z",
"url": "https://files.pythonhosted.org/packages/70/71/d697d75e83f2c9fd506c2f3c812e765d0db43cd5798240bf3e8534e74862/ramsey-0.3.0.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2025-02-08 12:23:06",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "ramsey-devs",
"github_project": "ramsey",
"travis_ci": false,
"coveralls": false,
"github_actions": true,
"lcname": "ramsey"
}