tf-kan-latest


Nametf-kan-latest JSON
Version 1.0.9 PyPI version JSON
download
home_pagehttps://github.com/sathyasubrahmanya/tf-kan
SummaryA Keras-native implementation of Kolmogorov-Arnold Networks (KANs) for TensorFlow.
upload_time2025-10-24 12:28:04
maintainerNone
docs_urlNone
authorSathyasubrahmanya v S
requires_python>=3.8
licenseNone
keywords tensorflow keras kan kolmogorov-arnold neural-networks machine-learning
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            

# TF-KAN: Kolmogorov-Arnold Networks for TensorFlow

[](https://www.google.com/search?q=https://badge.fury.io/py/tf-kan)
[](https://www.google.com/search?q=https://travis-ci.org/your-username/tf-kan)
[](https://opensource.org/licenses/MIT)

A Keras-native, high-performance implementation of **Kolmogorov-Arnold Networks (KANs)** for **TensorFlow 2.19+**.

This library provides easy-to-use Keras layers that replace standard linear transformations with learnable B-spline activation functions, allowing for more expressive and interpretable models.

-----

## Key Features

  * **🧠 Learnable Activations**: Goes beyond fixed activation functions like ReLU or SiLU by learning complex, data-driven activations on each weight.
  * **🧩 Seamless Keras Integration**: Use `DenseKAN` and `Conv*DKAN` layers as direct, drop-in replacements for standard Keras layers.
  * **âš¡ High Performance**: Core mathematical operations are compiled into static graphs with `@tf.function` for maximum speed.
  * **🔄 Adaptive Grids**: Dynamically update spline resolutions based on data, allowing the model to allocate its parameters more effectively.
  * **💾 Modern Serialization**: Save and load models containing KAN layers with `model.save()` and `tf.keras.models.load_model()`—no `custom_objects` needed.

-----

## Installation

```bash
pip install tf-kan
```

-----

## Core Concepts

In a traditional neural network, a connection is a single weight (`w`). In a KAN, each connection is a learnable 1D function (a **B-spline**), like a smart dimmer switch that can apply a complex curve to the input signal.

You control these functions with two hyperparameters:

  * **`grid_size`**: The resolution of the function. A larger size allows for more complex, "wiggly" functions.
  * **`spline_order`**: The smoothness of the function. An order of 3 (cubic) is recommended for smooth curves.

-----

## Examples

Here are several examples demonstrating how to use `tfkan` for different tasks.

### 1\. Basic Regression

This example builds a simple model to learn a 1D function, showcasing the `DenseKAN` layer.

```python
import tensorflow as tf
import numpy as np
from tfkan.layers import DenseKAN

# 1. Generate some synthetic data for y = sin(pi*x)
x_train = np.linspace(-1, 1, 100)[:, np.newaxis]
y_train = np.sin(np.pi * x_train)

# 2. Build the KAN model
# A small model is enough to learn this simple function
model = tf.keras.Sequential([
    tf.keras.layers.Input(shape=(1,)),
    DenseKAN(units=16, grid_size=8, spline_order=3, name='kan_layer_1'),
    DenseKAN(units=1, name='kan_output')
])

# 3. Compile and train
model.compile(optimizer='adam', loss='mean_squared_error')
print("--- Training a simple regressor ---")
model.fit(x_train, y_train, epochs=50, verbose=0)

# 4. Test the model
print("--- Prediction ---")
test_input = tf.constant([[0.5]]) # sin(pi * 0.5) = 1.0
prediction = model.predict(test_input)
print(f"Model prediction for input 0.5: {prediction[0][0]:.4f}")
model.summary()
```

### 2\. Image Classification (Hybrid CNN)

Mix standard Keras layers with `Conv2DKAN` and `DenseKAN` to build a powerful hybrid classifier.

```python
import tensorflow as tf
from tfkan.layers import Conv2DKAN, DenseKAN

# 1. Load a dataset (using dummy data here)
(x_train, y_train), _ = tf.keras.datasets.cifar10.load_data()
x_train = x_train.astype('float32') / 255.0

# 2. Build the hybrid model
model = tf.keras.Sequential([
    tf.keras.layers.Input(shape=(32, 32, 3)),

    # Standard Conv block
    tf.keras.layers.Conv2D(32, kernel_size=3, padding='same', activation='relu'),
    tf.keras.layers.MaxPooling2D(),

    # KAN Conv block with specific KAN arguments
    Conv2DKAN(
        filters=64,
        kernel_size=3,
        padding='same',
        name='kan_conv',
        kan_kwargs={'grid_size': 5, 'spline_order': 3}
    ),
    tf.keras.layers.GlobalAveragePooling2D(),

    # KAN Dense layers for final classification
    DenseKAN(units=128, grid_size=8, name='kan_dense'),
    tf.keras.layers.Dense(units=10, name='output_logits') # Standard output layer
])

# 3. Compile and train
model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
)
print("\n--- Training a hybrid CNN for image classification ---")
# model.fit(x_train, y_train, epochs=1, batch_size=64) # Uncomment to train
model.summary()
```

### 3\. Advanced Usage: Adaptive Grid Updates

KANs can dynamically update their internal grids to better fit the data distribution. This is useful for refining a pre-trained model.

```python
import tensorflow as tf
import numpy as np
from tfkan.layers import DenseKAN

# 1. Create a model and a sample data batch
model = tf.keras.Sequential([DenseKAN(16, grid_size=5, name='my_kan_layer', input_shape=(32,))])
sample_data = np.random.randn(100, 32).astype('float32')

# 2. Get the KAN layer from the model
kan_layer = model.get_layer('my_kan_layer')
print(f"Initial grid size: {kan_layer.grid_size}")

# 3. Update the grid based on the sample data
# This re-calculates knot locations to better cover the data's features
print("Updating grid from samples...")
kan_layer.update_grid_from_samples(sample_data)
print("Grid updated successfully.")

# 4. You can also extend the grid to a higher resolution
print("Extending grid to a larger size...")
try:
    kan_layer.extend_grid_from_samples(sample_data, extend_grid_size=10)
    print(f"Grid extended successfully. New grid size: {kan_layer.grid_size}")
except Exception as e:
    print(f"Error during extension: {e}")

```

### 4\. Time Series Forecasting

Use `Conv1DKAN` to find complex temporal patterns in sequential data.

```python
import tensorflow as tf
from tfkan.layers import Conv1DKAN, DenseKAN

# 1. Define model parameters for a time series task
lookback_window = 20  # Number of past time steps to use as input
num_features = 5      # Number of features at each time step
num_classes = 3       # Number of output classes

# 2. Build a model for sequence classification
model = tf.keras.Sequential([
    tf.keras.layers.Input(shape=(lookback_window, num_features)),
    
    # 1D KAN convolution to extract temporal features
    Conv1DKAN(
        filters=32,
        kernel_size=3,
        kan_kwargs={'grid_size': 8}
    ),
    tf.keras.layers.GlobalAveragePooling1D(),
    
    # Dense KAN layers for classification
    DenseKAN(64),
    tf.keras.layers.Dense(num_classes)
])

# 3. Compile the model
model.compile(optimizer='adam', loss='categorical_crossentropy')
print("\n--- Time Series Model ---")
model.summary()
```

-----

## Contributing

Contributions are welcome\! Please feel free to submit a pull request or open an issue.

## License

This project is licensed under the MIT License.

            

Raw data

            {
    "_id": null,
    "home_page": "https://github.com/sathyasubrahmanya/tf-kan",
    "name": "tf-kan-latest",
    "maintainer": null,
    "docs_url": null,
    "requires_python": ">=3.8",
    "maintainer_email": null,
    "keywords": "tensorflow, keras, kan, kolmogorov-arnold, neural-networks, machine-learning",
    "author": "Sathyasubrahmanya v S",
    "author_email": "sathyapel0005@gmail.com",
    "download_url": "https://files.pythonhosted.org/packages/b6/6f/3491f05b30ffd499a1e93017d6d956455e81525209005652b3377b3f1141/tf_kan_latest-1.0.9.tar.gz",
    "platform": null,
    "description": "\r\n\r\n# TF-KAN: Kolmogorov-Arnold Networks for TensorFlow\r\n\r\n[](https://www.google.com/search?q=https://badge.fury.io/py/tf-kan)\r\n[](https://www.google.com/search?q=https://travis-ci.org/your-username/tf-kan)\r\n[](https://opensource.org/licenses/MIT)\r\n\r\nA Keras-native, high-performance implementation of **Kolmogorov-Arnold Networks (KANs)** for **TensorFlow 2.19+**.\r\n\r\nThis library provides easy-to-use Keras layers that replace standard linear transformations with learnable B-spline activation functions, allowing for more expressive and interpretable models.\r\n\r\n-----\r\n\r\n## Key Features\r\n\r\n  * **\ud83e\udde0 Learnable Activations**: Goes beyond fixed activation functions like ReLU or SiLU by learning complex, data-driven activations on each weight.\r\n  * **\ud83e\udde9 Seamless Keras Integration**: Use `DenseKAN` and `Conv*DKAN` layers as direct, drop-in replacements for standard Keras layers.\r\n  * **\u26a1 High Performance**: Core mathematical operations are compiled into static graphs with `@tf.function` for maximum speed.\r\n  * **\ud83d\udd04 Adaptive Grids**: Dynamically update spline resolutions based on data, allowing the model to allocate its parameters more effectively.\r\n  * **\ud83d\udcbe Modern Serialization**: Save and load models containing KAN layers with `model.save()` and `tf.keras.models.load_model()`\u2014no `custom_objects` needed.\r\n\r\n-----\r\n\r\n## Installation\r\n\r\n```bash\r\npip install tf-kan\r\n```\r\n\r\n-----\r\n\r\n## Core Concepts\r\n\r\nIn a traditional neural network, a connection is a single weight (`w`). In a KAN, each connection is a learnable 1D function (a **B-spline**), like a smart dimmer switch that can apply a complex curve to the input signal.\r\n\r\nYou control these functions with two hyperparameters:\r\n\r\n  * **`grid_size`**: The resolution of the function. A larger size allows for more complex, \"wiggly\" functions.\r\n  * **`spline_order`**: The smoothness of the function. An order of 3 (cubic) is recommended for smooth curves.\r\n\r\n-----\r\n\r\n## Examples\r\n\r\nHere are several examples demonstrating how to use `tfkan` for different tasks.\r\n\r\n### 1\\. Basic Regression\r\n\r\nThis example builds a simple model to learn a 1D function, showcasing the `DenseKAN` layer.\r\n\r\n```python\r\nimport tensorflow as tf\r\nimport numpy as np\r\nfrom tfkan.layers import DenseKAN\r\n\r\n# 1. Generate some synthetic data for y = sin(pi*x)\r\nx_train = np.linspace(-1, 1, 100)[:, np.newaxis]\r\ny_train = np.sin(np.pi * x_train)\r\n\r\n# 2. Build the KAN model\r\n# A small model is enough to learn this simple function\r\nmodel = tf.keras.Sequential([\r\n    tf.keras.layers.Input(shape=(1,)),\r\n    DenseKAN(units=16, grid_size=8, spline_order=3, name='kan_layer_1'),\r\n    DenseKAN(units=1, name='kan_output')\r\n])\r\n\r\n# 3. Compile and train\r\nmodel.compile(optimizer='adam', loss='mean_squared_error')\r\nprint(\"--- Training a simple regressor ---\")\r\nmodel.fit(x_train, y_train, epochs=50, verbose=0)\r\n\r\n# 4. Test the model\r\nprint(\"--- Prediction ---\")\r\ntest_input = tf.constant([[0.5]]) # sin(pi * 0.5) = 1.0\r\nprediction = model.predict(test_input)\r\nprint(f\"Model prediction for input 0.5: {prediction[0][0]:.4f}\")\r\nmodel.summary()\r\n```\r\n\r\n### 2\\. Image Classification (Hybrid CNN)\r\n\r\nMix standard Keras layers with `Conv2DKAN` and `DenseKAN` to build a powerful hybrid classifier.\r\n\r\n```python\r\nimport tensorflow as tf\r\nfrom tfkan.layers import Conv2DKAN, DenseKAN\r\n\r\n# 1. Load a dataset (using dummy data here)\r\n(x_train, y_train), _ = tf.keras.datasets.cifar10.load_data()\r\nx_train = x_train.astype('float32') / 255.0\r\n\r\n# 2. Build the hybrid model\r\nmodel = tf.keras.Sequential([\r\n    tf.keras.layers.Input(shape=(32, 32, 3)),\r\n\r\n    # Standard Conv block\r\n    tf.keras.layers.Conv2D(32, kernel_size=3, padding='same', activation='relu'),\r\n    tf.keras.layers.MaxPooling2D(),\r\n\r\n    # KAN Conv block with specific KAN arguments\r\n    Conv2DKAN(\r\n        filters=64,\r\n        kernel_size=3,\r\n        padding='same',\r\n        name='kan_conv',\r\n        kan_kwargs={'grid_size': 5, 'spline_order': 3}\r\n    ),\r\n    tf.keras.layers.GlobalAveragePooling2D(),\r\n\r\n    # KAN Dense layers for final classification\r\n    DenseKAN(units=128, grid_size=8, name='kan_dense'),\r\n    tf.keras.layers.Dense(units=10, name='output_logits') # Standard output layer\r\n])\r\n\r\n# 3. Compile and train\r\nmodel.compile(\r\n    optimizer=tf.keras.optimizers.Adam(),\r\n    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\r\n    metrics=['accuracy']\r\n)\r\nprint(\"\\n--- Training a hybrid CNN for image classification ---\")\r\n# model.fit(x_train, y_train, epochs=1, batch_size=64) # Uncomment to train\r\nmodel.summary()\r\n```\r\n\r\n### 3\\. Advanced Usage: Adaptive Grid Updates\r\n\r\nKANs can dynamically update their internal grids to better fit the data distribution. This is useful for refining a pre-trained model.\r\n\r\n```python\r\nimport tensorflow as tf\r\nimport numpy as np\r\nfrom tfkan.layers import DenseKAN\r\n\r\n# 1. Create a model and a sample data batch\r\nmodel = tf.keras.Sequential([DenseKAN(16, grid_size=5, name='my_kan_layer', input_shape=(32,))])\r\nsample_data = np.random.randn(100, 32).astype('float32')\r\n\r\n# 2. Get the KAN layer from the model\r\nkan_layer = model.get_layer('my_kan_layer')\r\nprint(f\"Initial grid size: {kan_layer.grid_size}\")\r\n\r\n# 3. Update the grid based on the sample data\r\n# This re-calculates knot locations to better cover the data's features\r\nprint(\"Updating grid from samples...\")\r\nkan_layer.update_grid_from_samples(sample_data)\r\nprint(\"Grid updated successfully.\")\r\n\r\n# 4. You can also extend the grid to a higher resolution\r\nprint(\"Extending grid to a larger size...\")\r\ntry:\r\n    kan_layer.extend_grid_from_samples(sample_data, extend_grid_size=10)\r\n    print(f\"Grid extended successfully. New grid size: {kan_layer.grid_size}\")\r\nexcept Exception as e:\r\n    print(f\"Error during extension: {e}\")\r\n\r\n```\r\n\r\n### 4\\. Time Series Forecasting\r\n\r\nUse `Conv1DKAN` to find complex temporal patterns in sequential data.\r\n\r\n```python\r\nimport tensorflow as tf\r\nfrom tfkan.layers import Conv1DKAN, DenseKAN\r\n\r\n# 1. Define model parameters for a time series task\r\nlookback_window = 20  # Number of past time steps to use as input\r\nnum_features = 5      # Number of features at each time step\r\nnum_classes = 3       # Number of output classes\r\n\r\n# 2. Build a model for sequence classification\r\nmodel = tf.keras.Sequential([\r\n    tf.keras.layers.Input(shape=(lookback_window, num_features)),\r\n    \r\n    # 1D KAN convolution to extract temporal features\r\n    Conv1DKAN(\r\n        filters=32,\r\n        kernel_size=3,\r\n        kan_kwargs={'grid_size': 8}\r\n    ),\r\n    tf.keras.layers.GlobalAveragePooling1D(),\r\n    \r\n    # Dense KAN layers for classification\r\n    DenseKAN(64),\r\n    tf.keras.layers.Dense(num_classes)\r\n])\r\n\r\n# 3. Compile the model\r\nmodel.compile(optimizer='adam', loss='categorical_crossentropy')\r\nprint(\"\\n--- Time Series Model ---\")\r\nmodel.summary()\r\n```\r\n\r\n-----\r\n\r\n## Contributing\r\n\r\nContributions are welcome\\! Please feel free to submit a pull request or open an issue.\r\n\r\n## License\r\n\r\nThis project is licensed under the MIT License.\r\n",
    "bugtrack_url": null,
    "license": null,
    "summary": "A Keras-native implementation of Kolmogorov-Arnold Networks (KANs) for TensorFlow.",
    "version": "1.0.9",
    "project_urls": {
        "Homepage": "https://github.com/sathyasubrahmanya/tf-kan"
    },
    "split_keywords": [
        "tensorflow",
        " keras",
        " kan",
        " kolmogorov-arnold",
        " neural-networks",
        " machine-learning"
    ],
    "urls": [
        {
            "comment_text": null,
            "digests": {
                "blake2b_256": "1c75b6d0b060a61d0bd77ccd0500f27d7f3c69e8a40de4bcff4b7619dddace86",
                "md5": "49d5f264b5d294a51632b3edd567075c",
                "sha256": "19d2e85d7388acf8c2f4ebbc6702853624e2511bfe07414b7e27dd9dd8b19fbc"
            },
            "downloads": -1,
            "filename": "tf_kan_latest-1.0.9-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "49d5f264b5d294a51632b3edd567075c",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.8",
            "size": 18803,
            "upload_time": "2025-10-24T12:28:03",
            "upload_time_iso_8601": "2025-10-24T12:28:03.197822Z",
            "url": "https://files.pythonhosted.org/packages/1c/75/b6d0b060a61d0bd77ccd0500f27d7f3c69e8a40de4bcff4b7619dddace86/tf_kan_latest-1.0.9-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": null,
            "digests": {
                "blake2b_256": "b66f3491f05b30ffd499a1e93017d6d956455e81525209005652b3377b3f1141",
                "md5": "c66554c7e56a4facdbbb10ea55fd1102",
                "sha256": "ef08610662c1680a15561ab653d337e1400343fbd1aa09af5d5072d5d6912951"
            },
            "downloads": -1,
            "filename": "tf_kan_latest-1.0.9.tar.gz",
            "has_sig": false,
            "md5_digest": "c66554c7e56a4facdbbb10ea55fd1102",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.8",
            "size": 15886,
            "upload_time": "2025-10-24T12:28:04",
            "upload_time_iso_8601": "2025-10-24T12:28:04.293721Z",
            "url": "https://files.pythonhosted.org/packages/b6/6f/3491f05b30ffd499a1e93017d6d956455e81525209005652b3377b3f1141/tf_kan_latest-1.0.9.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2025-10-24 12:28:04",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "sathyasubrahmanya",
    "github_project": "tf-kan",
    "github_not_found": true,
    "lcname": "tf-kan-latest"
}
        
Elapsed time: 3.14341s