fedjax


Namefedjax JSON
Version 0.0.17 PyPI version JSON
download
home_pagehttps://github.com/google/fedjax
SummaryFederated learning simulation with JAX.
upload_time2023-07-12 09:37:12
maintainer
docs_urlNone
authorFedJAX Team
requires_python>=3.8
licenseApache 2.0
keywords federated python machine learning
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # FedJAX: Federated learning simulation with JAX

[![Build and minimal test](https://github.com/google/fedjax/actions/workflows/build_and_minimal_test.yml/badge.svg)](https://github.com/google/fedjax/actions/workflows/build_and_minimal_test.yml)
[![Documentation Status](https://readthedocs.org/projects/fedjax/badge/?version=latest)](https://fedjax.readthedocs.io/en/latest/?badge=latest)
![PyPI version](https://img.shields.io/pypi/v/fedjax)

[**Documentation**](https://fedjax.readthedocs.io/) |
[**Paper**](https://arxiv.org/abs/2108.02117)

NOTE: FedJAX is not an officially supported Google product. FedJAX is still in
the early stages and the API will likely continue to change.

## What is FedJAX?

FedJAX is a [JAX]-based open source library for
[Federated Learning](https://ai.googleblog.com/2017/04/federated-learning-collaborative.html)
simulations that emphasizes ease-of-use in research. With its simple primitives
for implementing federated learning algorithms, prepackaged datasets, models and
algorithms, and fast simulation speed, FedJAX aims to make developing and
evaluating federated algorithms faster and easier for researchers. FedJAX works
on accelerators (GPU and TPU) without much additional effort. Additional details
and benchmarks can be found in our [paper](https://arxiv.org/abs/2108.02117).

## Installation

You will need a moderately recent version of Python. Please check
[the PyPI page](https://pypi.org/project/fedjax/) for the up to date version
requirement.

First, install JAX. For a CPU-only version:

```
pip install --upgrade pip
pip install --upgrade jax jaxlib  # CPU-only version
```

For other devices (e.g. GPU), follow
[these instructions](https://github.com/google/jax#installation).

Then, install FedJAX from PyPI:

```
pip install fedjax
```

Or, to upgrade to the latest version of FedJAX:

```
pip install --upgrade git+https://github.com/google/fedjax.git
```

## Getting Started

Below is a simple example to verify FedJAX is installed correctly.

```python
import fedjax
import jax
import jax.numpy as jnp
import numpy as np

# {'client_id': client_dataset}.
fd = fedjax.InMemoryFederatedData({
    'a': {
        'x': np.array([1.0, 2.0, 3.0]),
        'y': np.array([2.0, 4.0, 6.0]),
    },
    'b': {
        'x': np.array([4.0]),
        'y': np.array([12.0])
    }
})
# Initial model parameters.
params = jnp.array(0.5)
# Mean squared error.
mse_loss = lambda params, batch: jnp.mean(
    (jnp.dot(batch['x'], params) - batch['y'])**2)
# Loss for clients 'a' and 'b'.
print(f"client a loss = {mse_loss(params, fd.get_client('a').all_examples())}")
print(f"client b loss = {mse_loss(params, fd.get_client('b').all_examples())}")
```

The following tutorial notebooks provide an introduction to FedJAX:

*   [Federated datasets](https://fedjax.readthedocs.io/en/latest/notebooks/dataset_tutorial.html)
*   [Working with models in FedJAX](https://fedjax.readthedocs.io/en/latest/notebooks/model_tutorial.html)
*   [Federated learning algorithms](https://fedjax.readthedocs.io/en/latest/notebooks/algorithms_tutorial.html)

You can also take a look at some of our working examples:

*   [Federated Averaging](examples/fed_avg.py)
*   [Full EMNIST example](examples/emnist_fed_avg.py)


## Citing FedJAX

To cite this repository:

```
@article{fedjax2021,
  title={{F}ed{JAX}: Federated learning simulation with {JAX}},
  author={Jae Hun Ro and Ananda Theertha Suresh and Ke Wu},
  journal={arXiv preprint arXiv:2108.02117},
  year={2021}
}
```

## Useful pointers

*   https://jax.readthedocs.io/en/latest/index.html
*   https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html
*   https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html
*   https://dm-haiku.readthedocs.io/en/latest/

[JAX]: https://github.com/google/jax
[Haiku]: https://github.com/deepmind/dm-haiku
[Stax]: https://github.com/google/jax/blob/main/jax/example_libraries/stax.py
[Optax]: https://github.com/deepmind/optax

            

Raw data

            {
    "_id": null,
    "home_page": "https://github.com/google/fedjax",
    "name": "fedjax",
    "maintainer": "",
    "docs_url": null,
    "requires_python": ">=3.8",
    "maintainer_email": "",
    "keywords": "federated python machine learning",
    "author": "FedJAX Team",
    "author_email": "no-reply@google.com",
    "download_url": "https://files.pythonhosted.org/packages/cf/64/c7c7929d9bdec871d6fe637e583b015ffe03ed5c6613c294488d7d47eeb3/fedjax-0.0.17.tar.gz",
    "platform": null,
    "description": "# FedJAX: Federated learning simulation with JAX\n\n[![Build and minimal test](https://github.com/google/fedjax/actions/workflows/build_and_minimal_test.yml/badge.svg)](https://github.com/google/fedjax/actions/workflows/build_and_minimal_test.yml)\n[![Documentation Status](https://readthedocs.org/projects/fedjax/badge/?version=latest)](https://fedjax.readthedocs.io/en/latest/?badge=latest)\n![PyPI version](https://img.shields.io/pypi/v/fedjax)\n\n[**Documentation**](https://fedjax.readthedocs.io/) |\n[**Paper**](https://arxiv.org/abs/2108.02117)\n\nNOTE: FedJAX is not an officially supported Google product. FedJAX is still in\nthe early stages and the API will likely continue to change.\n\n## What is FedJAX?\n\nFedJAX is a [JAX]-based open source library for\n[Federated Learning](https://ai.googleblog.com/2017/04/federated-learning-collaborative.html)\nsimulations that emphasizes ease-of-use in research. With its simple primitives\nfor implementing federated learning algorithms, prepackaged datasets, models and\nalgorithms, and fast simulation speed, FedJAX aims to make developing and\nevaluating federated algorithms faster and easier for researchers. FedJAX works\non accelerators (GPU and TPU) without much additional effort. Additional details\nand benchmarks can be found in our [paper](https://arxiv.org/abs/2108.02117).\n\n## Installation\n\nYou will need a moderately recent version of Python. Please check\n[the PyPI page](https://pypi.org/project/fedjax/) for the up to date version\nrequirement.\n\nFirst, install JAX. For a CPU-only version:\n\n```\npip install --upgrade pip\npip install --upgrade jax jaxlib  # CPU-only version\n```\n\nFor other devices (e.g. GPU), follow\n[these instructions](https://github.com/google/jax#installation).\n\nThen, install FedJAX from PyPI:\n\n```\npip install fedjax\n```\n\nOr, to upgrade to the latest version of FedJAX:\n\n```\npip install --upgrade git+https://github.com/google/fedjax.git\n```\n\n## Getting Started\n\nBelow is a simple example to verify FedJAX is installed correctly.\n\n```python\nimport fedjax\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\n\n# {'client_id': client_dataset}.\nfd = fedjax.InMemoryFederatedData({\n    'a': {\n        'x': np.array([1.0, 2.0, 3.0]),\n        'y': np.array([2.0, 4.0, 6.0]),\n    },\n    'b': {\n        'x': np.array([4.0]),\n        'y': np.array([12.0])\n    }\n})\n# Initial model parameters.\nparams = jnp.array(0.5)\n# Mean squared error.\nmse_loss = lambda params, batch: jnp.mean(\n    (jnp.dot(batch['x'], params) - batch['y'])**2)\n# Loss for clients 'a' and 'b'.\nprint(f\"client a loss = {mse_loss(params, fd.get_client('a').all_examples())}\")\nprint(f\"client b loss = {mse_loss(params, fd.get_client('b').all_examples())}\")\n```\n\nThe following tutorial notebooks provide an introduction to FedJAX:\n\n*   [Federated datasets](https://fedjax.readthedocs.io/en/latest/notebooks/dataset_tutorial.html)\n*   [Working with models in FedJAX](https://fedjax.readthedocs.io/en/latest/notebooks/model_tutorial.html)\n*   [Federated learning algorithms](https://fedjax.readthedocs.io/en/latest/notebooks/algorithms_tutorial.html)\n\nYou can also take a look at some of our working examples:\n\n*   [Federated Averaging](examples/fed_avg.py)\n*   [Full EMNIST example](examples/emnist_fed_avg.py)\n\n\n## Citing FedJAX\n\nTo cite this repository:\n\n```\n@article{fedjax2021,\n  title={{F}ed{JAX}: Federated learning simulation with {JAX}},\n  author={Jae Hun Ro and Ananda Theertha Suresh and Ke Wu},\n  journal={arXiv preprint arXiv:2108.02117},\n  year={2021}\n}\n```\n\n## Useful pointers\n\n*   https://jax.readthedocs.io/en/latest/index.html\n*   https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html\n*   https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html\n*   https://dm-haiku.readthedocs.io/en/latest/\n\n[JAX]: https://github.com/google/jax\n[Haiku]: https://github.com/deepmind/dm-haiku\n[Stax]: https://github.com/google/jax/blob/main/jax/example_libraries/stax.py\n[Optax]: https://github.com/deepmind/optax\n",
    "bugtrack_url": null,
    "license": "Apache 2.0",
    "summary": "Federated learning simulation with JAX.",
    "version": "0.0.17",
    "project_urls": {
        "Homepage": "https://github.com/google/fedjax"
    },
    "split_keywords": [
        "federated",
        "python",
        "machine",
        "learning"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "81ba6c5195fbbe38d6bd92936f18fc85fbaa0042c1e8b6a0424ccc30e82472fb",
                "md5": "e5a88248ecabb68643e70bca8431987f",
                "sha256": "42be1d21a57843ccdf3f2af802fe6f8fcdf8530da9b2812b178d18af8ac8c0a1"
            },
            "downloads": -1,
            "filename": "fedjax-0.0.17-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "e5a88248ecabb68643e70bca8431987f",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.8",
            "size": 616764,
            "upload_time": "2023-07-12T09:37:10",
            "upload_time_iso_8601": "2023-07-12T09:37:10.202307Z",
            "url": "https://files.pythonhosted.org/packages/81/ba/6c5195fbbe38d6bd92936f18fc85fbaa0042c1e8b6a0424ccc30e82472fb/fedjax-0.0.17-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "cf64c7c7929d9bdec871d6fe637e583b015ffe03ed5c6613c294488d7d47eeb3",
                "md5": "ad5489c22462405f9dda99af803b4b2f",
                "sha256": "8eab7a82b41b02095e804e50cb09c676edc1170affe4b881b34e320fac4e7b0c"
            },
            "downloads": -1,
            "filename": "fedjax-0.0.17.tar.gz",
            "has_sig": false,
            "md5_digest": "ad5489c22462405f9dda99af803b4b2f",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.8",
            "size": 133618,
            "upload_time": "2023-07-12T09:37:12",
            "upload_time_iso_8601": "2023-07-12T09:37:12.328544Z",
            "url": "https://files.pythonhosted.org/packages/cf/64/c7c7929d9bdec871d6fe637e583b015ffe03ed5c6613c294488d7d47eeb3/fedjax-0.0.17.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2023-07-12 09:37:12",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "google",
    "github_project": "fedjax",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "lcname": "fedjax"
}
        
Elapsed time: 0.08904s