<p align="center">
<img src="samplex_logo.png" alt="samplex Logo" width="50%" />
</p>
# samplex
Package of useful sampling algorithms written in MLX. We plan on exploring how a combination of unified memory (by exploiting GPU and CPU together) and auto-diff can be used to get highly efficient and robust sampling locally on your Mac.
Please get in touch if you're interested in contributing (tedwards2412@gmail.com and nash.sabti@gmail.com)!
# Installation
```python
pip install samplex
```
# Basic Usage
For a full example, please see the examples folder. Here is the basic structure for linear regression:
```python
from samplex.samplex import samplex
from samplex.samplers import MH_Gaussian_sampler
# First lets generate some data
x = mx.linspace(-5, 5, 20)
err = mx.random.normal(x.shape)
y = b_true * x**2 + m_true * x + c_true + err
# Our target distribution is just a line
def log_target_distribution(theta, data):
m, c, b = theta
x, y, sigma = data
model = b * x**2 + m * x + c
residual = y - model
return sum(-0.5 * (residual**2 / sigma**2))
# The sampler assumes it gets a target distribution with a single input vector theta
logtarget = lambda theta: log_target_distribution(theta, (x, y, err))
# Here are the sampler settings
Nwalkers = 32
Ndim = 3
Nsteps = 10_000
cov_matrix = mx.array([0.01, 0.01, 0.01])
jumping_factor = 1.0
theta0_array = mx.random.uniform(
mx.array([m_min, c_min, b_min]),
mx.array([m_max, c_max, b_max]),
(Nwalkers, Ndim),
)
# Firstly we instantiate a samplex class and then run!
sampler = MH_Gaussian_sampler(logtarget)
sam = samplex(sampler, Nwalkers)
sam.run(Nsteps, theta0_array, cov_matrix, jumping_factor)
```
# Next Steps:
- Get NUTs/HMC running
- Get Ensemble sampler running (emcee)
- Refine plotting
- Add helper functions for variety of priors
- Treating parameters with different update speeds
- Add file of priors and include in target distribution
- Include autocorrelation calculation for steps
Raw data
{
"_id": null,
"home_page": "https://github.com/tedwards2412/samplex",
"name": "samplex",
"maintainer": "",
"docs_url": null,
"requires_python": ">=3.9",
"maintainer_email": "",
"keywords": "sampling,inference,machine learning,autodiff,mlx",
"author": "Thomas Edwards, Nashwan Sabti",
"author_email": "tedwards2412@gmail.com",
"download_url": "https://files.pythonhosted.org/packages/1d/6e/a4360305cae530cc41b83a2800ac47a0889118cd67d0978c0fd910513c2c/samplex-0.0.2.tar.gz",
"platform": null,
"description": "<p align=\"center\">\n <img src=\"samplex_logo.png\" alt=\"samplex Logo\" width=\"50%\" />\n</p>\n\n# samplex\nPackage of useful sampling algorithms written in MLX. We plan on exploring how a combination of unified memory (by exploiting GPU and CPU together) and auto-diff can be used to get highly efficient and robust sampling locally on your Mac.\n\nPlease get in touch if you're interested in contributing (tedwards2412@gmail.com and nash.sabti@gmail.com)!\n\n# Installation\n\n```python\npip install samplex\n```\n\n# Basic Usage\n\nFor a full example, please see the examples folder. Here is the basic structure for linear regression:\n\n```python\nfrom samplex.samplex import samplex\nfrom samplex.samplers import MH_Gaussian_sampler\n\n# First lets generate some data\nx = mx.linspace(-5, 5, 20)\nerr = mx.random.normal(x.shape)\ny = b_true * x**2 + m_true * x + c_true + err\n\n\n# Our target distribution is just a line\ndef log_target_distribution(theta, data):\n m, c, b = theta\n x, y, sigma = data\n model = b * x**2 + m * x + c\n residual = y - model\n return sum(-0.5 * (residual**2 / sigma**2))\n\n# The sampler assumes it gets a target distribution with a single input vector theta\nlogtarget = lambda theta: log_target_distribution(theta, (x, y, err))\n\n# Here are the sampler settings\nNwalkers = 32\nNdim = 3\nNsteps = 10_000\ncov_matrix = mx.array([0.01, 0.01, 0.01])\njumping_factor = 1.0\n\ntheta0_array = mx.random.uniform(\n mx.array([m_min, c_min, b_min]),\n mx.array([m_max, c_max, b_max]),\n (Nwalkers, Ndim),\n)\n\n# Firstly we instantiate a samplex class and then run!\nsampler = MH_Gaussian_sampler(logtarget)\nsam = samplex(sampler, Nwalkers)\nsam.run(Nsteps, theta0_array, cov_matrix, jumping_factor)\n```\n\n# Next Steps:\n\n- Get NUTs/HMC running\n- Get Ensemble sampler running (emcee)\n- Refine plotting\n- Add helper functions for variety of priors\n- Treating parameters with different update speeds\n- Add file of priors and include in target distribution\n- Include autocorrelation calculation for steps\n",
"bugtrack_url": null,
"license": "MIT",
"summary": "Samplers in MLX",
"version": "0.0.2",
"project_urls": {
"Homepage": "https://github.com/tedwards2412/samplex"
},
"split_keywords": [
"sampling",
"inference",
"machine learning",
"autodiff",
"mlx"
],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "87e1d2c426aea56ea9a170e4434b0c3a8cb05360e4ff54e4324ae983a705f8e6",
"md5": "aaa555bb8852bf2bd5e2cc119cbf51b1",
"sha256": "9247912cacd545bef373ab9f380dfec619da15d2323333bf66f6cfd6a7f48b4d"
},
"downloads": -1,
"filename": "samplex-0.0.2-py3-none-any.whl",
"has_sig": false,
"md5_digest": "aaa555bb8852bf2bd5e2cc119cbf51b1",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": ">=3.9",
"size": 9102,
"upload_time": "2024-02-02T21:44:40",
"upload_time_iso_8601": "2024-02-02T21:44:40.625802Z",
"url": "https://files.pythonhosted.org/packages/87/e1/d2c426aea56ea9a170e4434b0c3a8cb05360e4ff54e4324ae983a705f8e6/samplex-0.0.2-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": "",
"digests": {
"blake2b_256": "1d6ea4360305cae530cc41b83a2800ac47a0889118cd67d0978c0fd910513c2c",
"md5": "b95197f26491ec3eaca55dda997d66a1",
"sha256": "903112a33dbcab71d3f1c37a3cb81a4dd023aa126f35e43ca78a9660e2e7281c"
},
"downloads": -1,
"filename": "samplex-0.0.2.tar.gz",
"has_sig": false,
"md5_digest": "b95197f26491ec3eaca55dda997d66a1",
"packagetype": "sdist",
"python_version": "source",
"requires_python": ">=3.9",
"size": 9396,
"upload_time": "2024-02-02T21:44:41",
"upload_time_iso_8601": "2024-02-02T21:44:41.686542Z",
"url": "https://files.pythonhosted.org/packages/1d/6e/a4360305cae530cc41b83a2800ac47a0889118cd67d0978c0fd910513c2c/samplex-0.0.2.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2024-02-02 21:44:41",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "tedwards2412",
"github_project": "samplex",
"travis_ci": false,
"coveralls": false,
"github_actions": false,
"lcname": "samplex"
}