# jaxdf - JAX-based Discretization Framework
[![Support](https://dcbadge.vercel.app/api/server/VtUb4fFznt?style=flat)](https://discord.gg/VtUb4fFznt)
[![License: LGPL v3](https://img.shields.io/badge/License-LGPL%20v3-blue.svg)](https://www.gnu.org/licenses/lgpl-3.0)
[![codecov](https://codecov.io/gh/ucl-bug/jaxdf/branch/main/graph/badge.svg?token=FIUYOCFDYL)](https://codecov.io/gh/ucl-bug/jaxdf)
[![CI](https://github.com/ucl-bug/jaxdf/actions/workflows/tests.yml/badge.svg)](https://github.com/ucl-bug/jaxdf/actions/workflows/tests.yml)
[**Overview**](#overview)
| [**Example**](#example)
| [**Installation**](#installation)
| [**Documentation**](https://ucl-bug.github.io/jaxdf/)
| [**Support**](#support)
<br/>
## Overview
Jaxdf is a package based on [JAX](https://jax.readthedocs.io/en/stable/) that provides a coding framework for creating differentiable numerical simulators with arbitrary discretizations.
The primary objective of Jaxdf is to aid in the construction of numerical models for physical systems, like wave propagation, or the numerical resolution of partial differential equations, in a manner that is easily tailored to the user's research requirements. These models are pure functions that can be seamlessly integrated into arbitrary differentiable programs written in [JAX](https://jax.readthedocs.io/en/stable/). For instance, they can be employed as layers within neural networks, or utilized in constructing a physics loss function.
<br/>
## Example
The script below constructs the non-linear operator **(∇<sup>2</sup> + sin)**, applying a Fourier spectral discretization on a square 2D domain. It then utilizes this operator to define a loss function. The gradient of this loss function is calculated using JAX's Automatic Differentiation.
```python
from jaxdf import operators as jops
from jaxdf import FourierSeries, operator
from jaxdf.geometry import Domain
from jax import numpy as jnp
from jax import jit, grad
# Defining operator
@operator
def custom_op(u, *, params=None):
grad_u = jops.gradient(u)
diag_jacobian = jops.diag_jacobian(grad_u)
laplacian = jops.sum_over_dims(diag_jacobian)
sin_u = jops.compose(u)(jnp.sin)
return laplacian + sin_u
# Defining discretizations
domain = Domain((128, 128), (1., 1.))
parameters = jnp.ones((128,128,1))
u = FourierSeries(parameters, domain)
# Define a differentiable loss function
@jit
def loss(u):
v = custom_op(u)
return jnp.mean(jnp.abs(v.on_grid)**2)
gradient = grad(loss)(u) # gradient is a FourierSeries
```
<br/>
## Installation
Before proceeding with the installation of `jaxdf`, ensure that [JAX is already installed](https://github.com/google/jax#installation) on your system. If you intend to utilize `jaxdf` with NVidia GPU support, follow the instructions to install JAX accordingly.
To install `jaxdf` from PyPI, use the `pip` command:
```bash
pip install jaxdf
```
For development purposes, install `jaxdf` by either cloning the repository or downloading and extracting the compressed archive. Afterward, navigate to the root folder in a terminal, and execute the following command:
```bash
pip install --upgrade poetry
poetry install
```
This will install the dependencies and the package itself (in editable mode).
## Support
[![Support](https://dcbadge.vercel.app/api/server/VtUb4fFznt?style=flat)](https://discord.gg/VtUb4fFznt)
If you encounter any issues with the code or wish to suggest new features, please feel free to open an issue. If you seek guidance, wish to discuss something, or simply want to say hi, don't hesitate to write a message in our [Discord channel](https://discord.gg/VtUb4fFznt).
<br/>
## Contributing
Contributions are absolutely welcome! Most contributions start with an issue. Please don't hesitate to create issues in which you ask for features, give feedback on performances, or simply want to reach out.
To make a pull request, please look at the detailed [Contributing guide](CONTRIBUTING.md) for how to do it, but fundamentally keep in mind the following main guidelines:
- If you add a new feature or fix a bug:
- Make sure it is covered by tests
- Add a line in the changelog using `kacl-cli`
- If you changed something in the documentation, make sure that the documentation site can be correctly build using `mkdocs serve`
<br/>
<br/>
## Citation
[![arXiv](https://img.shields.io/badge/arXiv-2111.05218-b31b1b.svg?style=flat)](https://arxiv.org/abs/2111.05218)
An initial version of this package was presented at the [Differentiable Programming workshop](https://diffprogramming.mit.edu/) at NeurIPS 2021.
```bibtex
@article{stanziola2021jaxdf,
author={Stanziola, Antonio and Arridge, Simon and Cox, Ben T. and Treeby, Bradley E.},
title={A research framework for writing differentiable PDE discretizations in JAX},
year={2021},
journal={Differentiable Programming workshop at Neural Information Processing Systems 2021}
}
```
<br/>
#### Acknowledgements
- Some of the packaging of this repository is done by editing [this templace from @rochacbruno](https://github.com/rochacbruno/python-project-template)
- The multiple-dispatch method employed is based on `plum`, check out this amazing project: https://github.com/wesselb/plum
#### Related projects
1. [`odl`](https://github.com/odlgroup/odl) Operator Discretization Library (ODL) is a python library for fast prototyping focusing on (but not restricted to) inverse problems.
3. [`deepXDE`](https://deepxde.readthedocs.io/en/latest/): a TensorFlow and PyTorch library for scientific machine learning.
4. [`SciML`](https://sciml.ai/): SciML is a NumFOCUS sponsored open source software organization created to unify the packages for scientific machine learning.
Raw data
{
"_id": null,
"home_page": null,
"name": "jaxdf",
"maintainer": null,
"docs_url": null,
"requires_python": "<4.0,>=3.9",
"maintainer_email": null,
"keywords": "jax, pde, discretization, differential equations, simulation, differentiable programming",
"author": "Antonio Stanziola",
"author_email": "a.stanziola@ucl.ac.uk",
"download_url": "https://files.pythonhosted.org/packages/34/09/e3b89ea3d73c74d12f45f6d3e09bd66c39665d38d58cfc672a12c875333d/jaxdf-0.2.8.tar.gz",
"platform": null,
"description": "# jaxdf - JAX-based Discretization Framework\n\n[![Support](https://dcbadge.vercel.app/api/server/VtUb4fFznt?style=flat)](https://discord.gg/VtUb4fFznt)\n[![License: LGPL v3](https://img.shields.io/badge/License-LGPL%20v3-blue.svg)](https://www.gnu.org/licenses/lgpl-3.0)\n[![codecov](https://codecov.io/gh/ucl-bug/jaxdf/branch/main/graph/badge.svg?token=FIUYOCFDYL)](https://codecov.io/gh/ucl-bug/jaxdf)\n[![CI](https://github.com/ucl-bug/jaxdf/actions/workflows/tests.yml/badge.svg)](https://github.com/ucl-bug/jaxdf/actions/workflows/tests.yml)\n\n[**Overview**](#overview)\n| [**Example**](#example)\n| [**Installation**](#installation)\n| [**Documentation**](https://ucl-bug.github.io/jaxdf/)\n| [**Support**](#support)\n\n<br/>\n\n## Overview\n\nJaxdf is a package based on [JAX](https://jax.readthedocs.io/en/stable/) that provides a coding framework for creating differentiable numerical simulators with arbitrary discretizations.\n\nThe primary objective of Jaxdf is to aid in the construction of numerical models for physical systems, like wave propagation, or the numerical resolution of partial differential equations, in a manner that is easily tailored to the user's research requirements. These models are pure functions that can be seamlessly integrated into arbitrary differentiable programs written in [JAX](https://jax.readthedocs.io/en/stable/). For instance, they can be employed as layers within neural networks, or utilized in constructing a physics loss function.\n\n\n<br/>\n\n## Example\n\nThe script below constructs the non-linear operator **(\u2207<sup>2</sup> + sin)**, applying a Fourier spectral discretization on a square 2D domain. It then utilizes this operator to define a loss function. The gradient of this loss function is calculated using JAX's Automatic Differentiation.\n\n\n```python\nfrom jaxdf import operators as jops\nfrom jaxdf import FourierSeries, operator\nfrom jaxdf.geometry import Domain\nfrom jax import numpy as jnp\nfrom jax import jit, grad\n\n\n# Defining operator\n@operator\ndef custom_op(u, *, params=None):\n grad_u = jops.gradient(u)\n diag_jacobian = jops.diag_jacobian(grad_u)\n laplacian = jops.sum_over_dims(diag_jacobian)\n sin_u = jops.compose(u)(jnp.sin)\n return laplacian + sin_u\n\n# Defining discretizations\ndomain = Domain((128, 128), (1., 1.))\nparameters = jnp.ones((128,128,1))\nu = FourierSeries(parameters, domain)\n\n# Define a differentiable loss function\n@jit\ndef loss(u):\n v = custom_op(u)\n return jnp.mean(jnp.abs(v.on_grid)**2)\n\ngradient = grad(loss)(u) # gradient is a FourierSeries\n```\n\n<br/>\n\n## Installation\n\nBefore proceeding with the installation of `jaxdf`, ensure that [JAX is already installed](https://github.com/google/jax#installation) on your system. If you intend to utilize `jaxdf` with NVidia GPU support, follow the instructions to install JAX accordingly.\n\nTo install `jaxdf` from PyPI, use the `pip` command:\n\n```bash\npip install jaxdf\n```\n\nFor development purposes, install `jaxdf` by either cloning the repository or downloading and extracting the compressed archive. Afterward, navigate to the root folder in a terminal, and execute the following command:\n```bash\npip install --upgrade poetry\npoetry install\n```\nThis will install the dependencies and the package itself (in editable mode).\n\n\n## Support\n\n[![Support](https://dcbadge.vercel.app/api/server/VtUb4fFznt?style=flat)](https://discord.gg/VtUb4fFznt)\n\nIf you encounter any issues with the code or wish to suggest new features, please feel free to open an issue. If you seek guidance, wish to discuss something, or simply want to say hi, don't hesitate to write a message in our [Discord channel](https://discord.gg/VtUb4fFznt).\n\n\n<br/>\n\n## Contributing\n\nContributions are absolutely welcome! Most contributions start with an issue. Please don't hesitate to create issues in which you ask for features, give feedback on performances, or simply want to reach out.\n\nTo make a pull request, please look at the detailed [Contributing guide](CONTRIBUTING.md) for how to do it, but fundamentally keep in mind the following main guidelines:\n\n- If you add a new feature or fix a bug:\n - Make sure it is covered by tests\n - Add a line in the changelog using `kacl-cli`\n- If you changed something in the documentation, make sure that the documentation site can be correctly build using `mkdocs serve`\n\n<br/>\n\n<br/>\n\n## Citation\n\n[![arXiv](https://img.shields.io/badge/arXiv-2111.05218-b31b1b.svg?style=flat)](https://arxiv.org/abs/2111.05218)\n\nAn initial version of this package was presented at the [Differentiable Programming workshop](https://diffprogramming.mit.edu/) at NeurIPS 2021.\n\n```bibtex\n@article{stanziola2021jaxdf,\n author={Stanziola, Antonio and Arridge, Simon and Cox, Ben T. and Treeby, Bradley E.},\n title={A research framework for writing differentiable PDE discretizations in JAX},\n year={2021},\n journal={Differentiable Programming workshop at Neural Information Processing Systems 2021}\n}\n```\n\n<br/>\n\n\n#### Acknowledgements\n\n- Some of the packaging of this repository is done by editing [this templace from @rochacbruno](https://github.com/rochacbruno/python-project-template)\n- The multiple-dispatch method employed is based on `plum`, check out this amazing project: https://github.com/wesselb/plum\n\n#### Related projects\n\n1. [`odl`](https://github.com/odlgroup/odl) Operator Discretization Library (ODL) is a python library for fast prototyping focusing on (but not restricted to) inverse problems.\n3. [`deepXDE`](https://deepxde.readthedocs.io/en/latest/): a TensorFlow and PyTorch library for scientific machine learning.\n4. [`SciML`](https://sciml.ai/): SciML is a NumFOCUS sponsored open source software organization created to unify the packages for scientific machine learning.\n",
"bugtrack_url": null,
"license": "LGPL-3.0-only",
"summary": "A JAX-based research framework for writing differentiable numerical simulators with arbitrary discretizations",
"version": "0.2.8",
"project_urls": {
"Bug Tracker": "https://github.com/ucl-bug/jaxdf/issues",
"Homepage": "https://ucl-bug.github.io/jaxdf",
"Repository": "https://github.com/ucl-bug/jaxdf",
"Support": "https://discord.gg/VtUb4fFznt"
},
"split_keywords": [
"jax",
" pde",
" discretization",
" differential equations",
" simulation",
" differentiable programming"
],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "f2089b7b4524a3d9282bc3592f92a9854a9981c161c995ef2d44e9e2f2be44a3",
"md5": "2b30bce09b7d2c56e7a52b80ba104eb2",
"sha256": "825924f513ed82049b6bdeaff82c4727b9ed172e235f363ef3f7db716e1f0556"
},
"downloads": -1,
"filename": "jaxdf-0.2.8-py3-none-any.whl",
"has_sig": false,
"md5_digest": "2b30bce09b7d2c56e7a52b80ba104eb2",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": "<4.0,>=3.9",
"size": 28529,
"upload_time": "2024-09-17T10:37:44",
"upload_time_iso_8601": "2024-09-17T10:37:44.252029Z",
"url": "https://files.pythonhosted.org/packages/f2/08/9b7b4524a3d9282bc3592f92a9854a9981c161c995ef2d44e9e2f2be44a3/jaxdf-0.2.8-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": "",
"digests": {
"blake2b_256": "3409e3b89ea3d73c74d12f45f6d3e09bd66c39665d38d58cfc672a12c875333d",
"md5": "f1f46e057216df593de36e59fc863dbc",
"sha256": "d5af416a13e7ba9f6c7a72a79d3f69fb9b90f639dcf3d6bda3be8e0e198f1e18"
},
"downloads": -1,
"filename": "jaxdf-0.2.8.tar.gz",
"has_sig": false,
"md5_digest": "f1f46e057216df593de36e59fc863dbc",
"packagetype": "sdist",
"python_version": "source",
"requires_python": "<4.0,>=3.9",
"size": 25858,
"upload_time": "2024-09-17T10:37:45",
"upload_time_iso_8601": "2024-09-17T10:37:45.685084Z",
"url": "https://files.pythonhosted.org/packages/34/09/e3b89ea3d73c74d12f45f6d3e09bd66c39665d38d58cfc672a12c875333d/jaxdf-0.2.8.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2024-09-17 10:37:45",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "ucl-bug",
"github_project": "jaxdf",
"travis_ci": false,
"coveralls": false,
"github_actions": true,
"lcname": "jaxdf"
}