jax-dataloader


Namejax-dataloader JSON
Version 0.1.0 PyPI version JSON
download
home_pagehttps://github.com/birkhoffg/jax-dataloader
SummaryDataloader for jax
upload_time2024-02-15 18:35:12
maintainer
docs_urlNone
authorBirkhoffG
requires_python>=3.8
licenseApache Software License 2.0
keywords python jax dataloader pytorch tensorflow datasets huggingface
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # Dataloader for JAX

<!-- WARNING: THIS FILE WAS AUTOGENERATED! DO NOT EDIT! -->

![Python](https://img.shields.io/pypi/pyversions/jax-dataloader.svg)
![CI
status](https://github.com/BirkhoffG/jax-dataloader/actions/workflows/nbdev.yaml/badge.svg)
![Docs](https://github.com/BirkhoffG/jax-dataloader/actions/workflows/deploy.yaml/badge.svg)
![pypi](https://img.shields.io/pypi/v/jax-dataloader.svg) ![GitHub
License](https://img.shields.io/github/license/BirkhoffG/jax-dataloader.svg)
<a href="https://static.pepy.tech/badge/jax-dataloader"><img src="https://static.pepy.tech/badge/jax-dataloader" alt="Downloads"></a>

## Overview

`jax_dataloader` brings *pytorch-like* dataloader API to `jax`. It
supports

- **4 datasets to download and pre-process data**:
  - [jax dataset](https://birkhoffg.github.io/jax-dataloader/dataset/)
  - [huggingface datasets](https://github.com/huggingface/datasets)
  - [pytorch
    Dataset](https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset)
  - [tensorflow dataset](www.tensorflow.org/datasets)
- **3 backends to iteratively load batches**:
  - [jax
    dataloader](https://birkhoffg.github.io/jax-dataloader/core.html#jax-dataloader)
  - [pytorch
    dataloader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader)
  - [tensorflow dataset](www.tensorflow.org/datasets)

A minimum `jax-dataloader` example:

``` python
import jax_dataloader as jdl

dataloader = jdl.DataLoader(
    dataset, # Can be a jdl.Dataset or pytorch or huggingface dataset
    backend='jax', # Use 'jax' for loading data (also supports `pytorch`)
)

batch = next(iter(dataloader)) # iterate next batch
```

## Installation

The latest `jax-dataloader` release can directly be installed from PyPI:

``` sh
pip install jax-dataloader
```

or install directly from the repository:

``` sh
pip install git+https://github.com/BirkhoffG/jax-dataloader.git
```

<div>

> **Note**
>
> We keep `jax-dataloader`’s dependencies minimum, which only install
> `jax`-related dependencies, and `plum-dispatch` for backend
> dispatching. If you wish to use integration of `pytorch`, huggingface
> `datasets`, or `tensorflow`, we recommend manually install those
> dependencies.
>
> You can also run `pip install jax-dataloader[all]` to install
> everything (not recommended).

</div>

## Usage

[`jax_dataloader.core.DataLoader`](https://birkhoffg.github.io/jax-dataloader/core.html#dataloader)
follows similar API as the pytorch dataloader.

- The `dataset` should be an object of the subclass of
  `jax_dataloader.core.Dataset` or `torch.utils.data.Dataset` or (the
  huggingface) `datasets.Dataset` or `tf.data.Dataset`.
- The `backend` should be one of `"jax"` or `"pytorch"` or
  `"tensorflow"`. This argument specifies which backend dataloader to
  load batches.

Note that not every dataset is compatible with every backend. See the
compatibility table below:

|                | `jdl.Dataset` | `torch_data.Dataset` | `tf.data.Dataset` | `datasets.Dataset` |
|:---------------|:--------------|:---------------------|:------------------|:-------------------|
| `"jax"`        | ✅            | ❌                   | ❌                | ✅                 |
| `"pytorch"`    | ✅            | ✅                   | ❌                | ✅                 |
| `"tensorflow"` | ✅            | ❌                   | ✅                | ✅                 |

### Using [`ArrayDataset`](https://birkhoffg.github.io/jax-dataloader/dataset.html#arraydataset)

The `jax_dataloader.core.ArrayDataset` is an easy way to wrap multiple
`jax.numpy.array` into one Dataset. For example, we can create an
[`ArrayDataset`](https://birkhoffg.github.io/jax-dataloader/dataset.html#arraydataset)
as follows:

``` python
# Create features `X` and labels `y`
X = jnp.arange(100).reshape(10, 10)
y = jnp.arange(10)
# Create an `ArrayDataset`
arr_ds = jdl.ArrayDataset(X, y)
```

This `arr_ds` can be loaded by *every* backends.

``` python
# Create a `DataLoader` from the `ArrayDataset` via jax backend
dataloader = jdl.DataLoader(arr_ds, 'jax', batch_size=5, shuffle=True)
# Or we can use the pytorch backend
dataloader = jdl.DataLoader(arr_ds, 'pytorch', batch_size=5, shuffle=True)
```

### Using Huggingface Datasets

The huggingface [datasets](https://github.com/huggingface/datasets) is a
morden library for downloading, pre-processing, and sharing datasets.
`jax_dataloader` supports directly passing the huggingface datasets.

``` python
from datasets import load_dataset
```

For example, We load the `"squad"` dataset from `datasets`:

``` python
hf_ds = load_dataset("squad")
```

Then, we can use `jax_dataloader` to load batches of `hf_ds`.

``` python
# Create a `DataLoader` from the `datasets.Dataset` via jax backend
dataloader = jdl.DataLoader(hf_ds['train'], 'jax', batch_size=5, shuffle=True)
# Or we can use the pytorch backend
dataloader = jdl.DataLoader(hf_ds['train'], 'pytorch', batch_size=5, shuffle=True)
```

### Using Pytorch Datasets

The [pytorch Dataset](https://pytorch.org/docs/stable/data.html) and its
ecosystems (e.g.,
[torchvision](https://pytorch.org/vision/stable/index.html),
[torchtext](https://pytorch.org/text/stable/index.html),
[torchaudio](https://pytorch.org/audio/stable/index.html)) supports many
built-in datasets. `jax_dataloader` supports directly passing the
pytorch Dataset.

<div>

> **Note**
>
> Unfortuantely, the [pytorch
> Dataset](https://pytorch.org/docs/stable/data.html) can only work with
> `backend=pytorch`. See the belowing example.

</div>

``` python
from torchvision.datasets import MNIST
import numpy as np
```

We load the MNIST dataset from `torchvision`. The `ToNumpy` object
transforms images to `numpy.array`.

``` python
class ToNumpy(object):
  def __call__(self, pic):
    return np.array(pic, dtype=float)
```

``` python
pt_ds = MNIST('/tmp/mnist/', download=True, transform=ToNumpy(), train=False)
```

This `pt_ds` can **only** be loaded via `"pytorch"` dataloaders.

``` python
dataloader = jdl.DataLoader(pt_ds, 'pytorch', batch_size=5, shuffle=True)
```

### Using Tensowflow Datasets

`jax_dataloader` supports directly passing the [tensorflow
datasets](www.tensorflow.org/datasets).

``` python
import tensorflow_datasets as tfds
import tensorflow as tf
```

For instance, we can load the MNIST dataset from `tensorflow_datasets`

``` python
tf_ds = tfds.load('mnist', split='test', as_supervised=True)
```

and use `jax_dataloader` for iterating the dataset.

``` python
dataloader = jdl.DataLoader(tf_ds, 'tensorflow', batch_size=5, shuffle=True)
```

            

Raw data

            {
    "_id": null,
    "home_page": "https://github.com/birkhoffg/jax-dataloader",
    "name": "jax-dataloader",
    "maintainer": "",
    "docs_url": null,
    "requires_python": ">=3.8",
    "maintainer_email": "",
    "keywords": "python jax dataloader pytorch tensorflow datasets huggingface",
    "author": "BirkhoffG",
    "author_email": "26811230+BirkhoffG@users.noreply.github.com",
    "download_url": "https://files.pythonhosted.org/packages/1d/29/f9d014ed4fd923a1efc49ac6a5d164de7f0479995eecbbd93d0fd4bfedc7/jax-dataloader-0.1.0.tar.gz",
    "platform": null,
    "description": "# Dataloader for JAX\n\n<!-- WARNING: THIS FILE WAS AUTOGENERATED! DO NOT EDIT! -->\n\n![Python](https://img.shields.io/pypi/pyversions/jax-dataloader.svg)\n![CI\nstatus](https://github.com/BirkhoffG/jax-dataloader/actions/workflows/nbdev.yaml/badge.svg)\n![Docs](https://github.com/BirkhoffG/jax-dataloader/actions/workflows/deploy.yaml/badge.svg)\n![pypi](https://img.shields.io/pypi/v/jax-dataloader.svg) ![GitHub\nLicense](https://img.shields.io/github/license/BirkhoffG/jax-dataloader.svg)\n<a href=\"https://static.pepy.tech/badge/jax-dataloader\"><img src=\"https://static.pepy.tech/badge/jax-dataloader\" alt=\"Downloads\"></a>\n\n## Overview\n\n`jax_dataloader` brings *pytorch-like* dataloader API to `jax`. It\nsupports\n\n- **4 datasets to download and pre-process data**:\n  - [jax dataset](https://birkhoffg.github.io/jax-dataloader/dataset/)\n  - [huggingface datasets](https://github.com/huggingface/datasets)\n  - [pytorch\n    Dataset](https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset)\n  - [tensorflow dataset](www.tensorflow.org/datasets)\n- **3 backends to iteratively load batches**:\n  - [jax\n    dataloader](https://birkhoffg.github.io/jax-dataloader/core.html#jax-dataloader)\n  - [pytorch\n    dataloader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader)\n  - [tensorflow dataset](www.tensorflow.org/datasets)\n\nA minimum `jax-dataloader` example:\n\n``` python\nimport jax_dataloader as jdl\n\ndataloader = jdl.DataLoader(\n    dataset, # Can be a jdl.Dataset or pytorch or huggingface dataset\n    backend='jax', # Use 'jax' for loading data (also supports `pytorch`)\n)\n\nbatch = next(iter(dataloader)) # iterate next batch\n```\n\n## Installation\n\nThe latest `jax-dataloader` release can directly be installed from PyPI:\n\n``` sh\npip install jax-dataloader\n```\n\nor install directly from the repository:\n\n``` sh\npip install git+https://github.com/BirkhoffG/jax-dataloader.git\n```\n\n<div>\n\n> **Note**\n>\n> We keep `jax-dataloader`\u2019s dependencies minimum, which only install\n> `jax`-related dependencies, and `plum-dispatch` for backend\n> dispatching. If you wish to use integration of `pytorch`, huggingface\n> `datasets`, or `tensorflow`, we recommend manually install those\n> dependencies.\n>\n> You can also run `pip install jax-dataloader[all]` to install\n> everything (not recommended).\n\n</div>\n\n## Usage\n\n[`jax_dataloader.core.DataLoader`](https://birkhoffg.github.io/jax-dataloader/core.html#dataloader)\nfollows similar API as the pytorch dataloader.\n\n- The `dataset` should be an object of the subclass of\n  `jax_dataloader.core.Dataset` or `torch.utils.data.Dataset` or (the\n  huggingface) `datasets.Dataset` or `tf.data.Dataset`.\n- The `backend` should be one of `\"jax\"` or `\"pytorch\"` or\n  `\"tensorflow\"`. This argument specifies which backend dataloader to\n  load batches.\n\nNote that not every dataset is compatible with every backend. See the\ncompatibility table below:\n\n|                | `jdl.Dataset` | `torch_data.Dataset` | `tf.data.Dataset` | `datasets.Dataset` |\n|:---------------|:--------------|:---------------------|:------------------|:-------------------|\n| `\"jax\"`        | \u2705            | \u274c                   | \u274c                | \u2705                 |\n| `\"pytorch\"`    | \u2705            | \u2705                   | \u274c                | \u2705                 |\n| `\"tensorflow\"` | \u2705            | \u274c                   | \u2705                | \u2705                 |\n\n### Using [`ArrayDataset`](https://birkhoffg.github.io/jax-dataloader/dataset.html#arraydataset)\n\nThe `jax_dataloader.core.ArrayDataset` is an easy way to wrap multiple\n`jax.numpy.array` into one Dataset. For example, we can create an\n[`ArrayDataset`](https://birkhoffg.github.io/jax-dataloader/dataset.html#arraydataset)\nas follows:\n\n``` python\n# Create features `X` and labels `y`\nX = jnp.arange(100).reshape(10, 10)\ny = jnp.arange(10)\n# Create an `ArrayDataset`\narr_ds = jdl.ArrayDataset(X, y)\n```\n\nThis `arr_ds` can be loaded by *every* backends.\n\n``` python\n# Create a `DataLoader` from the `ArrayDataset` via jax backend\ndataloader = jdl.DataLoader(arr_ds, 'jax', batch_size=5, shuffle=True)\n# Or we can use the pytorch backend\ndataloader = jdl.DataLoader(arr_ds, 'pytorch', batch_size=5, shuffle=True)\n```\n\n### Using Huggingface Datasets\n\nThe huggingface [datasets](https://github.com/huggingface/datasets) is a\nmorden library for downloading, pre-processing, and sharing datasets.\n`jax_dataloader` supports directly passing the huggingface datasets.\n\n``` python\nfrom datasets import load_dataset\n```\n\nFor example, We load the `\"squad\"` dataset from `datasets`:\n\n``` python\nhf_ds = load_dataset(\"squad\")\n```\n\nThen, we can use `jax_dataloader` to load batches of `hf_ds`.\n\n``` python\n# Create a `DataLoader` from the `datasets.Dataset` via jax backend\ndataloader = jdl.DataLoader(hf_ds['train'], 'jax', batch_size=5, shuffle=True)\n# Or we can use the pytorch backend\ndataloader = jdl.DataLoader(hf_ds['train'], 'pytorch', batch_size=5, shuffle=True)\n```\n\n### Using Pytorch Datasets\n\nThe [pytorch Dataset](https://pytorch.org/docs/stable/data.html) and its\necosystems (e.g.,\n[torchvision](https://pytorch.org/vision/stable/index.html),\n[torchtext](https://pytorch.org/text/stable/index.html),\n[torchaudio](https://pytorch.org/audio/stable/index.html)) supports many\nbuilt-in datasets. `jax_dataloader` supports directly passing the\npytorch Dataset.\n\n<div>\n\n> **Note**\n>\n> Unfortuantely, the [pytorch\n> Dataset](https://pytorch.org/docs/stable/data.html) can only work with\n> `backend=pytorch`. See the belowing example.\n\n</div>\n\n``` python\nfrom torchvision.datasets import MNIST\nimport numpy as np\n```\n\nWe load the MNIST dataset from `torchvision`. The `ToNumpy` object\ntransforms images to `numpy.array`.\n\n``` python\nclass ToNumpy(object):\n  def __call__(self, pic):\n    return np.array(pic, dtype=float)\n```\n\n``` python\npt_ds = MNIST('/tmp/mnist/', download=True, transform=ToNumpy(), train=False)\n```\n\nThis `pt_ds` can **only** be loaded via `\"pytorch\"` dataloaders.\n\n``` python\ndataloader = jdl.DataLoader(pt_ds, 'pytorch', batch_size=5, shuffle=True)\n```\n\n### Using Tensowflow Datasets\n\n`jax_dataloader` supports directly passing the [tensorflow\ndatasets](www.tensorflow.org/datasets).\n\n``` python\nimport tensorflow_datasets as tfds\nimport tensorflow as tf\n```\n\nFor instance, we can load the MNIST dataset from `tensorflow_datasets`\n\n``` python\ntf_ds = tfds.load('mnist', split='test', as_supervised=True)\n```\n\nand use `jax_dataloader` for iterating the dataset.\n\n``` python\ndataloader = jdl.DataLoader(tf_ds, 'tensorflow', batch_size=5, shuffle=True)\n```\n",
    "bugtrack_url": null,
    "license": "Apache Software License 2.0",
    "summary": "Dataloader for jax",
    "version": "0.1.0",
    "project_urls": {
        "Homepage": "https://github.com/birkhoffg/jax-dataloader"
    },
    "split_keywords": [
        "python",
        "jax",
        "dataloader",
        "pytorch",
        "tensorflow",
        "datasets",
        "huggingface"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "58fbf772b193be9ecf270aff32c931e1720f0225861b972e207ce073fcac59e9",
                "md5": "de2681982d644f5fa049332b23ac91e3",
                "sha256": "1029e984cface497eb42ac2603f809a4723dbeaa0a7ef765e8108f46d03f0189"
            },
            "downloads": -1,
            "filename": "jax_dataloader-0.1.0-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "de2681982d644f5fa049332b23ac91e3",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.8",
            "size": 18952,
            "upload_time": "2024-02-15T18:35:10",
            "upload_time_iso_8601": "2024-02-15T18:35:10.615342Z",
            "url": "https://files.pythonhosted.org/packages/58/fb/f772b193be9ecf270aff32c931e1720f0225861b972e207ce073fcac59e9/jax_dataloader-0.1.0-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "1d29f9d014ed4fd923a1efc49ac6a5d164de7f0479995eecbbd93d0fd4bfedc7",
                "md5": "b6477ffb2942c2f8930a36aa0a2d6c36",
                "sha256": "19058ee94fe548951f9f4b846965391d15e8ddcf6f8f1176836efb43d3dbe3d4"
            },
            "downloads": -1,
            "filename": "jax-dataloader-0.1.0.tar.gz",
            "has_sig": false,
            "md5_digest": "b6477ffb2942c2f8930a36aa0a2d6c36",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.8",
            "size": 18733,
            "upload_time": "2024-02-15T18:35:12",
            "upload_time_iso_8601": "2024-02-15T18:35:12.051121Z",
            "url": "https://files.pythonhosted.org/packages/1d/29/f9d014ed4fd923a1efc49ac6a5d164de7f0479995eecbbd93d0fd4bfedc7/jax-dataloader-0.1.0.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2024-02-15 18:35:12",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "birkhoffg",
    "github_project": "jax-dataloader",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "lcname": "jax-dataloader"
}
        
Elapsed time: 0.19770s