mpi4jax


Namempi4jax JSON
Version 0.6.1.post3 PyPI version JSON
download
home_pagehttps://github.com/mpi4jax/mpi4jax
SummaryZero-copy MPI communication of JAX arrays, for turbo-charged HPC applications in Python ⚡
upload_time2024-12-18 16:30:39
maintainerNone
docs_urlNone
authorFilippo Vicentini
requires_python>=3.8
licenseMIT
keywords
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            mpi4jax
=======

|JOSS paper| |PyPI Version| |Conda Version| |Tests| |codecov| |Documentation Status|

``mpi4jax`` enables zero-copy, multi-host communication of `JAX <https://jax.readthedocs.io/>`_ arrays, even from jitted code and from GPU memory.


But why?
--------

The JAX framework `has great performance for scientific computing workloads <https://github.com/dionhaefner/pyhpc-benchmarks>`_, but its `multi-host capabilities <https://jax.readthedocs.io/en/latest/jax.html#jax.pmap>`_ are still limited.

With ``mpi4jax``, you can scale your JAX-based simulations to *entire CPU and GPU clusters* (without ever leaving ``jax.jit``).

In the spirit of differentiable programming, ``mpi4jax`` also supports differentiating through some MPI operations.


Installation
------------

``mpi4jax`` is available through ``pip`` and ``conda``:

.. code:: bash

   $ pip install mpi4jax                     # Pip
   $ conda install -c conda-forge mpi4jax    # conda

Depending on the different jax backends you want to use, you can install mpi4jax in the following way

.. code:: bash

   # pip install 'jax[cpu]'
   $ pip install mpi4jax

   # pip install -U 'jax[cuda12]'
   $ pip install cython
   $ pip install mpi4jax --no-build-isolation

   # pip install -U 'jax[cuda12_local]'
   $ CUDA_ROOT=XXX pip install mpi4jax

(for more informations on jax GPU distributions, `see the JAX installation instructions <https://github.com/google/jax#installation>`_)

In case your MPI installation is not detected correctly, `it can help to install mpi4py separately <https://mpi4py.readthedocs.io/en/stable/install.html>`_. When using a pre-installed ``mpi4py``, you *must* use ``--no-build-isolation`` when installing ``mpi4jax``:

.. code:: bash

   # if mpi4py is already installed
   $ pip install cython
   $ pip install mpi4jax --no-build-isolation

`Our documentation includes some more advanced installation examples. <https://mpi4jax.readthedocs.io/en/latest/installation.html>`_


Example usage
-------------

.. code:: python

   from mpi4py import MPI
   import jax
   import jax.numpy as jnp
   import mpi4jax

   comm = MPI.COMM_WORLD
   rank = comm.Get_rank()

   @jax.jit
   def foo(arr):
      arr = arr + rank
      arr_sum, _ = mpi4jax.allreduce(arr, op=MPI.SUM, comm=comm)
      return arr_sum

   a = jnp.zeros((3, 3))
   result = foo(a)

   if rank == 0:
      print(result)

Running this script on 4 processes gives:

.. code:: bash

   $ mpirun -n 4 python example.py
   [[6. 6. 6.]
    [6. 6. 6.]
    [6. 6. 6.]]

``allreduce`` is just one example of the MPI primitives you can use. `See all supported operations here <https://mpi4jax.readthedocs.org/en/latest/api.html>`_.

Community guidelines
--------------------

If you have a question or feature request, or want to report a bug, feel free to `open an issue <https://github.com/mpi4jax/mpi4jax/issues>`_.

We welcome contributions of any kind `through pull requests <https://github.com/mpi4jax/mpi4jax/pulls>`_. For information on running our tests, debugging, and contribution guidelines please `refer to the corresponding documentation page <https://mpi4jax.readthedocs.org/en/latest/developers.html>`_.

How to cite
-----------

If you use ``mpi4jax`` in your work, please consider citing the following article:

