[](https://mancusolab.github.io/traceax/)
[](https://pypi.org/project/traceax/)
[](https://github.com/mancusolab/traceax)
[](https://opensource.org/licenses/MIT)
[](https://github.com/pypa/hatch)
# traceax
``traceax`` is a Python library to perform stochastic trace estimation for linear operators. Namely,
given a square linear operator <i>`A`</i>, ``traceax`` provides flexible routines that estimate,
$$\text{trace}(\mathbf{A}) = \sum_i \mathbf{A}_{ii},$$
using only matrix-vector products. ``traceax`` is heavily inspired by
[lineax](https://github.com/patrick-kidger/lineax) as well as
[XTrace](https://github.com/eepperly/XTrace).
[**Installation**](#installation)
| [**Example**](#get-started-with-example)
| [**Documentation**](#documentation)
| [**Citation**](#citation)
| [**Notes**](#notes)
| [**Support**](#support)
| [**Other Software**](#other-software)
------------------
## Installation
Users can directly install from `pip`:
``` bash
pip install traceax
```
Or, users can download the latest repository and then use `pip`:
```
git clone https://github.com/mancusolab/traceax.git
cd traceax
pip install .
```
## Get Started with Example
```python
import jax.numpy as jnp
import jax.random as rdm
import lineax as lx
import traceax as tx
# simulate simple symmetric matrix with exponential eigenvalue decay
seed = 0
N = 1000
key = rdm.PRNGKey(seed)
key, xkey = rdm.split(key)
X = rdm.normal(xkey, (N, N))
Q, R = jnp.linalg.qr(X)
U = jnp.power(0.7, jnp.arange(N))
A = (Q * U) @ Q.T
# should be numerically close
print(jnp.trace(A)) # 3.3333323
print(jnp.sum(U)) # 3.3333335
# setup linear operator
operator = lx.MatrixLinearOperator(A)
# number of matrix vector operators
k = 25
# split key for estimators
key, key1, key2, key3, key4 = rdm.split(key, 5)
# Hutchinson estimator; default samples Rademacher {-1,+1}
hutch = tx.HutchinsonEstimator()
print(hutch.estimate(key1, operator, k)) # (Array(3.4099615, dtype=float32), {})
# Hutch++ estimator; default samples Rademacher {-1,+1}
hpp = tx.HutchPlusPlusEstimator()
print(hpp.estimate(key2, operator, k)) # (Array(3.3033807, dtype=float32), {})
# XTrace estimator; default samples uniformly on n-Sphere
xt = tx.XTraceEstimator()
print(xt.estimate(key3, operator, k)) # (Array(3.3271673, dtype=float32), {'std.err': Array(0.01717775, dtype=float32)})
# XNysTrace estimator; Improved performance for NSD/PSD trace estimates
operator = lx.TaggedLinearOperator(operator, lx.positive_semidefinite_tag)
nt = tx.XNysTraceEstimator()
print(nt.estimate(key4, operator, k)) # (Array(3.3297246, dtype=float32), {'std.err': Array(0.00042093, dtype=float32)})
```
## Documentation
Documentation is available at [here](https://mancusolab.github.io/traceax/).
## Citation
If you use `traceax` in your work, please cite:
> Nahid, A.A., Serafin, L., Mancuso, N. (2025). <i>traceax</i>: a JAX-based framework for stochastic trace estimation. bioRxiv (https://doi.org/10.1101/2025.07.14.662216)
## Notes
- `traceax` uses [JAX](https://github.com/google/jax) with [Just In
Time](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html)
compilation to achieve high-speed computation. However, there are
some [issues](https://github.com/google/jax/issues/5501) for JAX
with Mac M1 chip. To solve this, users need to initiate conda using
[miniforge](https://github.com/conda-forge/miniforge), and then
install `traceax` using `pip` in the desired environment.
## Support
Please report any bugs or feature requests in the [Issue
Tracker](https://github.com/mancusolab/traceax/issues). If users have
any questions or comments, please contact Abdullah Al Nahid (<alnahid@usc.edu>) or
Nicholas Mancuso (<nmancuso@usc.edu>).
## Other Software
Feel free to use other software developed by [Mancuso
Lab](https://www.mancusolab.com/):
- [SuShiE](https://github.com/mancusolab/sushie): a Bayesian
fine-mapping framework for molecular QTL data across multiple
ancestries.
- [MA-FOCUS](https://github.com/mancusolab/ma-focus): a Bayesian
fine-mapping framework using
[TWAS](https://www.nature.com/articles/ng.3506) statistics across
multiple ancestries to identify the causal genes for complex traits.
- [SuSiE-PCA](https://github.com/mancusolab/susiepca): a scalable
Bayesian variable selection technique for sparse principal component
analysis
- [twas_sim](https://github.com/mancusolab/twas_sim): a Python
software to simulate [TWAS](https://www.nature.com/articles/ng.3506)
statistics.
- [FactorGo](https://github.com/mancusolab/factorgo): a scalable
variational factor analysis model that learns pleiotropic factors
from GWAS summary statistics.
- [HAMSTA](https://github.com/tszfungc/hamsta): a Python software to
estimate heritability explained by local ancestry data from
admixture mapping summary statistics.
------------------------------------------------------------------------
``traceax`` is distributed under the terms of the
[Apache-2.0 license](https://spdx.org/licenses/Apache-2.0.html).
------------------------------------------------------------------------
This project has been set up using Hatch. For details and usage
information on Hatch see <https://github.com/pypa/hatch>.
Raw data
{
"_id": null,
"home_page": null,
"name": "traceax",
"maintainer": null,
"docs_url": null,
"requires_python": ">=3.9",
"maintainer_email": null,
"keywords": "jax, machine-learning, statistics, trace-estimation",
"author": null,
"author_email": "Abdullah Al Nahid <alnahid@usc.edu>, Linda Serafin <lserafin@usc.edu>, Nicholas Mancuso <nmancuso@usc.edu>",
"download_url": "https://files.pythonhosted.org/packages/5b/4a/da55f878a0a59e87459343bf5449809706ae9d10ccb5400b940aba4494c7/traceax-1.0.2.tar.gz",
"platform": null,
"description": "[](https://mancusolab.github.io/traceax/)\n[](https://pypi.org/project/traceax/)\n[](https://github.com/mancusolab/traceax)\n[](https://opensource.org/licenses/MIT)\n[](https://github.com/pypa/hatch)\n\n# traceax\n``traceax`` is a Python library to perform stochastic trace estimation for linear operators. Namely,\ngiven a square linear operator <i>`A`</i>, ``traceax`` provides flexible routines that estimate,\n\n$$\\text{trace}(\\mathbf{A}) = \\sum_i \\mathbf{A}_{ii},$$\n\nusing only matrix-vector products. ``traceax`` is heavily inspired by\n[lineax](https://github.com/patrick-kidger/lineax) as well as\n[XTrace](https://github.com/eepperly/XTrace).\n\n [**Installation**](#installation)\n | [**Example**](#get-started-with-example)\n | [**Documentation**](#documentation)\n | [**Citation**](#citation)\n | [**Notes**](#notes)\n | [**Support**](#support)\n | [**Other Software**](#other-software)\n\n------------------\n\n## Installation\n\nUsers can directly install from `pip`:\n\n``` bash\npip install traceax\n```\n\nOr, users can download the latest repository and then use `pip`:\n\n```\ngit clone https://github.com/mancusolab/traceax.git\ncd traceax\npip install .\n```\n\n## Get Started with Example\n\n```python\nimport jax.numpy as jnp\nimport jax.random as rdm\nimport lineax as lx\n\nimport traceax as tx\n\n# simulate simple symmetric matrix with exponential eigenvalue decay\nseed = 0\nN = 1000\nkey = rdm.PRNGKey(seed)\nkey, xkey = rdm.split(key)\n\nX = rdm.normal(xkey, (N, N))\nQ, R = jnp.linalg.qr(X)\nU = jnp.power(0.7, jnp.arange(N))\nA = (Q * U) @ Q.T\n\n# should be numerically close\nprint(jnp.trace(A)) # 3.3333323\nprint(jnp.sum(U)) # 3.3333335\n\n# setup linear operator\noperator = lx.MatrixLinearOperator(A)\n\n# number of matrix vector operators\nk = 25\n\n# split key for estimators\nkey, key1, key2, key3, key4 = rdm.split(key, 5)\n\n# Hutchinson estimator; default samples Rademacher {-1,+1}\nhutch = tx.HutchinsonEstimator()\nprint(hutch.estimate(key1, operator, k)) # (Array(3.4099615, dtype=float32), {})\n\n# Hutch++ estimator; default samples Rademacher {-1,+1}\nhpp = tx.HutchPlusPlusEstimator()\nprint(hpp.estimate(key2, operator, k)) # (Array(3.3033807, dtype=float32), {})\n\n# XTrace estimator; default samples uniformly on n-Sphere\nxt = tx.XTraceEstimator()\nprint(xt.estimate(key3, operator, k)) # (Array(3.3271673, dtype=float32), {'std.err': Array(0.01717775, dtype=float32)})\n\n# XNysTrace estimator; Improved performance for NSD/PSD trace estimates\noperator = lx.TaggedLinearOperator(operator, lx.positive_semidefinite_tag)\nnt = tx.XNysTraceEstimator()\nprint(nt.estimate(key4, operator, k)) # (Array(3.3297246, dtype=float32), {'std.err': Array(0.00042093, dtype=float32)})\n```\n\n## Documentation\nDocumentation is available at [here](https://mancusolab.github.io/traceax/).\n\n## Citation\nIf you use `traceax` in your work, please cite:\n\n> Nahid, A.A., Serafin, L., Mancuso, N. (2025). <i>traceax</i>: a JAX-based framework for stochastic trace estimation. bioRxiv (https://doi.org/10.1101/2025.07.14.662216)\n\n## Notes\n\n- `traceax` uses [JAX](https://github.com/google/jax) with [Just In\n Time](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html)\n compilation to achieve high-speed computation. However, there are\n some [issues](https://github.com/google/jax/issues/5501) for JAX\n with Mac M1 chip. To solve this, users need to initiate conda using\n [miniforge](https://github.com/conda-forge/miniforge), and then\n install `traceax` using `pip` in the desired environment.\n\n\n## Support\n\nPlease report any bugs or feature requests in the [Issue\nTracker](https://github.com/mancusolab/traceax/issues). If users have\nany questions or comments, please contact Abdullah Al Nahid (<alnahid@usc.edu>) or\nNicholas Mancuso (<nmancuso@usc.edu>).\n\n## Other Software\n\nFeel free to use other software developed by [Mancuso\nLab](https://www.mancusolab.com/):\n\n- [SuShiE](https://github.com/mancusolab/sushie): a Bayesian\n fine-mapping framework for molecular QTL data across multiple\n ancestries.\n- [MA-FOCUS](https://github.com/mancusolab/ma-focus): a Bayesian\n fine-mapping framework using\n [TWAS](https://www.nature.com/articles/ng.3506) statistics across\n multiple ancestries to identify the causal genes for complex traits.\n- [SuSiE-PCA](https://github.com/mancusolab/susiepca): a scalable\n Bayesian variable selection technique for sparse principal component\n analysis\n- [twas_sim](https://github.com/mancusolab/twas_sim): a Python\n software to simulate [TWAS](https://www.nature.com/articles/ng.3506)\n statistics.\n- [FactorGo](https://github.com/mancusolab/factorgo): a scalable\n variational factor analysis model that learns pleiotropic factors\n from GWAS summary statistics.\n- [HAMSTA](https://github.com/tszfungc/hamsta): a Python software to\n estimate heritability explained by local ancestry data from\n admixture mapping summary statistics.\n\n------------------------------------------------------------------------\n\n``traceax`` is distributed under the terms of the\n[Apache-2.0 license](https://spdx.org/licenses/Apache-2.0.html).\n\n\n------------------------------------------------------------------------\n\nThis project has been set up using Hatch. For details and usage\ninformation on Hatch see <https://github.com/pypa/hatch>.\n",
"bugtrack_url": null,
"license": "Apache2.0",
"summary": "Stochastic trace estimation in JAX, Lineax, and Equinox",
"version": "1.0.2",
"project_urls": {
"Documentation": "https://github.com/mancusolab/traceax#readme",
"Issues": "https://github.com/mancusolab/traceax/issues",
"Source": "https://github.com/mancusolab/traceax"
},
"split_keywords": [
"jax",
" machine-learning",
" statistics",
" trace-estimation"
],
"urls": [
{
"comment_text": null,
"digests": {
"blake2b_256": "1f24a4c64549a549b872f87cbb4fa435cfa1c7b9f4c9f25f77939db81808eddb",
"md5": "72a4f60a8ae11c7b9bfd603d68f66340",
"sha256": "b11ac99b8f8fd5f7103d2e3ffb6cf213d7ea1742cbb181a52569ec2ac8161cca"
},
"downloads": -1,
"filename": "traceax-1.0.2-py3-none-any.whl",
"has_sig": false,
"md5_digest": "72a4f60a8ae11c7b9bfd603d68f66340",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": ">=3.9",
"size": 14299,
"upload_time": "2025-07-24T04:40:20",
"upload_time_iso_8601": "2025-07-24T04:40:20.997128Z",
"url": "https://files.pythonhosted.org/packages/1f/24/a4c64549a549b872f87cbb4fa435cfa1c7b9f4c9f25f77939db81808eddb/traceax-1.0.2-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": null,
"digests": {
"blake2b_256": "5b4ada55f878a0a59e87459343bf5449809706ae9d10ccb5400b940aba4494c7",
"md5": "322cb7cff8852b9901f3e6db296ce6d6",
"sha256": "6fa0f319dcea7e2560773055e519b8bf587156ec240243f0a703ce02d3a1c03a"
},
"downloads": -1,
"filename": "traceax-1.0.2.tar.gz",
"has_sig": false,
"md5_digest": "322cb7cff8852b9901f3e6db296ce6d6",
"packagetype": "sdist",
"python_version": "source",
"requires_python": ">=3.9",
"size": 30365,
"upload_time": "2025-07-24T04:40:21",
"upload_time_iso_8601": "2025-07-24T04:40:21.980084Z",
"url": "https://files.pythonhosted.org/packages/5b/4a/da55f878a0a59e87459343bf5449809706ae9d10ccb5400b940aba4494c7/traceax-1.0.2.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2025-07-24 04:40:21",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "mancusolab",
"github_project": "traceax#readme",
"travis_ci": false,
"coveralls": false,
"github_actions": true,
"lcname": "traceax"
}