dataclass-array


Namedataclass-array JSON
Version 1.5.2 PyPI version JSON
download
home_page
SummaryDataclasses that behave like numpy arrays (with indexing, slicing, vectorization).
upload_time2024-03-19 15:25:17
maintainer
docs_urlNone
author
requires_python>=3.11
license
keywords dataclass dataclasses numpy jax tensorflow array
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # Dataclass Array

[![Unittests](https://github.com/google-research/dataclass_array/actions/workflows/pytest_and_autopublish.yml/badge.svg)](https://github.com/google-research/dataclass_array/actions/workflows/pytest_and_autopublish.yml)
[![PyPI version](https://badge.fury.io/py/dataclass_array.svg)](https://badge.fury.io/py/dataclass_array)
[![Documentation Status](https://readthedocs.org/projects/dataclass-array/badge/?version=latest)](https://dataclass-array.readthedocs.io/en/latest/?badge=latest)


`DataclassArray` are dataclasses which behave like numpy-like arrays (can be
batched, reshaped, sliced,...), compatible with Jax, TensorFlow, and numpy (with
torch support planned).

This reduce boilerplate and improve readability. See the
[motivating examples](#motivating-examples) section bellow.

To view an example of dataclass arrays used in practice, see
[visu3d](https://github.com/google-research/visu3d).

## Documentation

### Definition

To create a `dca.DataclassArray`, take a frozen dataclass and:

*   Inherit from `dca.DataclassArray`
*   Annotate the fields with `dataclass_array.typing` to specify the inner shape
    and dtype of the array (see below for static or nested dataclass fields).
    The array types are an alias from
    [`etils.array_types`](https://github.com/google/etils/blob/main/etils/array_types/README.md).

```python
import dataclass_array as dca
from dataclass_array.typing import FloatArray


class Ray(dca.DataclassArray):
  pos: FloatArray['*batch_shape 3']
  dir: FloatArray['*batch_shape 3']
```

### Usage

Afterwards, the dataclass can be used as a numpy array:

```python
ray = Ray(pos=jnp.zeros((3, 3)), dir=jnp.eye(3))


ray.shape == (3,)  # 3 rays batched together
ray.pos.shape == (3, 3)  # Individual fields still available

# Numpy slicing/indexing/masking
ray = ray[..., 1:2]
ray = ray[norm(ray.dir) > 1e-7]

# Shape transformation
ray = ray.reshape((1, 3))
ray = ray.reshape('h w -> w h')  # Native einops support
ray = ray.flatten()

# Stack multiple dataclass arrays together
ray = dca.stack([ray0, ray1, ...])

# Supports TF, Jax, Numpy (torch planned) and can be easily converted
ray = ray.as_jax()  # as_np(), as_tf()
ray.xnp == jax.numpy  # `numpy`, `jax.numpy`, `tf.experimental.numpy`

# Compatibility `with jax.tree_util`, `jax.vmap`,..
ray = jax.tree_util.tree_map(lambda x: x+1, ray)
```

A `DataclassArray` has 2 types of fields:

*   Array fields: Fields batched like numpy arrays, with reshape, slicing,...
    Can be `xnp.ndarray` or nested `dca.DataclassArray`.
*   Static fields: Other non-numpy field. Are not modified by reshaping,...
    Static fields are also ignored in `jax.tree_map`.

```python
class MyArray(dca.DataclassArray):
  # Array fields
  a: FloatArray['*batch_shape 3']  # Defined by `etils.array_types`
  b: FloatArray['*batch_shape _ _']  # Dynamic shape
  c: Ray  # Nested DataclassArray (equivalent to `Ray['*batch_shape']`)
  d: Ray['*batch_shape 6']

  # Array fields explicitly defined
  e: Any = dca.field(shape=(3,), dtype=np.float32)
  f: Any = dca.field(shape=(None,  None), dtype=np.float32)  # Dynamic shape
  g: Ray = dca.field(shape=(3,), dtype=Ray)  # Nested DataclassArray

  # Static field (everything not defined as above)
  static0: float
  static1: np.array
```

### Vectorization

`@dca.vectorize_method` allow your dataclass method to automatically support
batching:

1.  Implement method as if `self.shape == ()`
2.  Decorate the method with `dca.vectorize_method`

```python
class Camera(dca.DataclassArray):
  K: FloatArray['*batch_shape 4 4']
  resolution = tuple[int, int]

  @dca.vectorize_method
  def rays(self) -> Ray:
    # Inside `@dca.vectorize_method` shape is always guarantee to be `()`
    assert self.shape == ()
    assert self.K.shape == (4, 4)

    # Compute the ray as if there was only a single camera
    return Ray(pos=..., dir=...)
```

Afterward, we can generate rays for multiple camera batched together:

```python
cams = Camera(K=K)  # K.shape == (num_cams, 4, 4)
rays = cams.rays()  # Generate the rays for all the cameras

cams.shape == (num_cams,)
rays.shape == (num_cams, h, w)
```

`@dca.vectorize_method` is similar to `jax.vmap` but:

*   Only work on `dca.DataclassArray` methods
*   Instead of vectorizing a single axis, `@dca.vectorize_method` will vectorize
    over `*self.shape` (not just `self.shape[0]`). This is like if `vmap` was
    applied to `self.flatten()`
*   When multiple arguments, axis with dimension `1` are broadcasted.

For example, with `__matmul__(self, x: T) -> T`:

```python
() @ (*x,) -> (*x,)
(b,) @ (b, *x) -> (b, *x)
(b,) @ (1, *x) -> (b, *x)
(1,) @ (b, *x) -> (b, *x)
(b, h, w) @ (b, h, w, *x) -> (b, h, w, *x)
(1, h, w) @ (b, 1, 1, *x) -> (b, h, w, *x)
(a, *x) @ (b, *x) -> Error: Incompatible a != b
```

To test on Colab, see the `visu3d` dataclass
[Colab tutorial](https://colab.research.google.com/github/google-research/visu3d/blob/main/docs/dataclass.ipynb).

## Motivating examples

`dca.DataclassArray` improve readability by simplifying common patterns:

*   Reshaping all fields of a dataclass:

    Before (`rays` is simple `dataclass`):

    ```python
    num_rays = math.prod(rays.origins.shape[:-1])
    rays = jax.tree_map(lambda r: r.reshape((num_rays, -1)), rays)
    ```

    After (`rays` is `DataclassArray`):

    ```python
    rays = rays.flatten()  # (b, h, w) -> (b*h*w,)
    ```

*   Rendering a video:

    Before (`cams: list[Camera]`):

    ```python
    img = cams[0].render(scene)
    imgs = np.stack([cam.render(scene) for cam in cams[::2]])
    imgs = np.stack([cam.render(scene) for cam in cams])
    ```

    After (`cams: Camera` with `cams.shape == (num_cams,)`):

    ```python
    img = cams[0].render(scene)  # Render only the first camera (to debug)
    imgs = cams[::2].render(scene)  # Render 1/2 frames (for quicker iteration)
    imgs = cams.render(scene)  # Render all cameras at once
    ```

## Installation

```sh
pip install dataclass_array
```

*This is not an official Google product*


            

Raw data

            {
    "_id": null,
    "home_page": "",
    "name": "dataclass-array",
    "maintainer": "",
    "docs_url": null,
    "requires_python": ">=3.11",
    "maintainer_email": "",
    "keywords": "dataclass,dataclasses,numpy,jax,tensorflow,array",
    "author": "",
    "author_email": "dataclass_array team <dataclass_array@google.com>",
    "download_url": "https://files.pythonhosted.org/packages/fc/4f/02913b0b0c52bf8e4891c85c24b2a121c62117ff1f003d38219941f29b4a/dataclass_array-1.5.2.tar.gz",
    "platform": null,
    "description": "# Dataclass Array\n\n[![Unittests](https://github.com/google-research/dataclass_array/actions/workflows/pytest_and_autopublish.yml/badge.svg)](https://github.com/google-research/dataclass_array/actions/workflows/pytest_and_autopublish.yml)\n[![PyPI version](https://badge.fury.io/py/dataclass_array.svg)](https://badge.fury.io/py/dataclass_array)\n[![Documentation Status](https://readthedocs.org/projects/dataclass-array/badge/?version=latest)](https://dataclass-array.readthedocs.io/en/latest/?badge=latest)\n\n\n`DataclassArray` are dataclasses which behave like numpy-like arrays (can be\nbatched, reshaped, sliced,...), compatible with Jax, TensorFlow, and numpy (with\ntorch support planned).\n\nThis reduce boilerplate and improve readability. See the\n[motivating examples](#motivating-examples) section bellow.\n\nTo view an example of dataclass arrays used in practice, see\n[visu3d](https://github.com/google-research/visu3d).\n\n## Documentation\n\n### Definition\n\nTo create a `dca.DataclassArray`, take a frozen dataclass and:\n\n*   Inherit from `dca.DataclassArray`\n*   Annotate the fields with `dataclass_array.typing` to specify the inner shape\n    and dtype of the array (see below for static or nested dataclass fields).\n    The array types are an alias from\n    [`etils.array_types`](https://github.com/google/etils/blob/main/etils/array_types/README.md).\n\n```python\nimport dataclass_array as dca\nfrom dataclass_array.typing import FloatArray\n\n\nclass Ray(dca.DataclassArray):\n  pos: FloatArray['*batch_shape 3']\n  dir: FloatArray['*batch_shape 3']\n```\n\n### Usage\n\nAfterwards, the dataclass can be used as a numpy array:\n\n```python\nray = Ray(pos=jnp.zeros((3, 3)), dir=jnp.eye(3))\n\n\nray.shape == (3,)  # 3 rays batched together\nray.pos.shape == (3, 3)  # Individual fields still available\n\n# Numpy slicing/indexing/masking\nray = ray[..., 1:2]\nray = ray[norm(ray.dir) > 1e-7]\n\n# Shape transformation\nray = ray.reshape((1, 3))\nray = ray.reshape('h w -> w h')  # Native einops support\nray = ray.flatten()\n\n# Stack multiple dataclass arrays together\nray = dca.stack([ray0, ray1, ...])\n\n# Supports TF, Jax, Numpy (torch planned) and can be easily converted\nray = ray.as_jax()  # as_np(), as_tf()\nray.xnp == jax.numpy  # `numpy`, `jax.numpy`, `tf.experimental.numpy`\n\n# Compatibility `with jax.tree_util`, `jax.vmap`,..\nray = jax.tree_util.tree_map(lambda x: x+1, ray)\n```\n\nA `DataclassArray` has 2 types of fields:\n\n*   Array fields: Fields batched like numpy arrays, with reshape, slicing,...\n    Can be `xnp.ndarray` or nested `dca.DataclassArray`.\n*   Static fields: Other non-numpy field. Are not modified by reshaping,...\n    Static fields are also ignored in `jax.tree_map`.\n\n```python\nclass MyArray(dca.DataclassArray):\n  # Array fields\n  a: FloatArray['*batch_shape 3']  # Defined by `etils.array_types`\n  b: FloatArray['*batch_shape _ _']  # Dynamic shape\n  c: Ray  # Nested DataclassArray (equivalent to `Ray['*batch_shape']`)\n  d: Ray['*batch_shape 6']\n\n  # Array fields explicitly defined\n  e: Any = dca.field(shape=(3,), dtype=np.float32)\n  f: Any = dca.field(shape=(None,  None), dtype=np.float32)  # Dynamic shape\n  g: Ray = dca.field(shape=(3,), dtype=Ray)  # Nested DataclassArray\n\n  # Static field (everything not defined as above)\n  static0: float\n  static1: np.array\n```\n\n### Vectorization\n\n`@dca.vectorize_method` allow your dataclass method to automatically support\nbatching:\n\n1.  Implement method as if `self.shape == ()`\n2.  Decorate the method with `dca.vectorize_method`\n\n```python\nclass Camera(dca.DataclassArray):\n  K: FloatArray['*batch_shape 4 4']\n  resolution = tuple[int, int]\n\n  @dca.vectorize_method\n  def rays(self) -> Ray:\n    # Inside `@dca.vectorize_method` shape is always guarantee to be `()`\n    assert self.shape == ()\n    assert self.K.shape == (4, 4)\n\n    # Compute the ray as if there was only a single camera\n    return Ray(pos=..., dir=...)\n```\n\nAfterward, we can generate rays for multiple camera batched together:\n\n```python\ncams = Camera(K=K)  # K.shape == (num_cams, 4, 4)\nrays = cams.rays()  # Generate the rays for all the cameras\n\ncams.shape == (num_cams,)\nrays.shape == (num_cams, h, w)\n```\n\n`@dca.vectorize_method` is similar to `jax.vmap` but:\n\n*   Only work on `dca.DataclassArray` methods\n*   Instead of vectorizing a single axis, `@dca.vectorize_method` will vectorize\n    over `*self.shape` (not just `self.shape[0]`). This is like if `vmap` was\n    applied to `self.flatten()`\n*   When multiple arguments, axis with dimension `1` are broadcasted.\n\nFor example, with `__matmul__(self, x: T) -> T`:\n\n```python\n() @ (*x,) -> (*x,)\n(b,) @ (b, *x) -> (b, *x)\n(b,) @ (1, *x) -> (b, *x)\n(1,) @ (b, *x) -> (b, *x)\n(b, h, w) @ (b, h, w, *x) -> (b, h, w, *x)\n(1, h, w) @ (b, 1, 1, *x) -> (b, h, w, *x)\n(a, *x) @ (b, *x) -> Error: Incompatible a != b\n```\n\nTo test on Colab, see the `visu3d` dataclass\n[Colab tutorial](https://colab.research.google.com/github/google-research/visu3d/blob/main/docs/dataclass.ipynb).\n\n## Motivating examples\n\n`dca.DataclassArray` improve readability by simplifying common patterns:\n\n*   Reshaping all fields of a dataclass:\n\n    Before (`rays` is simple `dataclass`):\n\n    ```python\n    num_rays = math.prod(rays.origins.shape[:-1])\n    rays = jax.tree_map(lambda r: r.reshape((num_rays, -1)), rays)\n    ```\n\n    After (`rays` is `DataclassArray`):\n\n    ```python\n    rays = rays.flatten()  # (b, h, w) -> (b*h*w,)\n    ```\n\n*   Rendering a video:\n\n    Before (`cams: list[Camera]`):\n\n    ```python\n    img = cams[0].render(scene)\n    imgs = np.stack([cam.render(scene) for cam in cams[::2]])\n    imgs = np.stack([cam.render(scene) for cam in cams])\n    ```\n\n    After (`cams: Camera` with `cams.shape == (num_cams,)`):\n\n    ```python\n    img = cams[0].render(scene)  # Render only the first camera (to debug)\n    imgs = cams[::2].render(scene)  # Render 1/2 frames (for quicker iteration)\n    imgs = cams.render(scene)  # Render all cameras at once\n    ```\n\n## Installation\n\n```sh\npip install dataclass_array\n```\n\n*This is not an official Google product*\n\n",
    "bugtrack_url": null,
    "license": "",
    "summary": "Dataclasses that behave like numpy arrays (with indexing, slicing, vectorization).",
    "version": "1.5.2",
    "project_urls": {
        "changelog": "https://github.com/google-research/dataclass_array/blob/main/CHANGELOG.md",
        "documentation": "https://dataclass-array.readthedocs.io",
        "homepage": "https://github.com/google-research/dataclass_array",
        "repository": "https://github.com/google-research/dataclass_array"
    },
    "split_keywords": [
        "dataclass",
        "dataclasses",
        "numpy",
        "jax",
        "tensorflow",
        "array"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "61b4eb6273672d493fd169ed62918ff8c13af38549e1281087d2a757af0bb918",
                "md5": "ba478d006f797b81c3a95890c5059955",
                "sha256": "9394b0c31a9dff7f4210151cf98a7ea56d45965baefb22354475ec5dd5e6b5ed"
            },
            "downloads": -1,
            "filename": "dataclass_array-1.5.2-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "ba478d006f797b81c3a95890c5059955",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.11",
            "size": 43638,
            "upload_time": "2024-03-19T15:25:13",
            "upload_time_iso_8601": "2024-03-19T15:25:13.147270Z",
            "url": "https://files.pythonhosted.org/packages/61/b4/eb6273672d493fd169ed62918ff8c13af38549e1281087d2a757af0bb918/dataclass_array-1.5.2-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "fc4f02913b0b0c52bf8e4891c85c24b2a121c62117ff1f003d38219941f29b4a",
                "md5": "b9f0ceb4f818485d36817a0a57f2414b",
                "sha256": "39343847138c9c4aced96fb4b31dea48b7f2f73b257b01282a5cba6fd8107b94"
            },
            "downloads": -1,
            "filename": "dataclass_array-1.5.2.tar.gz",
            "has_sig": false,
            "md5_digest": "b9f0ceb4f818485d36817a0a57f2414b",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.11",
            "size": 34785,
            "upload_time": "2024-03-19T15:25:17",
            "upload_time_iso_8601": "2024-03-19T15:25:17.482164Z",
            "url": "https://files.pythonhosted.org/packages/fc/4f/02913b0b0c52bf8e4891c85c24b2a121c62117ff1f003d38219941f29b4a/dataclass_array-1.5.2.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2024-03-19 15:25:17",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "google-research",
    "github_project": "dataclass_array",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "lcname": "dataclass-array"
}
        
Elapsed time: 2.42108s