jax-md


Namejax-md JSON
Version 0.2.8 PyPI version JSON
download
home_pagehttps://github.com/google/jax-md
SummaryDifferentiable, Hardware Accelerated, Molecular Dynamics
upload_time2023-08-09 23:18:25
maintainer
docs_urlNone
authorGoogle
requires_python>=2.7
licenseApache 2.0
keywords
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI
coveralls test coverage No coveralls.
            # JAX, M.D.

### Accelerated, Differentiable, Molecular Dynamics
[**Quickstart**](#getting-started) | [**Reference docs**](https://jax-md.readthedocs.io/en/main/) | [**Paper**](https://arxiv.org/pdf/1912.04232.pdf) | [**NeurIPS 2020**](https://neurips.cc/virtual/2020/public/poster_83d3d4b6c9579515e1679aca8cbc8033.html)

![Build Status](https://github.com/google/jax-md/workflows/Build/badge.svg?branch=main) [![Coverage](https://codecov.io/gh/google/jax-md/branch/main/graph/badge.svg?token=JYQpbNyICv)](https://codecov.io/gh/google/jax-md) [![PyPI](https://img.shields.io/pypi/v/jax-md)](https://pypi.org/project/jax-md/) [![PyPI - License](https://img.shields.io/pypi/l/jax_md)](https://github.com/google/jax-md/blob/main/LICENSE)

Molecular dynamics is a workhorse of modern computational condensed matter physics. It is frequently used to simulate materials to observe how small scale interactions can give rise to complex large-scale phenomenology. Most molecular dynamics packages (e.g. HOOMD Blue or LAMMPS) are complicated, specialized pieces of code that are many thousands of lines long. They typically involve significant code duplication to allow for running simulations on CPU and GPU. Additionally, large amounts of code is often devoted to taking derivatives of quantities to compute functions of interest (e.g. gradients of energies to compute forces).

However, recent work in machine learning has led to significant software developments that might make it possible to write more concise molecular dynamics simulations that offer a range of benefits. Here we target JAX, which allows us to write python code that gets compiled to XLA and allows us to run on CPU, GPU, or TPU. Moreover, JAX allows us to take derivatives of python code. Thus, not only is this molecular dynamics simulation automatically hardware accelerated, it is also __end-to-end__ differentiable. This should allow for some interesting experiments that we're excited to explore.

JAX, MD is a research project that is currently under development. Expect sharp edges and possibly some API breaking changes as we continue to support a broader set of simulations. JAX MD is a functional and data driven library. Data is stored in arrays or tuples of arrays and functions transform data from one state to another.

### Getting Started

For a video introducing JAX MD along with a [demo](https://colab.research.google.com/github/google/jax-md/blob/main/notebooks/talk_demo.ipynb), check out this talk from the Physics meets Machine Learning series:

[![Science Meets ML Talk](https://img.youtube.com/vi/Bkm8tGET7-w/0.jpg)](https://www.youtube.com/watch?v=Bkm8tGET7-w)

To get started playing around with JAX MD check out the following colab notebooks on Google Cloud without needing to install anything. For a very simple introduction, I would recommend the Minimization example. For an example of a bunch of the features of JAX MD, check out the JAX MD cookbook.

- [JAX MD Cookbook](https://colab.research.google.com/github/google/jax-md/blob/main/notebooks/jax_md_cookbook.ipynb)
- [Minimization](https://colab.research.google.com/github/google/jax-md/blob/main/notebooks/minimization.ipynb)
- [NVE Simulation](https://colab.research.google.com/github/google/jax-md/blob/main/notebooks/nve_simulation.ipynb)
- [NVT Simulation](https://colab.research.google.com/github/google/jax-md/blob/main/notebooks/nvt_simulation.ipynb)
- [NPT Simulation](https://colab.research.google.com/github/google/jax-md/blob/main/notebooks/npt_simulation.ipynb)
- [NVE with Neighbor Lists](https://colab.research.google.com/github/google/jax-md/blob/main/notebooks/nve_neighbor_list.ipynb)
- [Custom Potentials](https://colab.research.google.com/github/google/jax-md/blob/main/notebooks/customizing_potentials_cookbook.ipynb)
- [Neural Network Potentials](https://colab.research.google.com/github/google/jax-md/blob/main/notebooks/neural_networks.ipynb)
- [Flocking](https://colab.research.google.com/github/google/jax-md/blob/main/notebooks/flocking.ipynb)
- [Meta Optimization](https://colab.research.google.com/github/google/jax-md/blob/main/notebooks/meta_optimization.ipynb)
- [Swap Monte Carlo (Cargese Summer School)](https://colab.research.google.com/github/google/jax-md/blob/main/notebooks/cargese_swap_mc.ipynb)
- [Implicit Differentiation](https://colab.research.google.com/github/google/jax-md/blob/main/notebooks/implicit_differentiation.ipynb)
- [Athermal Linear Elasticity](https://colab.research.google.com/github/google/jax-md/blob/main/notebooks/athermal_linear_elasticity.ipynb)
- [Smash a Sand Castle](https://colab.research.google.com/github/google/jax-md/blob/main/notebooks/sand_castle.ipynb)

You can install JAX MD locally with pip,
```
pip install jax-md --upgrade
```
If you want to build the latest version then you can grab the most recent version from head,
```
git clone https://github.com/google/jax-md
pip install -e jax-md
```

# Overview

We now summarize the main components of the library.

## Spaces ([`space.py`](https://jax-md.readthedocs.io/en/main/jax_md.space.html))

In general we must have a way of computing the pairwise distance between atoms. We must also have efficient strategies for moving atoms in some space that may or may not be globally isomorphic to R^N. For example, periodic boundary conditions are commonplace in simulations and must be respected. Spaces are defined as a pair of functions, `(displacement_fn, shift_fn)`. Given two points `displacement_fn(R_1, R_2)` computes the displacement vector between the two points. If you would like to compute displacement vectors between all pairs of points in a given `(N, dim)` matrix the function `space.map_product` appropriately vectorizes `displacement_fn`. It is often useful to define a metric instead of a displacement function in which case you can use the helper function `space.metric` to convert a displacement function to a metric function. Given a point and a shift `shift_fn(R, dR)` displaces the point `R` by an amount `dR`.

The following spaces are currently supported:
- [`space.free()`](https://jax-md.readthedocs.io/en/main/jax_md.space.html?highlight=free#jax_md.space.free) specifies a space with free boundary conditions.
- [`space.periodic(box_size)`](https://jax-md.readthedocs.io/en/main/jax_md.space.html?highlight=periodic#jax_md.space.periodic) specifies a space with periodic boundary conditions of side length `box_size`.
- [`space.periodic_general(box)`](https://jax-md.readthedocs.io/en/main/jax_md.space.html?highlight=periodic_general#jax_md.space.periodic_general) specifies a space as a periodic parellelopiped formed by transforming the unit cube by an affine transformation `box`.

Example:

```python
from jax_md import space
box_size = 25.0
displacement_fn, shift_fn = space.periodic(box_size)
```

## Potential Energy ([`energy.py`](https://jax-md.readthedocs.io/en/main/jax_md.energy.html))

In the simplest case, molecular dynamics calculations are often based on a pair potential that is defined by a user. This then is used to compute a total energy whose negative gradient gives forces. One of the very nice things about JAX is that we get forces for free! The second part of the code is devoted to computing energies.

We provide the following classical potentials:
- [`energy.soft_sphere`](https://jax-md.readthedocs.io/en/main/jax_md.energy.html?highlight=soft_sphere#jax_md.energy.soft_sphere) a soft sphere whose energy increases as the overlap of the spheres to some power, `alpha`.
- [`energy.lennard_jones`](https://jax-md.readthedocs.io/en/main/jax_md.energy.html?highlight=lennard_jones#jax_md.energy.lennard_jones) a standard 12-6 Lennard-Jones potential.
- [`energy.morse`](https://jax-md.readthedocs.io/en/main/jax_md.energy.html?highlight=morse#jax_md.energy.morse) a morse potential.
- [`energy.tersoff`](https://jax-md.readthedocs.io/en/main/jax_md.energy.html?highlight=tersoff#jax_md.energy.tersoff) the Tersoff potential for simulating semiconducting materials. Can load parameters from LAMMPS files.
- [`energy.eam`](https://jax-md.readthedocs.io/en/main/jax_md.energy.html?highlight=eam#jax_md.energy.eam) embedded atom model potential with ability to load parameters from LAMMPS files.
- [`energy.stillinger_weber`](https://jax-md.readthedocs.io/en/main/jax_md.energy.html?highlight=stillinger_weber#jax_md.energy.stillinger_weber) used to model Silicon-like systems.
- [`energy.bks`](https://jax-md.readthedocs.io/en/main/jax_md.energy.html?highlight=bks#jax_md.energy.bks) Beest-Kramer-van Santen potential used to model silica.
- [`energy.gupta`](https://jax-md.readthedocs.io/en/main/jax_md.energy.html?highlight=gupta#jax_md.energy.gupta) used to model gold nanoclusters.

We also provide the following neural network potentials:
- [`energy.behler_parrinello`](https://jax-md.readthedocs.io/en/main/jax_md.energy.html?highlight=behler_parrinello#jax_md.energy.behler_parrinello) a widely used fixed-feature neural network architecture for molecular systems.
- [`energy.graph_network`](https://jax-md.readthedocs.io/en/main/jax_md.energy.html?highlight=graph_network#jax_md.energy.graph_network) a deep graph neural network designed for energy fitting.

For finite-ranged potentials it is often useful to consider only interactions within a certain neighborhood. We include the `_neighbor_list` modifier to the above potentials that uses a list of neighbors (see below) for optimization.

Example:

```python
import jax.numpy as np
from jax import random
from jax_md import energy, quantity
N = 1000
spatial_dimension = 2
key = random.PRNGKey(0)
R = random.uniform(key, (N, spatial_dimension), minval=0.0, maxval=1.0)
energy_fn = energy.lennard_jones_pair(displacement_fn)
print('E = {}'.format(energy_fn(R)))
force_fn = quantity.force(energy_fn)
print('Total Squared Force = {}'.format(np.sum(force_fn(R) ** 2)))
```

## Dynamics ([`simulate.py`](https://jax-md.readthedocs.io/en/main/jax_md.simulate.html), [`minimize.py`](https://jax-md.readthedocs.io/en/main/jax_md.minimize.html))

Given an energy function and a system, there are a number of dynamics are useful to simulate. The simulation code is based on the structure of the optimizers found in JAX. In particular, each simulation function returns an initialization function and an update function. The initialization function takes a set of positions and creates the necessary dynamical state variables. The update function does a single step of dynamics to the dynamical state variables and returns an updated state.

We include a several different kinds of dynamics. However, there is certainly room to add more for e.g. constant strain simulations.

It is often desirable to find an energy minimum of the system. We provide two methods to do this. We provide simple gradient descent minimization. This is mostly for pedagogical purposes, since it often performs poorly. We additionally include the FIRE algorithm which often sees significantly faster convergence. Moreover a common experiment to run in the context of molecular dynamics is to simulate a system with a fixed volume and temperature.

We provide the following dynamics:
- [`simulate.nve`](https://jax-md.readthedocs.io/en/main/jax_md.simulate.html?highlight=nve#jax_md.simulate.nve) Constant energy simulation; numerically integrates Newton's laws directly.
- [`simulate.nvt_nose_hoover`](https://jax-md.readthedocs.io/en/main/jax_md.simulate.html?highlight=nvt_nose_hoover#jax_md.simulate.nvt_nose_hoover) Uses Nose-Hoover chain to simulate a constant temperature system.
- [`simulate.npt_nose_hoover`](https://jax-md.readthedocs.io/en/main/jax_md.simulate.html?highlight=nvp_nose_hoover#jax_md.simulate.nvp_nose_hoover) Uses Nose-Hoover chain to simulate a system at constant pressure and temperature.
- [`simulate.nvt_langevin`](https://jax-md.readthedocs.io/en/main/jax_md.simulate.html?highlight=nvt_langevin#jax_md.simulate.nvt_langevin) Simulates a system by numerically integrating the Langevin stochastic differential equation.
- [`simulate.hybrid_swap_mc`](https://jax-md.readthedocs.io/en/main/jax_md.simulate.html?highlight=hybrid_swap_mc#jax_md.simulate.hybrid_swap_mc) Alternates NVT dynamics with Monte-Carlo swapping moves to generate low energy glasses.
- [`simulate.brownian`](https://jax-md.readthedocs.io/en/main/jax_md.simulate.html?highlight=brownian#jax_md.simulate.brownian) Simulates brownian motion.
- [`minimize.gradient_descent`](https://jax-md.readthedocs.io/en/main/jax_md.minimize.html?highlight=gradient_descent#jax_md.minimize.gradient_descent) Minimizes a system using gradient descent.
- [`minimize.fire_descent`](https://jax-md.readthedocs.io/en/main/jax_md.minimize.html?highlight=fire_descent#jax_md.minimize.fire_descent) Minimizes a system using the fast inertial relaxation engine.

Example:

```python
from jax_md import simulate
temperature = 1.0
dt = 1e-3
init, update = simulate.nvt_nose_hoover(energy_fn, shift_fn, dt, temperature)
state = init(key, R)
for _ in range(100):
  state = update(state)
R = state.position
```

## Spatial Partitioning ([`partition.py`](https://jax-md.readthedocs.io/en/main/jax_md.partition.html))

In many applications, it is useful to construct spatial partitions of particles / objects in a simulation.

We provide the following methods:
- [`partition.cell_list`](https://jax-md.readthedocs.io/en/main/jax_md.partition.html?highlight=cell_list#jax_md.partition.cell_list) Partitions objects (and metadata) into a grid of cells.
- [`partition.neighbor_list`](https://jax-md.readthedocs.io/en/main/jax_md.partition.html?highlight=neighbor_list#jax_md.partition.neighbor_list) Constructs a set of neighbors within some cutoff distance for each object in a simulation.

Cell List Example:
```python
from jax_md import partition

cell_size = 5.0
capacity = 10
cell_list_fn = partition.cell_list(box_size, cell_size, capacity)
cell_list_data = cell_list_fn.allocate(R)
```

Neighbor List Example:
```python
from jax_md import partition

neighbor_list_fn = partition.neighbor_list(displacement_fn, box_size, cell_size)
neighbors = neighbor_list_fn.allocate(R) # Create a new neighbor list.

# Do some simulating....

neighbors = neighbors.update(R)  # Update the neighbor list without resizing.
if neighbors.did_buffer_overflow:  # Couldn't fit all the neighbors into the list.
  neighbors = neighbor_list_fn.allocate(R)  # So create a new neighbor list.
```

There are three different formats of neighbor list supported: `Dense`, `Sparse`, and `OrderedSparse`. `Dense` neighbor lists store neighbors in an `(particle_count, neighbors_per_particle)` array, `Sparse` neighbor lists store neighbors in a `(2, total_neighbors)` array of pairs, `OrderedSparse` neighbor lists are like `Sparse` neighbor lists, but they only store pairs such that `i < j`.

# Development

JAX MD is under active development. We have very limited development resources and so we typically focus on adding features that will have high impact to researchers using JAX MD (including us). Please don't hesitate to open feature requests to help us guide development. We more than welcome contributions!

## Technical gotchas

### GPU

You must follow [JAX's](https://www.github.com/google/jax/) GPU installation instructions to enable GPU support.


### 64-bit precision
To enable 64-bit precision, set the respective JAX flag _before_ importing `jax_md` (see the JAX [guide](https://colab.research.google.com/github/google/jax/blob/main/notebooks/Common_Gotchas_in_JAX.ipynb#scrollTo=YTktlwTTMgFl)), for example:

```python
from jax.config import config
config.update("jax_enable_x64", True)
```

# Publications

JAX MD has been used in the following publications. If you don't see your paper on the list, but you used JAX MD let us know and we'll add it to the list!

1. [A Differentiable Neural-Network Force Field for Ionic Liquids. (J. Chem. Inf. Model. 2022)](https://pubs.acs.org/doi/abs/10.1021/acs.jcim.1c01380)<br> H. Montes-Campos, J. Carrete, S. Bichelmaier, L. M. Varela, and G. K. H. Madsen
2. [Correlation Tracking: Using simulations to interpolate highly correlated particle tracks. (Phys. Rev. E. 2022)](https://journals.aps.org/pre/abstract/10.1103/PhysRevE.105.044608?ft=1)<br> E. M. King, Z. Wang, D. A. Weitz, F. Spaepen, and M. P. Brenner
3. [Optimal Control of Nonequilibrium Systems Through Automatic Differentiation.](https://arxiv.org/abs/2201.00098)<br> M. C. Engel, J. A. Smith, and M. P. Brenner
4. [Graph Neural Networks Accelerated Molecular Dynamics. (J. Chem. Phys. 2022)](https://aip.scitation.org/doi/10.1063/5.0083060)<br> Z. Li, K. Meidani, P. Yadav, and A. B. Farimani
5. [Gradients are Not All You Need.](https://arxiv.org/abs/2111.05803)<br> L. Metz, C. D. Freeman, S. S. Schoenholz, and T. Kachman
6. [Lagrangian Neural Network with Differential Symmetries and Relational Inductive Bias.](https://arxiv.org/abs/2110.03266)<br> R. Bhattoo, S. Ranu, and N. M. A. Krishnan
7. [Efficient and Modular Implicit Differentiation.](https://arxiv.org/abs/2105.15183)<br> M. Blondel, Q. Berthet, M. Cuturi, R. Frostig, S. Hoyer, F. Llinares-López, F. Pedregosa, and J.-P. Vert
8. [Learning neural network potentials from experimental data via Differentiable Trajectory Reweighting.<br>(Nature Communications 2021)](https://www.nature.com/articles/s41467-021-27241-4)<br> S. Thaler and J. Zavadlav
9. [Learn2Hop: Learned Optimization on Rough Landscapes. (ICML 2021)](http://proceedings.mlr.press/v139/merchant21a.html)<br> A. Merchant, L. Metz, S. S. Schoenholz, and E. D. Cubuk
10. [Designing self-assembling kinetics with differentiable statistical physics models. (PNAS 2021)](https://www.pnas.org/content/118/10/e2024083118.short)<br> C. P. Goodrich, E. M. King, S. S. Schoenholz, E. D. Cubuk, and  M. P. Brenner

# Citation

If you use the code in a publication, please cite the repo using the .bib,

```
@inproceedings{jaxmd2020,
 author = {Schoenholz, Samuel S. and Cubuk, Ekin D.},
 booktitle = {Advances in Neural Information Processing Systems},
 publisher = {Curran Associates, Inc.},
 title = {JAX M.D. A Framework for Differentiable Physics},
 url = {https://papers.nips.cc/paper/2020/file/83d3d4b6c9579515e1679aca8cbc8033-Paper.pdf},
 volume = {33},
 year = {2020}
}
```

            

Raw data

            {
    "_id": null,
    "home_page": "https://github.com/google/jax-md",
    "name": "jax-md",
    "maintainer": "",
    "docs_url": null,
    "requires_python": ">=2.7",
    "maintainer_email": "",
    "keywords": "",
    "author": "Google",
    "author_email": "jax-md-dev@google.com",
    "download_url": "https://files.pythonhosted.org/packages/33/d6/f1cc8f8d13c16aacb0ab11a5042d2f2a06f2ba599fd0b8c062b964ba2d11/jax-md-0.2.8.tar.gz",
    "platform": null,
    "description": "# JAX, M.D.\n\n### Accelerated, Differentiable, Molecular Dynamics\n[**Quickstart**](#getting-started) | [**Reference docs**](https://jax-md.readthedocs.io/en/main/) | [**Paper**](https://arxiv.org/pdf/1912.04232.pdf) | [**NeurIPS 2020**](https://neurips.cc/virtual/2020/public/poster_83d3d4b6c9579515e1679aca8cbc8033.html)\n\n![Build Status](https://github.com/google/jax-md/workflows/Build/badge.svg?branch=main) [![Coverage](https://codecov.io/gh/google/jax-md/branch/main/graph/badge.svg?token=JYQpbNyICv)](https://codecov.io/gh/google/jax-md) [![PyPI](https://img.shields.io/pypi/v/jax-md)](https://pypi.org/project/jax-md/) [![PyPI - License](https://img.shields.io/pypi/l/jax_md)](https://github.com/google/jax-md/blob/main/LICENSE)\n\nMolecular dynamics is a workhorse of modern computational condensed matter physics. It is frequently used to simulate materials to observe how small scale interactions can give rise to complex large-scale phenomenology. Most molecular dynamics packages (e.g. HOOMD Blue or LAMMPS) are complicated, specialized pieces of code that are many thousands of lines long. They typically involve significant code duplication to allow for running simulations on CPU and GPU. Additionally, large amounts of code is often devoted to taking derivatives of quantities to compute functions of interest (e.g. gradients of energies to compute forces).\n\nHowever, recent work in machine learning has led to significant software developments that might make it possible to write more concise molecular dynamics simulations that offer a range of benefits. Here we target JAX, which allows us to write python code that gets compiled to XLA and allows us to run on CPU, GPU, or TPU. Moreover, JAX allows us to take derivatives of python code. Thus, not only is this molecular dynamics simulation automatically hardware accelerated, it is also __end-to-end__ differentiable. This should allow for some interesting experiments that we're excited to explore.\n\nJAX, MD is a research project that is currently under development. Expect sharp edges and possibly some API breaking changes as we continue to support a broader set of simulations. JAX MD is a functional and data driven library. Data is stored in arrays or tuples of arrays and functions transform data from one state to another.\n\n### Getting Started\n\nFor a video introducing JAX MD along with a [demo](https://colab.research.google.com/github/google/jax-md/blob/main/notebooks/talk_demo.ipynb), check out this talk from the Physics meets Machine Learning series:\n\n[![Science Meets ML Talk](https://img.youtube.com/vi/Bkm8tGET7-w/0.jpg)](https://www.youtube.com/watch?v=Bkm8tGET7-w)\n\nTo get started playing around with JAX MD check out the following colab notebooks on Google Cloud without needing to install anything. For a very simple introduction, I would recommend the Minimization example. For an example of a bunch of the features of JAX MD, check out the JAX MD cookbook.\n\n- [JAX MD Cookbook](https://colab.research.google.com/github/google/jax-md/blob/main/notebooks/jax_md_cookbook.ipynb)\n- [Minimization](https://colab.research.google.com/github/google/jax-md/blob/main/notebooks/minimization.ipynb)\n- [NVE Simulation](https://colab.research.google.com/github/google/jax-md/blob/main/notebooks/nve_simulation.ipynb)\n- [NVT Simulation](https://colab.research.google.com/github/google/jax-md/blob/main/notebooks/nvt_simulation.ipynb)\n- [NPT Simulation](https://colab.research.google.com/github/google/jax-md/blob/main/notebooks/npt_simulation.ipynb)\n- [NVE with Neighbor Lists](https://colab.research.google.com/github/google/jax-md/blob/main/notebooks/nve_neighbor_list.ipynb)\n- [Custom Potentials](https://colab.research.google.com/github/google/jax-md/blob/main/notebooks/customizing_potentials_cookbook.ipynb)\n- [Neural Network Potentials](https://colab.research.google.com/github/google/jax-md/blob/main/notebooks/neural_networks.ipynb)\n- [Flocking](https://colab.research.google.com/github/google/jax-md/blob/main/notebooks/flocking.ipynb)\n- [Meta Optimization](https://colab.research.google.com/github/google/jax-md/blob/main/notebooks/meta_optimization.ipynb)\n- [Swap Monte Carlo (Cargese Summer School)](https://colab.research.google.com/github/google/jax-md/blob/main/notebooks/cargese_swap_mc.ipynb)\n- [Implicit Differentiation](https://colab.research.google.com/github/google/jax-md/blob/main/notebooks/implicit_differentiation.ipynb)\n- [Athermal Linear Elasticity](https://colab.research.google.com/github/google/jax-md/blob/main/notebooks/athermal_linear_elasticity.ipynb)\n- [Smash a Sand Castle](https://colab.research.google.com/github/google/jax-md/blob/main/notebooks/sand_castle.ipynb)\n\nYou can install JAX MD locally with pip,\n```\npip install jax-md --upgrade\n```\nIf you want to build the latest version then you can grab the most recent version from head,\n```\ngit clone https://github.com/google/jax-md\npip install -e jax-md\n```\n\n# Overview\n\nWe now summarize the main components of the library.\n\n## Spaces ([`space.py`](https://jax-md.readthedocs.io/en/main/jax_md.space.html))\n\nIn general we must have a way of computing the pairwise distance between atoms. We must also have efficient strategies for moving atoms in some space that may or may not be globally isomorphic to R^N. For example, periodic boundary conditions are commonplace in simulations and must be respected. Spaces are defined as a pair of functions, `(displacement_fn, shift_fn)`. Given two points `displacement_fn(R_1, R_2)` computes the displacement vector between the two points. If you would like to compute displacement vectors between all pairs of points in a given `(N, dim)` matrix the function `space.map_product` appropriately vectorizes `displacement_fn`. It is often useful to define a metric instead of a displacement function in which case you can use the helper function `space.metric` to convert a displacement function to a metric function. Given a point and a shift `shift_fn(R, dR)` displaces the point `R` by an amount `dR`.\n\nThe following spaces are currently supported:\n- [`space.free()`](https://jax-md.readthedocs.io/en/main/jax_md.space.html?highlight=free#jax_md.space.free) specifies a space with free boundary conditions.\n- [`space.periodic(box_size)`](https://jax-md.readthedocs.io/en/main/jax_md.space.html?highlight=periodic#jax_md.space.periodic) specifies a space with periodic boundary conditions of side length `box_size`.\n- [`space.periodic_general(box)`](https://jax-md.readthedocs.io/en/main/jax_md.space.html?highlight=periodic_general#jax_md.space.periodic_general) specifies a space as a periodic parellelopiped formed by transforming the unit cube by an affine transformation `box`.\n\nExample:\n\n```python\nfrom jax_md import space\nbox_size = 25.0\ndisplacement_fn, shift_fn = space.periodic(box_size)\n```\n\n## Potential Energy ([`energy.py`](https://jax-md.readthedocs.io/en/main/jax_md.energy.html))\n\nIn the simplest case, molecular dynamics calculations are often based on a pair potential that is defined by a user. This then is used to compute a total energy whose negative gradient gives forces. One of the very nice things about JAX is that we get forces for free! The second part of the code is devoted to computing energies.\n\nWe provide the following classical potentials:\n- [`energy.soft_sphere`](https://jax-md.readthedocs.io/en/main/jax_md.energy.html?highlight=soft_sphere#jax_md.energy.soft_sphere) a soft sphere whose energy increases as the overlap of the spheres to some power, `alpha`.\n- [`energy.lennard_jones`](https://jax-md.readthedocs.io/en/main/jax_md.energy.html?highlight=lennard_jones#jax_md.energy.lennard_jones) a standard 12-6 Lennard-Jones potential.\n- [`energy.morse`](https://jax-md.readthedocs.io/en/main/jax_md.energy.html?highlight=morse#jax_md.energy.morse) a morse potential.\n- [`energy.tersoff`](https://jax-md.readthedocs.io/en/main/jax_md.energy.html?highlight=tersoff#jax_md.energy.tersoff) the Tersoff potential for simulating semiconducting materials. Can load parameters from LAMMPS files.\n- [`energy.eam`](https://jax-md.readthedocs.io/en/main/jax_md.energy.html?highlight=eam#jax_md.energy.eam) embedded atom model potential with ability to load parameters from LAMMPS files.\n- [`energy.stillinger_weber`](https://jax-md.readthedocs.io/en/main/jax_md.energy.html?highlight=stillinger_weber#jax_md.energy.stillinger_weber) used to model Silicon-like systems.\n- [`energy.bks`](https://jax-md.readthedocs.io/en/main/jax_md.energy.html?highlight=bks#jax_md.energy.bks) Beest-Kramer-van Santen potential used to model silica.\n- [`energy.gupta`](https://jax-md.readthedocs.io/en/main/jax_md.energy.html?highlight=gupta#jax_md.energy.gupta) used to model gold nanoclusters.\n\nWe also provide the following neural network potentials:\n- [`energy.behler_parrinello`](https://jax-md.readthedocs.io/en/main/jax_md.energy.html?highlight=behler_parrinello#jax_md.energy.behler_parrinello) a widely used fixed-feature neural network architecture for molecular systems.\n- [`energy.graph_network`](https://jax-md.readthedocs.io/en/main/jax_md.energy.html?highlight=graph_network#jax_md.energy.graph_network) a deep graph neural network designed for energy fitting.\n\nFor finite-ranged potentials it is often useful to consider only interactions within a certain neighborhood. We include the `_neighbor_list` modifier to the above potentials that uses a list of neighbors (see below) for optimization.\n\nExample:\n\n```python\nimport jax.numpy as np\nfrom jax import random\nfrom jax_md import energy, quantity\nN = 1000\nspatial_dimension = 2\nkey = random.PRNGKey(0)\nR = random.uniform(key, (N, spatial_dimension), minval=0.0, maxval=1.0)\nenergy_fn = energy.lennard_jones_pair(displacement_fn)\nprint('E = {}'.format(energy_fn(R)))\nforce_fn = quantity.force(energy_fn)\nprint('Total Squared Force = {}'.format(np.sum(force_fn(R) ** 2)))\n```\n\n## Dynamics ([`simulate.py`](https://jax-md.readthedocs.io/en/main/jax_md.simulate.html), [`minimize.py`](https://jax-md.readthedocs.io/en/main/jax_md.minimize.html))\n\nGiven an energy function and a system, there are a number of dynamics are useful to simulate. The simulation code is based on the structure of the optimizers found in JAX. In particular, each simulation function returns an initialization function and an update function. The initialization function takes a set of positions and creates the necessary dynamical state variables. The update function does a single step of dynamics to the dynamical state variables and returns an updated state.\n\nWe include a several different kinds of dynamics. However, there is certainly room to add more for e.g. constant strain simulations.\n\nIt is often desirable to find an energy minimum of the system. We provide two methods to do this. We provide simple gradient descent minimization. This is mostly for pedagogical purposes, since it often performs poorly. We additionally include the FIRE algorithm which often sees significantly faster convergence. Moreover a common experiment to run in the context of molecular dynamics is to simulate a system with a fixed volume and temperature.\n\nWe provide the following dynamics:\n- [`simulate.nve`](https://jax-md.readthedocs.io/en/main/jax_md.simulate.html?highlight=nve#jax_md.simulate.nve) Constant energy simulation; numerically integrates Newton's laws directly.\n- [`simulate.nvt_nose_hoover`](https://jax-md.readthedocs.io/en/main/jax_md.simulate.html?highlight=nvt_nose_hoover#jax_md.simulate.nvt_nose_hoover) Uses Nose-Hoover chain to simulate a constant temperature system.\n- [`simulate.npt_nose_hoover`](https://jax-md.readthedocs.io/en/main/jax_md.simulate.html?highlight=nvp_nose_hoover#jax_md.simulate.nvp_nose_hoover) Uses Nose-Hoover chain to simulate a system at constant pressure and temperature.\n- [`simulate.nvt_langevin`](https://jax-md.readthedocs.io/en/main/jax_md.simulate.html?highlight=nvt_langevin#jax_md.simulate.nvt_langevin) Simulates a system by numerically integrating the Langevin stochastic differential equation.\n- [`simulate.hybrid_swap_mc`](https://jax-md.readthedocs.io/en/main/jax_md.simulate.html?highlight=hybrid_swap_mc#jax_md.simulate.hybrid_swap_mc) Alternates NVT dynamics with Monte-Carlo swapping moves to generate low energy glasses.\n- [`simulate.brownian`](https://jax-md.readthedocs.io/en/main/jax_md.simulate.html?highlight=brownian#jax_md.simulate.brownian) Simulates brownian motion.\n- [`minimize.gradient_descent`](https://jax-md.readthedocs.io/en/main/jax_md.minimize.html?highlight=gradient_descent#jax_md.minimize.gradient_descent) Minimizes a system using gradient descent.\n- [`minimize.fire_descent`](https://jax-md.readthedocs.io/en/main/jax_md.minimize.html?highlight=fire_descent#jax_md.minimize.fire_descent) Minimizes a system using the fast inertial relaxation engine.\n\nExample:\n\n```python\nfrom jax_md import simulate\ntemperature = 1.0\ndt = 1e-3\ninit, update = simulate.nvt_nose_hoover(energy_fn, shift_fn, dt, temperature)\nstate = init(key, R)\nfor _ in range(100):\n  state = update(state)\nR = state.position\n```\n\n## Spatial Partitioning ([`partition.py`](https://jax-md.readthedocs.io/en/main/jax_md.partition.html))\n\nIn many applications, it is useful to construct spatial partitions of particles / objects in a simulation.\n\nWe provide the following methods:\n- [`partition.cell_list`](https://jax-md.readthedocs.io/en/main/jax_md.partition.html?highlight=cell_list#jax_md.partition.cell_list) Partitions objects (and metadata) into a grid of cells.\n- [`partition.neighbor_list`](https://jax-md.readthedocs.io/en/main/jax_md.partition.html?highlight=neighbor_list#jax_md.partition.neighbor_list) Constructs a set of neighbors within some cutoff distance for each object in a simulation.\n\nCell List Example:\n```python\nfrom jax_md import partition\n\ncell_size = 5.0\ncapacity = 10\ncell_list_fn = partition.cell_list(box_size, cell_size, capacity)\ncell_list_data = cell_list_fn.allocate(R)\n```\n\nNeighbor List Example:\n```python\nfrom jax_md import partition\n\nneighbor_list_fn = partition.neighbor_list(displacement_fn, box_size, cell_size)\nneighbors = neighbor_list_fn.allocate(R) # Create a new neighbor list.\n\n# Do some simulating....\n\nneighbors = neighbors.update(R)  # Update the neighbor list without resizing.\nif neighbors.did_buffer_overflow:  # Couldn't fit all the neighbors into the list.\n  neighbors = neighbor_list_fn.allocate(R)  # So create a new neighbor list.\n```\n\nThere are three different formats of neighbor list supported: `Dense`, `Sparse`, and `OrderedSparse`. `Dense` neighbor lists store neighbors in an `(particle_count, neighbors_per_particle)` array, `Sparse` neighbor lists store neighbors in a `(2, total_neighbors)` array of pairs, `OrderedSparse` neighbor lists are like `Sparse` neighbor lists, but they only store pairs such that `i < j`.\n\n# Development\n\nJAX MD is under active development. We have very limited development resources and so we typically focus on adding features that will have high impact to researchers using JAX MD (including us). Please don't hesitate to open feature requests to help us guide development. We more than welcome contributions!\n\n## Technical gotchas\n\n### GPU\n\nYou must follow [JAX's](https://www.github.com/google/jax/) GPU installation instructions to enable GPU support.\n\n\n### 64-bit precision\nTo enable 64-bit precision, set the respective JAX flag _before_ importing `jax_md` (see the JAX [guide](https://colab.research.google.com/github/google/jax/blob/main/notebooks/Common_Gotchas_in_JAX.ipynb#scrollTo=YTktlwTTMgFl)), for example:\n\n```python\nfrom jax.config import config\nconfig.update(\"jax_enable_x64\", True)\n```\n\n# Publications\n\nJAX MD has been used in the following publications. If you don't see your paper on the list, but you used JAX MD let us know and we'll add it to the list!\n\n1. [A Differentiable Neural-Network Force Field for Ionic Liquids. (J. Chem. Inf. Model. 2022)](https://pubs.acs.org/doi/abs/10.1021/acs.jcim.1c01380)<br> H. Montes-Campos, J. Carrete, S. Bichelmaier, L. M. Varela, and G. K. H. Madsen\n2. [Correlation Tracking: Using simulations to interpolate highly correlated particle tracks. (Phys. Rev. E. 2022)](https://journals.aps.org/pre/abstract/10.1103/PhysRevE.105.044608?ft=1)<br> E. M. King, Z. Wang, D. A. Weitz, F. Spaepen, and M. P. Brenner\n3. [Optimal Control of Nonequilibrium Systems Through Automatic Differentiation.](https://arxiv.org/abs/2201.00098)<br> M. C. Engel, J. A. Smith, and M. P. Brenner\n4. [Graph Neural Networks Accelerated Molecular Dynamics. (J. Chem. Phys. 2022)](https://aip.scitation.org/doi/10.1063/5.0083060)<br> Z. Li, K. Meidani, P. Yadav, and A. B. Farimani\n5. [Gradients are Not All You Need.](https://arxiv.org/abs/2111.05803)<br> L. Metz, C. D. Freeman, S. S. Schoenholz, and T. Kachman\n6. [Lagrangian Neural Network with Differential Symmetries and Relational Inductive Bias.](https://arxiv.org/abs/2110.03266)<br> R. Bhattoo, S. Ranu, and N. M. A. Krishnan\n7. [Efficient and Modular Implicit Differentiation.](https://arxiv.org/abs/2105.15183)<br> M. Blondel, Q. Berthet, M. Cuturi, R. Frostig, S. Hoyer, F. Llinares-L\u00f3pez, F. Pedregosa, and J.-P. Vert\n8. [Learning neural network potentials from experimental data via Differentiable Trajectory Reweighting.<br>(Nature Communications 2021)](https://www.nature.com/articles/s41467-021-27241-4)<br> S. Thaler and J. Zavadlav\n9. [Learn2Hop: Learned Optimization on Rough Landscapes. (ICML 2021)](http://proceedings.mlr.press/v139/merchant21a.html)<br> A. Merchant, L. Metz, S. S. Schoenholz, and E. D. Cubuk\n10. [Designing self-assembling kinetics with differentiable statistical physics models. (PNAS 2021)](https://www.pnas.org/content/118/10/e2024083118.short)<br> C. P. Goodrich, E. M. King, S. S. Schoenholz, E. D. Cubuk, and  M. P. Brenner\n\n# Citation\n\nIf you use the code in a publication, please cite the repo using the .bib,\n\n```\n@inproceedings{jaxmd2020,\n author = {Schoenholz, Samuel S. and Cubuk, Ekin D.},\n booktitle = {Advances in Neural Information Processing Systems},\n publisher = {Curran Associates, Inc.},\n title = {JAX M.D. A Framework for Differentiable Physics},\n url = {https://papers.nips.cc/paper/2020/file/83d3d4b6c9579515e1679aca8cbc8033-Paper.pdf},\n volume = {33},\n year = {2020}\n}\n```\n",
    "bugtrack_url": null,
    "license": "Apache 2.0",
    "summary": "Differentiable, Hardware Accelerated, Molecular Dynamics",
    "version": "0.2.8",
    "project_urls": {
        "Bug Tracker": "https://github.com/google/jax-md/issues",
        "Documentation": "https://arxiv.org/abs/1912.04232",
        "Download": "https://pypi.org/project/jax-md/",
        "Homepage": "https://github.com/google/jax-md",
        "Source Code": "https://github.com/google/jax-md"
    },
    "split_keywords": [],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "f2cb4cd5f9f36fba4e78a95c9de1cad499a35451479135960dda75ecdd740885",
                "md5": "974320905c756e5e11e1e07b9b8c2ccf",
                "sha256": "ca8dd37343a324151a89716b0fd39f67c0658c4f8da7ec52a250bc2c98e277dc"
            },
            "downloads": -1,
            "filename": "jax_md-0.2.8-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "974320905c756e5e11e1e07b9b8c2ccf",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=2.7",
            "size": 150980,
            "upload_time": "2023-08-09T23:18:24",
            "upload_time_iso_8601": "2023-08-09T23:18:24.147379Z",
            "url": "https://files.pythonhosted.org/packages/f2/cb/4cd5f9f36fba4e78a95c9de1cad499a35451479135960dda75ecdd740885/jax_md-0.2.8-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "33d6f1cc8f8d13c16aacb0ab11a5042d2f2a06f2ba599fd0b8c062b964ba2d11",
                "md5": "da7c0234ea139c64403246074d35d295",
                "sha256": "ad74e443c8e89933e38b0e2654b52f7f5aeabbd81a0934c765f09417eaa2e95b"
            },
            "downloads": -1,
            "filename": "jax-md-0.2.8.tar.gz",
            "has_sig": false,
            "md5_digest": "da7c0234ea139c64403246074d35d295",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=2.7",
            "size": 140202,
            "upload_time": "2023-08-09T23:18:25",
            "upload_time_iso_8601": "2023-08-09T23:18:25.883792Z",
            "url": "https://files.pythonhosted.org/packages/33/d6/f1cc8f8d13c16aacb0ab11a5042d2f2a06f2ba599fd0b8c062b964ba2d11/jax-md-0.2.8.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2023-08-09 23:18:25",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "google",
    "github_project": "jax-md",
    "travis_ci": true,
    "coveralls": false,
    "github_actions": true,
    "lcname": "jax-md"
}
        
Elapsed time: 0.10003s