[](https://github.com/jaxleyverse/tridiax/actions)
# tridiax
`tridiax` implements solvers for tridiagonal systems in jax. All solvers support CPU and GPU, are compatible with `jit` compilation and can be differentiated with `grad`.
### Implemented solvers
- [Thomas algorithm](http://www.industrial-maths.com/ms6021_thomas.pdf)
- [Divide and conquer](https://courses.engr.illinois.edu/cs554/fa2013/notes/09_tridiagonal_8up.pdf)
- [Stone's algorithm](https://dl.acm.org/doi/pdf/10.1145/321738.321741)
Generally, Thomas algorithm will be faster on CPU whereas the divide and conquer
algorithm and Stone's algorithm will be faster on GPU.
### Known limitations
Currently, the `divide_conquer` solver only supports systems whose dimensionality is a power of `2`.
### Usage
```python
from tridiax import thomas_solve, divide_conquer_solve, stone_solve
dim = 1024
diag = jnp.asarray(np.random.randn(dim))
upper = jnp.asarray(np.random.randn(dim - 1))
lower = jnp.asarray(np.random.randn(dim - 1))
solve = jnp.asarray(np.random.randn(dim))
solution = thomas_solve(lower, diag, upper, solve)
```
If many systems of the same size are solved and the divide and conquer algorithm is used, it helps to precompute the reordering indizes:
```python
from tridiax import divide_conquer_solve, divide_conquer_index
dim = 1024
diag = jnp.asarray(np.random.randn(dim))
upper = jnp.asarray(np.random.randn(dim - 1))
lower = jnp.asarray(np.random.randn(dim - 1))
solve = jnp.asarray(np.random.randn(dim))
indexing = divide_conquer_index(dim)
solution = divide_conquer_solve(lower, diag, upper, solve, indexing=indexing)
```
### Installation
`tridiax` is available on [`pypi`](https://pypi.org/project/tridiax/):
```sh
pip install tridiax
```
This will install `tridiax` with CPU support. If you want GPU support, follow the instructions on the [`JAX` github repository](https://github.com/google/jax) to install `JAX` with GPU support (in addition to installing tridiax). For example, for NVIDIA GPUs, run
```sh
pip install -U "jax[cuda12]"
```
Raw data
{
"_id": null,
"home_page": null,
"name": "tridiax",
"maintainer": null,
"docs_url": null,
"requires_python": ">=3.8",
"maintainer_email": null,
"keywords": "tridiagonal, linear algebra, solver, jax",
"author": null,
"author_email": "jaxleyverse <jaxleyverse@gmail.com>",
"download_url": "https://files.pythonhosted.org/packages/5d/8d/55d41b1de379faf0518b8e110c656bef40e73059df4cfff51c0b72cb4928/tridiax-0.2.1.tar.gz",
"platform": null,
"description": "[](https://github.com/jaxleyverse/tridiax/actions)\n\n# tridiax\n`tridiax` implements solvers for tridiagonal systems in jax. All solvers support CPU and GPU, are compatible with `jit` compilation and can be differentiated with `grad`.\n\n\n### Implemented solvers\n\n- [Thomas algorithm](http://www.industrial-maths.com/ms6021_thomas.pdf)\n- [Divide and conquer](https://courses.engr.illinois.edu/cs554/fa2013/notes/09_tridiagonal_8up.pdf)\n- [Stone's algorithm](https://dl.acm.org/doi/pdf/10.1145/321738.321741)\n\nGenerally, Thomas algorithm will be faster on CPU whereas the divide and conquer\nalgorithm and Stone's algorithm will be faster on GPU.\n\n\n### Known limitations\n\nCurrently, the `divide_conquer` solver only supports systems whose dimensionality is a power of `2`.\n\n\n### Usage\n\n```python\nfrom tridiax import thomas_solve, divide_conquer_solve, stone_solve\n\ndim = 1024\ndiag = jnp.asarray(np.random.randn(dim))\nupper = jnp.asarray(np.random.randn(dim - 1))\nlower = jnp.asarray(np.random.randn(dim - 1))\nsolve = jnp.asarray(np.random.randn(dim))\nsolution = thomas_solve(lower, diag, upper, solve)\n```\n\nIf many systems of the same size are solved and the divide and conquer algorithm is used, it helps to precompute the reordering indizes:\n```python\nfrom tridiax import divide_conquer_solve, divide_conquer_index\n\ndim = 1024\ndiag = jnp.asarray(np.random.randn(dim))\nupper = jnp.asarray(np.random.randn(dim - 1))\nlower = jnp.asarray(np.random.randn(dim - 1))\nsolve = jnp.asarray(np.random.randn(dim))\n\nindexing = divide_conquer_index(dim)\nsolution = divide_conquer_solve(lower, diag, upper, solve, indexing=indexing)\n```\n\n### Installation\n\n`tridiax` is available on [`pypi`](https://pypi.org/project/tridiax/):\n```sh\npip install tridiax\n```\nThis will install `tridiax` with CPU support. If you want GPU support, follow the instructions on the [`JAX` github repository](https://github.com/google/jax) to install `JAX` with GPU support (in addition to installing tridiax). For example, for NVIDIA GPUs, run\n```sh\npip install -U \"jax[cuda12]\"\n```\n",
"bugtrack_url": null,
"license": null,
"summary": "Solving tridiagonal systems.",
"version": "0.2.1",
"project_urls": {
"source": "https://github.com/jaxleyverse/tridiax",
"tracker": "https://github.com/jaxleyverse/tridiax/issues"
},
"split_keywords": [
"tridiagonal",
" linear algebra",
" solver",
" jax"
],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "15fdf69ff723a4e6534fce070acc5c50b80e739b2efb4c49ec580a629c6a3898",
"md5": "a7bde90843e9a656b11d87da677781a8",
"sha256": "311b0ed41671303197e219019fb9d22d6b31c841ddf5fdd1ec2601e09ed4e750"
},
"downloads": -1,
"filename": "tridiax-0.2.1-py3-none-any.whl",
"has_sig": false,
"md5_digest": "a7bde90843e9a656b11d87da677781a8",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": ">=3.8",
"size": 11813,
"upload_time": "2024-12-05T09:55:52",
"upload_time_iso_8601": "2024-12-05T09:55:52.741502Z",
"url": "https://files.pythonhosted.org/packages/15/fd/f69ff723a4e6534fce070acc5c50b80e739b2efb4c49ec580a629c6a3898/tridiax-0.2.1-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": "",
"digests": {
"blake2b_256": "5d8d55d41b1de379faf0518b8e110c656bef40e73059df4cfff51c0b72cb4928",
"md5": "e6766c0167a3ce5780be6222e4e76c96",
"sha256": "95a8c6d003cdd694487c99e5ba2c43d4fb4dfbe3a3df96e9ac2c80c1c4aaecd1"
},
"downloads": -1,
"filename": "tridiax-0.2.1.tar.gz",
"has_sig": false,
"md5_digest": "e6766c0167a3ce5780be6222e4e76c96",
"packagetype": "sdist",
"python_version": "source",
"requires_python": ">=3.8",
"size": 11483,
"upload_time": "2024-12-05T09:55:54",
"upload_time_iso_8601": "2024-12-05T09:55:54.374368Z",
"url": "https://files.pythonhosted.org/packages/5d/8d/55d41b1de379faf0518b8e110c656bef40e73059df4cfff51c0b72cb4928/tridiax-0.2.1.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2024-12-05 09:55:54",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "jaxleyverse",
"github_project": "tridiax",
"travis_ci": false,
"coveralls": false,
"github_actions": true,
"lcname": "tridiax"
}