tversky-nn


Nametversky-nn JSON
Version 1.0.1 PyPI version JSON
download
home_pageNone
SummaryMy faithful reproduction of Tversky Neural Networks (TNNs)
upload_time2025-08-15 17:38:06
maintainerNone
docs_urlNone
authorNone
requires_python>=3.10
licenseMIT
keywords deep-learning pytorch tversky neural-networks xai ml
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # Tversky Neural Networks (TNN)

[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
[![Python 3.10+](https://img.shields.io/badge/python-3.10+-blue.svg)](https://www.python.org/downloads/release/python-3100/)

A PyTorch implementation of **Tversky Neural Networks (TNNs)**, a novel architecture that replaces traditional linear classification layers with Tversky similarity-based projection layers. This implementation faithfully reproduces the key concepts from the original paper and provides optimized, production-ready models for both research and practical applications.

## 🚀 What are Tversky Neural Networks?

Tversky Neural Networks introduce a fundamentally different approach to neural network classification by leveraging **Tversky similarity functions** instead of traditional dot-product operations. The key innovation is the **Tversky Projection Layer**, which:

- **Replaces linear layers** with learnable prototype-based similarity computations
- **Uses asymmetric similarity** through Tversky index (α, β parameters)
- **Provides interpretable representations** through learned prototypes
- **Maintains competitive accuracy** while offering explainable decision boundaries

### Core Mathematical Foundation

The Tversky Projection Layer computes similarities between input features and learned prototypes using:

```
S_Ω,α,β,θ(x, π_k) = |x ∩ π_k|_Ω / (|x ∩ π_k|_Ω + α|x \ π_k|_Ω + β|π_k \ x|_Ω + θ)
```

Where:
- `x` is the input feature vector
- `π_k` are learned prototypes  
- `Ω` is a learned feature bank
- `α, β` control asymmetric similarity weighting
- `θ` provides numerical stability

## 📦 Installation

### From PyPI (Recommended)

```bash
pip install tnn
```

### From Source

```bash
git clone https://github.com/akshathmangudi/tnn.git
cd tnn
pip install -e .
```

### Dependencies

- Python 3.10+
- PyTorch 2.0+
- torchvision 0.15+
- numpy
- scikit-learn
- tqdm
- pillow

## 🎯 Quick Start

### Basic Usage

```python
import torch
from tnn.models import get_resnet_model
from tnn.datasets import get_mnist_loaders

# Create a TverskyResNet model
model = get_resnet_model(
    architecture='resnet18',
    num_classes=10,
    use_tversky=True,
    num_prototypes=8,
    alpha=0.5,
    beta=0.5
)

# Load MNIST dataset
train_loader, val_loader, test_loader = get_mnist_loaders(
    data_dir='./data',
    batch_size=64
)

# Use the model
x = torch.randn(32, 3, 224, 224)  # Batch of images
outputs = model(x)  # Shape: (32, 10)
```

### XOR Toy Problem

Demonstrate TNN capabilities on the classic XOR problem:

```python
from tnn.models.xor import TverskyXORNet
import torch

# Create XOR model
model = TverskyXORNet(
    hidden_dim=8,
    num_prototypes=4,
    alpha=0.5,
    beta=0.5
)

# XOR data
x = torch.tensor([[0., 0.], [0., 1.], [1., 0.], [1., 1.]])
y = torch.tensor([0, 1, 1, 0])

# Forward pass
predictions = model(x)
```

## 🏃‍♂️ Training Models

### MNIST Classification

Train a TverskyResNet on MNIST:

```bash
# Train with Tversky layer (recommended)
python train_resnet.py --dataset mnist --architecture resnet18 --epochs 50 --lr 0.01

# Train baseline (linear layer)
python train_resnet.py --dataset mnist --architecture resnet18 --use-linear --epochs 50 --lr 0.01

# Quick test (2 epochs)
python train_resnet.py --dataset mnist --epochs 2 --lr 0.01
```

### XOR Toy Problem

```bash
python train_xor.py
```

### Advanced Training Options

TO BE UPDATED

## 📊 Results

Our implementation achieves strong performance across different tasks:

### MNIST Classification Results

| Configuration | Architecture | Classifier | Val Accuracy | Train Accuracy | Training Time |
|---------------|-------------|------------|--------------|----------------|---------------|
| **Optimized TNN** | ResNet18 | Tversky (8 prototypes) | **98.88%** | **98.81%** | ~32 min (2 epochs) |
| Baseline | ResNet18 | Linear | - | - | - |

**Key Training Metrics:**
- **Epoch 1:** Training Acc: 89.81%
- **Epoch 2:** Training Acc: 98.81%, Validation Acc: 98.88%
- **Model Size:** 11.18M parameters (4,608 in Tversky classifier)
- **Convergence:** Fast and stable with proper hyperparameters

### XOR Toy Problem Results

| Metric | Value |
|--------|-------|
| **Final Test Accuracy** | **93.00%** |
| Class 0 Accuracy | 95.40% |
| Class 1 Accuracy | 91.15% |
| Training Epochs | 500 |
| Convergence | Smooth, interpretable decision boundary |

**Visual Results:**
- Clear non-linear decision boundary
- Interpretable learned prototypes
- Smooth training curves

## 🔬 Key Features

### ✅ What Works Well

1. **Fast Convergence**: With proper hyperparameters (lr=0.01), TNNs converge quickly
2. **High Accuracy**: Achieves 98.88% validation accuracy on MNIST
3. **Interpretability**: Learned prototypes provide insight into model decisions
4. **Flexibility**: Support for multiple ResNet architectures
5. **Stability**: Robust training with mixed precision and proper initialization

### 🏗️ Architecture Highlights

- **Modular Design**: Easy to swap Tversky layers for linear layers
- **Multiple Architectures**: ResNet18/50/101/152 support
- **Pretrained Weights**: ImageNet initialization available
- **Mixed Precision**: Automatic mixed precision training
- **Comprehensive Logging**: Detailed metrics and checkpointing

### 🎛️ Configurable Hyperparameters

```python
# Tversky similarity parameters
alpha: float = 0.5              # Controls importance of false positives
beta: float = 0.5               # Controls importance of false negatives  
num_prototypes: int = 8         # Number of learned prototypes
theta: float = 1e-7             # Numerical stability constant

# Architecture options
intersection_reduction = "product"        # or "mean"
difference_reduction = "subtractmatch"    # or "ignorematch"
feature_bank_init = "xavier"             # Feature bank initialization
prototype_init = "xavier"                # Prototype initialization
```

## 🚧 Current Limitations & Future Work

### Known Issues Resolved ✅

- **Double Classification Layer**: Fixed architecture that was causing convergence issues
- **Softmax Placement**: Corrected `apply_softmax=False` in Tversky layer
- **Learning Rate**: Optimized default learning rate from 0.001 → 0.01
- **Initialization**: Improved prototype and feature bank initialization

### Future Enhancements 🔮

1. **Extended Datasets**: Support for CIFAR-10/100, ImageNet
2. **Additional Architectures**: Vision Transformers, EfficientNets
3. **Advanced Features**: 
   - Prototype visualization tools
   - Attention mechanisms
   - Multi-modal support
4. **Optimization**: 
   - Further convergence improvements
   - Memory optimization for large models
5. **Research Extensions**:
   - Adaptive α, β parameters
   - Hierarchical prototypes
   - Ensemble methods

## 📈 Performance Optimizations Applied

Our implementation includes several key optimizations discovered during development:

1. **Architectural Fixes**:
   - Removed double classification layer causing gradient flow issues
   - Set `apply_softmax=False` in Tversky layer for better optimization
   - Improved linear layer initialization with Xavier uniform

2. **Training Optimizations**:
   - Increased learning rate to 0.01 for faster convergence
   - Mixed precision training for memory efficiency
   - Cosine annealing scheduler for better convergence

3. **Numerical Stability**:
   - Proper theta parameter (1e-7) for numerical stability
   - Xavier initialization for all learnable parameters
   - Gradient clipping and proper loss scaling

## 🤝 Contributing

We welcome contributions! Areas where help is needed:

- Additional dataset implementations
- New architecture support  
- Performance optimizations
- Documentation improvements
- Bug fixes and testing

## To add: 
- [ ] Include GPT-2 implementation and benchmarks. 
- [ ] Run ResNet18 benchmarks on NABirds Dataset. 
- [ ] Add benchmarks for different datasets for different weight distributions. 
- [ ] Unify training configuration instead of keeping several training files for different models. 
- [ ] Include type checking and other software development process standards to maintain robustness. 

## 📝 Citation

If you use this implementation in your research, please cite:

```bibtex
@software{tnn_pytorch,
  author = {Akshath Mangudi},
  title = {TNN: A PyTorch Implementation of Tversky Neural Networks},
  year = {2025},
  url = {https://github.com/your-username/tnn}
}
```

For the original Tversky Neural Networks paper, please cite:
```bibtex
@article{tversky_neural_networks,
  title={Tversky Neural Networks},
  author={[Original Authors]},
  journal={[Journal Name]},
  year={[Year]},
  url={[Paper URL]}
}
```

## 📄 License

This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.

## 🙏 Acknowledgments

- Original Tversky Neural Networks paper authors
- PyTorch team for the excellent deep learning framework
- torchvision for pretrained models and datasets

---

**Built with ❤️ and PyTorch** | **Ready for production use** | **Optimized for research**

            

Raw data

            {
    "_id": null,
    "home_page": null,
    "name": "tversky-nn",
    "maintainer": null,
    "docs_url": null,
    "requires_python": ">=3.10",
    "maintainer_email": null,
    "keywords": "deep-learning, pytorch, tversky, neural-networks, xai, ml",
    "author": null,
    "author_email": "Akshath Mangudi <akshathmangudi@gmail.com>",
    "download_url": "https://files.pythonhosted.org/packages/ff/a2/6dc8faca6b4d2edbb03eae6af2cc5aa323cecdbe5790b17639917a82137a/tversky_nn-1.0.1.tar.gz",
    "platform": null,
    "description": "# Tversky Neural Networks (TNN)\n\n[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)\n[![Python 3.10+](https://img.shields.io/badge/python-3.10+-blue.svg)](https://www.python.org/downloads/release/python-3100/)\n\nA PyTorch implementation of **Tversky Neural Networks (TNNs)**, a novel architecture that replaces traditional linear classification layers with Tversky similarity-based projection layers. This implementation faithfully reproduces the key concepts from the original paper and provides optimized, production-ready models for both research and practical applications.\n\n## \ud83d\ude80 What are Tversky Neural Networks?\n\nTversky Neural Networks introduce a fundamentally different approach to neural network classification by leveraging **Tversky similarity functions** instead of traditional dot-product operations. The key innovation is the **Tversky Projection Layer**, which:\n\n- **Replaces linear layers** with learnable prototype-based similarity computations\n- **Uses asymmetric similarity** through Tversky index (\u03b1, \u03b2 parameters)\n- **Provides interpretable representations** through learned prototypes\n- **Maintains competitive accuracy** while offering explainable decision boundaries\n\n### Core Mathematical Foundation\n\nThe Tversky Projection Layer computes similarities between input features and learned prototypes using:\n\n```\nS_\u03a9,\u03b1,\u03b2,\u03b8(x, \u03c0_k) = |x \u2229 \u03c0_k|_\u03a9 / (|x \u2229 \u03c0_k|_\u03a9 + \u03b1|x \\ \u03c0_k|_\u03a9 + \u03b2|\u03c0_k \\ x|_\u03a9 + \u03b8)\n```\n\nWhere:\n- `x` is the input feature vector\n- `\u03c0_k` are learned prototypes  \n- `\u03a9` is a learned feature bank\n- `\u03b1, \u03b2` control asymmetric similarity weighting\n- `\u03b8` provides numerical stability\n\n## \ud83d\udce6 Installation\n\n### From PyPI (Recommended)\n\n```bash\npip install tnn\n```\n\n### From Source\n\n```bash\ngit clone https://github.com/akshathmangudi/tnn.git\ncd tnn\npip install -e .\n```\n\n### Dependencies\n\n- Python 3.10+\n- PyTorch 2.0+\n- torchvision 0.15+\n- numpy\n- scikit-learn\n- tqdm\n- pillow\n\n## \ud83c\udfaf Quick Start\n\n### Basic Usage\n\n```python\nimport torch\nfrom tnn.models import get_resnet_model\nfrom tnn.datasets import get_mnist_loaders\n\n# Create a TverskyResNet model\nmodel = get_resnet_model(\n    architecture='resnet18',\n    num_classes=10,\n    use_tversky=True,\n    num_prototypes=8,\n    alpha=0.5,\n    beta=0.5\n)\n\n# Load MNIST dataset\ntrain_loader, val_loader, test_loader = get_mnist_loaders(\n    data_dir='./data',\n    batch_size=64\n)\n\n# Use the model\nx = torch.randn(32, 3, 224, 224)  # Batch of images\noutputs = model(x)  # Shape: (32, 10)\n```\n\n### XOR Toy Problem\n\nDemonstrate TNN capabilities on the classic XOR problem:\n\n```python\nfrom tnn.models.xor import TverskyXORNet\nimport torch\n\n# Create XOR model\nmodel = TverskyXORNet(\n    hidden_dim=8,\n    num_prototypes=4,\n    alpha=0.5,\n    beta=0.5\n)\n\n# XOR data\nx = torch.tensor([[0., 0.], [0., 1.], [1., 0.], [1., 1.]])\ny = torch.tensor([0, 1, 1, 0])\n\n# Forward pass\npredictions = model(x)\n```\n\n## \ud83c\udfc3\u200d\u2642\ufe0f Training Models\n\n### MNIST Classification\n\nTrain a TverskyResNet on MNIST:\n\n```bash\n# Train with Tversky layer (recommended)\npython train_resnet.py --dataset mnist --architecture resnet18 --epochs 50 --lr 0.01\n\n# Train baseline (linear layer)\npython train_resnet.py --dataset mnist --architecture resnet18 --use-linear --epochs 50 --lr 0.01\n\n# Quick test (2 epochs)\npython train_resnet.py --dataset mnist --epochs 2 --lr 0.01\n```\n\n### XOR Toy Problem\n\n```bash\npython train_xor.py\n```\n\n### Advanced Training Options\n\nTO BE UPDATED\n\n## \ud83d\udcca Results\n\nOur implementation achieves strong performance across different tasks:\n\n### MNIST Classification Results\n\n| Configuration | Architecture | Classifier | Val Accuracy | Train Accuracy | Training Time |\n|---------------|-------------|------------|--------------|----------------|---------------|\n| **Optimized TNN** | ResNet18 | Tversky (8 prototypes) | **98.88%** | **98.81%** | ~32 min (2 epochs) |\n| Baseline | ResNet18 | Linear | - | - | - |\n\n**Key Training Metrics:**\n- **Epoch 1:** Training Acc: 89.81%\n- **Epoch 2:** Training Acc: 98.81%, Validation Acc: 98.88%\n- **Model Size:** 11.18M parameters (4,608 in Tversky classifier)\n- **Convergence:** Fast and stable with proper hyperparameters\n\n### XOR Toy Problem Results\n\n| Metric | Value |\n|--------|-------|\n| **Final Test Accuracy** | **93.00%** |\n| Class 0 Accuracy | 95.40% |\n| Class 1 Accuracy | 91.15% |\n| Training Epochs | 500 |\n| Convergence | Smooth, interpretable decision boundary |\n\n**Visual Results:**\n- Clear non-linear decision boundary\n- Interpretable learned prototypes\n- Smooth training curves\n\n## \ud83d\udd2c Key Features\n\n### \u2705 What Works Well\n\n1. **Fast Convergence**: With proper hyperparameters (lr=0.01), TNNs converge quickly\n2. **High Accuracy**: Achieves 98.88% validation accuracy on MNIST\n3. **Interpretability**: Learned prototypes provide insight into model decisions\n4. **Flexibility**: Support for multiple ResNet architectures\n5. **Stability**: Robust training with mixed precision and proper initialization\n\n### \ud83c\udfd7\ufe0f Architecture Highlights\n\n- **Modular Design**: Easy to swap Tversky layers for linear layers\n- **Multiple Architectures**: ResNet18/50/101/152 support\n- **Pretrained Weights**: ImageNet initialization available\n- **Mixed Precision**: Automatic mixed precision training\n- **Comprehensive Logging**: Detailed metrics and checkpointing\n\n### \ud83c\udf9b\ufe0f Configurable Hyperparameters\n\n```python\n# Tversky similarity parameters\nalpha: float = 0.5              # Controls importance of false positives\nbeta: float = 0.5               # Controls importance of false negatives  \nnum_prototypes: int = 8         # Number of learned prototypes\ntheta: float = 1e-7             # Numerical stability constant\n\n# Architecture options\nintersection_reduction = \"product\"        # or \"mean\"\ndifference_reduction = \"subtractmatch\"    # or \"ignorematch\"\nfeature_bank_init = \"xavier\"             # Feature bank initialization\nprototype_init = \"xavier\"                # Prototype initialization\n```\n\n## \ud83d\udea7 Current Limitations & Future Work\n\n### Known Issues Resolved \u2705\n\n- **Double Classification Layer**: Fixed architecture that was causing convergence issues\n- **Softmax Placement**: Corrected `apply_softmax=False` in Tversky layer\n- **Learning Rate**: Optimized default learning rate from 0.001 \u2192 0.01\n- **Initialization**: Improved prototype and feature bank initialization\n\n### Future Enhancements \ud83d\udd2e\n\n1. **Extended Datasets**: Support for CIFAR-10/100, ImageNet\n2. **Additional Architectures**: Vision Transformers, EfficientNets\n3. **Advanced Features**: \n   - Prototype visualization tools\n   - Attention mechanisms\n   - Multi-modal support\n4. **Optimization**: \n   - Further convergence improvements\n   - Memory optimization for large models\n5. **Research Extensions**:\n   - Adaptive \u03b1, \u03b2 parameters\n   - Hierarchical prototypes\n   - Ensemble methods\n\n## \ud83d\udcc8 Performance Optimizations Applied\n\nOur implementation includes several key optimizations discovered during development:\n\n1. **Architectural Fixes**:\n   - Removed double classification layer causing gradient flow issues\n   - Set `apply_softmax=False` in Tversky layer for better optimization\n   - Improved linear layer initialization with Xavier uniform\n\n2. **Training Optimizations**:\n   - Increased learning rate to 0.01 for faster convergence\n   - Mixed precision training for memory efficiency\n   - Cosine annealing scheduler for better convergence\n\n3. **Numerical Stability**:\n   - Proper theta parameter (1e-7) for numerical stability\n   - Xavier initialization for all learnable parameters\n   - Gradient clipping and proper loss scaling\n\n## \ud83e\udd1d Contributing\n\nWe welcome contributions! Areas where help is needed:\n\n- Additional dataset implementations\n- New architecture support  \n- Performance optimizations\n- Documentation improvements\n- Bug fixes and testing\n\n## To add: \n- [ ] Include GPT-2 implementation and benchmarks. \n- [ ] Run ResNet18 benchmarks on NABirds Dataset. \n- [ ] Add benchmarks for different datasets for different weight distributions. \n- [ ] Unify training configuration instead of keeping several training files for different models. \n- [ ] Include type checking and other software development process standards to maintain robustness. \n\n## \ud83d\udcdd Citation\n\nIf you use this implementation in your research, please cite:\n\n```bibtex\n@software{tnn_pytorch,\n  author = {Akshath Mangudi},\n  title = {TNN: A PyTorch Implementation of Tversky Neural Networks},\n  year = {2025},\n  url = {https://github.com/your-username/tnn}\n}\n```\n\nFor the original Tversky Neural Networks paper, please cite:\n```bibtex\n@article{tversky_neural_networks,\n  title={Tversky Neural Networks},\n  author={[Original Authors]},\n  journal={[Journal Name]},\n  year={[Year]},\n  url={[Paper URL]}\n}\n```\n\n## \ud83d\udcc4 License\n\nThis project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.\n\n## \ud83d\ude4f Acknowledgments\n\n- Original Tversky Neural Networks paper authors\n- PyTorch team for the excellent deep learning framework\n- torchvision for pretrained models and datasets\n\n---\n\n**Built with \u2764\ufe0f and PyTorch** | **Ready for production use** | **Optimized for research**\n",
    "bugtrack_url": null,
    "license": "MIT",
    "summary": "My faithful reproduction of Tversky Neural Networks (TNNs)",
    "version": "1.0.1",
    "project_urls": {
        "Homepage": "https://github.com/akshathmangudi/tnn",
        "Issues": "https://github.com/akshathmangudi/tnn/issues",
        "Repository": "https://github.com/akshathmangudi/tnn"
    },
    "split_keywords": [
        "deep-learning",
        " pytorch",
        " tversky",
        " neural-networks",
        " xai",
        " ml"
    ],
    "urls": [
        {
            "comment_text": null,
            "digests": {
                "blake2b_256": "6134deaf3d4eae91466a61df8c598d45a062c9d94d184dfb28b644cab1992ca9",
                "md5": "709c2d88c973d6fb0a0d2933e239b726",
                "sha256": "fb37db76896cb01ab4edb345fb18838a96982a8ddf3a61985da6783261a5a989"
            },
            "downloads": -1,
            "filename": "tversky_nn-1.0.1-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "709c2d88c973d6fb0a0d2933e239b726",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.10",
            "size": 30727,
            "upload_time": "2025-08-15T17:38:04",
            "upload_time_iso_8601": "2025-08-15T17:38:04.614073Z",
            "url": "https://files.pythonhosted.org/packages/61/34/deaf3d4eae91466a61df8c598d45a062c9d94d184dfb28b644cab1992ca9/tversky_nn-1.0.1-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": null,
            "digests": {
                "blake2b_256": "ffa26dc8faca6b4d2edbb03eae6af2cc5aa323cecdbe5790b17639917a82137a",
                "md5": "8237170ef3be5d0d79334dc3eb3b4196",
                "sha256": "78197ee4b27f7c535c21d67d69a7df46ffc412df6b2b64c341887aaceee51180"
            },
            "downloads": -1,
            "filename": "tversky_nn-1.0.1.tar.gz",
            "has_sig": false,
            "md5_digest": "8237170ef3be5d0d79334dc3eb3b4196",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.10",
            "size": 30562,
            "upload_time": "2025-08-15T17:38:06",
            "upload_time_iso_8601": "2025-08-15T17:38:06.410407Z",
            "url": "https://files.pythonhosted.org/packages/ff/a2/6dc8faca6b4d2edbb03eae6af2cc5aa323cecdbe5790b17639917a82137a/tversky_nn-1.0.1.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2025-08-15 17:38:06",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "akshathmangudi",
    "github_project": "tnn",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": false,
    "lcname": "tversky-nn"
}
        
Elapsed time: 1.46992s