# JAX-PT
JAX-PT is a rewrite of the [FAST-PT](https://github.com/jablazek/FAST-PT/) codebase to be compatible with JAX's autodifferentiation and JIT compilation tools. This code can be integrated into full JAX data computation pipelines or used on its own. When compiled, the main JAX-PT functions (same as Fast-PT) can see a 5-100x speed increase on Fast-PT 4.0. (depending on the function) For more in depth examples on the features of functionality of Jax-PT, please see [examples](https://github.com/vschac/JAX-PT/tree/main/examples/jpt_example.py).
### FAST-PT
FAST-PT is a code to calculate quantities in cosmological perturbation theory
at 1-loop (including, e.g., corrections to the matter power spectrum). The code
utilizes Fourier methods combined with analytic expressions to reduce the
computation time to scale as N log N, where N is the number of grid points in
the input linear power spectrum.
[](https://arxiv.org/abs/1603.04826)
[](https://arxiv.org/abs/1609.05978)
[](https://arxiv.org/abs/1708.09247)
## Installation
### Default Installation
```bash
pip install jax-pt
```
#### Dev Installation:
```bash
pip install jax-pt[dev]
```
## GPU Usage
JAX-PT allows for you to specify a device to run your computations on. During init pass 'cpu', 'gpu', or any other jax.Device to the device kwarg:
```python
import jax
import jax.numpy as jnp
from jaxpt import JAXPT
# Check available devices
print("Available devices:", jax.devices())
k = jnp.logspace(-3, 1, 1000)
# Create JAXPT instance (defaults to CPU)
jpt = JAXPT(k, warmup="moderate")
# Specify to use GPU
jpt = JAXPT(k, warmup="moderate", device="gpu")
# Add a different jax Device
devices = jax.devices()
jpt = JAXPT(k, warmup="moderate", device=devices[0]) # or any index from devices list
```
Please remember to install the correct jax CUDA libraries for your CUDA version.
For example:
```bash
pip install jax[cuda12]
```
Raw data
{
"_id": null,
"home_page": null,
"name": "jax-pt",
"maintainer": null,
"docs_url": null,
"requires_python": ">=3.8",
"maintainer_email": null,
"keywords": "cosmology, perturbation-theory, jax, gpu, astrophysics, fastpt",
"author": null,
"author_email": "Vincent Schacknies <vincent.schacknies@icloud.com>",
"download_url": "https://files.pythonhosted.org/packages/fa/02/1f882cc85444756ce220e2e501b07a6b46e9fbef1347a1325a8eb5ea3467/jax_pt-1.0.0.tar.gz",
"platform": null,
"description": "# JAX-PT\n\nJAX-PT is a rewrite of the [FAST-PT](https://github.com/jablazek/FAST-PT/) codebase to be compatible with JAX's autodifferentiation and JIT compilation tools. This code can be integrated into full JAX data computation pipelines or used on its own. When compiled, the main JAX-PT functions (same as Fast-PT) can see a 5-100x speed increase on Fast-PT 4.0. (depending on the function) For more in depth examples on the features of functionality of Jax-PT, please see [examples](https://github.com/vschac/JAX-PT/tree/main/examples/jpt_example.py).\n\n### FAST-PT\nFAST-PT is a code to calculate quantities in cosmological perturbation theory\nat 1-loop (including, e.g., corrections to the matter power spectrum). The code\nutilizes Fourier methods combined with analytic expressions to reduce the\ncomputation time to scale as N log N, where N is the number of grid points in\nthe input linear power spectrum.\n\n[](https://arxiv.org/abs/1603.04826)\n[](https://arxiv.org/abs/1609.05978)\n[](https://arxiv.org/abs/1708.09247)\n\n\n\n## Installation\n\n### Default Installation\n```bash\npip install jax-pt\n```\n\n#### Dev Installation:\n```bash\npip install jax-pt[dev]\n```\n\n## GPU Usage\n\nJAX-PT allows for you to specify a device to run your computations on. During init pass 'cpu', 'gpu', or any other jax.Device to the device kwarg:\n\n```python\nimport jax\nimport jax.numpy as jnp\nfrom jaxpt import JAXPT\n\n# Check available devices\nprint(\"Available devices:\", jax.devices())\n\nk = jnp.logspace(-3, 1, 1000)\n\n# Create JAXPT instance (defaults to CPU)\njpt = JAXPT(k, warmup=\"moderate\")\n\n# Specify to use GPU\njpt = JAXPT(k, warmup=\"moderate\", device=\"gpu\")\n\n# Add a different jax Device\ndevices = jax.devices()\njpt = JAXPT(k, warmup=\"moderate\", device=devices[0]) # or any index from devices list\n```\n\nPlease remember to install the correct jax CUDA libraries for your CUDA version.\nFor example:\n\n```bash\npip install jax[cuda12]\n```\n",
"bugtrack_url": null,
"license": "MIT",
"summary": "JAX-accelerated FAST-PT for computing perturbation theory power spectra",
"version": "1.0.0",
"project_urls": {
"Bug Tracker": "https://github.com/vschac/jax-pt/issues",
"Documentation": "https://github.com/vschac/JAX-PT/blob/main/README.md",
"Homepage": "https://github.com/vschac/jax-pt",
"Repository": "https://github.com/vschac/jax-pt"
},
"split_keywords": [
"cosmology",
" perturbation-theory",
" jax",
" gpu",
" astrophysics",
" fastpt"
],
"urls": [
{
"comment_text": null,
"digests": {
"blake2b_256": "bb3e11a6111718d439ad6377f9e383fd8eb6654d456a60b567c423b1bd2785db",
"md5": "56ad835f30272c725cb5c766c17f735d",
"sha256": "00c731307197359e93cb81733715ebb4a420b9441828152e122fdcb539cb91f3"
},
"downloads": -1,
"filename": "jax_pt-1.0.0-py3-none-any.whl",
"has_sig": false,
"md5_digest": "56ad835f30272c725cb5c766c17f735d",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": ">=3.8",
"size": 17958,
"upload_time": "2025-08-10T21:54:19",
"upload_time_iso_8601": "2025-08-10T21:54:19.004102Z",
"url": "https://files.pythonhosted.org/packages/bb/3e/11a6111718d439ad6377f9e383fd8eb6654d456a60b567c423b1bd2785db/jax_pt-1.0.0-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": null,
"digests": {
"blake2b_256": "fa021f882cc85444756ce220e2e501b07a6b46e9fbef1347a1325a8eb5ea3467",
"md5": "952032cc0ee425d0e2dff2a6cd6ac076",
"sha256": "4314eb2e090c1abef06d42a0cc8aebb1feb3081c06dd363e95d729ab730ce1c6"
},
"downloads": -1,
"filename": "jax_pt-1.0.0.tar.gz",
"has_sig": false,
"md5_digest": "952032cc0ee425d0e2dff2a6cd6ac076",
"packagetype": "sdist",
"python_version": "source",
"requires_python": ">=3.8",
"size": 22990,
"upload_time": "2025-08-10T21:54:20",
"upload_time_iso_8601": "2025-08-10T21:54:20.386469Z",
"url": "https://files.pythonhosted.org/packages/fa/02/1f882cc85444756ce220e2e501b07a6b46e9fbef1347a1325a8eb5ea3467/jax_pt-1.0.0.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2025-08-10 21:54:20",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "vschac",
"github_project": "jax-pt",
"travis_ci": false,
"coveralls": false,
"github_actions": false,
"requirements": [
{
"name": "jax",
"specs": [
[
">=",
"0.4.16"
]
]
},
{
"name": "jaxlib",
"specs": [
[
">=",
"0.4.16"
]
]
},
{
"name": "fast-pt",
"specs": [
[
">=",
"4.0.0"
]
]
}
],
"lcname": "jax-pt"
}