samplex


Namesamplex JSON
Version 0.0.2 PyPI version JSON
download
home_pagehttps://github.com/tedwards2412/samplex
SummarySamplers in MLX
upload_time2024-02-02 21:44:41
maintainer
docs_urlNone
authorThomas Edwards, Nashwan Sabti
requires_python>=3.9
licenseMIT
keywords sampling inference machine learning autodiff mlx
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            <p align="center">
  <img src="samplex_logo.png" alt="samplex Logo" width="50%" />
</p>

# samplex
Package of useful sampling algorithms written in MLX. We plan on exploring how a combination of unified memory (by exploiting GPU and CPU together) and auto-diff can be used to get highly efficient and robust sampling locally on your Mac.

Please get in touch if you're interested in contributing (tedwards2412@gmail.com and nash.sabti@gmail.com)!

# Installation

```python
pip install samplex
```

# Basic Usage

For a full example, please see the examples folder. Here is the basic structure for linear regression:

```python
from samplex.samplex import samplex
from samplex.samplers import MH_Gaussian_sampler

# First lets generate some data
x = mx.linspace(-5, 5, 20)
err = mx.random.normal(x.shape)
y = b_true * x**2 + m_true * x + c_true + err


# Our target distribution is just a line
def log_target_distribution(theta, data):
    m, c, b = theta
    x, y, sigma = data
    model = b * x**2 + m * x + c
    residual = y - model
    return sum(-0.5 * (residual**2 / sigma**2))

# The sampler assumes it gets a target distribution with a single input vector theta
logtarget = lambda theta: log_target_distribution(theta, (x, y, err))

# Here are the sampler settings
Nwalkers = 32
Ndim = 3
Nsteps = 10_000
cov_matrix = mx.array([0.01, 0.01, 0.01])
jumping_factor = 1.0

theta0_array = mx.random.uniform(
    mx.array([m_min, c_min, b_min]),
    mx.array([m_max, c_max, b_max]),
    (Nwalkers, Ndim),
)

# Firstly we instantiate a samplex class and then run!
sampler = MH_Gaussian_sampler(logtarget)
sam = samplex(sampler, Nwalkers)
sam.run(Nsteps, theta0_array, cov_matrix, jumping_factor)
```

# Next Steps:

- Get NUTs/HMC running
- Get Ensemble sampler running (emcee)
- Refine plotting
- Add helper functions for variety of priors
- Treating parameters with different update speeds
- Add file of priors and include in target distribution
- Include autocorrelation calculation for steps

            