::

  @article{mpi4jax,
    doi = {10.21105/joss.03419},
    url = {https://doi.org/10.21105/joss.03419},
    year = {2021},
    publisher = {The Open Journal},
    volume = {6},
    number = {65},
    pages = {3419},
    author = {Dion Häfner and Filippo Vicentini},
    title = {mpi4jax: Zero-copy MPI communication of JAX arrays},
    journal = {Journal of Open Source Software}
  }

.. |Tests| image:: https://github.com/mpi4jax/mpi4jax/workflows/Tests/badge.svg
   :target: https://github.com/mpi4jax/mpi4jax/actions?query=branch%3Amaster
.. |codecov| image:: https://codecov.io/gh/mpi4jax/mpi4jax/branch/master/graph/badge.svg
   :target: https://codecov.io/gh/mpi4jax/mpi4jax
.. |PyPI Version| image:: https://img.shields.io/pypi/v/mpi4jax
   :target: https://pypi.org/project/mpi4jax/
.. |Conda Version| image:: https://img.shields.io/conda/vn/conda-forge/mpi4jax.svg
   :target: https://anaconda.org/conda-forge/mpi4jax
.. |Documentation Status| image:: https://readthedocs.org/projects/mpi4jax/badge/?version=latest
   :target: https://mpi4jax.readthedocs.io/en/latest/?badge=latest
.. |JOSS paper| image:: https://joss.theoj.org/papers/10.21105/joss.03419/status.svg
   :target: https://doi.org/10.21105/joss.03419

            

Raw data

            {
    "_id": null,
    "home_page": "https://github.com/mpi4jax/mpi4jax",
    "name": "mpi4jax",
    "maintainer": null,
    "docs_url": null,
    "requires_python": ">=3.8",
    "maintainer_email": null,
    "keywords": null,
    "author": "Filippo Vicentini",
    "author_email": "filippovicentini@gmail.com",
    "download_url": "https://files.pythonhosted.org/packages/6a/99/4dd972a692846fd4abb4fcfb18f0cc0b1d907df4b06a3f35fc2bc14c43e6/mpi4jax-0.6.1.post3.tar.gz",
    "platform": null,
    "description": "mpi4jax\n=======\n\n|JOSS paper| |PyPI Version| |Conda Version| |Tests| |codecov| |Documentation Status|\n\n``mpi4jax`` enables zero-copy, multi-host communication of `JAX <https://jax.readthedocs.io/>`_ arrays, even from jitted code and from GPU memory.\n\n\nBut why?\n--------\n\nThe JAX framework `has great performance for scientific computing workloads <https://github.com/dionhaefner/pyhpc-benchmarks>`_, but its `multi-host capabilities <https://jax.readthedocs.io/en/latest/jax.html#jax.pmap>`_ are still limited.\n\nWith ``mpi4jax``, you can scale your JAX-based simulations to *entire CPU and GPU clusters* (without ever leaving ``jax.jit``).\n\nIn the spirit of differentiable programming, ``mpi4jax`` also supports differentiating through some MPI operations.\n\n\nInstallation\n------------\n\n``mpi4jax`` is available through ``pip`` and ``conda``:\n\n.. code:: bash\n\n   $ pip install mpi4jax                     # Pip\n   $ conda install -c conda-forge mpi4jax    # conda\n\nDepending on the different jax backends you want to use, you can install mpi4jax in the following way\n\n.. code:: bash\n\n   # pip install 'jax[cpu]'\n   $ pip install mpi4jax\n\n   # pip install -U 'jax[cuda12]'\n   $ pip install cython\n   $ pip install mpi4jax --no-build-isolation\n\n   # pip install -U 'jax[cuda12_local]'\n   $ CUDA_ROOT=XXX pip install mpi4jax\n\n(for more informations on jax GPU distributions, `see the JAX installation instructions <https://github.com/google/jax#installation>`_)\n\nIn case your MPI installation is not detected correctly, `it can help to install mpi4py separately <https://mpi4py.readthedocs.io/en/stable/install.html>`_. When using a pre-installed ``mpi4py``, you *must* use ``--no-build-isolation`` when installing ``mpi4jax``:\n\n.. code:: bash\n\n   # if mpi4py is already installed\n   $ pip install cython\n   $ pip install mpi4jax --no-build-isolation\n\n`Our documentation includes some more advanced installation examples. <https://mpi4jax.readthedocs.io/en/latest/installation.html>`_\n\n\nExample usage\n-------------\n\n.. code:: python\n\n   from mpi4py import MPI\n   import jax\n   import jax.numpy as jnp\n   import mpi4jax\n\n   comm = MPI.COMM_WORLD\n   rank = comm.Get_rank()\n\n   @jax.jit\n   def foo(arr):\n      arr = arr + rank\n      arr_sum, _ = mpi4jax.allreduce(arr, op=MPI.SUM, comm=comm)\n      return arr_sum\n\n   a = jnp.zeros((3, 3))\n   result = foo(a)\n\n   if rank == 0:\n      print(result)\n\nRunning this script on 4 processes gives:\n\n.. code:: bash\n\n   $ mpirun -n 4 python example.py\n   [[6. 6. 6.]\n    [6. 6. 6.]\n    [6. 6. 6.]]\n\n``allreduce`` is just one example of the MPI primitives you can use. `See all supported operations here <https://mpi4jax.readthedocs.org/en/latest/api.html>`_.\n\nCommunity guidelines\n--------------------\n\nIf you have a question or feature request, or want to report a bug, feel free to `open an issue <https://github.com/mpi4jax/mpi4jax/issues>`_.\n\nWe welcome contributions of any kind `through pull requests <https://github.com/mpi4jax/mpi4jax/pulls>`_. For information on running our tests, debugging, and contribution guidelines please `refer to the corresponding documentation page <https://mpi4jax.readthedocs.org/en/latest/developers.html>`_.\n\nHow to cite\n-----------\n\nIf you use ``mpi4jax`` in your work, please consider citing the following article:\n\n::\n\n  @article{mpi4jax,\n    doi = {10.21105/joss.03419},\n    url = {https://doi.org/10.21105/joss.03419},\n    year = {2021},\n    publisher = {The Open Journal},\n    volume = {6},\n    number = {65},\n    pages = {3419},\n    author = {Dion H\u00e4fner and Filippo Vicentini},\n    title = {mpi4jax: Zero-copy MPI communication of JAX arrays},\n    journal = {Journal of Open Source Software}\n  }\n\n.. |Tests| image:: https://github.com/mpi4jax/mpi4jax/workflows/Tests/badge.svg\n   :target: https://github.com/mpi4jax/mpi4jax/actions?query=branch%3Amaster\n.. |codecov| image:: https://codecov.io/gh/mpi4jax/mpi4jax/branch/master/graph/badge.svg\n   :target: https://codecov.io/gh/mpi4jax/mpi4jax\n.. |PyPI Version| image:: https://img.shields.io/pypi/v/mpi4jax\n   :target: https://pypi.org/project/mpi4jax/\n.. |Conda Version| image:: https://img.shields.io/conda/vn/conda-forge/mpi4jax.svg\n   :target: https://anaconda.org/conda-forge/mpi4jax\n.. |Documentation Status| image:: https://readthedocs.org/projects/mpi4jax/badge/?version=latest\n   :target: https://mpi4jax.readthedocs.io/en/latest/?badge=latest\n.. |JOSS paper| image:: https://joss.theoj.org/papers/10.21105/joss.03419/status.svg\n   :target: https://doi.org/10.21105/joss.03419\n",
    "bugtrack_url": null,
    "license": "MIT",
    "summary": "Zero-copy MPI communication of JAX arrays, for turbo-charged HPC applications in Python \u26a1",
    "version": "0.6.1.post3",
    "project_urls": {
        "Homepage": "https://github.com/mpi4jax/mpi4jax"
    },
    "split_keywords": [],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "6a994dd972a692846fd4abb4fcfb18f0cc0b1d907df4b06a3f35fc2bc14c43e6",
                "md5": "95c023a0a647544d1c60f04597372888",
                "sha256": "a5f4e74c47b9178c80a87a123b2002ed6da1579e9b259e8fcb0ae22afa673bc3"
            },
            "downloads": -1,
            "filename": "mpi4jax-0.6.1.post3.tar.gz",
            "has_sig": false,
            "md5_digest": "95c023a0a647544d1c60f04597372888",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.8",
            "size": 72123,
            "upload_time": "2024-12-18T16:30:39",
            "upload_time_iso_8601": "2024-12-18T16:30:39.358307Z",
            "url": "https://files.pythonhosted.org/packages/6a/99/4dd972a692846fd4abb4fcfb18f0cc0b1d907df4b06a3f35fc2bc14c43e6/mpi4jax-0.6.1.post3.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2024-12-18 16:30:39",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "mpi4jax",
    "github_project": "mpi4jax",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "lcname": "mpi4jax"
}
        
Elapsed time: 1.17959s