jax-tqdm


Namejax-tqdm JSON
Version 0.3.1 PyPI version JSON
download
home_pagehttps://github.com/jeremiecoullon/jax-tqdm
SummaryTqdm progress bar for JAX scans and loops
upload_time2024-10-14 23:39:39
maintainerNone
docs_urlNone
authorJeremie Coullon
requires_python<4.0,>=3.9
licenseMIT
keywords jax tqdm
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # JAX-Tqdm

Add a [tqdm](https://github.com/tqdm/tqdm) progress bar to your JAX scans and loops.

## Installation

Install with pip:

```bash
pip install jax-tqdm
```

## Example Usage

### In `jax.lax.scan`

```python
from jax_tqdm import scan_tqdm
from jax import lax
import jax.numpy as jnp

n = 10_000

@scan_tqdm(n)
def step(carry, x):
    return carry + 1, carry + 1

last_number, all_numbers = lax.scan(step, 0, jnp.arange(n))
```

### In `jax.lax.fori_loop`

```python
from jax_tqdm import loop_tqdm
from jax import lax

n = 10_000

@loop_tqdm(n)
def step(i, val):
    return val + 1

last_number = lax.fori_loop(0, n, step, 0)
```

### Scans & Loops Inside Vmap

For scans and loops inside a map, jax-tqdm can print stacked progress bars
showing the individual progress of each process. To do this you can wrap
the initial value of the loop or scan inside a `PBar` class, along with the
index of the progress bar. For example

```python
from jax_tqdm import PBar, scan_tqdm
import jax

n = 10_000

@scan_tqdm(n)
def step(carry, _):
    return carry + 1, carry + 1

def map_func(i):
    # Wrap the initial value and pass the
    # progress bar index
    init = PBar(id=i, carry=0)
    final_value, _all_numbers = jax.lax.scan(
        step, init, jax.numpy.arange(n)
    )
    return (
        final_value.carry,
        _all_numbers,
    )

last_numbers, all_numbers = jax.vmap(map_func)(jax.numpy.arange(10))
```

The indices of the progress bars should be contiguous integers starting
from 0.

### Print Rate

By default, the progress bar is updated 20 times over the course of the scan/loop
(for performance purposes, see [below](#why-jax-tqdm)). This
update rate can be manually controlled with the `print_rate` keyword argument. For
example:

```python
from jax_tqdm import scan_tqdm
from jax import lax
import jax.numpy as jnp

n = 10_000

@scan_tqdm(n, print_rate=2)
def step(carry, x):
    return carry + 1, carry + 1

last_number, all_numbers = lax.scan(step, 0, jnp.arange(n))
```

will update every other step.

### Progress Bar Type

You can select the [tqdm](https://github.com/tqdm/tqdm) [submodule](https://github.com/tqdm/tqdm/tree/master?tab=readme-ov-file#submodules) manually with the `tqdm_type` option. The options are `'std'`, `'notebook'`, or `'auto'`.
```python
from jax_tqdm import scan_tqdm
from jax import lax
import jax.numpy as jnp

n = 10_000

@scan_tqdm(n, print_rate=1, tqdm_type='std') # tqdm_type='std' or 'notebook' or 'auto'
def step(carry, x):
    return carry + 1, carry + 1

last_number, all_numbers = lax.scan(step, 0, jnp.arange(n))
```

### Progress Bar Options

Any additional keyword arguments are passed to the [tqdm](https://github.com/tqdm/tqdm)
progress bar constructor. For example:

```python
from jax_tqdm import scan_tqdm
from jax import lax
import jax.numpy as jnp

n = 10_000

@scan_tqdm(n, print_rate=1, desc='progress bar', position=0, leave=False)
def step(carry, x):
    return carry + 1, carry + 1

last_number, all_numbers = lax.scan(step, 0, jnp.arange(n))
```

## Why JAX-Tqdm?

JAX functions are [pure](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions),
so side effects such as printing progress when running scans and loops are not allowed.
However, the
[debug module](https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html#exploring-debug-callback)
has primitives for calling Python functions on the host from JAX code. This can be used
to update a Python tqdm progress bar regularly during the computation. JAX-tqdm
implements this for JAX scans and loops and is used by simply adding a decorator to the
body of your update function.

Note that as the tqdm progress bar is only updated 20 times during the scan or loop,
there is no performance penalty.

The code is explained in more detail in this [blog post](https://www.jeremiecoullon.com/2021/01/29/jax_progress_bar/).

## Developers

Dependencies can be installed with [poetry](https://python-poetry.org/) by running

```bash
poetry install
```

### Pre-Commit Hooks

Pre commit hooks can be installed by running

```bash
pre-commit install
```

Pre-commit checks can then be run using

```bash
task lint
```

### Tests

Tests can be run with

```bash
task test
```

            

Raw data

            {
    "_id": null,
    "home_page": "https://github.com/jeremiecoullon/jax-tqdm",
    "name": "jax-tqdm",
    "maintainer": null,
    "docs_url": null,
    "requires_python": "<4.0,>=3.9",
    "maintainer_email": null,
    "keywords": "jax, tqdm",
    "author": "Jeremie Coullon",
    "author_email": "jeremie.coullon@gmail.com",
    "download_url": "https://files.pythonhosted.org/packages/50/9e/0a4ded775300f4e7f49f987cf67f7b650475b7a4433002c5eaebbcae3ce4/jax_tqdm-0.3.1.tar.gz",
    "platform": null,
    "description": "# JAX-Tqdm\n\nAdd a [tqdm](https://github.com/tqdm/tqdm) progress bar to your JAX scans and loops.\n\n## Installation\n\nInstall with pip:\n\n```bash\npip install jax-tqdm\n```\n\n## Example Usage\n\n### In `jax.lax.scan`\n\n```python\nfrom jax_tqdm import scan_tqdm\nfrom jax import lax\nimport jax.numpy as jnp\n\nn = 10_000\n\n@scan_tqdm(n)\ndef step(carry, x):\n    return carry + 1, carry + 1\n\nlast_number, all_numbers = lax.scan(step, 0, jnp.arange(n))\n```\n\n### In `jax.lax.fori_loop`\n\n```python\nfrom jax_tqdm import loop_tqdm\nfrom jax import lax\n\nn = 10_000\n\n@loop_tqdm(n)\ndef step(i, val):\n    return val + 1\n\nlast_number = lax.fori_loop(0, n, step, 0)\n```\n\n### Scans & Loops Inside Vmap\n\nFor scans and loops inside a map, jax-tqdm can print stacked progress bars\nshowing the individual progress of each process. To do this you can wrap\nthe initial value of the loop or scan inside a `PBar` class, along with the\nindex of the progress bar. For example\n\n```python\nfrom jax_tqdm import PBar, scan_tqdm\nimport jax\n\nn = 10_000\n\n@scan_tqdm(n)\ndef step(carry, _):\n    return carry + 1, carry + 1\n\ndef map_func(i):\n    # Wrap the initial value and pass the\n    # progress bar index\n    init = PBar(id=i, carry=0)\n    final_value, _all_numbers = jax.lax.scan(\n        step, init, jax.numpy.arange(n)\n    )\n    return (\n        final_value.carry,\n        _all_numbers,\n    )\n\nlast_numbers, all_numbers = jax.vmap(map_func)(jax.numpy.arange(10))\n```\n\nThe indices of the progress bars should be contiguous integers starting\nfrom 0.\n\n### Print Rate\n\nBy default, the progress bar is updated 20 times over the course of the scan/loop\n(for performance purposes, see [below](#why-jax-tqdm)). This\nupdate rate can be manually controlled with the `print_rate` keyword argument. For\nexample:\n\n```python\nfrom jax_tqdm import scan_tqdm\nfrom jax import lax\nimport jax.numpy as jnp\n\nn = 10_000\n\n@scan_tqdm(n, print_rate=2)\ndef step(carry, x):\n    return carry + 1, carry + 1\n\nlast_number, all_numbers = lax.scan(step, 0, jnp.arange(n))\n```\n\nwill update every other step.\n\n### Progress Bar Type\n\nYou can select the [tqdm](https://github.com/tqdm/tqdm) [submodule](https://github.com/tqdm/tqdm/tree/master?tab=readme-ov-file#submodules) manually with the `tqdm_type` option. The options are `'std'`, `'notebook'`, or `'auto'`.\n```python\nfrom jax_tqdm import scan_tqdm\nfrom jax import lax\nimport jax.numpy as jnp\n\nn = 10_000\n\n@scan_tqdm(n, print_rate=1, tqdm_type='std') # tqdm_type='std' or 'notebook' or 'auto'\ndef step(carry, x):\n    return carry + 1, carry + 1\n\nlast_number, all_numbers = lax.scan(step, 0, jnp.arange(n))\n```\n\n### Progress Bar Options\n\nAny additional keyword arguments are passed to the [tqdm](https://github.com/tqdm/tqdm)\nprogress bar constructor. For example:\n\n```python\nfrom jax_tqdm import scan_tqdm\nfrom jax import lax\nimport jax.numpy as jnp\n\nn = 10_000\n\n@scan_tqdm(n, print_rate=1, desc='progress bar', position=0, leave=False)\ndef step(carry, x):\n    return carry + 1, carry + 1\n\nlast_number, all_numbers = lax.scan(step, 0, jnp.arange(n))\n```\n\n## Why JAX-Tqdm?\n\nJAX functions are [pure](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions),\nso side effects such as printing progress when running scans and loops are not allowed.\nHowever, the\n[debug module](https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html#exploring-debug-callback)\nhas primitives for calling Python functions on the host from JAX code. This can be used\nto update a Python tqdm progress bar regularly during the computation. JAX-tqdm\nimplements this for JAX scans and loops and is used by simply adding a decorator to the\nbody of your update function.\n\nNote that as the tqdm progress bar is only updated 20 times during the scan or loop,\nthere is no performance penalty.\n\nThe code is explained in more detail in this [blog post](https://www.jeremiecoullon.com/2021/01/29/jax_progress_bar/).\n\n## Developers\n\nDependencies can be installed with [poetry](https://python-poetry.org/) by running\n\n```bash\npoetry install\n```\n\n### Pre-Commit Hooks\n\nPre commit hooks can be installed by running\n\n```bash\npre-commit install\n```\n\nPre-commit checks can then be run using\n\n```bash\ntask lint\n```\n\n### Tests\n\nTests can be run with\n\n```bash\ntask test\n```\n",
    "bugtrack_url": null,
    "license": "MIT",
    "summary": "Tqdm progress bar for JAX scans and loops",
    "version": "0.3.1",
    "project_urls": {
        "Homepage": "https://github.com/jeremiecoullon/jax-tqdm",
        "Repository": "https://github.com/jeremiecoullon/jax-tqdm"
    },
    "split_keywords": [
        "jax",
        " tqdm"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "d9c60f22fff746feeae26d0fdc4b29cf3db6495432e7eee5eae21e1368703098",
                "md5": "478657787cd2509c2bb4433974850740",
                "sha256": "a44a650ec150149be1532ee827e9fa34fc3e76dd5df0da13e3e69d912dfd37d5"
            },
            "downloads": -1,
            "filename": "jax_tqdm-0.3.1-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "478657787cd2509c2bb4433974850740",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": "<4.0,>=3.9",
            "size": 5311,
            "upload_time": "2024-10-14T23:39:38",
            "upload_time_iso_8601": "2024-10-14T23:39:38.408938Z",
            "url": "https://files.pythonhosted.org/packages/d9/c6/0f22fff746feeae26d0fdc4b29cf3db6495432e7eee5eae21e1368703098/jax_tqdm-0.3.1-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "509e0a4ded775300f4e7f49f987cf67f7b650475b7a4433002c5eaebbcae3ce4",
                "md5": "c810827ce86d9bb2173fd24e73454a42",
                "sha256": "b76b3ecc334b91f414f52740c28c1aa85bb75eb8d1a0048b5723db3dd4812a1a"
            },
            "downloads": -1,
            "filename": "jax_tqdm-0.3.1.tar.gz",
            "has_sig": false,
            "md5_digest": "c810827ce86d9bb2173fd24e73454a42",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": "<4.0,>=3.9",
            "size": 4856,
            "upload_time": "2024-10-14T23:39:39",
            "upload_time_iso_8601": "2024-10-14T23:39:39.475011Z",
            "url": "https://files.pythonhosted.org/packages/50/9e/0a4ded775300f4e7f49f987cf67f7b650475b7a4433002c5eaebbcae3ce4/jax_tqdm-0.3.1.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2024-10-14 23:39:39",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "jeremiecoullon",
    "github_project": "jax-tqdm",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "lcname": "jax-tqdm"
}
        
Elapsed time: 0.72150s