Raw data

            {
    "_id": null,
    "home_page": "https://github.com/tedwards2412/samplex",
    "name": "samplex",
    "maintainer": "",
    "docs_url": null,
    "requires_python": ">=3.9",
    "maintainer_email": "",
    "keywords": "sampling,inference,machine learning,autodiff,mlx",
    "author": "Thomas Edwards, Nashwan Sabti",
    "author_email": "tedwards2412@gmail.com",
    "download_url": "https://files.pythonhosted.org/packages/1d/6e/a4360305cae530cc41b83a2800ac47a0889118cd67d0978c0fd910513c2c/samplex-0.0.2.tar.gz",
    "platform": null,
    "description": "<p align=\"center\">\n  <img src=\"samplex_logo.png\" alt=\"samplex Logo\" width=\"50%\" />\n</p>\n\n# samplex\nPackage of useful sampling algorithms written in MLX. We plan on exploring how a combination of unified memory (by exploiting GPU and CPU together) and auto-diff can be used to get highly efficient and robust sampling locally on your Mac.\n\nPlease get in touch if you're interested in contributing (tedwards2412@gmail.com and nash.sabti@gmail.com)!\n\n# Installation\n\n```python\npip install samplex\n```\n\n# Basic Usage\n\nFor a full example, please see the examples folder. Here is the basic structure for linear regression:\n\n```python\nfrom samplex.samplex import samplex\nfrom samplex.samplers import MH_Gaussian_sampler\n\n# First lets generate some data\nx = mx.linspace(-5, 5, 20)\nerr = mx.random.normal(x.shape)\ny = b_true * x**2 + m_true * x + c_true + err\n\n\n# Our target distribution is just a line\ndef log_target_distribution(theta, data):\n    m, c, b = theta\n    x, y, sigma = data\n    model = b * x**2 + m * x + c\n    residual = y - model\n    return sum(-0.5 * (residual**2 / sigma**2))\n\n# The sampler assumes it gets a target distribution with a single input vector theta\nlogtarget = lambda theta: log_target_distribution(theta, (x, y, err))\n\n# Here are the sampler settings\nNwalkers = 32\nNdim = 3\nNsteps = 10_000\ncov_matrix = mx.array([0.01, 0.01, 0.01])\njumping_factor = 1.0\n\ntheta0_array = mx.random.uniform(\n    mx.array([m_min, c_min, b_min]),\n    mx.array([m_max, c_max, b_max]),\n    (Nwalkers, Ndim),\n)\n\n# Firstly we instantiate a samplex class and then run!\nsampler = MH_Gaussian_sampler(logtarget)\nsam = samplex(sampler, Nwalkers)\nsam.run(Nsteps, theta0_array, cov_matrix, jumping_factor)\n```\n\n# Next Steps:\n\n- Get NUTs/HMC running\n- Get Ensemble sampler running (emcee)\n- Refine plotting\n- Add helper functions for variety of priors\n- Treating parameters with different update speeds\n- Add file of priors and include in target distribution\n- Include autocorrelation calculation for steps\n",
    "bugtrack_url": null,
    "license": "MIT",
    "summary": "Samplers in MLX",
    "version": "0.0.2",
    "project_urls": {
        "Homepage": "https://github.com/tedwards2412/samplex"
    },
    "split_keywords": [
        "sampling",
        "inference",
        "machine learning",
        "autodiff",
        "mlx"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "87e1d2c426aea56ea9a170e4434b0c3a8cb05360e4ff54e4324ae983a705f8e6",
                "md5": "aaa555bb8852bf2bd5e2cc119cbf51b1",
                "sha256": "9247912cacd545bef373ab9f380dfec619da15d2323333bf66f6cfd6a7f48b4d"
            },
            "downloads": -1,
            "filename": "samplex-0.0.2-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "aaa555bb8852bf2bd5e2cc119cbf51b1",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.9",
            "size": 9102,
            "upload_time": "2024-02-02T21:44:40",
            "upload_time_iso_8601": "2024-02-02T21:44:40.625802Z",
            "url": "https://files.pythonhosted.org/packages/87/e1/d2c426aea56ea9a170e4434b0c3a8cb05360e4ff54e4324ae983a705f8e6/samplex-0.0.2-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "1d6ea4360305cae530cc41b83a2800ac47a0889118cd67d0978c0fd910513c2c",
                "md5": "b95197f26491ec3eaca55dda997d66a1",
                "sha256": "903112a33dbcab71d3f1c37a3cb81a4dd023aa126f35e43ca78a9660e2e7281c"
            },
            "downloads": -1,
            "filename": "samplex-0.0.2.tar.gz",
            "has_sig": false,
            "md5_digest": "b95197f26491ec3eaca55dda997d66a1",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.9",
            "size": 9396,
            "upload_time": "2024-02-02T21:44:41",
            "upload_time_iso_8601": "2024-02-02T21:44:41.686542Z",
            "url": "https://files.pythonhosted.org/packages/1d/6e/a4360305cae530cc41b83a2800ac47a0889118cd67d0978c0fd910513c2c/samplex-0.0.2.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2024-02-02 21:44:41",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "tedwards2412",
    "github_project": "samplex",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": false,
    "lcname": "samplex"
}
        
Elapsed time: 0.18424s