jax-fem


Namejax-fem JSON
Version 0.0.4 PyPI version JSON
download
home_pageNone
SummaryGPU accelerated Finite element analysis package in JAX.
upload_time2024-04-20 08:12:16
maintainerNone
docs_urlNone
authorTianju Xue
requires_python>=3.9
licenseGPL-3.0 License
keywords jax gpu python finite element analysis differentiable programming
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            A GPU-accelerated differentiable finite element analysis package based on [JAX](https://github.com/google/jax). Used to be part of the suite of open-source python packages for Additive Manufacturing (AM) research, [JAX-AM](https://github.com/tianjuxue/jax-am).

## Finite Element Method (FEM)
![Github Star](https://img.shields.io/github/stars/deepmodeling/jax-fem)
![Github Fork](https://img.shields.io/github/forks/deepmodeling/jax-fem)
![License](https://img.shields.io/github/license/deepmodeling/jax-fem)

FEM is a powerful tool, where we support the following features

- 2D quadrilateral/triangle elements
- 3D hexahedron/tetrahedron elements
- First and second order elements
- Dirichlet/Neumann/Robin boundary conditions
- Linear and nonlinear analysis including
  - Heat equation
  - Linear elasticity
  - Hyperelasticity
  - Plasticity (macro and crystal plasticity)
- Differentiable simulation for solving inverse/design problems __without__ human deriving sensitivities, e.g.,
  - Topology optimization
  - Optimal thermal control
- Integration with PETSc for solver choices

**Updates** (Dec 11, 2023):

- We now support multi-physics problems in the sense that multiple variables can be solved monolithically. For example, consider running  `python -m applications.stokes.example`
- Weak form is now defined through  volume integral and surface integral. We can now treat body force, "mass kernel" and "Laplace kernel" in a unified way through volume integral, and treat "Neumann B.C." and "Robin B.C." in a unified way through surface integral. 

<p align="middle">
  <img src="docs/ded.gif" width="600" />
</p>
<p align="middle">
    <em >Thermal profile in direct energy deposition.</em>
</p>

<p align="middle">
  <img src="docs/von_mises.png" width="400" />
</p>
<p align="middle">
    <em >Linear static analysis of a bracket.</em>
</p>

<p align="middle">
  <img src="docs/polycrystal_grain.gif" width="360" />
  <img src="docs/polycrystal_stress.gif" width="360" />
</p>
<p align="middle">
    <em >Crystal plasticity: grain structure (left) and stress-xx (right).</em>
</p>

<p align="middle">
  <img src="docs/stokes_u.png" width="360" />
  <img src="docs/stokes_p.png" width="360" />
</p>
<p align="middle">
    <em >Stokes flow: velocity (left) and pressure(right).</em>
</p>

<p align="middle">
  <img src="docs/to.gif" width="600" />
</p>
<p align="middle">
    <em >Topology optimization with differentiable simulation.</em>
</p>

## Installation

Create a conda environment from the given [`environment.yml`](https://github.com/deepmodeling/jax-fem/blob/main/environment.yml) file and activate it:

```bash
conda env create -f environment.yml
conda activate jax-fem-env
```

Install JAX
- See jax installation [instructions](https://jax.readthedocs.io/en/latest/installation.html#). Depending on your hardware, you may install the CPU or GPU version of JAX. Both will work, while GPU version usually gives better performance.


Then there are two options to continue:

### Option 1

Clone the repository:

```bash
git clone https://github.com/deepmodeling/jax-fem.git
cd jax-fem
```

and install the package locally:

```bash

pip install -e .
```

**Quick tests**: You can check `demos/` for a variety of FEM cases. For example, run

```bash
python -m demos.hyperelasticity.example
```

for hyperelasticity. 

Also, 

```bash
python -m tests.benchmarks
```

will execute a set of test cases.


### Option 2

Install the package from the [PyPI release](https://pypi.org/project/jax-fem/) directly:

```bash
pip install jax-fem
```

**Quick tests**: You can create an `example.py` file and run it:

```bash
python example.py
```

```python
import jax
import jax.numpy as np
import os

from jax_fem.problem import Problem
from jax_fem.solver import solver
from jax_fem.utils import save_sol
from jax_fem.generate_mesh import get_meshio_cell_type, Mesh, rectangle_mesh

class Poisson(Problem):
    def get_tensor_map(self):
        return lambda x: x

    def get_mass_map(self):
        def mass_map(u, x):
            val = -np.array([10*np.exp(-(np.power(x[0] - 0.5, 2) + np.power(x[1] - 0.5, 2)) / 0.02)])
            return val
        return mass_map

    def get_surface_maps(self):
        def surface_map(u, x):
            return -np.array([np.sin(5.*x[0])])

        return [surface_map, surface_map]

ele_type = 'QUAD4'
cell_type = get_meshio_cell_type(ele_type)
Lx, Ly = 1., 1.
meshio_mesh = rectangle_mesh(Nx=32, Ny=32, domain_x=Lx, domain_y=Ly)
mesh = Mesh(meshio_mesh.points, meshio_mesh.cells_dict[cell_type])

def left(point):
    return np.isclose(point[0], 0., atol=1e-5)

def right(point):
    return np.isclose(point[0], Lx, atol=1e-5)

def bottom(point):
    return np.isclose(point[1], 0., atol=1e-5)

def top(point):
    return np.isclose(point[1], Ly, atol=1e-5)

def dirichlet_val_left(point):
    return 0.

def dirichlet_val_right(point):
    return 0.

location_fns = [left, right]
value_fns = [dirichlet_val_left, dirichlet_val_right]
vecs = [0, 0]
dirichlet_bc_info = [location_fns, vecs, value_fns]

location_fns = [bottom, top]

problem = Poisson(mesh=mesh, vec=1, dim=2, ele_type=ele_type, dirichlet_bc_info=dirichlet_bc_info, location_fns=location_fns)
sol = solver(problem, linear=True, use_petsc=True)

data_dir = os.path.join(os.path.dirname(__file__), 'data')
vtk_path = os.path.join(data_dir, f'vtk/u.vtu')
save_sol(problem.fes[0], sol[0], vtk_path)
```


## License

This project is licensed under the GNU General Public License v3 - see the [LICENSE](https://www.gnu.org/licenses/) for details.

## Citations

If you found this library useful in academic or industry work, we appreciate your support if you consider 1) starring the project on Github, and 2) citing relevant papers:

```bibtex
@article{xue2023jax,
  title={JAX-FEM: A differentiable GPU-accelerated 3D finite element solver for automatic inverse design and mechanistic data science},
  author={Xue, Tianju and Liao, Shuheng and Gan, Zhengtao and Park, Chanwook and Xie, Xiaoyu and Liu, Wing Kam and Cao, Jian},
  journal={Computer Physics Communications},
  pages={108802},
  year={2023},
  publisher={Elsevier}
}
```

            

Raw data

            {
    "_id": null,
    "home_page": null,
    "name": "jax-fem",
    "maintainer": null,
    "docs_url": null,
    "requires_python": ">=3.9",
    "maintainer_email": null,
    "keywords": "JAX, GPU, Python, Finite element analysis, Differentiable programming",
    "author": "Tianju Xue",
    "author_email": "cetxue@ust.hk",
    "download_url": "https://files.pythonhosted.org/packages/d0/af/0345128bb180c974b7a1f107f0bf22b9f39c2565122b5d210811ebe59fdd/jax_fem-0.0.4.tar.gz",
    "platform": null,
    "description": "A GPU-accelerated differentiable finite element analysis package based on [JAX](https://github.com/google/jax). Used to be part of the suite of open-source python packages for Additive Manufacturing (AM) research, [JAX-AM](https://github.com/tianjuxue/jax-am).\n\n## Finite Element Method (FEM)\n![Github Star](https://img.shields.io/github/stars/deepmodeling/jax-fem)\n![Github Fork](https://img.shields.io/github/forks/deepmodeling/jax-fem)\n![License](https://img.shields.io/github/license/deepmodeling/jax-fem)\n\nFEM is a powerful tool, where we support the following features\n\n- 2D quadrilateral/triangle elements\n- 3D hexahedron/tetrahedron elements\n- First and second order elements\n- Dirichlet/Neumann/Robin boundary conditions\n- Linear and nonlinear analysis including\n  - Heat equation\n  - Linear elasticity\n  - Hyperelasticity\n  - Plasticity (macro and crystal plasticity)\n- Differentiable simulation for solving inverse/design problems __without__ human deriving sensitivities, e.g.,\n  - Topology optimization\n  - Optimal thermal control\n- Integration with PETSc for solver choices\n\n**Updates** (Dec 11, 2023):\n\n- We now support multi-physics problems in the sense that multiple variables can be solved monolithically. For example, consider running  `python -m applications.stokes.example`\n- Weak form is now defined through  volume integral and surface integral. We can now treat body force, \"mass kernel\" and \"Laplace kernel\" in a unified way through volume integral, and treat \"Neumann B.C.\" and \"Robin B.C.\" in a unified way through surface integral. \n\n<p align=\"middle\">\n  <img src=\"docs/ded.gif\" width=\"600\" />\n</p>\n<p align=\"middle\">\n    <em >Thermal profile in direct energy deposition.</em>\n</p>\n\n<p align=\"middle\">\n  <img src=\"docs/von_mises.png\" width=\"400\" />\n</p>\n<p align=\"middle\">\n    <em >Linear static analysis of a bracket.</em>\n</p>\n\n<p align=\"middle\">\n  <img src=\"docs/polycrystal_grain.gif\" width=\"360\" />\n  <img src=\"docs/polycrystal_stress.gif\" width=\"360\" />\n</p>\n<p align=\"middle\">\n    <em >Crystal plasticity: grain structure (left) and stress-xx (right).</em>\n</p>\n\n<p align=\"middle\">\n  <img src=\"docs/stokes_u.png\" width=\"360\" />\n  <img src=\"docs/stokes_p.png\" width=\"360\" />\n</p>\n<p align=\"middle\">\n    <em >Stokes flow: velocity (left) and pressure(right).</em>\n</p>\n\n<p align=\"middle\">\n  <img src=\"docs/to.gif\" width=\"600\" />\n</p>\n<p align=\"middle\">\n    <em >Topology optimization with differentiable simulation.</em>\n</p>\n\n## Installation\n\nCreate a conda environment from the given [`environment.yml`](https://github.com/deepmodeling/jax-fem/blob/main/environment.yml) file and activate it:\n\n```bash\nconda env create -f environment.yml\nconda activate jax-fem-env\n```\n\nInstall JAX\n- See jax installation [instructions](https://jax.readthedocs.io/en/latest/installation.html#). Depending on your hardware, you may install the CPU or GPU version of JAX. Both will work, while GPU version usually gives better performance.\n\n\nThen there are two options to continue:\n\n### Option 1\n\nClone the repository:\n\n```bash\ngit clone https://github.com/deepmodeling/jax-fem.git\ncd jax-fem\n```\n\nand install the package locally:\n\n```bash\n\npip install -e .\n```\n\n**Quick tests**: You can check `demos/` for a variety of FEM cases. For example, run\n\n```bash\npython -m demos.hyperelasticity.example\n```\n\nfor hyperelasticity. \n\nAlso, \n\n```bash\npython -m tests.benchmarks\n```\n\nwill execute a set of test cases.\n\n\n### Option 2\n\nInstall the package from the [PyPI release](https://pypi.org/project/jax-fem/) directly:\n\n```bash\npip install jax-fem\n```\n\n**Quick tests**: You can create an `example.py` file and run it:\n\n```bash\npython example.py\n```\n\n```python\nimport jax\nimport jax.numpy as np\nimport os\n\nfrom jax_fem.problem import Problem\nfrom jax_fem.solver import solver\nfrom jax_fem.utils import save_sol\nfrom jax_fem.generate_mesh import get_meshio_cell_type, Mesh, rectangle_mesh\n\nclass Poisson(Problem):\n    def get_tensor_map(self):\n        return lambda x: x\n\n    def get_mass_map(self):\n        def mass_map(u, x):\n            val = -np.array([10*np.exp(-(np.power(x[0] - 0.5, 2) + np.power(x[1] - 0.5, 2)) / 0.02)])\n            return val\n        return mass_map\n\n    def get_surface_maps(self):\n        def surface_map(u, x):\n            return -np.array([np.sin(5.*x[0])])\n\n        return [surface_map, surface_map]\n\nele_type = 'QUAD4'\ncell_type = get_meshio_cell_type(ele_type)\nLx, Ly = 1., 1.\nmeshio_mesh = rectangle_mesh(Nx=32, Ny=32, domain_x=Lx, domain_y=Ly)\nmesh = Mesh(meshio_mesh.points, meshio_mesh.cells_dict[cell_type])\n\ndef left(point):\n    return np.isclose(point[0], 0., atol=1e-5)\n\ndef right(point):\n    return np.isclose(point[0], Lx, atol=1e-5)\n\ndef bottom(point):\n    return np.isclose(point[1], 0., atol=1e-5)\n\ndef top(point):\n    return np.isclose(point[1], Ly, atol=1e-5)\n\ndef dirichlet_val_left(point):\n    return 0.\n\ndef dirichlet_val_right(point):\n    return 0.\n\nlocation_fns = [left, right]\nvalue_fns = [dirichlet_val_left, dirichlet_val_right]\nvecs = [0, 0]\ndirichlet_bc_info = [location_fns, vecs, value_fns]\n\nlocation_fns = [bottom, top]\n\nproblem = Poisson(mesh=mesh, vec=1, dim=2, ele_type=ele_type, dirichlet_bc_info=dirichlet_bc_info, location_fns=location_fns)\nsol = solver(problem, linear=True, use_petsc=True)\n\ndata_dir = os.path.join(os.path.dirname(__file__), 'data')\nvtk_path = os.path.join(data_dir, f'vtk/u.vtu')\nsave_sol(problem.fes[0], sol[0], vtk_path)\n```\n\n\n## License\n\nThis project is licensed under the GNU General Public License v3 - see the [LICENSE](https://www.gnu.org/licenses/) for details.\n\n## Citations\n\nIf you found this library useful in academic or industry work, we appreciate your support if you consider 1) starring the project on Github, and 2) citing relevant papers:\n\n```bibtex\n@article{xue2023jax,\n  title={JAX-FEM: A differentiable GPU-accelerated 3D finite element solver for automatic inverse design and mechanistic data science},\n  author={Xue, Tianju and Liao, Shuheng and Gan, Zhengtao and Park, Chanwook and Xie, Xiaoyu and Liu, Wing Kam and Cao, Jian},\n  journal={Computer Physics Communications},\n  pages={108802},\n  year={2023},\n  publisher={Elsevier}\n}\n```\n",
    "bugtrack_url": null,
    "license": "GPL-3.0 License",
    "summary": "GPU accelerated Finite element analysis package in JAX.",
    "version": "0.0.4",
    "project_urls": null,
    "split_keywords": [
        "jax",
        " gpu",
        " python",
        " finite element analysis",
        " differentiable programming"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "aedda76cf30c3bfafb1b17a1600fbd21d4e909c6b273d0c5d77033fb98d62255",
                "md5": "c5a6d19378adbb2f654e45ca027943c8",
                "sha256": "799156faf4ac69a37728c916b51cf357a375b533ef2327bc42abc7d917e8410c"
            },
            "downloads": -1,
            "filename": "jax_fem-0.0.4-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "c5a6d19378adbb2f654e45ca027943c8",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.9",
            "size": 34939,
            "upload_time": "2024-04-20T08:12:13",
            "upload_time_iso_8601": "2024-04-20T08:12:13.574600Z",
            "url": "https://files.pythonhosted.org/packages/ae/dd/a76cf30c3bfafb1b17a1600fbd21d4e909c6b273d0c5d77033fb98d62255/jax_fem-0.0.4-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "d0af0345128bb180c974b7a1f107f0bf22b9f39c2565122b5d210811ebe59fdd",
                "md5": "4fee89306710eff9d820d5b8c5e93c30",
                "sha256": "763a8477b662ac1a13cc02a0af07c484f14a1e72512e56a0c61ed98fbdb59ed1"
            },
            "downloads": -1,
            "filename": "jax_fem-0.0.4.tar.gz",
            "has_sig": false,
            "md5_digest": "4fee89306710eff9d820d5b8c5e93c30",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.9",
            "size": 47739,
            "upload_time": "2024-04-20T08:12:16",
            "upload_time_iso_8601": "2024-04-20T08:12:16.444306Z",
            "url": "https://files.pythonhosted.org/packages/d0/af/0345128bb180c974b7a1f107f0bf22b9f39c2565122b5d210811ebe59fdd/jax_fem-0.0.4.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2024-04-20 08:12:16",
    "github": false,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "lcname": "jax-fem"
}
        
Elapsed time: 0.22801s