trainax


Nametrainax JSON
Version 0.0.2 PyPI version JSON
download
home_pageNone
SummaryTraining methodologies for Autoregressive Neural Emulators in JAX built on top of Equinox & Optax.
upload_time2024-10-17 07:29:40
maintainerNone
docs_urlNone
authorFelix Koehler
requires_python~=3.10
licenseNone
keywords jax sciml deep-learning pde neural operator
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            <p align="center">
<b>Learning Methodologies for Autoregressive Neural Emulators.</b>
</p>

<p align="center">
<a href="https://pypi.org/project/trainax/">
  <img src="https://img.shields.io/pypi/v/trainax.svg" alt="PyPI">
</a>
<a href="https://github.com/ceyron/trainax/actions/workflows/test.yml">
  <img src="https://github.com/ceyron/trainax/actions/workflows/test.yml/badge.svg" alt="Tests">
</a>
<a href="https://fkoehler.site/trainax/">
  <img src="https://img.shields.io/badge/docs-latest-green" alt="docs-latest">
</a>
<a href="https://github.com/ceyron/trainax/releases">
  <img src="https://img.shields.io/github/v/release/ceyron/trainax?include_prereleases&label=changelog" alt="Changelog">
</a>
<a href="https://github.com/ceyron/trainax/blob/main/LICENSE.txt">
  <img src="https://img.shields.io/badge/license-MIT-blue" alt="License">
</a>
</p>

<p align="center">
  <a href="#installation">Installation</a> •
  <a href="#quickstart">Quickstart</a> •
  <a href="#background">Background</a> •
  <a href="#features">Features</a> •
  <a href="#a-taxonomy-of-training-methodologies">Taxonomy</a> •
  <a href="#license">License</a>
</p>

<p align="center">
    <img src="https://github.com/user-attachments/assets/99e054ea-cd79-4ba9-853a-9d74e26ce35e" width="400">
</p>

Convenience abstractions using `optax` to train neural networks to
autoregressively emulate time-dependent problems taking care of trajectory
subsampling and offering a wide range of training methodologies (regarding
unrolling length and including differentiable physics).

## Installation

Clone the repository, navigate to the folder and install the package with pip:
```bash
pip install trainax
```

