jax-smi


Namejax-smi JSON
Version 1.0.3 PyPI version JSON
download
home_pagehttps://github.com/ayaka14732/jax-smi
SummaryJAX Synergistic Memory Inspector
upload_time2023-03-10 04:02:02
maintainer
docs_urlNone
authorAyaka Mikazuki
requires_python>=3.8, <4
license
keywords jax machine-learning
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # JAX Synergistic Memory Inspector

![](https://raw.githubusercontent.com/ayaka14732/jax-smi/main/demo/1.gif)

`jax-smi` is a tool for real-time inspection of the memory usage of a JAX process. It is similar to `nvidia-smi` for GPU, but works on multiple platforms including CPU, GPU and TPU.

On TPU platforms, `jax-smi` is the only way to monitor TPU memory usage. On GPU platforms, `jax-smi` is also preferable to `nvidia-smi`. The latter is unable to report real-time memory usage of JAX processes, as JAX always [pre-allocates 90% of the GPU memory](https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html) by default.

This project is supported by Cloud TPUs from Google's [TPU Research Cloud](https://sites.research.google/trc/about/) (TRC).

## Installation

Install `go`. On Ubuntu, this is usually done by:

```sh
sudo apt-get install golang
```

If you followed [tpu-starter](https://github.com/ayaka14732/tpu-starter) to set up the TPU environment, `go` should be already installed.

Then install `jax-smi` with:

```sh
pip install jax-smi
```

## Usage

In your JAX script:

```python
from jax_smi import initialise_tracking
initialise_tracking()
# some computation...
```

Open a shell and run:

```sh
jax-smi
```

## Approach

Save the memory profile to `/dev/shm/memory.prof` in a separate thread every 1 second using [`jax.profiler.save_device_memory_profile()`](https://jax.readthedocs.io/en/latest/_autosummary/jax.profiler.save_device_memory_profile.html).

Inspect the memory profile with `go tool pprof -tags /dev/shm/memory.prof`.

See <https://twitter.com/ayaka14732/status/1565013139594551296> for more details.

## Limitations

Tracing can only be performed by one process at a time. If tracing is performed by multiple JAX processes, they will write the memory profiles to the same file, which will lead to conflicts.

The `jax-smi` command line tool cannot detect if a memory profile file is out of date. Therefore, even if no JAX process is running, the tool will still read the outdated memory profile and report outdated memory usage information.



            

Raw data

            {
    "_id": null,
    "home_page": "https://github.com/ayaka14732/jax-smi",
    "name": "jax-smi",
    "maintainer": "",
    "docs_url": null,
    "requires_python": ">=3.8, <4",
    "maintainer_email": "",
    "keywords": "jax machine-learning",
    "author": "Ayaka Mikazuki",
    "author_email": "ayaka@mail.shn.hk",
    "download_url": "https://files.pythonhosted.org/packages/ba/a9/7378450ed74788510ead58617db284de93e77d2d0efa4dfbe8d429071a1d/jax-smi-1.0.3.tar.gz",
    "platform": null,
    "description": "# JAX Synergistic Memory Inspector\n\n![](https://raw.githubusercontent.com/ayaka14732/jax-smi/main/demo/1.gif)\n\n`jax-smi` is a tool for real-time inspection of the memory usage of a JAX process. It is similar to `nvidia-smi` for GPU, but works on multiple platforms including CPU, GPU and TPU.\n\nOn TPU platforms, `jax-smi` is the only way to monitor TPU memory usage. On GPU platforms, `jax-smi` is also preferable to `nvidia-smi`. The latter is unable to report real-time memory usage of JAX processes, as JAX always [pre-allocates 90% of the GPU memory](https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html) by default.\n\nThis project is supported by Cloud TPUs from Google's [TPU Research Cloud](https://sites.research.google/trc/about/) (TRC).\n\n## Installation\n\nInstall `go`. On Ubuntu, this is usually done by:\n\n```sh\nsudo apt-get install golang\n```\n\nIf you followed [tpu-starter](https://github.com/ayaka14732/tpu-starter) to set up the TPU environment, `go` should be already installed.\n\nThen install `jax-smi` with:\n\n```sh\npip install jax-smi\n```\n\n## Usage\n\nIn your JAX script:\n\n```python\nfrom jax_smi import initialise_tracking\ninitialise_tracking()\n# some computation...\n```\n\nOpen a shell and run:\n\n```sh\njax-smi\n```\n\n## Approach\n\nSave the memory profile to `/dev/shm/memory.prof` in a separate thread every 1 second using [`jax.profiler.save_device_memory_profile()`](https://jax.readthedocs.io/en/latest/_autosummary/jax.profiler.save_device_memory_profile.html).\n\nInspect the memory profile with `go tool pprof -tags /dev/shm/memory.prof`.\n\nSee <https://twitter.com/ayaka14732/status/1565013139594551296> for more details.\n\n## Limitations\n\nTracing can only be performed by one process at a time. If tracing is performed by multiple JAX processes, they will write the memory profiles to the same file, which will lead to conflicts.\n\nThe `jax-smi` command line tool cannot detect if a memory profile file is out of date. Therefore, even if no JAX process is running, the tool will still read the outdated memory profile and report outdated memory usage information.\n\n\n",
    "bugtrack_url": null,
    "license": "",
    "summary": "JAX Synergistic Memory Inspector",
    "version": "1.0.3",
    "split_keywords": [
        "jax",
        "machine-learning"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "baa3e202e99884b37d696c18745d5cffd071f71c7cda86d81a8d30c154be4472",
                "md5": "758b338092d6e08ada42f1b4b0fc517d",
                "sha256": "b443af72561dd8fbfe320f67bd0787f49fa2d462be68cddb84507e610ea4c857"
            },
            "downloads": -1,
            "filename": "jax_smi-1.0.3-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "758b338092d6e08ada42f1b4b0fc517d",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.8, <4",
            "size": 6726,
            "upload_time": "2023-03-10T04:02:01",
            "upload_time_iso_8601": "2023-03-10T04:02:01.583571Z",
            "url": "https://files.pythonhosted.org/packages/ba/a3/e202e99884b37d696c18745d5cffd071f71c7cda86d81a8d30c154be4472/jax_smi-1.0.3-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "baa97378450ed74788510ead58617db284de93e77d2d0efa4dfbe8d429071a1d",
                "md5": "2e1604536d8ef2737252f3420f774331",
                "sha256": "c64f6cea7864e40399f23cd0153817eb76dae443e6aa7535326d90c7268d806c"
            },
            "downloads": -1,
            "filename": "jax-smi-1.0.3.tar.gz",
            "has_sig": false,
            "md5_digest": "2e1604536d8ef2737252f3420f774331",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.8, <4",
            "size": 6264,
            "upload_time": "2023-03-10T04:02:02",
            "upload_time_iso_8601": "2023-03-10T04:02:02.667776Z",
            "url": "https://files.pythonhosted.org/packages/ba/a9/7378450ed74788510ead58617db284de93e77d2d0efa4dfbe8d429071a1d/jax-smi-1.0.3.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2023-03-10 04:02:02",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "github_user": "ayaka14732",
    "github_project": "jax-smi",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "lcname": "jax-smi"
}
        
Elapsed time: 0.04269s