# Tversky Neural Networks (TNN)
[](https://opensource.org/licenses/MIT)
[](https://www.python.org/downloads/release/python-3100/)
A PyTorch implementation of **Tversky Neural Networks (TNNs)**, a novel architecture that replaces traditional linear classification layers with Tversky similarity-based projection layers. This implementation faithfully reproduces the key concepts from the original paper and provides optimized, production-ready models for both research and practical applications.
## 🚀 What are Tversky Neural Networks?
Tversky Neural Networks introduce a fundamentally different approach to neural network classification by leveraging **Tversky similarity functions** instead of traditional dot-product operations. The key innovation is the **Tversky Projection Layer**, which:
- **Replaces linear layers** with learnable prototype-based similarity computations
- **Uses asymmetric similarity** through Tversky index (α, β parameters)
- **Provides interpretable representations** through learned prototypes
- **Maintains competitive accuracy** while offering explainable decision boundaries
### Core Mathematical Foundation
The Tversky Projection Layer computes similarities between input features and learned prototypes using:
```
S_Ω,α,β,θ(x, π_k) = |x ∩ π_k|_Ω / (|x ∩ π_k|_Ω + α|x \ π_k|_Ω + β|π_k \ x|_Ω + θ)
```
Where:
- `x` is the input feature vector
- `π_k` are learned prototypes
- `Ω` is a learned feature bank
- `α, β` control asymmetric similarity weighting
- `θ` provides numerical stability
## 📦 Installation
### From PyPI (Recommended)
```bash
pip install tnn
```
### From Source
```bash
git clone https://github.com/akshathmangudi/tnn.git
cd tnn
pip install -e .
```
### Dependencies
- Python 3.10+
- PyTorch 2.0+
- torchvision 0.15+
- numpy
- scikit-learn
- tqdm
- pillow
## 🎯 Quick Start
### Basic Usage
```python
import torch
from tnn.models import get_resnet_model
from tnn.datasets import get_mnist_loaders
# Create a TverskyResNet model
model = get_resnet_model(
architecture='resnet18',
num_classes=10,
use_tversky=True,
num_prototypes=8,
alpha=0.5,
beta=0.5
)
# Load MNIST dataset
train_loader, val_loader, test_loader = get_mnist_loaders(
data_dir='./data',
batch_size=64
)
# Use the model
x = torch.randn(32, 3, 224, 224) # Batch of images
outputs = model(x) # Shape: (32, 10)
```
### XOR Toy Problem
Demonstrate TNN capabilities on the classic XOR problem:
```python
from tnn.models.xor import TverskyXORNet
import torch
# Create XOR model
model = TverskyXORNet(
hidden_dim=8,
num_prototypes=4,
alpha=0.5,
beta=0.5
)
# XOR data
x = torch.tensor([[0., 0.], [0., 1.], [1., 0.], [1., 1.]])
y = torch.tensor([0, 1, 1, 0])
# Forward pass
predictions = model(x)
```
## 🏃♂️ Training Models
### MNIST Classification
Train a TverskyResNet on MNIST:
```bash
# Train with Tversky layer (recommended)
python train_resnet.py --dataset mnist --architecture resnet18 --epochs 50 --lr 0.01
# Train baseline (linear layer)
python train_resnet.py --dataset mnist --architecture resnet18 --use-linear --epochs 50 --lr 0.01
# Quick test (2 epochs)
python train_resnet.py --dataset mnist --epochs 2 --lr 0.01
```
### XOR Toy Problem
```bash
python train_xor.py
```
### Advanced Training Options
TO BE UPDATED
## 📊 Results
Our implementation achieves strong performance across different tasks:
### MNIST Classification Results
| Configuration | Architecture | Classifier | Val Accuracy | Train Accuracy | Training Time |
|---------------|-------------|------------|--------------|----------------|---------------|
| **Optimized TNN** | ResNet18 | Tversky (8 prototypes) | **98.88%** | **98.81%** | ~32 min (2 epochs) |
| Baseline | ResNet18 | Linear | - | - | - |
**Key Training Metrics:**
- **Epoch 1:** Training Acc: 89.81%
- **Epoch 2:** Training Acc: 98.81%, Validation Acc: 98.88%
- **Model Size:** 11.18M parameters (4,608 in Tversky classifier)
- **Convergence:** Fast and stable with proper hyperparameters
### XOR Toy Problem Results
| Metric | Value |
|--------|-------|
| **Final Test Accuracy** | **93.00%** |
| Class 0 Accuracy | 95.40% |
| Class 1 Accuracy | 91.15% |
| Training Epochs | 500 |
| Convergence | Smooth, interpretable decision boundary |
**Visual Results:**
- Clear non-linear decision boundary
- Interpretable learned prototypes
- Smooth training curves
## 🔬 Key Features
### ✅ What Works Well
1. **Fast Convergence**: With proper hyperparameters (lr=0.01), TNNs converge quickly
2. **High Accuracy**: Achieves 98.88% validation accuracy on MNIST
3. **Interpretability**: Learned prototypes provide insight into model decisions
4. **Flexibility**: Support for multiple ResNet architectures
5. **Stability**: Robust training with mixed precision and proper initialization
### 🏗️ Architecture Highlights
- **Modular Design**: Easy to swap Tversky layers for linear layers
- **Multiple Architectures**: ResNet18/50/101/152 support
- **Pretrained Weights**: ImageNet initialization available
- **Mixed Precision**: Automatic mixed precision training
- **Comprehensive Logging**: Detailed metrics and checkpointing
### 🎛️ Configurable Hyperparameters
```python
# Tversky similarity parameters
alpha: float = 0.5 # Controls importance of false positives
beta: float = 0.5 # Controls importance of false negatives
num_prototypes: int = 8 # Number of learned prototypes
theta: float = 1e-7 # Numerical stability constant
# Architecture options
intersection_reduction = "product" # or "mean"
difference_reduction = "subtractmatch" # or "ignorematch"
feature_bank_init = "xavier" # Feature bank initialization
prototype_init = "xavier" # Prototype initialization
```
## 🚧 Current Limitations & Future Work
### Known Issues Resolved ✅
- **Double Classification Layer**: Fixed architecture that was causing convergence issues
- **Softmax Placement**: Corrected `apply_softmax=False` in Tversky layer
- **Learning Rate**: Optimized default learning rate from 0.001 → 0.01
- **Initialization**: Improved prototype and feature bank initialization
### Future Enhancements 🔮
1. **Extended Datasets**: Support for CIFAR-10/100, ImageNet
2. **Additional Architectures**: Vision Transformers, EfficientNets
3. **Advanced Features**:
- Prototype visualization tools
- Attention mechanisms
- Multi-modal support
4. **Optimization**:
- Further convergence improvements
- Memory optimization for large models
5. **Research Extensions**:
- Adaptive α, β parameters
- Hierarchical prototypes
- Ensemble methods
## 📈 Performance Optimizations Applied
Our implementation includes several key optimizations discovered during development:
1. **Architectural Fixes**:
- Removed double classification layer causing gradient flow issues
- Set `apply_softmax=False` in Tversky layer for better optimization
- Improved linear layer initialization with Xavier uniform
2. **Training Optimizations**:
- Increased learning rate to 0.01 for faster convergence
- Mixed precision training for memory efficiency
- Cosine annealing scheduler for better convergence
3. **Numerical Stability**:
- Proper theta parameter (1e-7) for numerical stability
- Xavier initialization for all learnable parameters
- Gradient clipping and proper loss scaling
## 🤝 Contributing
We welcome contributions! Areas where help is needed:
- Additional dataset implementations
- New architecture support
- Performance optimizations
- Documentation improvements
- Bug fixes and testing
## To add:
- [ ] Include GPT-2 implementation and benchmarks.
- [ ] Run ResNet18 benchmarks on NABirds Dataset.
- [ ] Add benchmarks for different datasets for different weight distributions.
- [ ] Unify training configuration instead of keeping several training files for different models.
- [ ] Include type checking and other software development process standards to maintain robustness.
## 📝 Citation
If you use this implementation in your research, please cite:
```bibtex
@software{tnn_pytorch,
author = {Akshath Mangudi},
title = {TNN: A PyTorch Implementation of Tversky Neural Networks},
year = {2025},
url = {https://github.com/your-username/tnn}
}
```
For the original Tversky Neural Networks paper, please cite:
```bibtex
@article{tversky_neural_networks,
title={Tversky Neural Networks},
author={[Original Authors]},
journal={[Journal Name]},
year={[Year]},
url={[Paper URL]}
}
```
## 📄 License
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
## 🙏 Acknowledgments
- Original Tversky Neural Networks paper authors
- PyTorch team for the excellent deep learning framework
- torchvision for pretrained models and datasets
---
**Built with ❤️ and PyTorch** | **Ready for production use** | **Optimized for research**
Raw data
{
"_id": null,
"home_page": null,
"name": "tversky-nn",
"maintainer": null,
"docs_url": null,
"requires_python": ">=3.10",
"maintainer_email": null,
"keywords": "deep-learning, pytorch, tversky, neural-networks, xai, ml",
"author": null,
"author_email": "Akshath Mangudi <akshathmangudi@gmail.com>",
"download_url": "https://files.pythonhosted.org/packages/ff/a2/6dc8faca6b4d2edbb03eae6af2cc5aa323cecdbe5790b17639917a82137a/tversky_nn-1.0.1.tar.gz",
"platform": null,
"description": "# Tversky Neural Networks (TNN)\n\n[](https://opensource.org/licenses/MIT)\n[](https://www.python.org/downloads/release/python-3100/)\n\nA PyTorch implementation of **Tversky Neural Networks (TNNs)**, a novel architecture that replaces traditional linear classification layers with Tversky similarity-based projection layers. This implementation faithfully reproduces the key concepts from the original paper and provides optimized, production-ready models for both research and practical applications.\n\n## \ud83d\ude80 What are Tversky Neural Networks?\n\nTversky Neural Networks introduce a fundamentally different approach to neural network classification by leveraging **Tversky similarity functions** instead of traditional dot-product operations. The key innovation is the **Tversky Projection Layer**, which:\n\n- **Replaces linear layers** with learnable prototype-based similarity computations\n- **Uses asymmetric similarity** through Tversky index (\u03b1, \u03b2 parameters)\n- **Provides interpretable representations** through learned prototypes\n- **Maintains competitive accuracy** while offering explainable decision boundaries\n\n### Core Mathematical Foundation\n\nThe Tversky Projection Layer computes similarities between input features and learned prototypes using:\n\n```\nS_\u03a9,\u03b1,\u03b2,\u03b8(x, \u03c0_k) = |x \u2229 \u03c0_k|_\u03a9 / (|x \u2229 \u03c0_k|_\u03a9 + \u03b1|x \\ \u03c0_k|_\u03a9 + \u03b2|\u03c0_k \\ x|_\u03a9 + \u03b8)\n```\n\nWhere:\n- `x` is the input feature vector\n- `\u03c0_k` are learned prototypes \n- `\u03a9` is a learned feature bank\n- `\u03b1, \u03b2` control asymmetric similarity weighting\n- `\u03b8` provides numerical stability\n\n## \ud83d\udce6 Installation\n\n### From PyPI (Recommended)\n\n```bash\npip install tnn\n```\n\n### From Source\n\n```bash\ngit clone https://github.com/akshathmangudi/tnn.git\ncd tnn\npip install -e .\n```\n\n### Dependencies\n\n- Python 3.10+\n- PyTorch 2.0+\n- torchvision 0.15+\n- numpy\n- scikit-learn\n- tqdm\n- pillow\n\n## \ud83c\udfaf Quick Start\n\n### Basic Usage\n\n```python\nimport torch\nfrom tnn.models import get_resnet_model\nfrom tnn.datasets import get_mnist_loaders\n\n# Create a TverskyResNet model\nmodel = get_resnet_model(\n architecture='resnet18',\n num_classes=10,\n use_tversky=True,\n num_prototypes=8,\n alpha=0.5,\n beta=0.5\n)\n\n# Load MNIST dataset\ntrain_loader, val_loader, test_loader = get_mnist_loaders(\n data_dir='./data',\n batch_size=64\n)\n\n# Use the model\nx = torch.randn(32, 3, 224, 224) # Batch of images\noutputs = model(x) # Shape: (32, 10)\n```\n\n### XOR Toy Problem\n\nDemonstrate TNN capabilities on the classic XOR problem:\n\n```python\nfrom tnn.models.xor import TverskyXORNet\nimport torch\n\n# Create XOR model\nmodel = TverskyXORNet(\n hidden_dim=8,\n num_prototypes=4,\n alpha=0.5,\n beta=0.5\n)\n\n# XOR data\nx = torch.tensor([[0., 0.], [0., 1.], [1., 0.], [1., 1.]])\ny = torch.tensor([0, 1, 1, 0])\n\n# Forward pass\npredictions = model(x)\n```\n\n## \ud83c\udfc3\u200d\u2642\ufe0f Training Models\n\n### MNIST Classification\n\nTrain a TverskyResNet on MNIST:\n\n```bash\n# Train with Tversky layer (recommended)\npython train_resnet.py --dataset mnist --architecture resnet18 --epochs 50 --lr 0.01\n\n# Train baseline (linear layer)\npython train_resnet.py --dataset mnist --architecture resnet18 --use-linear --epochs 50 --lr 0.01\n\n# Quick test (2 epochs)\npython train_resnet.py --dataset mnist --epochs 2 --lr 0.01\n```\n\n### XOR Toy Problem\n\n```bash\npython train_xor.py\n```\n\n### Advanced Training Options\n\nTO BE UPDATED\n\n## \ud83d\udcca Results\n\nOur implementation achieves strong performance across different tasks:\n\n### MNIST Classification Results\n\n| Configuration | Architecture | Classifier | Val Accuracy | Train Accuracy | Training Time |\n|---------------|-------------|------------|--------------|----------------|---------------|\n| **Optimized TNN** | ResNet18 | Tversky (8 prototypes) | **98.88%** | **98.81%** | ~32 min (2 epochs) |\n| Baseline | ResNet18 | Linear | - | - | - |\n\n**Key Training Metrics:**\n- **Epoch 1:** Training Acc: 89.81%\n- **Epoch 2:** Training Acc: 98.81%, Validation Acc: 98.88%\n- **Model Size:** 11.18M parameters (4,608 in Tversky classifier)\n- **Convergence:** Fast and stable with proper hyperparameters\n\n### XOR Toy Problem Results\n\n| Metric | Value |\n|--------|-------|\n| **Final Test Accuracy** | **93.00%** |\n| Class 0 Accuracy | 95.40% |\n| Class 1 Accuracy | 91.15% |\n| Training Epochs | 500 |\n| Convergence | Smooth, interpretable decision boundary |\n\n**Visual Results:**\n- Clear non-linear decision boundary\n- Interpretable learned prototypes\n- Smooth training curves\n\n## \ud83d\udd2c Key Features\n\n### \u2705 What Works Well\n\n1. **Fast Convergence**: With proper hyperparameters (lr=0.01), TNNs converge quickly\n2. **High Accuracy**: Achieves 98.88% validation accuracy on MNIST\n3. **Interpretability**: Learned prototypes provide insight into model decisions\n4. **Flexibility**: Support for multiple ResNet architectures\n5. **Stability**: Robust training with mixed precision and proper initialization\n\n### \ud83c\udfd7\ufe0f Architecture Highlights\n\n- **Modular Design**: Easy to swap Tversky layers for linear layers\n- **Multiple Architectures**: ResNet18/50/101/152 support\n- **Pretrained Weights**: ImageNet initialization available\n- **Mixed Precision**: Automatic mixed precision training\n- **Comprehensive Logging**: Detailed metrics and checkpointing\n\n### \ud83c\udf9b\ufe0f Configurable Hyperparameters\n\n```python\n# Tversky similarity parameters\nalpha: float = 0.5 # Controls importance of false positives\nbeta: float = 0.5 # Controls importance of false negatives \nnum_prototypes: int = 8 # Number of learned prototypes\ntheta: float = 1e-7 # Numerical stability constant\n\n# Architecture options\nintersection_reduction = \"product\" # or \"mean\"\ndifference_reduction = \"subtractmatch\" # or \"ignorematch\"\nfeature_bank_init = \"xavier\" # Feature bank initialization\nprototype_init = \"xavier\" # Prototype initialization\n```\n\n## \ud83d\udea7 Current Limitations & Future Work\n\n### Known Issues Resolved \u2705\n\n- **Double Classification Layer**: Fixed architecture that was causing convergence issues\n- **Softmax Placement**: Corrected `apply_softmax=False` in Tversky layer\n- **Learning Rate**: Optimized default learning rate from 0.001 \u2192 0.01\n- **Initialization**: Improved prototype and feature bank initialization\n\n### Future Enhancements \ud83d\udd2e\n\n1. **Extended Datasets**: Support for CIFAR-10/100, ImageNet\n2. **Additional Architectures**: Vision Transformers, EfficientNets\n3. **Advanced Features**: \n - Prototype visualization tools\n - Attention mechanisms\n - Multi-modal support\n4. **Optimization**: \n - Further convergence improvements\n - Memory optimization for large models\n5. **Research Extensions**:\n - Adaptive \u03b1, \u03b2 parameters\n - Hierarchical prototypes\n - Ensemble methods\n\n## \ud83d\udcc8 Performance Optimizations Applied\n\nOur implementation includes several key optimizations discovered during development:\n\n1. **Architectural Fixes**:\n - Removed double classification layer causing gradient flow issues\n - Set `apply_softmax=False` in Tversky layer for better optimization\n - Improved linear layer initialization with Xavier uniform\n\n2. **Training Optimizations**:\n - Increased learning rate to 0.01 for faster convergence\n - Mixed precision training for memory efficiency\n - Cosine annealing scheduler for better convergence\n\n3. **Numerical Stability**:\n - Proper theta parameter (1e-7) for numerical stability\n - Xavier initialization for all learnable parameters\n - Gradient clipping and proper loss scaling\n\n## \ud83e\udd1d Contributing\n\nWe welcome contributions! Areas where help is needed:\n\n- Additional dataset implementations\n- New architecture support \n- Performance optimizations\n- Documentation improvements\n- Bug fixes and testing\n\n## To add: \n- [ ] Include GPT-2 implementation and benchmarks. \n- [ ] Run ResNet18 benchmarks on NABirds Dataset. \n- [ ] Add benchmarks for different datasets for different weight distributions. \n- [ ] Unify training configuration instead of keeping several training files for different models. \n- [ ] Include type checking and other software development process standards to maintain robustness. \n\n## \ud83d\udcdd Citation\n\nIf you use this implementation in your research, please cite:\n\n```bibtex\n@software{tnn_pytorch,\n author = {Akshath Mangudi},\n title = {TNN: A PyTorch Implementation of Tversky Neural Networks},\n year = {2025},\n url = {https://github.com/your-username/tnn}\n}\n```\n\nFor the original Tversky Neural Networks paper, please cite:\n```bibtex\n@article{tversky_neural_networks,\n title={Tversky Neural Networks},\n author={[Original Authors]},\n journal={[Journal Name]},\n year={[Year]},\n url={[Paper URL]}\n}\n```\n\n## \ud83d\udcc4 License\n\nThis project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.\n\n## \ud83d\ude4f Acknowledgments\n\n- Original Tversky Neural Networks paper authors\n- PyTorch team for the excellent deep learning framework\n- torchvision for pretrained models and datasets\n\n---\n\n**Built with \u2764\ufe0f and PyTorch** | **Ready for production use** | **Optimized for research**\n",
"bugtrack_url": null,
"license": "MIT",
"summary": "My faithful reproduction of Tversky Neural Networks (TNNs)",
"version": "1.0.1",
"project_urls": {
"Homepage": "https://github.com/akshathmangudi/tnn",
"Issues": "https://github.com/akshathmangudi/tnn/issues",
"Repository": "https://github.com/akshathmangudi/tnn"
},
"split_keywords": [
"deep-learning",
" pytorch",
" tversky",
" neural-networks",
" xai",
" ml"
],
"urls": [
{
"comment_text": null,
"digests": {
"blake2b_256": "6134deaf3d4eae91466a61df8c598d45a062c9d94d184dfb28b644cab1992ca9",
"md5": "709c2d88c973d6fb0a0d2933e239b726",
"sha256": "fb37db76896cb01ab4edb345fb18838a96982a8ddf3a61985da6783261a5a989"
},
"downloads": -1,
"filename": "tversky_nn-1.0.1-py3-none-any.whl",
"has_sig": false,
"md5_digest": "709c2d88c973d6fb0a0d2933e239b726",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": ">=3.10",
"size": 30727,
"upload_time": "2025-08-15T17:38:04",
"upload_time_iso_8601": "2025-08-15T17:38:04.614073Z",
"url": "https://files.pythonhosted.org/packages/61/34/deaf3d4eae91466a61df8c598d45a062c9d94d184dfb28b644cab1992ca9/tversky_nn-1.0.1-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": null,
"digests": {
"blake2b_256": "ffa26dc8faca6b4d2edbb03eae6af2cc5aa323cecdbe5790b17639917a82137a",
"md5": "8237170ef3be5d0d79334dc3eb3b4196",
"sha256": "78197ee4b27f7c535c21d67d69a7df46ffc412df6b2b64c341887aaceee51180"
},
"downloads": -1,
"filename": "tversky_nn-1.0.1.tar.gz",
"has_sig": false,
"md5_digest": "8237170ef3be5d0d79334dc3eb3b4196",
"packagetype": "sdist",
"python_version": "source",
"requires_python": ">=3.10",
"size": 30562,
"upload_time": "2025-08-15T17:38:06",
"upload_time_iso_8601": "2025-08-15T17:38:06.410407Z",
"url": "https://files.pythonhosted.org/packages/ff/a2/6dc8faca6b4d2edbb03eae6af2cc5aa323cecdbe5790b17639917a82137a/tversky_nn-1.0.1.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2025-08-15 17:38:06",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "akshathmangudi",
"github_project": "tnn",
"travis_ci": false,
"coveralls": false,
"github_actions": false,
"lcname": "tversky-nn"
}