Requires Python 3.10+ and JAX 0.4.13+. 👉 [JAX install guide](https://jax.readthedocs.io/en/latest/installation.html).

## Documentation

The documentation is available at [fkoehler.site/trainax](https://fkoehler.site/trainax/).

## Quickstart

Train a kernel size 2 linear convolution (no bias) to become an emulator for the
1D advection problem.

```python
import jax
import jax.numpy as jnp
import equinox as eqx
import optax  # pip install optax
import trainax as tx

CFL = -0.75

ref_data = tx.sample_data.advection_1d_periodic(
    cfl = CFL,
    key = jax.random.PRNGKey(0),
)

linear_conv_kernel_2 = eqx.nn.Conv1d(
    1, 1, 2,
    padding="SAME", padding_mode="CIRCULAR", use_bias=False,
    key=jax.random.PRNGKey(73)
)

sup_1_trainer, sup_5_trainer, sup_20_trainer = (
    tx.trainer.SupervisedTrainer(
        ref_data,
        num_rollout_steps=r,
        optimizer=optax.adam(1e-2),
        num_training_steps=1000,
        batch_size=32,
    )
    for r in (1, 5, 20)
)

sup_1_conv, sup_1_loss_history = sup_1_trainer(
    linear_conv_kernel_2, key=jax.random.PRNGKey(42)
)
sup_5_conv, sup_5_loss_history = sup_5_trainer(
    linear_conv_kernel_2, key=jax.random.PRNGKey(42)
)
sup_20_conv, sup_20_loss_history = sup_20_trainer(
    linear_conv_kernel_2, key=jax.random.PRNGKey(42)
)

FOU_STENCIL = jnp.array([1+CFL, -CFL])

print(jnp.linalg.norm(sup_1_conv.weight - FOU_STENCIL))   # 0.033
print(jnp.linalg.norm(sup_5_conv.weight - FOU_STENCIL))   # 0.025
print(jnp.linalg.norm(sup_20_conv.weight - FOU_STENCIL))  # 0.017
```

Increasing the supervised unrolling steps during training makes the learned
stencil come closer to the numerical FOU stencil.

## Background

After the discretization of space and time, the simulation of a time-dependent
partial differential equation amounts to the repeated application of a
simulation operator $\mathcal{P}_h$. Here, we are interested in
imitating/emulating this physical/numerical operator with a neural network
$f_\theta$. This repository is concerned with an abstract implementation of all
ways we can frame a learning problem to inject "knowledge" from $\mathcal{P}_h$
into $f_\theta$.

Assume we have a distribution of initial conditions $\mathcal{Q}$ from which we
sample $S$ initial conditions, $u^{[0]} \propto \mathcal{Q}$. Then, we can save
them in an array of shape $(S, C, *N)$ (with C channels and an arbitrary number
of spatial axes of dimension N) and repeatedly apply $\mathcal{P}$ to obtain the
training trajectory of shape $(S, T+1, C, *N)$.

For a one-step supervised learning task, we substack the training trajectory
into windows of size $2$ and merge the two leftover batch axes to get a data
array of shape $(S \cdot T, 2, N)$ that can be used in supervised learning
scenario

$$
L(\theta) = \mathbb{E}_{(u^{[0]}, u^{[1]}) \sim \mathcal{Q}} \left[ l\left( f_\theta(u^{[0]}), u^{[1]} \right) \right]
$$

where $l$ is a **time-level loss**. In the easiest case $l = \text{MSE}$.

`Trainax` supports way more than just one-step supervised learning, e.g., to
train with unrolled steps, to include the reference simulator $\mathcal{P}_h$ in
training, train on residuum conditions instead of resolved reference states, cut
and modify the gradient flow, etc.

## Features

* Wide collection of unrolled training methodologies:
  * Supervised
  * Diverted Chain
  * Mix Chain
  * Residuum
* Based on [JAX](https://github.com/google/jax):
  * One of the best Automatic Differentiation engines (forward & reverse)
  * Automatic vectorization
  * Backend-agnostic code (run on CPU, GPU, and TPU)
* Build on top and compatible with [Equinox](https://github.com/patrick-kidger/equinox)
* Batch-Parallel Training
* Collection of Callbacks
* Composability


<!-- ## A Taxonomy of Training Methodologies

The major axes that need to be chosen are:

* The unrolled length (how often the network is applied autoregressively on the
  input)
* The branch length (how long the reference goes alongside the network; we get
  full supervised if that is as long as the rollout length)
* Whether the physics is resolved (diverted-chain and supervised) or only given
  as a condition (residuum-based loss)

Additional axes are:

* The time level loss (how two states are compared, or a residuum state is reduced)
* The time level weights (if there is network rollout, shall states further away
  from the initial condition be weighted differently (like exponential
  discounting in reinforcement learning))
* If the main chain of network rollout is interleaved with a physics solver (-> mix chain)
* Modifications to the gradient flow:
    * Cutting the backpropagation through time in the main chain (after each
      step, or sparse)
    * Cutting the diverted physics
    * Cutting the one or both levels of the inputs to a residuum function.

### Implementation details

There are three levels of hierarchy:

1. The `loss` submodule defines time-level wise comparisons between two states.
   A state is either a tensor of shape `(num_channels, ...)` (with ellipsis
   indicating an arbitrary number of spatial dim,ensions) or a tensor of shape
   `(num_batches, num_channels, ...)`. The time level loss is implemented for
   the former but allows additional vectorized and (mean-)aggregated on the
   latter. (In the schematic above, the time-level loss is the green circle).
2. The `configuration` submodule devises how neural time stepper $f_\theta$
   (denoted *NN* in the schematic) interplays with the numerical simulator
   $\mathcal{P}_h$. Similar to the time-level loss this is a callable PyTree
   which requires during calling the neural stepper and some data. What this
   data contains depends on the concrete configuration. For supervised rollout
   training it is the batch of (sub-) trajectories to be considered. Other
   configurations might also require the reference stepper or a two consecutive
   time level based residuum function. Each configuration is essentially an
   abstract implementation of the major methodologies (supervised,
   diverted-chain, mix-chain, residuum). The most general diverted chain
   implementation contains supervised and branch-one diverted chain as special
   cases. All configurations allow setting additional constructor arguments to,
   e.g., cut the backpropagation through time (sparsely) or to supply time-level
   weightings (for example to exponentially discount contributions over long
   rollouts).
3. The `training` submodule combines a configuration together with stochastic
   minibatching on a set of reference trajectories. For each configuration,
   there is a corresponding trainer that essentially is sugarcoating around
   combining the relevant configuration with the `GeneralTrainer` and a
   trajectory substacker.

You can find an overview of predictor learning setups
[here](https://fkoehler.site/predictor-learning-setups/). -->

## Acknowledgements

### Citation

This package was developed as part of the `APEBench paper` (accepted at Neurips 2024), we will soon add the citation here.

### Funding

The main author (Felix Koehler) is a PhD student in the group of [Prof. Thuerey at TUM](https://ge.in.tum.de/) and his research is funded by the [Munich Center for Machine Learning](https://mcml.ai/).

### License

MIT, see [here](https://github.com/Ceyron/trainax/blob/main/LICENSE.txt)

---

> [fkoehler.site](https://fkoehler.site/) &nbsp;&middot;&nbsp;
> GitHub [@ceyron](https://github.com/ceyron) &nbsp;&middot;&nbsp;
> X [@felix_m_koehler](https://twitter.com/felix_m_koehler) &nbsp;&middot;&nbsp;
> LinkedIn [Felix Köhler](www.linkedin.com/in/felix-koehler)

            

Raw data

            {
    "_id": null,
    "home_page": null,
    "name": "trainax",
    "maintainer": null,
    "docs_url": null,
    "requires_python": "~=3.10",
    "maintainer_email": null,
    "keywords": "jax, sciml, deep-learning, pde, neural operator",
    "author": "Felix Koehler",
    "author_email": null,
    "download_url": "https://files.pythonhosted.org/packages/95/e4/ba802dff60d93f8701f9910546c848356b34816d67993d990efb0e5a7faf/trainax-0.0.2.tar.gz",
    "platform": null,
    "description": "<p align=\"center\">\n<b>Learning Methodologies for Autoregressive Neural Emulators.</b>\n</p>\n\n<p align=\"center\">\n<a href=\"https://pypi.org/project/trainax/\">\n  <img src=\"https://img.shields.io/pypi/v/trainax.svg\" alt=\"PyPI\">\n</a>\n<a href=\"https://github.com/ceyron/trainax/actions/workflows/test.yml\">\n  <img src=\"https://github.com/ceyron/trainax/actions/workflows/test.yml/badge.svg\" alt=\"Tests\">\n</a>\n<a href=\"https://fkoehler.site/trainax/\">\n  <img src=\"https://img.shields.io/badge/docs-latest-green\" alt=\"docs-latest\">\n</a>\n<a href=\"https://github.com/ceyron/trainax/releases\">\n  <img src=\"https://img.shields.io/github/v/release/ceyron/trainax?include_prereleases&label=changelog\" alt=\"Changelog\">\n</a>\n<a href=\"https://github.com/ceyron/trainax/blob/main/LICENSE.txt\">\n  <img src=\"https://img.shields.io/badge/license-MIT-blue\" alt=\"License\">\n</a>\n</p>\n\n<p align=\"center\">\n  <a href=\"#installation\">Installation</a> \u2022\n  <a href=\"#quickstart\">Quickstart</a> \u2022\n  <a href=\"#background\">Background</a> \u2022\n  <a href=\"#features\">Features</a> \u2022\n  <a href=\"#a-taxonomy-of-training-methodologies\">Taxonomy</a> \u2022\n  <a href=\"#license\">License</a>\n</p>\n\n<p align=\"center\">\n    <img src=\"https://github.com/user-attachments/assets/99e054ea-cd79-4ba9-853a-9d74e26ce35e\" width=\"400\">\n</p>\n\nConvenience abstractions using `optax` to train neural networks to\nautoregressively emulate time-dependent problems taking care of trajectory\nsubsampling and offering a wide range of training methodologies (regarding\nunrolling length and including differentiable physics).\n\n## Installation\n\nClone the repository, navigate to the folder and install the package with pip:\n```bash\npip install trainax\n```\n\nRequires Python 3.10+ and JAX 0.4.13+. \ud83d\udc49 [JAX install guide](https://jax.readthedocs.io/en/latest/installation.html).\n\n## Documentation\n\nThe documentation is available at [fkoehler.site/trainax](https://fkoehler.site/trainax/).\n\n## Quickstart\n\nTrain a kernel size 2 linear convolution (no bias) to become an emulator for the\n1D advection problem.\n\n```python\nimport jax\nimport jax.numpy as jnp\nimport equinox as eqx\nimport optax  # pip install optax\nimport trainax as tx\n\nCFL = -0.75\n\nref_data = tx.sample_data.advection_1d_periodic(\n    cfl = CFL,\n    key = jax.random.PRNGKey(0),\n)\n\nlinear_conv_kernel_2 = eqx.nn.Conv1d(\n    1, 1, 2,\n    padding=\"SAME\", padding_mode=\"CIRCULAR\", use_bias=False,\n    key=jax.random.PRNGKey(73)\n)\n\nsup_1_trainer, sup_5_trainer, sup_20_trainer = (\n    tx.trainer.SupervisedTrainer(\n        ref_data,\n        num_rollout_steps=r,\n        optimizer=optax.adam(1e-2),\n        num_training_steps=1000,\n        batch_size=32,\n    )\n    for r in (1, 5, 20)\n)\n\nsup_1_conv, sup_1_loss_history = sup_1_trainer(\n    linear_conv_kernel_2, key=jax.random.PRNGKey(42)\n)\nsup_5_conv, sup_5_loss_history = sup_5_trainer(\n    linear_conv_kernel_2, key=jax.random.PRNGKey(42)\n)\nsup_20_conv, sup_20_loss_history = sup_20_trainer(\n    linear_conv_kernel_2, key=jax.random.PRNGKey(42)\n)\n\nFOU_STENCIL = jnp.array([1+CFL, -CFL])\n\nprint(jnp.linalg.norm(sup_1_conv.weight - FOU_STENCIL))   # 0.033\nprint(jnp.linalg.norm(sup_5_conv.weight - FOU_STENCIL))   # 0.025\nprint(jnp.linalg.norm(sup_20_conv.weight - FOU_STENCIL))  # 0.017\n```\n\nIncreasing the supervised unrolling steps during training makes the learned\nstencil come closer to the numerical FOU stencil.\n\n## Background\n\nAfter the discretization of space and time, the simulation of a time-dependent\npartial differential equation amounts to the repeated application of a\nsimulation operator $\\mathcal{P}_h$. Here, we are interested in\nimitating/emulating this physical/numerical operator with a neural network\n$f_\\theta$. This repository is concerned with an abstract implementation of all\nways we can frame a learning problem to inject \"knowledge\" from $\\mathcal{P}_h$\ninto $f_\\theta$.\n\nAssume we have a distribution of initial conditions $\\mathcal{Q}$ from which we\nsample $S$ initial conditions, $u^{[0]} \\propto \\mathcal{Q}$. Then, we can save\nthem in an array of shape $(S, C, *N)$ (with C channels and an arbitrary number\nof spatial axes of dimension N) and repeatedly apply $\\mathcal{P}$ to obtain the\ntraining trajectory of shape $(S, T+1, C, *N)$.\n\nFor a one-step supervised learning task, we substack the training trajectory\ninto windows of size $2$ and merge the two leftover batch axes to get a data\narray of shape $(S \\cdot T, 2, N)$ that can be used in supervised learning\nscenario\n\n$$\nL(\\theta) = \\mathbb{E}_{(u^{[0]}, u^{[1]}) \\sim \\mathcal{Q}} \\left[ l\\left( f_\\theta(u^{[0]}), u^{[1]} \\right) \\right]\n$$\n\nwhere $l$ is a **time-level loss**. In the easiest case $l = \\text{MSE}$.\n\n`Trainax` supports way more than just one-step supervised learning, e.g., to\ntrain with unrolled steps, to include the reference simulator $\\mathcal{P}_h$ in\ntraining, train on residuum conditions instead of resolved reference states, cut\nand modify the gradient flow, etc.\n\n## Features\n\n* Wide collection of unrolled training methodologies:\n  * Supervised\n  * Diverted Chain\n  * Mix Chain\n  * Residuum\n* Based on [JAX](https://github.com/google/jax):\n  * One of the best Automatic Differentiation engines (forward & reverse)\n  * Automatic vectorization\n  * Backend-agnostic code (run on CPU, GPU, and TPU)\n* Build on top and compatible with [Equinox](https://github.com/patrick-kidger/equinox)\n* Batch-Parallel Training\n* Collection of Callbacks\n* Composability\n\n\n<!-- ## A Taxonomy of Training Methodologies\n\nThe major axes that need to be chosen are:\n\n* The unrolled length (how often the network is applied autoregressively on the\n  input)\n* The branch length (how long the reference goes alongside the network; we get\n  full supervised if that is as long as the rollout length)\n* Whether the physics is resolved (diverted-chain and supervised) or only given\n  as a condition (residuum-based loss)\n\nAdditional axes are:\n\n* The time level loss (how two states are compared, or a residuum state is reduced)\n* The time level weights (if there is network rollout, shall states further away\n  from the initial condition be weighted differently (like exponential\n  discounting in reinforcement learning))\n* If the main chain of network rollout is interleaved with a physics solver (-> mix chain)\n* Modifications to the gradient flow:\n    * Cutting the backpropagation through time in the main chain (after each\n      step, or sparse)\n    * Cutting the diverted physics\n    * Cutting the one or both levels of the inputs to a residuum function.\n\n### Implementation details\n\nThere are three levels of hierarchy:\n\n1. The `loss` submodule defines time-level wise comparisons between two states.\n   A state is either a tensor of shape `(num_channels, ...)` (with ellipsis\n   indicating an arbitrary number of spatial dim,ensions) or a tensor of shape\n   `(num_batches, num_channels, ...)`. The time level loss is implemented for\n   the former but allows additional vectorized and (mean-)aggregated on the\n   latter. (In the schematic above, the time-level loss is the green circle).\n2. The `configuration` submodule devises how neural time stepper $f_\\theta$\n   (denoted *NN* in the schematic) interplays with the numerical simulator\n   $\\mathcal{P}_h$. Similar to the time-level loss this is a callable PyTree\n   which requires during calling the neural stepper and some data. What this\n   data contains depends on the concrete configuration. For supervised rollout\n   training it is the batch of (sub-) trajectories to be considered. Other\n   configurations might also require the reference stepper or a two consecutive\n   time level based residuum function. Each configuration is essentially an\n   abstract implementation of the major methodologies (supervised,\n   diverted-chain, mix-chain, residuum). The most general diverted chain\n   implementation contains supervised and branch-one diverted chain as special\n   cases. All configurations allow setting additional constructor arguments to,\n   e.g., cut the backpropagation through time (sparsely) or to supply time-level\n   weightings (for example to exponentially discount contributions over long\n   rollouts).\n3. The `training` submodule combines a configuration together with stochastic\n   minibatching on a set of reference trajectories. For each configuration,\n   there is a corresponding trainer that essentially is sugarcoating around\n   combining the relevant configuration with the `GeneralTrainer` and a\n   trajectory substacker.\n\nYou can find an overview of predictor learning setups\n[here](https://fkoehler.site/predictor-learning-setups/). -->\n\n## Acknowledgements\n\n### Citation\n\nThis package was developed as part of the `APEBench paper` (accepted at Neurips 2024), we will soon add the citation here.\n\n### Funding\n\nThe main author (Felix Koehler) is a PhD student in the group of [Prof. Thuerey at TUM](https://ge.in.tum.de/) and his research is funded by the [Munich Center for Machine Learning](https://mcml.ai/).\n\n### License\n\nMIT, see [here](https://github.com/Ceyron/trainax/blob/main/LICENSE.txt)\n\n---\n\n> [fkoehler.site](https://fkoehler.site/) &nbsp;&middot;&nbsp;\n> GitHub [@ceyron](https://github.com/ceyron) &nbsp;&middot;&nbsp;\n> X [@felix_m_koehler](https://twitter.com/felix_m_koehler) &nbsp;&middot;&nbsp;\n> LinkedIn [Felix K\u00f6hler](www.linkedin.com/in/felix-koehler)\n",
    "bugtrack_url": null,
    "license": null,
    "summary": "Training methodologies for Autoregressive Neural Emulators in JAX built on top of Equinox & Optax.",
    "version": "0.0.2",
    "project_urls": {
        "repository": "https://github.com/Ceyron/pdequinox"
    },
    "split_keywords": [
        "jax",
        " sciml",
        " deep-learning",
        " pde",
        " neural operator"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "d3f9747a4d67dcbb0e6925a0902abf9239ba9f80aeb672501db416459e9c77ec",
                "md5": "6b611b99644ff0771c7401af1435d2e1",
                "sha256": "402a798beb17534c61ca383ce354a2325c395e8fe6125603a73cdc85f30b8f0c"
            },
            "downloads": -1,
            "filename": "trainax-0.0.2-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "6b611b99644ff0771c7401af1435d2e1",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": "~=3.10",
            "size": 41795,
            "upload_time": "2024-10-17T07:29:38",
            "upload_time_iso_8601": "2024-10-17T07:29:38.461752Z",
            "url": "https://files.pythonhosted.org/packages/d3/f9/747a4d67dcbb0e6925a0902abf9239ba9f80aeb672501db416459e9c77ec/trainax-0.0.2-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "95e4ba802dff60d93f8701f9910546c848356b34816d67993d990efb0e5a7faf",
                "md5": "0c85bc6358373e8781105e8d1ed9505b",
                "sha256": "3c7eeeb94e351db7ff0b036b1c1fb6f78ddc25ab72d6c1afe69547cbefa70ca8"
            },
            "downloads": -1,
            "filename": "trainax-0.0.2.tar.gz",
            "has_sig": false,
            "md5_digest": "0c85bc6358373e8781105e8d1ed9505b",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": "~=3.10",
            "size": 29148,
            "upload_time": "2024-10-17T07:29:40",
            "upload_time_iso_8601": "2024-10-17T07:29:40.147784Z",
            "url": "https://files.pythonhosted.org/packages/95/e4/ba802dff60d93f8701f9910546c848356b34816d67993d990efb0e5a7faf/trainax-0.0.2.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2024-10-17 07:29:40",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "Ceyron",
    "github_project": "pdequinox",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "lcname": "trainax"
}
        
Elapsed time: 0.97227s