# CosmoPower-JAX
<p align="center">
<img src="https://user-images.githubusercontent.com/25639122/235351711-39be2b50-dbcb-4964-adbf-f38ffc74ef5f.jpeg" width="300" height="202.5"
alt="CPJ_logo"/>
</p>
<div align="center">



[](https://arxiv.org/abs/2305.06347)
</div>
`CosmoPower-JAX` in an extension of the [CosmoPower](https://github.com/alessiospuriomancini/cosmopower) framework to emulate cosmological power spectra in a differentiable way. With `CosmoPower-JAX` you can efficiently run Hamiltonian Monte Carlo with hundreds of parameters (for example, nuisance parameters describing systematic effects), on CPUs and GPUs, in a fraction of the time which would be required with traditional methods. We provide some examples on how to use the neural emulators below, and more applications [in our paper](https://arxiv.org/abs/2305.06347). You can also have a look at [our poster](https://github.com/dpiras/dpiras.github.io/blob/master/assets/images/poster_CPJ.pdf) presented at the [ML-IAP/CCA-2023](https://indico.iap.fr/event/1/overview) conference, which includes a video on how to use `CosmoPower-JAX`.
Of course, with `CosmoPower-JAX` you can also obtain efficient and differentiable predictions of cosmological power spectra. We show how to achieve this in less than 5 lines of code below.
## Installation
To install `CosmoPower-JAX`, you can simply use `pip`:
pip install cosmopower-jax
We recommend doing it in a fresh `conda` environment, to avoid clashes (e.g. `conda create -n cpj python=3.9 && conda activate cpj`).
Alternatively, you can:
git clone https://github.com/dpiras/cosmopower-jax.git
cd cosmopower-jax
pip install .
The latter will also give you access to a Jupyter notebook with some examples.
## Usage & example
After the installation, getting a cosmological power spectrum prediction is as simple as (e.g. for the CMB temperature power spectrum):
import numpy as np
from cosmopower_jax.cosmopower_jax import CosmoPowerJAX as CPJ
# omega_b, omega_cdm, h, tau, n_s, ln10^10A_s
cosmo_params = np.array([0.025, 0.11, 0.68, 0.1, 0.97, 3.1])
emulator = CPJ(probe='cmb_tt')
emulator_predictions = emulator.predict(cosmo_params)
Similarly, we can also compute derivatives like:
emulator_derivatives = emulator.derivative(cosmo_params)
Rather than passing an array, as in the original `CosmoPower` syntax you can also pass a dictionary:
cosmo_params = {'omega_b': [0.025],
'omega_cdm': [0.11],
'h': [0.68],
'tau_reio': [0.1],
'n_s': [0.97],
'ln10^{10}A_s': [3.1],
}
emulator = CPJ(probe='cmb_tt')
emulator_predictions = emulator.predict(cosmo_params)
We also support reusing original `CosmoPower` models, which you can now use in JAX without retraining. In that case, you should:
```
git clone https://github.com/dpiras/cosmopower-jax.git
cd cosmopower-jax
```
and move your model(s) `.pkl` files into the folder `cosmopower_jax/trained_models`. At this point:
- if you can call your models from the `cosmopower-jax` folder you are in, you should be good to go;
- otherwise, run first `pip install .`, and then you should be able to call your custom models from anywhere.
To finally call a custom model, you can run:
```
from cosmopower_jax.cosmopower_jax import CosmoPowerJAX as CPJ
emulator_custom = CPJ(probe='custom_log', filename='<custom_filename>.pkl')
```
where `<custom_filename>.pkl` is the filename (only, no path) with your custom model, and `custom_log` indicates that your model was trained on log-spectra, so all predictions will be returned elevated to the power of 10. Alternatively, you can pass `custom_pca`, and you will automatically get the predictions for a model trained with `PCAplusNN`. In this case the parameter dictionary should of course contain the parameter keys corresponding to your trained model. We also allow the full `filepath` of the trained model to be indicated: in this case, do not specify `filename` and only indicate the full `filepath` including the suffix.
We provide a full walkthrough and all instructions in the accompanying [Jupyter notebook](https://github.com/dpiras/cosmopower-jax/blob/main/notebooks/emulators_example.ipynb), and we describe `CosmoPower-JAX` in detail in the release paper. We currently do not provide the code to train a neural-network model in JAX; if you would like to re-train a JAX-based neural network on different data, [raise an issue](https://github.com/dpiras/cosmopower-jax/issues) or contact [Davide Piras](mailto:davide.piras@unige.ch).
### Note if you are using `TensorFlow>=2.14`
If you are reusing a model trained with `CosmoPower` and have a `TensorFlow` version higher or equal to 2.14, you might get an error when trying to load the model, even in `CosmoPower-JAX`. This is [a known issue](https://github.com/alessiospuriomancini/cosmopower/issues/22). In this case, you should run the `convert_tf214.py` script available in this repository to transform your `.pkl` file into a different format (based on `NumPy`) that will then be read by `CosmoPower-JAX`. You only have to do the conversion once for each `.pkl` file you have, make sure you `pip install .` after the conversion, and everything else should remain unchanged.
## Contributing and contacts
Feel free to [fork](https://github.com/dpiras/cosmopower-jax/fork) this repository to work on it; otherwise, please [raise an issue](https://github.com/dpiras/cosmopower-jax/issues) or contact [Davide Piras](mailto:davide.piras@unige.ch).
## Citation
If you use `CosmoPower-JAX` in your work, please cite both papers as follows:
@article{Piras23,
author = {{Piras}, Davide and {Spurio Mancini}, Alessio},
title = "{CosmoPower-JAX: high-dimensional Bayesian inference
with differentiable cosmological emulators}",
journal = {The Open Journal of Astrophysics},
keywords = {Astrophysics - Cosmology and Nongalactic Astrophysics,
Astrophysics - Instrumentation and Methods for Astrophysics,
Computer Science - Machine Learning},
year = 2023,
month = jul,
volume = {6},
eid = {20},
pages = {20},
doi = {10.21105/astro.2305.06347},
archivePrefix = {arXiv},
eprint = {2305.06347},
primaryClass = {astro-ph.CO}
}
@article{SpurioMancini2022,
title={CosmoPower: emulating cosmological power spectra for
accelerated Bayesian inference from next-generation surveys},
volume={511},
ISSN={1365-2966},
url={http://dx.doi.org/10.1093/mnras/stac064},
DOI={10.1093/mnras/stac064},
number={2},
journal={Monthly Notices of the Royal Astronomical Society},
publisher={Oxford University Press (OUP)},
author={Spurio Mancini, Alessio and Piras, Davide and
Alsing, Justin and Joachimi, Benjamin and Hobson, Michael P},
year={2022},
month={Jan},
pages={1771–1788}
}
## License
`CosmoPower-JAX` is released under the GPL-3 license - see [LICENSE](https://github.com/dpiras/cosmopower-jax/blob/main/LICENSE)-, subject to
the non-commercial use condition - see [LICENSE_EXT](https://github.com/dpiras/cosmopower-jax/blob/main/LICENSE_EXT).
CosmoPower-JAX
Copyright (C) 2023 Davide Piras & contributors
This program is released under the GPL-3 license (see LICENSE),
subject to a non-commercial use condition (see LICENSE_EXT).
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
Raw data
{
"_id": null,
"home_page": "https://github.com/dpiras/cosmopower-jax",
"name": "cosmopower-jax",
"maintainer": null,
"docs_url": null,
"requires_python": ">=3.9",
"maintainer_email": null,
"keywords": null,
"author": "Davide Piras",
"author_email": "davide.piras@unige.ch",
"download_url": "https://files.pythonhosted.org/packages/75/da/802c81eb64055f3e6bf3fc4cc35a3f568927be79171af51232dfd4e68355/cosmopower_jax-0.5.4.tar.gz",
"platform": null,
"description": "# CosmoPower-JAX\n\n\n<p align=\"center\">\n <img src=\"https://user-images.githubusercontent.com/25639122/235351711-39be2b50-dbcb-4964-adbf-f38ffc74ef5f.jpeg\" width=\"300\" height=\"202.5\"\n alt=\"CPJ_logo\"/>\n</p>\n<div align=\"center\">\n \n\n\n\n[](https://arxiv.org/abs/2305.06347)\n\n\n</div>\n\n \n`CosmoPower-JAX` in an extension of the [CosmoPower](https://github.com/alessiospuriomancini/cosmopower) framework to emulate cosmological power spectra in a differentiable way. With `CosmoPower-JAX` you can efficiently run Hamiltonian Monte Carlo with hundreds of parameters (for example, nuisance parameters describing systematic effects), on CPUs and GPUs, in a fraction of the time which would be required with traditional methods. We provide some examples on how to use the neural emulators below, and more applications [in our paper](https://arxiv.org/abs/2305.06347). You can also have a look at [our poster](https://github.com/dpiras/dpiras.github.io/blob/master/assets/images/poster_CPJ.pdf) presented at the [ML-IAP/CCA-2023](https://indico.iap.fr/event/1/overview) conference, which includes a video on how to use `CosmoPower-JAX`.\n\nOf course, with `CosmoPower-JAX` you can also obtain efficient and differentiable predictions of cosmological power spectra. We show how to achieve this in less than 5 lines of code below.\n\n## Installation\n\nTo install `CosmoPower-JAX`, you can simply use `pip`:\n\n pip install cosmopower-jax\n\nWe recommend doing it in a fresh `conda` environment, to avoid clashes (e.g. `conda create -n cpj python=3.9 && conda activate cpj`). \n\nAlternatively, you can:\n\n git clone https://github.com/dpiras/cosmopower-jax.git\n cd cosmopower-jax\n pip install . \n\nThe latter will also give you access to a Jupyter notebook with some examples.\n\n## Usage & example\n\nAfter the installation, getting a cosmological power spectrum prediction is as simple as (e.g. for the CMB temperature power spectrum):\n\n import numpy as np\n from cosmopower_jax.cosmopower_jax import CosmoPowerJAX as CPJ\n # omega_b, omega_cdm, h, tau, n_s, ln10^10A_s\n cosmo_params = np.array([0.025, 0.11, 0.68, 0.1, 0.97, 3.1])\n emulator = CPJ(probe='cmb_tt')\n emulator_predictions = emulator.predict(cosmo_params)\n\nSimilarly, we can also compute derivatives like:\n\n emulator_derivatives = emulator.derivative(cosmo_params)\n\nRather than passing an array, as in the original `CosmoPower` syntax you can also pass a dictionary:\n\n cosmo_params = {'omega_b': [0.025],\n 'omega_cdm': [0.11],\n 'h': [0.68],\n 'tau_reio': [0.1],\n 'n_s': [0.97],\n 'ln10^{10}A_s': [3.1],\n }\n emulator = CPJ(probe='cmb_tt')\n emulator_predictions = emulator.predict(cosmo_params)\n\nWe also support reusing original `CosmoPower` models, which you can now use in JAX without retraining. In that case, you should: \n\n```\n git clone https://github.com/dpiras/cosmopower-jax.git\n cd cosmopower-jax\n```\n\nand move your model(s) `.pkl` files into the folder `cosmopower_jax/trained_models`. At this point:\n\n- if you can call your models from the `cosmopower-jax` folder you are in, you should be good to go;\n- otherwise, run first `pip install .`, and then you should be able to call your custom models from anywhere.\n \nTo finally call a custom model, you can run:\n\n```\nfrom cosmopower_jax.cosmopower_jax import CosmoPowerJAX as CPJ\nemulator_custom = CPJ(probe='custom_log', filename='<custom_filename>.pkl')\n```\n\nwhere `<custom_filename>.pkl` is the filename (only, no path) with your custom model, and `custom_log` indicates that your model was trained on log-spectra, so all predictions will be returned elevated to the power of 10. Alternatively, you can pass `custom_pca`, and you will automatically get the predictions for a model trained with `PCAplusNN`. In this case the parameter dictionary should of course contain the parameter keys corresponding to your trained model. We also allow the full `filepath` of the trained model to be indicated: in this case, do not specify `filename` and only indicate the full `filepath` including the suffix.\n\nWe provide a full walkthrough and all instructions in the accompanying [Jupyter notebook](https://github.com/dpiras/cosmopower-jax/blob/main/notebooks/emulators_example.ipynb), and we describe `CosmoPower-JAX` in detail in the release paper. We currently do not provide the code to train a neural-network model in JAX; if you would like to re-train a JAX-based neural network on different data, [raise an issue](https://github.com/dpiras/cosmopower-jax/issues) or contact [Davide Piras](mailto:davide.piras@unige.ch).\n\n### Note if you are using `TensorFlow>=2.14`\nIf you are reusing a model trained with `CosmoPower` and have a `TensorFlow` version higher or equal to 2.14, you might get an error when trying to load the model, even in `CosmoPower-JAX`. This is [a known issue](https://github.com/alessiospuriomancini/cosmopower/issues/22). In this case, you should run the `convert_tf214.py` script available in this repository to transform your `.pkl` file into a different format (based on `NumPy`) that will then be read by `CosmoPower-JAX`. You only have to do the conversion once for each `.pkl` file you have, make sure you `pip install .` after the conversion, and everything else should remain unchanged.\n\n\n## Contributing and contacts\n\nFeel free to [fork](https://github.com/dpiras/cosmopower-jax/fork) this repository to work on it; otherwise, please [raise an issue](https://github.com/dpiras/cosmopower-jax/issues) or contact [Davide Piras](mailto:davide.piras@unige.ch).\n\n## Citation\nIf you use `CosmoPower-JAX` in your work, please cite both papers as follows:\n\n @article{Piras23,\n author = {{Piras}, Davide and {Spurio Mancini}, Alessio},\n title = \"{CosmoPower-JAX: high-dimensional Bayesian inference \n with differentiable cosmological emulators}\",\n journal = {The Open Journal of Astrophysics},\n keywords = {Astrophysics - Cosmology and Nongalactic Astrophysics, \n Astrophysics - Instrumentation and Methods for Astrophysics, \n Computer Science - Machine Learning},\n year = 2023,\n month = jul,\n volume = {6},\n eid = {20},\n pages = {20},\n doi = {10.21105/astro.2305.06347},\n archivePrefix = {arXiv},\n eprint = {2305.06347},\n primaryClass = {astro-ph.CO}\n }\n \n @article{SpurioMancini2022,\n title={CosmoPower: emulating cosmological power spectra for \n accelerated Bayesian inference from next-generation surveys},\n volume={511},\n ISSN={1365-2966},\n url={http://dx.doi.org/10.1093/mnras/stac064},\n DOI={10.1093/mnras/stac064},\n number={2},\n journal={Monthly Notices of the Royal Astronomical Society},\n publisher={Oxford University Press (OUP)},\n author={Spurio\u00a0Mancini, Alessio and Piras, Davide and \n Alsing, Justin and Joachimi, Benjamin and Hobson, Michael P},\n year={2022},\n month={Jan},\n pages={1771\u20131788}\n }\n \n\n## License\n\n`CosmoPower-JAX` is released under the GPL-3 license - see [LICENSE](https://github.com/dpiras/cosmopower-jax/blob/main/LICENSE)-, subject to \nthe non-commercial use condition - see [LICENSE_EXT](https://github.com/dpiras/cosmopower-jax/blob/main/LICENSE_EXT).\n\n CosmoPower-JAX \n Copyright (C) 2023 Davide Piras & contributors\n\n This program is released under the GPL-3 license (see LICENSE), \n subject to a non-commercial use condition (see LICENSE_EXT).\n\n This program is distributed in the hope that it will be useful,\n but WITHOUT ANY WARRANTY; without even the implied warranty of\n MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.\n",
"bugtrack_url": null,
"license": "GNU General Public License v3.0 (GPLv3)",
"summary": "Differentiable cosmological emulators",
"version": "0.5.4",
"project_urls": {
"Homepage": "https://github.com/dpiras/cosmopower-jax"
},
"split_keywords": [],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "1d656b49ea170a47fcb4f4847246fe4cb01a6971054074a1996a2489ccaf9119",
"md5": "2c6bbdb7a23861b63a8930ca2425536a",
"sha256": "03d149da7349d8d386c1be89b9a6ac9a4deebaa91cc5a193cf04a080b4862c09"
},
"downloads": -1,
"filename": "cosmopower_jax-0.5.4-py3-none-any.whl",
"has_sig": false,
"md5_digest": "2c6bbdb7a23861b63a8930ca2425536a",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": ">=3.9",
"size": 81094592,
"upload_time": "2024-10-26T13:03:53",
"upload_time_iso_8601": "2024-10-26T13:03:53.507483Z",
"url": "https://files.pythonhosted.org/packages/1d/65/6b49ea170a47fcb4f4847246fe4cb01a6971054074a1996a2489ccaf9119/cosmopower_jax-0.5.4-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": "",
"digests": {
"blake2b_256": "75da802c81eb64055f3e6bf3fc4cc35a3f568927be79171af51232dfd4e68355",
"md5": "8669c3edf1a8d93c406e7cee15406964",
"sha256": "dc5c7a7ca6ac887713d52ab4b99fe385b842f086a65f2efb21121eeb58da8f07"
},
"downloads": -1,
"filename": "cosmopower_jax-0.5.4.tar.gz",
"has_sig": false,
"md5_digest": "8669c3edf1a8d93c406e7cee15406964",
"packagetype": "sdist",
"python_version": "source",
"requires_python": ">=3.9",
"size": 81101142,
"upload_time": "2024-10-26T13:04:02",
"upload_time_iso_8601": "2024-10-26T13:04:02.677136Z",
"url": "https://files.pythonhosted.org/packages/75/da/802c81eb64055f3e6bf3fc4cc35a3f568927be79171af51232dfd4e68355/cosmopower_jax-0.5.4.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2024-10-26 13:04:02",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "dpiras",
"github_project": "cosmopower-jax",
"travis_ci": false,
"coveralls": false,
"github_actions": false,
"requirements": [],
"lcname": "cosmopower-jax"
}