scikit-jax


Namescikit-jax JSON
Version 0.0.2 PyPI version JSON
download
home_pagehttps://github.com/LiibanMo/scikit-jax
SummaryClassical machine learning algorithms on the GPU/TPU.
upload_time2024-09-15 16:44:39
maintainerNone
docs_urlNone
authorLiiban Mohamud
requires_python>=3.9
licenseNone
keywords jax classical machine learning
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            <p align="center">
  <img src="assets/logo.png" alt="Alt text"/>
</p>

# Scikit-JAX: Classical Machine Learning on the GPU

Welcome to **Scikit-JAX**, a machine learning library designed to leverage the power of GPUs through JAX for efficient and scalable classical machine learning algorithms. Our library provides implementations for a variety of classical machine learning techniques, optimized for performance and ease of use.

## Features

- **Linear Regression**: Implemented with options for different weight initialization methods and dropout regularization.
- **KMeans**: Clustering algorithm to group data points into clusters.
- **Principal Component Analysis (PCA)**: Dimensionality reduction technique to simplify data while preserving essential features.
- **Multinomial Naive Bayes**: Classifier suitable for discrete data, such as text classification tasks.
- **Gaussian Naive Bayes**: Classifier for continuous data with a normal distribution assumption.

## Installation

To install Scikit-JAX, you can use pip. The package is available on PyPI:

```python
pip install scikit-jax
```

## Usage

Here is a quick guide on how to use the key components of Scikit-JAX.

### Linear Regression
```py
from skjax.linear_model import LinearRegression

# Initialize the model
model = LinearRegression(weights_init='xavier', epochs=100, learning_rate=0.01)

# Fit the model
model.fit(X_train, y_train)

# Make predictions
predictions = model.predict(X_test)

# Plot losses
model.plot_losses()
```

### K-Means
```python
from skjax.clustering import KMeans

# Initialize the model
kmeans = KMeans(num_clusters=3)

# Fit the model
kmeans.fit(X_train)
```

### Gaussian Naive Bayes
```python
from skjax.naive_bayes import GaussianNaiveBayes

# Initialize the model
nb = GaussianNaiveBayes()

# Fit the model
nb.fit(X_train, y_train)

# Make predictions
predictions = nb.predict(X_test)
```

### License

Scikit-JAX is licensed under the [MIT License](LICENSE.txt).

            

Raw data

            {
    "_id": null,
    "home_page": "https://github.com/LiibanMo/scikit-jax",
    "name": "scikit-jax",
    "maintainer": null,
    "docs_url": null,
    "requires_python": ">=3.9",
    "maintainer_email": null,
    "keywords": "jax classical machine learning",
    "author": "Liiban Mohamud",
    "author_email": "liibanmohamud12@gmail.com",
    "download_url": "https://files.pythonhosted.org/packages/90/5e/58e5b2a5622f46cfc102dcdf994570d1a47d952a4c46a2966bbe7cdf413e/scikit_jax-0.0.2.tar.gz",
    "platform": null,
    "description": "<p align=\"center\">\n  <img src=\"assets/logo.png\" alt=\"Alt text\"/>\n</p>\n\n# Scikit-JAX: Classical Machine Learning on the GPU\n\nWelcome to **Scikit-JAX**, a machine learning library designed to leverage the power of GPUs through JAX for efficient and scalable classical machine learning algorithms. Our library provides implementations for a variety of classical machine learning techniques, optimized for performance and ease of use.\n\n## Features\n\n- **Linear Regression**: Implemented with options for different weight initialization methods and dropout regularization.\n- **KMeans**: Clustering algorithm to group data points into clusters.\n- **Principal Component Analysis (PCA)**: Dimensionality reduction technique to simplify data while preserving essential features.\n- **Multinomial Naive Bayes**: Classifier suitable for discrete data, such as text classification tasks.\n- **Gaussian Naive Bayes**: Classifier for continuous data with a normal distribution assumption.\n\n## Installation\n\nTo install Scikit-JAX, you can use pip. The package is available on PyPI:\n\n```python\npip install scikit-jax\n```\n\n## Usage\n\nHere is a quick guide on how to use the key components of Scikit-JAX.\n\n### Linear Regression\n```py\nfrom skjax.linear_model import LinearRegression\n\n# Initialize the model\nmodel = LinearRegression(weights_init='xavier', epochs=100, learning_rate=0.01)\n\n# Fit the model\nmodel.fit(X_train, y_train)\n\n# Make predictions\npredictions = model.predict(X_test)\n\n# Plot losses\nmodel.plot_losses()\n```\n\n### K-Means\n```python\nfrom skjax.clustering import KMeans\n\n# Initialize the model\nkmeans = KMeans(num_clusters=3)\n\n# Fit the model\nkmeans.fit(X_train)\n```\n\n### Gaussian Naive Bayes\n```python\nfrom skjax.naive_bayes import GaussianNaiveBayes\n\n# Initialize the model\nnb = GaussianNaiveBayes()\n\n# Fit the model\nnb.fit(X_train, y_train)\n\n# Make predictions\npredictions = nb.predict(X_test)\n```\n\n### License\n\nScikit-JAX is licensed under the [MIT License](LICENSE.txt).\n",
    "bugtrack_url": null,
    "license": null,
    "summary": "Classical machine learning algorithms on the GPU/TPU.",
    "version": "0.0.2",
    "project_urls": {
        "Homepage": "https://github.com/LiibanMo/scikit-jax"
    },
    "split_keywords": [
        "jax",
        "classical",
        "machine",
        "learning"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "c66498e0f28a2fb083e3efbd4b124bfe80cb15ce90761d30dd1ac5e9d291d9a1",
                "md5": "2ade95056621c4ed4fac75b8137bec84",
                "sha256": "b136f34db60826b544007d0423e9c47c17bd6272523cdbd3e9783d63a206169f"
            },
            "downloads": -1,
            "filename": "scikit_jax-0.0.2-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "2ade95056621c4ed4fac75b8137bec84",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.9",
            "size": 19505,
            "upload_time": "2024-09-15T16:44:38",
            "upload_time_iso_8601": "2024-09-15T16:44:38.250021Z",
            "url": "https://files.pythonhosted.org/packages/c6/64/98e0f28a2fb083e3efbd4b124bfe80cb15ce90761d30dd1ac5e9d291d9a1/scikit_jax-0.0.2-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "905e58e5b2a5622f46cfc102dcdf994570d1a47d952a4c46a2966bbe7cdf413e",
                "md5": "b6a6c9436eb01be4069c9089a361e99d",
                "sha256": "333e1c6ec3680a803afb5ae0561dfa05973d265b30151429f72729ce45582d37"
            },
            "downloads": -1,
            "filename": "scikit_jax-0.0.2.tar.gz",
            "has_sig": false,
            "md5_digest": "b6a6c9436eb01be4069c9089a361e99d",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.9",
            "size": 13717,
            "upload_time": "2024-09-15T16:44:39",
            "upload_time_iso_8601": "2024-09-15T16:44:39.879863Z",
            "url": "https://files.pythonhosted.org/packages/90/5e/58e5b2a5622f46cfc102dcdf994570d1a47d952a4c46a2966bbe7cdf413e/scikit_jax-0.0.2.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2024-09-15 16:44:39",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "LiibanMo",
    "github_project": "scikit-jax",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "requirements": [],
    "lcname": "scikit-jax"
}
        
Elapsed time: 0.38128s