mpi4jax
=======
|JOSS paper| |PyPI Version| |Conda Version| |Tests| |codecov| |Documentation Status|
``mpi4jax`` enables zero-copy, multi-host communication of `JAX <https://jax.readthedocs.io/>`_ arrays, even from jitted code and from GPU memory.
But why?
--------
The JAX framework `has great performance for scientific computing workloads <https://github.com/dionhaefner/pyhpc-benchmarks>`_, but its `multi-host capabilities <https://jax.readthedocs.io/en/latest/jax.html#jax.pmap>`_ are still limited.
With ``mpi4jax``, you can scale your JAX-based simulations to *entire CPU and GPU clusters* (without ever leaving ``jax.jit``).
In the spirit of differentiable programming, ``mpi4jax`` also supports differentiating through some MPI operations.
Installation
------------
``mpi4jax`` is available through ``pip`` and ``conda``:
.. code:: bash
$ pip install mpi4jax # Pip
$ conda install -c conda-forge mpi4jax # conda
Depending on the different jax backends you want to use, you can install mpi4jax in the following way
.. code:: bash
# pip install 'jax[cpu]'
$ pip install mpi4jax
# pip install -U 'jax[cuda12]'
$ pip install cython
$ pip install mpi4jax --no-build-isolation
# pip install -U 'jax[cuda12_local]'
$ CUDA_ROOT=XXX pip install mpi4jax
(for more informations on jax GPU distributions, `see the JAX installation instructions <https://github.com/google/jax#installation>`_)
In case your MPI installation is not detected correctly, `it can help to install mpi4py separately <https://mpi4py.readthedocs.io/en/stable/install.html>`_. When using a pre-installed ``mpi4py``, you *must* use ``--no-build-isolation`` when installing ``mpi4jax``:
.. code:: bash
# if mpi4py is already installed
$ pip install cython
$ pip install mpi4jax --no-build-isolation
`Our documentation includes some more advanced installation examples. <https://mpi4jax.readthedocs.io/en/latest/installation.html>`_
Example usage
-------------
.. code:: python
from mpi4py import MPI
import jax
import jax.numpy as jnp
import mpi4jax
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
@jax.jit
def foo(arr):
arr = arr + rank
arr_sum, _ = mpi4jax.allreduce(arr, op=MPI.SUM, comm=comm)
return arr_sum
a = jnp.zeros((3, 3))
result = foo(a)
if rank == 0:
print(result)
Running this script on 4 processes gives:
.. code:: bash
$ mpirun -n 4 python example.py
[[6. 6. 6.]
[6. 6. 6.]
[6. 6. 6.]]
``allreduce`` is just one example of the MPI primitives you can use. `See all supported operations here <https://mpi4jax.readthedocs.org/en/latest/api.html>`_.
Community guidelines
--------------------
If you have a question or feature request, or want to report a bug, feel free to `open an issue <https://github.com/mpi4jax/mpi4jax/issues>`_.
We welcome contributions of any kind `through pull requests <https://github.com/mpi4jax/mpi4jax/pulls>`_. For information on running our tests, debugging, and contribution guidelines please `refer to the corresponding documentation page <https://mpi4jax.readthedocs.org/en/latest/developers.html>`_.
How to cite
-----------
If you use ``mpi4jax`` in your work, please consider citing the following article:
::
@article{mpi4jax,
doi = {10.21105/joss.03419},
url = {https://doi.org/10.21105/joss.03419},
year = {2021},
publisher = {The Open Journal},
volume = {6},
number = {65},
pages = {3419},
author = {Dion Häfner and Filippo Vicentini},
title = {mpi4jax: Zero-copy MPI communication of JAX arrays},
journal = {Journal of Open Source Software}
}
.. |Tests| image:: https://github.com/mpi4jax/mpi4jax/workflows/Tests/badge.svg
:target: https://github.com/mpi4jax/mpi4jax/actions?query=branch%3Amaster
.. |codecov| image:: https://codecov.io/gh/mpi4jax/mpi4jax/branch/master/graph/badge.svg
:target: https://codecov.io/gh/mpi4jax/mpi4jax
.. |PyPI Version| image:: https://img.shields.io/pypi/v/mpi4jax
:target: https://pypi.org/project/mpi4jax/
.. |Conda Version| image:: https://img.shields.io/conda/vn/conda-forge/mpi4jax.svg
:target: https://anaconda.org/conda-forge/mpi4jax
.. |Documentation Status| image:: https://readthedocs.org/projects/mpi4jax/badge/?version=latest
:target: https://mpi4jax.readthedocs.io/en/latest/?badge=latest
.. |JOSS paper| image:: https://joss.theoj.org/papers/10.21105/joss.03419/status.svg
:target: https://doi.org/10.21105/joss.03419
Raw data
{
"_id": null,
"home_page": "https://github.com/mpi4jax/mpi4jax",
"name": "mpi4jax",
"maintainer": null,
"docs_url": null,
"requires_python": ">=3.8",
"maintainer_email": null,
"keywords": null,
"author": "Filippo Vicentini",
"author_email": "filippovicentini@gmail.com",
"download_url": "https://files.pythonhosted.org/packages/6a/99/4dd972a692846fd4abb4fcfb18f0cc0b1d907df4b06a3f35fc2bc14c43e6/mpi4jax-0.6.1.post3.tar.gz",
"platform": null,
"description": "mpi4jax\n=======\n\n|JOSS paper| |PyPI Version| |Conda Version| |Tests| |codecov| |Documentation Status|\n\n``mpi4jax`` enables zero-copy, multi-host communication of `JAX <https://jax.readthedocs.io/>`_ arrays, even from jitted code and from GPU memory.\n\n\nBut why?\n--------\n\nThe JAX framework `has great performance for scientific computing workloads <https://github.com/dionhaefner/pyhpc-benchmarks>`_, but its `multi-host capabilities <https://jax.readthedocs.io/en/latest/jax.html#jax.pmap>`_ are still limited.\n\nWith ``mpi4jax``, you can scale your JAX-based simulations to *entire CPU and GPU clusters* (without ever leaving ``jax.jit``).\n\nIn the spirit of differentiable programming, ``mpi4jax`` also supports differentiating through some MPI operations.\n\n\nInstallation\n------------\n\n``mpi4jax`` is available through ``pip`` and ``conda``:\n\n.. code:: bash\n\n $ pip install mpi4jax # Pip\n $ conda install -c conda-forge mpi4jax # conda\n\nDepending on the different jax backends you want to use, you can install mpi4jax in the following way\n\n.. code:: bash\n\n # pip install 'jax[cpu]'\n $ pip install mpi4jax\n\n # pip install -U 'jax[cuda12]'\n $ pip install cython\n $ pip install mpi4jax --no-build-isolation\n\n # pip install -U 'jax[cuda12_local]'\n $ CUDA_ROOT=XXX pip install mpi4jax\n\n(for more informations on jax GPU distributions, `see the JAX installation instructions <https://github.com/google/jax#installation>`_)\n\nIn case your MPI installation is not detected correctly, `it can help to install mpi4py separately <https://mpi4py.readthedocs.io/en/stable/install.html>`_. When using a pre-installed ``mpi4py``, you *must* use ``--no-build-isolation`` when installing ``mpi4jax``:\n\n.. code:: bash\n\n # if mpi4py is already installed\n $ pip install cython\n $ pip install mpi4jax --no-build-isolation\n\n`Our documentation includes some more advanced installation examples. <https://mpi4jax.readthedocs.io/en/latest/installation.html>`_\n\n\nExample usage\n-------------\n\n.. code:: python\n\n from mpi4py import MPI\n import jax\n import jax.numpy as jnp\n import mpi4jax\n\n comm = MPI.COMM_WORLD\n rank = comm.Get_rank()\n\n @jax.jit\n def foo(arr):\n arr = arr + rank\n arr_sum, _ = mpi4jax.allreduce(arr, op=MPI.SUM, comm=comm)\n return arr_sum\n\n a = jnp.zeros((3, 3))\n result = foo(a)\n\n if rank == 0:\n print(result)\n\nRunning this script on 4 processes gives:\n\n.. code:: bash\n\n $ mpirun -n 4 python example.py\n [[6. 6. 6.]\n [6. 6. 6.]\n [6. 6. 6.]]\n\n``allreduce`` is just one example of the MPI primitives you can use. `See all supported operations here <https://mpi4jax.readthedocs.org/en/latest/api.html>`_.\n\nCommunity guidelines\n--------------------\n\nIf you have a question or feature request, or want to report a bug, feel free to `open an issue <https://github.com/mpi4jax/mpi4jax/issues>`_.\n\nWe welcome contributions of any kind `through pull requests <https://github.com/mpi4jax/mpi4jax/pulls>`_. For information on running our tests, debugging, and contribution guidelines please `refer to the corresponding documentation page <https://mpi4jax.readthedocs.org/en/latest/developers.html>`_.\n\nHow to cite\n-----------\n\nIf you use ``mpi4jax`` in your work, please consider citing the following article:\n\n::\n\n @article{mpi4jax,\n doi = {10.21105/joss.03419},\n url = {https://doi.org/10.21105/joss.03419},\n year = {2021},\n publisher = {The Open Journal},\n volume = {6},\n number = {65},\n pages = {3419},\n author = {Dion H\u00e4fner and Filippo Vicentini},\n title = {mpi4jax: Zero-copy MPI communication of JAX arrays},\n journal = {Journal of Open Source Software}\n }\n\n.. |Tests| image:: https://github.com/mpi4jax/mpi4jax/workflows/Tests/badge.svg\n :target: https://github.com/mpi4jax/mpi4jax/actions?query=branch%3Amaster\n.. |codecov| image:: https://codecov.io/gh/mpi4jax/mpi4jax/branch/master/graph/badge.svg\n :target: https://codecov.io/gh/mpi4jax/mpi4jax\n.. |PyPI Version| image:: https://img.shields.io/pypi/v/mpi4jax\n :target: https://pypi.org/project/mpi4jax/\n.. |Conda Version| image:: https://img.shields.io/conda/vn/conda-forge/mpi4jax.svg\n :target: https://anaconda.org/conda-forge/mpi4jax\n.. |Documentation Status| image:: https://readthedocs.org/projects/mpi4jax/badge/?version=latest\n :target: https://mpi4jax.readthedocs.io/en/latest/?badge=latest\n.. |JOSS paper| image:: https://joss.theoj.org/papers/10.21105/joss.03419/status.svg\n :target: https://doi.org/10.21105/joss.03419\n",
"bugtrack_url": null,
"license": "MIT",
"summary": "Zero-copy MPI communication of JAX arrays, for turbo-charged HPC applications in Python \u26a1",
"version": "0.6.1.post3",
"project_urls": {
"Homepage": "https://github.com/mpi4jax/mpi4jax"
},
"split_keywords": [],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "6a994dd972a692846fd4abb4fcfb18f0cc0b1d907df4b06a3f35fc2bc14c43e6",
"md5": "95c023a0a647544d1c60f04597372888",
"sha256": "a5f4e74c47b9178c80a87a123b2002ed6da1579e9b259e8fcb0ae22afa673bc3"
},
"downloads": -1,
"filename": "mpi4jax-0.6.1.post3.tar.gz",
"has_sig": false,
"md5_digest": "95c023a0a647544d1c60f04597372888",
"packagetype": "sdist",
"python_version": "source",
"requires_python": ">=3.8",
"size": 72123,
"upload_time": "2024-12-18T16:30:39",
"upload_time_iso_8601": "2024-12-18T16:30:39.358307Z",
"url": "https://files.pythonhosted.org/packages/6a/99/4dd972a692846fd4abb4fcfb18f0cc0b1d907df4b06a3f35fc2bc14c43e6/mpi4jax-0.6.1.post3.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2024-12-18 16:30:39",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "mpi4jax",
"github_project": "mpi4jax",
"travis_ci": false,
"coveralls": false,
"github_actions": true,
"lcname": "mpi4jax"
}