pytorch-frame


Namepytorch-frame JSON
Version 0.2.3 PyPI version JSON
download
home_pageNone
SummaryTabular Deep Learning Library for PyTorch
upload_time2024-07-08 22:04:33
maintainerNone
docs_urlNone
authorNone
requires_python>=3.8
licenseNone
keywords deep-learning pytorch tabular-learning data-frame
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            [testing-image]: https://github.com/pyg-team/pytorch-frame/actions/workflows/testing.yml/badge.svg
[testing-url]: https://github.com/pyg-team/pytorch-frame/actions/workflows/testing.yml
[contributing-image]: https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat
[contributing-url]: https://github.com/pyg-team/pytorch-frame/blob/master/.github/CONTRIBUTING.md
[slack-image]: https://img.shields.io/badge/slack-pyf-brightgreen
[slack-url]: https://data.pyg.org/slack.html
[pypi-image]: https://badge.fury.io/py/pytorch-frame.svg
[pypi-url]: https://pypi.python.org/pypi/pytorch-frame
[docs-image]: https://readthedocs.org/projects/pytorch-frame/badge/?version=latest
[docs-url]: https://pytorch-frame.readthedocs.io/en/latest

<div align="center">

<img height="175" src="https://raw.githubusercontent.com/pyg-team/pyg_sphinx_theme/master/pyg_sphinx_theme/static/img/pytorch_frame_logo_text.png?sanitize=true" />

<br>
<br>

**A modular deep learning framework for building neural network models on heterogeneous tabular data.**

--------------------------------------------------------------------------------

[![PyPI Version][pypi-image]][pypi-url]
[![Testing Status][testing-image]][testing-url]
[![Docs Status][docs-image]][docs-url]
[![Contributing][contributing-image]][contributing-url]
[![Slack][slack-image]][slack-url]

</div>

**[Documentation](https://pytorch-frame.readthedocs.io)** | **[Paper](https://arxiv.org/abs/2404.00776)**

PyTorch Frame is a deep learning extension for [PyTorch](https://pytorch.org/), designed for heterogeneous tabular data with different column types, including numerical, categorical, time, text, and images. It offers a modular framework for implementing existing and future methods. The library features methods from state-of-the-art models, user-friendly mini-batch loaders, benchmark datasets, and interfaces for custom data integration.

PyTorch Frame democratizes deep learning research for tabular data, catering to both novices and experts alike. Our goals are:

1. **Facilitate Deep Learning for Tabular Data:** Historically, tree-based models (e.g., GBDT) excelled at tabular learning but had notable limitations, such as integration difficulties with downstream models, and handling complex column types, such as texts, sequences, and embeddings. Deep tabular models are promising to resolve the limitations. We aim to facilitate deep learning research on tabular data by modularizing its implementation and supporting the diverse column types.

2. **Integrates with Diverse Model Architectures like Large Language Models:** PyTorch Frame supports integration with a variety of different architectures including LLMs. With any downloaded model or embedding API endpoint, you can encode your text data with embeddings and train it with deep learning models alongside other complex semantic types. We support the following (but not limited to):
<table>
  <tr>
    <td align="center">
      <a href="https://platform.openai.com/docs/guides/embeddings">
        <img src="docs/source/_figures/OpenAI_Logo.png" alt="OpenAI" width="100px"/>
      </a>
      <br /><a href="https://github.com/pyg-team/pytorch-frame/blob/master/examples/llm_embedding.py">OpenAI Embedding Code Example</a>
    </td>
    <td align="center">
      <a href="https://cohere.com/embeddings">
        <img src="docs/source/_figures/cohere-logo.png" alt="Cohere" width="100px"/>
      </a>
      <br /><a href="https://github.com/pyg-team/pytorch-frame/blob/master/examples/llm_embedding.py">Cohere Embed v3 Code Example</a>
    </td>
    <td align="center">
      <a href="https://huggingface.co/">
        <img src="docs/source/_figures/hf-logo-with-title.png" alt="Hugging Face" width="100px"/>
      </a>
      <br /><a href="https://github.com/pyg-team/pytorch-frame/blob/master/examples/transformers_text.py">Hugging Face Code Example</a>
    </td>
      <td align="center">
      <a href="https://www.voyageai.com/">
        <img src="docs/source/_figures/voyageai.webp" alt="Voyage AI" width="100px"/>
      </a>
      <br /><a href="https://github.com/pyg-team/pytorch-frame/blob/master/examples/llm_embedding.py">Voyage AI Code Example</a>
    </td>
  </tr>
</table>

<hr style="border: 0.5px solid #ccc;">

* [Library Highlights](#library-highlights)
* [Architecture Overview](#architecture-overview)
* [Quick Tour](#quick-tour)
* [Implemented Deep Tabular Models](#implemented-deep-tabular-models)
* [Benchmark](#benchmark)
* [Installation](#installation)

## Library Highlights

PyTorch Frame builds directly upon PyTorch, ensuring a smooth transition for existing PyTorch users. Key features include:

* **Diverse column types**:
  PyTorch Frame supports learning across various column types: `numerical`, `categorical`, `multicategorical`, `text_embedded`, `text_tokenized`, `timestamp`, `image_embedded`, and `embedding`. See [here](https://pytorch-frame.readthedocs.io/en/latest/handling_advanced_stypes/handle_heterogeneous_stypes.html) for the detailed tutorial.
* **Modular model design**:
  Enables modular deep learning model implementations, promoting reusability, clear coding, and experimentation flexibility. Further details in the [architecture overview](#architecture-overview).
* **Models**
  Implements many [state-of-the-art deep tabular models](#implemented-deep-tabular-models) as well as strong GBDTs (XGBoost, CatBoost, and LightGBM) with hyper-parameter tuning.
* **Datasets**:
  Comes with a collection of readily-usable tabular datasets. Also supports custom datasets to solve your own problem.
  We [benchmark](https://github.com/pyg-team/pytorch-frame/blob/master/benchmark) deep tabular models against GBDTs.
* **PyTorch integration**:
  Integrates effortlessly with other PyTorch libraries, facilitating end-to-end training of PyTorch Frame with downstream PyTorch models. For example, by integrating with [PyG](https://pyg.org/), a PyTorch library for GNNs, we can perform deep learning over relational databases. Learn more in [RelBench](https://relbench.stanford.edu/) and [example code (WIP)](https://github.com/snap-stanford/relbench/blob/main/examples/gnn.py).

## Architecture Overview

Models in PyTorch Frame follow a modular design of `FeatureEncoder`, `TableConv`, and `Decoder`, as shown in the figure below:

<p align="center">
  <img width="50%" src="https://raw.githubusercontent.com/pyg-team/pytorch-frame/master/docs/source/_figures/architecture.png" />
</p>

In essence, this modular setup empowers users to effortlessly experiment with myriad architectures:

* `Materialization` handles converting the raw pandas `DataFrame` into a `TensorFrame` that is amenable to Pytorch-based training and modeling.
* `FeatureEncoder` encodes `TensorFrame` into hidden column embeddings of size `[batch_size, num_cols, channels]`.
* `TableConv` models column-wise interactions over the hidden embeddings.
* `Decoder` generates embedding/prediction per row.


## Quick Tour

In this quick tour, we showcase the ease of creating and training a deep tabular model with only a few lines of code.

### Build and train your own deep tabular model

As an example, we implement a simple `ExampleTransformer` following the modular architecture of Pytorch Frame.
In the example below:
* `self.encoder` maps an input `TensorFrame` to an embedding of size `[batch_size, num_cols, channels]`.
* `self.convs` interatively transforms the embedding of size `[batch_size, num_cols, channels]` into an embedding of the same size.
* `self.decoder` pools the embedding of size `[batch_size, num_cols, channels]` into `[batch_size, out_channels]`.

```python
from torch import Tensor
from torch.nn import Linear, Module, ModuleList

import torch_frame
from torch_frame import TensorFrame, stype
from torch_frame.nn.conv import TabTransformerConv
from torch_frame.nn.encoder import (
    EmbeddingEncoder,
    LinearEncoder,
    StypeWiseFeatureEncoder,
)

class ExampleTransformer(Module):
    def __init__(
        self,
        channels, out_channels, num_layers, num_heads,
        col_stats, col_names_dict,
    ):
        super().__init__()
        self.encoder = StypeWiseFeatureEncoder(
            out_channels=channels,
            col_stats=col_stats,
            col_names_dict=col_names_dict,
            stype_encoder_dict={
                stype.categorical: EmbeddingEncoder(),
                stype.numerical: LinearEncoder()
            },
        )
        self.convs = ModuleList([
            TabTransformerConv(
                channels=channels,
                num_heads=num_heads,
            ) for _ in range(num_layers)
        ])
        self.decoder = Linear(channels, out_channels)

    def forward(self, tf: TensorFrame) -> Tensor:
        x, _ = self.encoder(tf)
        for conv in self.convs:
            x = conv(x)
        out = self.decoder(x.mean(dim=1))
        return out
```

To prepare the data, we can quickly instantiate a pre-defined dataset and create a
PyTorch-compatible data loader as follows:

```python
from torch_frame.datasets import Yandex
from torch_frame.data import DataLoader

dataset = Yandex(root='/tmp/adult', name='adult')
dataset.materialize()
train_dataset = dataset[:0.8]
train_loader = DataLoader(train_dataset.tensor_frame, batch_size=128,
                          shuffle=True)
```

Then, we just follow the <a href="https://pytorch.org/tutorials/beginner/basics/optimization_tutorial.html#full-implementation">standard PyTorch training procedure</a> to optimize the
model parameters. That's it!

```python
import torch
import torch.nn.functional as F

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ExampleTransformer(
    channels=32,
    out_channels=dataset.num_classes,
    num_layers=2,
    num_heads=8,
    col_stats=train_dataset.col_stats,
    col_names_dict=train_dataset.tensor_frame.col_names_dict,
).to(device)

optimizer = torch.optim.Adam(model.parameters())

for epoch in range(50):
    for tf in train_loader:
        tf = tf.to(device)
        pred = model.forward(tf)
        loss = F.cross_entropy(pred, tf.y)
        optimizer.zero_grad()
        loss.backward()
```

## Implemented Deep Tabular Models

We list currently supported deep tabular models:

* **[Trompt](https://pytorch-frame.readthedocs.io/en/latest/generated/torch_frame.nn.models.Trompt.html)** from Chen *et al.*: [Trompt: Towards a Better Deep Neural Network for Tabular Data](https://arxiv.org/abs/2305.18446) (ICML 2023) [[**Example**](https://github.com/pyg-team/pytorch-frame/blob/master/examples/trompt.py)]
* **[FTTransformer](https://pytorch-frame.readthedocs.io/en/latest/generated/torch_frame.nn.models.FTTransformer.html)** from Gorishniy *et al.*: [Revisiting Deep Learning Models for Tabular Data](https://arxiv.org/abs/2106.11959) (NeurIPS 2021) [[**Example**](https://github.com/pyg-team/pytorch-frame/blob/master/examples/revisiting.py)]
* **[ResNet](https://pytorch-frame.readthedocs.io/en/latest/generated/torch_frame.nn.models.ResNet.html)** from Gorishniy *et al.*: [Revisiting Deep Learning Models for Tabular Data](https://arxiv.org/abs/2106.11959) (NeurIPS 2021) [[**Example**](https://github.com/pyg-team/pytorch-frame/blob/master/examples/revisiting.py)]
* **[TabNet](https://pytorch-frame.readthedocs.io/en/latest/generated/torch_frame.nn.models.TabNet.html)** from Arık *et al.*: [TabNet: Attentive Interpretable Tabular Learning](https://arxiv.org/abs/1908.07442) (AAAI 2021) [[**Example**](https://github.com/pyg-team/pytorch-frame/blob/master/examples/tabnet.py)]
* **[ExcelFormer](https://pytorch-frame.readthedocs.io/en/latest/generated/torch_frame.nn.models.ExcelFormer.html)** from Chen *et al.*: [ExcelFormer: A Neural Network Surpassing GBDTs on Tabular Data](https://arxiv.org/abs/2301.02819) [[**Example**](https://github.com/pyg-team/pytorch-frame/blob/master/examples/excelformer.py)]
* **[TabTransformer](https://pytorch-frame.readthedocs.io/en/latest/generated/torch_frame.nn.models.TabTransformer.html)** from Huang *et al.*: [TabTransformer: Tabular Data Modeling Using Contextual Embeddings](https://arxiv.org/abs/2012.06678) [[**Example**](https://github.com/pyg-team/pytorch-frame/blob/master/examples/tab_transformer.py)]

In addition, we implemented `XGBoost`, `CatBoost`, and `LightGBM` [examples](https://github.com/pyg-team/pytorch-frame/blob/master/examples/tuned_gbdt.py) with hyperparameter-tuning using [Optuna](https://optuna.org/) for users who'd like to compare their model performance with `GBDTs`.


## Benchmark

We benchmark recent tabular deep learning models against GBDTs over diverse public datasets with different sizes and task types.

The following chart shows the performance of various models on small regression datasets, where the row represents the model names and the column represents dataset indices (we have 13 datasets here). For more results on classification and larger datasets, please check the [benchmark documentation](https://github.com/pyg-team/pytorch-frame/blob/master/benchmark).

| Model Name          | dataset_0       | dataset_1       | dataset_2       | dataset_3       | dataset_4       | dataset_5       | dataset_6       | dataset_7       | dataset_8       | dataset_9       | dataset_10      | dataset_11      | dataset_12      |
|:--------------------|:----------------|:----------------|:----------------|:----------------|:----------------|:----------------|:----------------|:----------------|:----------------|:----------------|:----------------|:----------------|:----------------|
| XGBoost             | **0.247±0.000** | 0.077±0.000     | 0.167±0.000     | 1.119±0.000     | 0.328±0.000     | 1.024±0.000     | **0.292±0.000** | 0.606±0.000     | **0.876±0.000** | 0.023±0.000     | **0.697±0.000** | 0.865±0.000     | 0.435±0.000     |
| CatBoost            | 0.265±0.000     | 0.062±0.000     | 0.128±0.000     | 0.336±0.000     | 0.346±0.000     | 0.443±0.000     | 0.375±0.000     | 0.273±0.000     | 0.881±0.000     | 0.040±0.000     | 0.756±0.000     | 0.876±0.000     | 0.439±0.000     |
| LightGBM            | 0.253±0.000     | 0.054±0.000     | **0.112±0.000** | 0.302±0.000     | 0.325±0.000     | **0.384±0.000** | 0.295±0.000     | **0.272±0.000** | 0.877±0.000     | 0.011±0.000     | 0.702±0.000     | **0.863±0.000** | **0.395±0.000** |
| Trompt              | 0.261±0.003     | **0.015±0.005** | 0.118±0.001     | **0.262±0.001** | **0.323±0.001** | 0.418±0.003     | 0.329±0.009     | 0.312±0.002     | OOM             | **0.008±0.001** | 0.779±0.006     | 0.874±0.004     | 0.424±0.005     |
| ResNet              | 0.288±0.006     | 0.018±0.003     | 0.124±0.001     | 0.268±0.001     | 0.335±0.001     | 0.434±0.004     | 0.325±0.012     | 0.324±0.004     | 0.895±0.005     | 0.036±0.002     | 0.794±0.006     | 0.875±0.004     | 0.468±0.004     |
| FTTransformerBucket | 0.325±0.008     | 0.096±0.005     | 0.360±0.354     | 0.284±0.005     | 0.342±0.004     | 0.441±0.003     | 0.345±0.007     | 0.339±0.003     | OOM             | 0.105±0.011     | 0.807±0.010     | 0.885±0.008     | 0.468±0.006     |
| ExcelFormer         | 0.302±0.003     | 0.099±0.003     | 0.145±0.003     | 0.382±0.011     | 0.344±0.002     | 0.411±0.005     | 0.359±0.016     | 0.336±0.008     | OOM             | 0.192±0.014     | 0.794±0.005     | 0.890±0.003     | 0.445±0.005     |
| FTTransformer       | 0.335±0.010     | 0.161±0.022     | 0.140±0.002     | 0.277±0.004     | 0.335±0.003     | 0.445±0.003     | 0.361±0.018     | 0.345±0.005     | OOM             | 0.106±0.012     | 0.826±0.005     | 0.896±0.007     | 0.461±0.003     |
| TabNet              | 0.279±0.003     | 0.224±0.016     | 0.141±0.010     | 0.275±0.002     | 0.348±0.003     | 0.451±0.007     | 0.355±0.030     | 0.332±0.004     | 0.992±0.182     | 0.015±0.002     | 0.805±0.014     | 0.885±0.013     | 0.544±0.011     |
| TabTransformer      | 0.624±0.003     | 0.229±0.003     | 0.369±0.005     | 0.340±0.004     | 0.388±0.002     | 0.539±0.003     | 0.619±0.005     | 0.351±0.001     | 0.893±0.005     | 0.431±0.001     | 0.819±0.002     | 0.886±0.005     | 0.545±0.004     |


We see that some recent deep tabular models were able to achieve competitive model performance to strong GBDTs (despite being 5--100 times slower to train). Making deep tabular models even more performant with less compute is a fruitful direction for future research.

We also benchmark different text encoders on a real-world tabular dataset ([Wine Reviews](https://pytorch-frame.readthedocs.io/en/latest/generated/torch_frame.datasets.MultimodalTextBenchmark.html#torch_frame.datasets.MultimodalTextBenchmark)) with one text column. The following table shows the performance:

| Test Acc   | Method          | Model Name                                                 | Source        |
|:-----------|:----------------|:-----------------------------------------------------------|:--------------|
| 0.7926     | Pre-trained     | sentence-transformers/all-distilroberta-v1 (125M # params) | Hugging Face  |
| 0.7998     | Pre-trained     | embed-english-v3.0 (dimension size: 1024)                  | Cohere        |
| 0.8102     | Pre-trained     | text-embedding-ada-002 (dimension size: 1536)              | OpenAI        |
| 0.8147     | Pre-trained     | voyage-01 (dimension size: 1024)                           | Voyage AI     |
| 0.8203     | Pre-trained     | intfloat/e5-mistral-7b-instruct (7B # params)              | Hugging Face  |
| **0.8230** | LoRA Finetune   | DistilBERT (66M # params)                                  | Hugging Face  |

The benchmark script for Hugging Face text encoders is in this [file](https://github.com/pyg-team/pytorch-frame/blob/master/examples/transformers_text.py) and for the rest of text encoders is in this [file](https://github.com/pyg-team/pytorch-frame/blob/master/examples/llm_embedding.py).

## Installation

PyTorch Frame is available for Python 3.8 to Python 3.11.

```
pip install pytorch_frame
```

See [the installation guide](https://pytorch-frame.readthedocs.io/en/latest/get_started/installation.html) for other options.

## Cite

If you use PyTorch Frame in your work, please cite our paper (Bibtex below).
```
@article{hu2024pytorch,
  title={PyTorch Frame: A Modular Framework for Multi-Modal Tabular Learning},
  author={Hu, Weihua and Yuan, Yiwen and Zhang, Zecheng and Nitta, Akihiro and Cao, Kaidi and Kocijan, Vid and Leskovec, Jure and Fey, Matthias},
  journal={arXiv preprint arXiv:2404.00776},
  year={2024}
}
```


            

Raw data

            {
    "_id": null,
    "home_page": null,
    "name": "pytorch-frame",
    "maintainer": null,
    "docs_url": null,
    "requires_python": ">=3.8",
    "maintainer_email": null,
    "keywords": "deep-learning, pytorch, tabular-learning, data-frame",
    "author": null,
    "author_email": "PyG Team <team@pyg.org>",
    "download_url": "https://files.pythonhosted.org/packages/7f/5c/3c2225a2391665d0d91a445c1f34a6cce61c4667e35d73a926141cc033f9/pytorch_frame-0.2.3.tar.gz",
    "platform": null,
    "description": "[testing-image]: https://github.com/pyg-team/pytorch-frame/actions/workflows/testing.yml/badge.svg\n[testing-url]: https://github.com/pyg-team/pytorch-frame/actions/workflows/testing.yml\n[contributing-image]: https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat\n[contributing-url]: https://github.com/pyg-team/pytorch-frame/blob/master/.github/CONTRIBUTING.md\n[slack-image]: https://img.shields.io/badge/slack-pyf-brightgreen\n[slack-url]: https://data.pyg.org/slack.html\n[pypi-image]: https://badge.fury.io/py/pytorch-frame.svg\n[pypi-url]: https://pypi.python.org/pypi/pytorch-frame\n[docs-image]: https://readthedocs.org/projects/pytorch-frame/badge/?version=latest\n[docs-url]: https://pytorch-frame.readthedocs.io/en/latest\n\n<div align=\"center\">\n\n<img height=\"175\" src=\"https://raw.githubusercontent.com/pyg-team/pyg_sphinx_theme/master/pyg_sphinx_theme/static/img/pytorch_frame_logo_text.png?sanitize=true\" />\n\n<br>\n<br>\n\n**A modular deep learning framework for building neural network models on heterogeneous tabular data.**\n\n--------------------------------------------------------------------------------\n\n[![PyPI Version][pypi-image]][pypi-url]\n[![Testing Status][testing-image]][testing-url]\n[![Docs Status][docs-image]][docs-url]\n[![Contributing][contributing-image]][contributing-url]\n[![Slack][slack-image]][slack-url]\n\n</div>\n\n**[Documentation](https://pytorch-frame.readthedocs.io)** | **[Paper](https://arxiv.org/abs/2404.00776)**\n\nPyTorch Frame is a deep learning extension for [PyTorch](https://pytorch.org/), designed for heterogeneous tabular data with different column types, including numerical, categorical, time, text, and images. It offers a modular framework for implementing existing and future methods. The library features methods from state-of-the-art models, user-friendly mini-batch loaders, benchmark datasets, and interfaces for custom data integration.\n\nPyTorch Frame democratizes deep learning research for tabular data, catering to both novices and experts alike. Our goals are:\n\n1. **Facilitate Deep Learning for Tabular Data:** Historically, tree-based models (e.g., GBDT) excelled at tabular learning but had notable limitations, such as integration difficulties with downstream models, and handling complex column types, such as texts, sequences, and embeddings. Deep tabular models are promising to resolve the limitations. We aim to facilitate deep learning research on tabular data by modularizing its implementation and supporting the diverse column types.\n\n2. **Integrates with Diverse Model Architectures like Large Language Models:** PyTorch Frame supports integration with a variety of different architectures including LLMs. With any downloaded model or embedding API endpoint, you can encode your text data with embeddings and train it with deep learning models alongside other complex semantic types. We support the following (but not limited to):\n<table>\n  <tr>\n    <td align=\"center\">\n      <a href=\"https://platform.openai.com/docs/guides/embeddings\">\n        <img src=\"docs/source/_figures/OpenAI_Logo.png\" alt=\"OpenAI\" width=\"100px\"/>\n      </a>\n      <br /><a href=\"https://github.com/pyg-team/pytorch-frame/blob/master/examples/llm_embedding.py\">OpenAI Embedding Code Example</a>\n    </td>\n    <td align=\"center\">\n      <a href=\"https://cohere.com/embeddings\">\n        <img src=\"docs/source/_figures/cohere-logo.png\" alt=\"Cohere\" width=\"100px\"/>\n      </a>\n      <br /><a href=\"https://github.com/pyg-team/pytorch-frame/blob/master/examples/llm_embedding.py\">Cohere Embed v3 Code Example</a>\n    </td>\n    <td align=\"center\">\n      <a href=\"https://huggingface.co/\">\n        <img src=\"docs/source/_figures/hf-logo-with-title.png\" alt=\"Hugging Face\" width=\"100px\"/>\n      </a>\n      <br /><a href=\"https://github.com/pyg-team/pytorch-frame/blob/master/examples/transformers_text.py\">Hugging Face Code Example</a>\n    </td>\n      <td align=\"center\">\n      <a href=\"https://www.voyageai.com/\">\n        <img src=\"docs/source/_figures/voyageai.webp\" alt=\"Voyage AI\" width=\"100px\"/>\n      </a>\n      <br /><a href=\"https://github.com/pyg-team/pytorch-frame/blob/master/examples/llm_embedding.py\">Voyage AI Code Example</a>\n    </td>\n  </tr>\n</table>\n\n<hr style=\"border: 0.5px solid #ccc;\">\n\n* [Library Highlights](#library-highlights)\n* [Architecture Overview](#architecture-overview)\n* [Quick Tour](#quick-tour)\n* [Implemented Deep Tabular Models](#implemented-deep-tabular-models)\n* [Benchmark](#benchmark)\n* [Installation](#installation)\n\n## Library Highlights\n\nPyTorch Frame builds directly upon PyTorch, ensuring a smooth transition for existing PyTorch users. Key features include:\n\n* **Diverse column types**:\n  PyTorch Frame supports learning across various column types: `numerical`, `categorical`, `multicategorical`, `text_embedded`, `text_tokenized`, `timestamp`, `image_embedded`, and `embedding`. See [here](https://pytorch-frame.readthedocs.io/en/latest/handling_advanced_stypes/handle_heterogeneous_stypes.html) for the detailed tutorial.\n* **Modular model design**:\n  Enables modular deep learning model implementations, promoting reusability, clear coding, and experimentation flexibility. Further details in the [architecture overview](#architecture-overview).\n* **Models**\n  Implements many [state-of-the-art deep tabular models](#implemented-deep-tabular-models) as well as strong GBDTs (XGBoost, CatBoost, and LightGBM) with hyper-parameter tuning.\n* **Datasets**:\n  Comes with a collection of readily-usable tabular datasets. Also supports custom datasets to solve your own problem.\n  We [benchmark](https://github.com/pyg-team/pytorch-frame/blob/master/benchmark) deep tabular models against GBDTs.\n* **PyTorch integration**:\n  Integrates effortlessly with other PyTorch libraries, facilitating end-to-end training of PyTorch Frame with downstream PyTorch models. For example, by integrating with [PyG](https://pyg.org/), a PyTorch library for GNNs, we can perform deep learning over relational databases. Learn more in [RelBench](https://relbench.stanford.edu/) and [example code (WIP)](https://github.com/snap-stanford/relbench/blob/main/examples/gnn.py).\n\n## Architecture Overview\n\nModels in PyTorch Frame follow a modular design of `FeatureEncoder`, `TableConv`, and `Decoder`, as shown in the figure below:\n\n<p align=\"center\">\n  <img width=\"50%\" src=\"https://raw.githubusercontent.com/pyg-team/pytorch-frame/master/docs/source/_figures/architecture.png\" />\n</p>\n\nIn essence, this modular setup empowers users to effortlessly experiment with myriad architectures:\n\n* `Materialization` handles converting the raw pandas `DataFrame` into a `TensorFrame` that is amenable to Pytorch-based training and modeling.\n* `FeatureEncoder` encodes `TensorFrame` into hidden column embeddings of size `[batch_size, num_cols, channels]`.\n* `TableConv` models column-wise interactions over the hidden embeddings.\n* `Decoder` generates embedding/prediction per row.\n\n\n## Quick Tour\n\nIn this quick tour, we showcase the ease of creating and training a deep tabular model with only a few lines of code.\n\n### Build and train your own deep tabular model\n\nAs an example, we implement a simple `ExampleTransformer` following the modular architecture of Pytorch Frame.\nIn the example below:\n* `self.encoder` maps an input `TensorFrame` to an embedding of size `[batch_size, num_cols, channels]`.\n* `self.convs` interatively transforms the embedding of size `[batch_size, num_cols, channels]` into an embedding of the same size.\n* `self.decoder` pools the embedding of size `[batch_size, num_cols, channels]` into `[batch_size, out_channels]`.\n\n```python\nfrom torch import Tensor\nfrom torch.nn import Linear, Module, ModuleList\n\nimport torch_frame\nfrom torch_frame import TensorFrame, stype\nfrom torch_frame.nn.conv import TabTransformerConv\nfrom torch_frame.nn.encoder import (\n    EmbeddingEncoder,\n    LinearEncoder,\n    StypeWiseFeatureEncoder,\n)\n\nclass ExampleTransformer(Module):\n    def __init__(\n        self,\n        channels, out_channels, num_layers, num_heads,\n        col_stats, col_names_dict,\n    ):\n        super().__init__()\n        self.encoder = StypeWiseFeatureEncoder(\n            out_channels=channels,\n            col_stats=col_stats,\n            col_names_dict=col_names_dict,\n            stype_encoder_dict={\n                stype.categorical: EmbeddingEncoder(),\n                stype.numerical: LinearEncoder()\n            },\n        )\n        self.convs = ModuleList([\n            TabTransformerConv(\n                channels=channels,\n                num_heads=num_heads,\n            ) for _ in range(num_layers)\n        ])\n        self.decoder = Linear(channels, out_channels)\n\n    def forward(self, tf: TensorFrame) -> Tensor:\n        x, _ = self.encoder(tf)\n        for conv in self.convs:\n            x = conv(x)\n        out = self.decoder(x.mean(dim=1))\n        return out\n```\n\nTo prepare the data, we can quickly instantiate a pre-defined dataset and create a\nPyTorch-compatible data loader as follows:\n\n```python\nfrom torch_frame.datasets import Yandex\nfrom torch_frame.data import DataLoader\n\ndataset = Yandex(root='/tmp/adult', name='adult')\ndataset.materialize()\ntrain_dataset = dataset[:0.8]\ntrain_loader = DataLoader(train_dataset.tensor_frame, batch_size=128,\n                          shuffle=True)\n```\n\nThen, we just follow the <a href=\"https://pytorch.org/tutorials/beginner/basics/optimization_tutorial.html#full-implementation\">standard PyTorch training procedure</a> to optimize the\nmodel parameters. That's it!\n\n```python\nimport torch\nimport torch.nn.functional as F\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nmodel = ExampleTransformer(\n    channels=32,\n    out_channels=dataset.num_classes,\n    num_layers=2,\n    num_heads=8,\n    col_stats=train_dataset.col_stats,\n    col_names_dict=train_dataset.tensor_frame.col_names_dict,\n).to(device)\n\noptimizer = torch.optim.Adam(model.parameters())\n\nfor epoch in range(50):\n    for tf in train_loader:\n        tf = tf.to(device)\n        pred = model.forward(tf)\n        loss = F.cross_entropy(pred, tf.y)\n        optimizer.zero_grad()\n        loss.backward()\n```\n\n## Implemented Deep Tabular Models\n\nWe list currently supported deep tabular models:\n\n* **[Trompt](https://pytorch-frame.readthedocs.io/en/latest/generated/torch_frame.nn.models.Trompt.html)** from Chen *et al.*: [Trompt: Towards a Better Deep Neural Network for Tabular Data](https://arxiv.org/abs/2305.18446) (ICML 2023) [[**Example**](https://github.com/pyg-team/pytorch-frame/blob/master/examples/trompt.py)]\n* **[FTTransformer](https://pytorch-frame.readthedocs.io/en/latest/generated/torch_frame.nn.models.FTTransformer.html)** from Gorishniy *et al.*: [Revisiting Deep Learning Models for Tabular Data](https://arxiv.org/abs/2106.11959) (NeurIPS 2021) [[**Example**](https://github.com/pyg-team/pytorch-frame/blob/master/examples/revisiting.py)]\n* **[ResNet](https://pytorch-frame.readthedocs.io/en/latest/generated/torch_frame.nn.models.ResNet.html)** from Gorishniy *et al.*: [Revisiting Deep Learning Models for Tabular Data](https://arxiv.org/abs/2106.11959) (NeurIPS 2021) [[**Example**](https://github.com/pyg-team/pytorch-frame/blob/master/examples/revisiting.py)]\n* **[TabNet](https://pytorch-frame.readthedocs.io/en/latest/generated/torch_frame.nn.models.TabNet.html)** from Ar\u0131k *et al.*: [TabNet: Attentive Interpretable Tabular Learning](https://arxiv.org/abs/1908.07442) (AAAI 2021) [[**Example**](https://github.com/pyg-team/pytorch-frame/blob/master/examples/tabnet.py)]\n* **[ExcelFormer](https://pytorch-frame.readthedocs.io/en/latest/generated/torch_frame.nn.models.ExcelFormer.html)** from Chen *et al.*: [ExcelFormer: A Neural Network Surpassing GBDTs on Tabular Data](https://arxiv.org/abs/2301.02819) [[**Example**](https://github.com/pyg-team/pytorch-frame/blob/master/examples/excelformer.py)]\n* **[TabTransformer](https://pytorch-frame.readthedocs.io/en/latest/generated/torch_frame.nn.models.TabTransformer.html)** from Huang *et al.*: [TabTransformer: Tabular Data Modeling Using Contextual Embeddings](https://arxiv.org/abs/2012.06678) [[**Example**](https://github.com/pyg-team/pytorch-frame/blob/master/examples/tab_transformer.py)]\n\nIn addition, we implemented `XGBoost`, `CatBoost`, and `LightGBM` [examples](https://github.com/pyg-team/pytorch-frame/blob/master/examples/tuned_gbdt.py) with hyperparameter-tuning using [Optuna](https://optuna.org/) for users who'd like to compare their model performance with `GBDTs`.\n\n\n## Benchmark\n\nWe benchmark recent tabular deep learning models against GBDTs over diverse public datasets with different sizes and task types.\n\nThe following chart shows the performance of various models on small regression datasets, where the row represents the model names and the column represents dataset indices (we have 13 datasets here). For more results on classification and larger datasets, please check the [benchmark documentation](https://github.com/pyg-team/pytorch-frame/blob/master/benchmark).\n\n| Model Name          | dataset_0       | dataset_1       | dataset_2       | dataset_3       | dataset_4       | dataset_5       | dataset_6       | dataset_7       | dataset_8       | dataset_9       | dataset_10      | dataset_11      | dataset_12      |\n|:--------------------|:----------------|:----------------|:----------------|:----------------|:----------------|:----------------|:----------------|:----------------|:----------------|:----------------|:----------------|:----------------|:----------------|\n| XGBoost             | **0.247\u00b10.000** | 0.077\u00b10.000     | 0.167\u00b10.000     | 1.119\u00b10.000     | 0.328\u00b10.000     | 1.024\u00b10.000     | **0.292\u00b10.000** | 0.606\u00b10.000     | **0.876\u00b10.000** | 0.023\u00b10.000     | **0.697\u00b10.000** | 0.865\u00b10.000     | 0.435\u00b10.000     |\n| CatBoost            | 0.265\u00b10.000     | 0.062\u00b10.000     | 0.128\u00b10.000     | 0.336\u00b10.000     | 0.346\u00b10.000     | 0.443\u00b10.000     | 0.375\u00b10.000     | 0.273\u00b10.000     | 0.881\u00b10.000     | 0.040\u00b10.000     | 0.756\u00b10.000     | 0.876\u00b10.000     | 0.439\u00b10.000     |\n| LightGBM            | 0.253\u00b10.000     | 0.054\u00b10.000     | **0.112\u00b10.000** | 0.302\u00b10.000     | 0.325\u00b10.000     | **0.384\u00b10.000** | 0.295\u00b10.000     | **0.272\u00b10.000** | 0.877\u00b10.000     | 0.011\u00b10.000     | 0.702\u00b10.000     | **0.863\u00b10.000** | **0.395\u00b10.000** |\n| Trompt              | 0.261\u00b10.003     | **0.015\u00b10.005** | 0.118\u00b10.001     | **0.262\u00b10.001** | **0.323\u00b10.001** | 0.418\u00b10.003     | 0.329\u00b10.009     | 0.312\u00b10.002     | OOM             | **0.008\u00b10.001** | 0.779\u00b10.006     | 0.874\u00b10.004     | 0.424\u00b10.005     |\n| ResNet              | 0.288\u00b10.006     | 0.018\u00b10.003     | 0.124\u00b10.001     | 0.268\u00b10.001     | 0.335\u00b10.001     | 0.434\u00b10.004     | 0.325\u00b10.012     | 0.324\u00b10.004     | 0.895\u00b10.005     | 0.036\u00b10.002     | 0.794\u00b10.006     | 0.875\u00b10.004     | 0.468\u00b10.004     |\n| FTTransformerBucket | 0.325\u00b10.008     | 0.096\u00b10.005     | 0.360\u00b10.354     | 0.284\u00b10.005     | 0.342\u00b10.004     | 0.441\u00b10.003     | 0.345\u00b10.007     | 0.339\u00b10.003     | OOM             | 0.105\u00b10.011     | 0.807\u00b10.010     | 0.885\u00b10.008     | 0.468\u00b10.006     |\n| ExcelFormer         | 0.302\u00b10.003     | 0.099\u00b10.003     | 0.145\u00b10.003     | 0.382\u00b10.011     | 0.344\u00b10.002     | 0.411\u00b10.005     | 0.359\u00b10.016     | 0.336\u00b10.008     | OOM             | 0.192\u00b10.014     | 0.794\u00b10.005     | 0.890\u00b10.003     | 0.445\u00b10.005     |\n| FTTransformer       | 0.335\u00b10.010     | 0.161\u00b10.022     | 0.140\u00b10.002     | 0.277\u00b10.004     | 0.335\u00b10.003     | 0.445\u00b10.003     | 0.361\u00b10.018     | 0.345\u00b10.005     | OOM             | 0.106\u00b10.012     | 0.826\u00b10.005     | 0.896\u00b10.007     | 0.461\u00b10.003     |\n| TabNet              | 0.279\u00b10.003     | 0.224\u00b10.016     | 0.141\u00b10.010     | 0.275\u00b10.002     | 0.348\u00b10.003     | 0.451\u00b10.007     | 0.355\u00b10.030     | 0.332\u00b10.004     | 0.992\u00b10.182     | 0.015\u00b10.002     | 0.805\u00b10.014     | 0.885\u00b10.013     | 0.544\u00b10.011     |\n| TabTransformer      | 0.624\u00b10.003     | 0.229\u00b10.003     | 0.369\u00b10.005     | 0.340\u00b10.004     | 0.388\u00b10.002     | 0.539\u00b10.003     | 0.619\u00b10.005     | 0.351\u00b10.001     | 0.893\u00b10.005     | 0.431\u00b10.001     | 0.819\u00b10.002     | 0.886\u00b10.005     | 0.545\u00b10.004     |\n\n\nWe see that some recent deep tabular models were able to achieve competitive model performance to strong GBDTs (despite being 5--100 times slower to train). Making deep tabular models even more performant with less compute is a fruitful direction for future research.\n\nWe also benchmark different text encoders on a real-world tabular dataset ([Wine Reviews](https://pytorch-frame.readthedocs.io/en/latest/generated/torch_frame.datasets.MultimodalTextBenchmark.html#torch_frame.datasets.MultimodalTextBenchmark)) with one text column. The following table shows the performance:\n\n| Test Acc   | Method          | Model Name                                                 | Source        |\n|:-----------|:----------------|:-----------------------------------------------------------|:--------------|\n| 0.7926     | Pre-trained     | sentence-transformers/all-distilroberta-v1 (125M # params) | Hugging Face  |\n| 0.7998     | Pre-trained     | embed-english-v3.0 (dimension size: 1024)                  | Cohere        |\n| 0.8102     | Pre-trained     | text-embedding-ada-002 (dimension size: 1536)              | OpenAI        |\n| 0.8147     | Pre-trained     | voyage-01 (dimension size: 1024)                           | Voyage AI     |\n| 0.8203     | Pre-trained     | intfloat/e5-mistral-7b-instruct (7B # params)              | Hugging Face  |\n| **0.8230** | LoRA Finetune   | DistilBERT (66M # params)                                  | Hugging Face  |\n\nThe benchmark script for Hugging Face text encoders is in this [file](https://github.com/pyg-team/pytorch-frame/blob/master/examples/transformers_text.py) and for the rest of text encoders is in this [file](https://github.com/pyg-team/pytorch-frame/blob/master/examples/llm_embedding.py).\n\n## Installation\n\nPyTorch Frame is available for Python 3.8 to Python 3.11.\n\n```\npip install pytorch_frame\n```\n\nSee [the installation guide](https://pytorch-frame.readthedocs.io/en/latest/get_started/installation.html) for other options.\n\n## Cite\n\nIf you use PyTorch Frame in your work, please cite our paper (Bibtex below).\n```\n@article{hu2024pytorch,\n  title={PyTorch Frame: A Modular Framework for Multi-Modal Tabular Learning},\n  author={Hu, Weihua and Yuan, Yiwen and Zhang, Zecheng and Nitta, Akihiro and Cao, Kaidi and Kocijan, Vid and Leskovec, Jure and Fey, Matthias},\n  journal={arXiv preprint arXiv:2404.00776},\n  year={2024}\n}\n```\n\n",
    "bugtrack_url": null,
    "license": null,
    "summary": "Tabular Deep Learning Library for PyTorch",
    "version": "0.2.3",
    "project_urls": {
        "changelog": "https://github.com/pyg-team/pytorch-frame/blob/master/CHANGELOG.md",
        "documentation": "https://pytorch-frame.readthedocs.io",
        "homepage": "https://pyg.org",
        "repository": "https://github.com/pyg-team/pytorch-frame.git"
    },
    "split_keywords": [
        "deep-learning",
        " pytorch",
        " tabular-learning",
        " data-frame"
    ],
    "urls": [
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "cb252c20d903d1fb65dc96a28a1fbf79998cf7f885e0fbc750d12af29ec16404",
                "md5": "01f7467452f6cb5604aa6d647f3db534",
                "sha256": "ef0d41a92bc5a090a6dd562a4f948406ad4f9656bde1c7262c6fad29f6e5521a"
            },
            "downloads": -1,
            "filename": "pytorch_frame-0.2.3-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "01f7467452f6cb5604aa6d647f3db534",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": ">=3.8",
            "size": 143313,
            "upload_time": "2024-07-08T22:04:30",
            "upload_time_iso_8601": "2024-07-08T22:04:30.910527Z",
            "url": "https://files.pythonhosted.org/packages/cb/25/2c20d903d1fb65dc96a28a1fbf79998cf7f885e0fbc750d12af29ec16404/pytorch_frame-0.2.3-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": "",
            "digests": {
                "blake2b_256": "7f5c3c2225a2391665d0d91a445c1f34a6cce61c4667e35d73a926141cc033f9",
                "md5": "5d0e7fb9b9c10d2404c0b9229c33f158",
                "sha256": "33907a25641932751019f696262e3bf3a44e7679aabfc43e55f7243c6aa85c23"
            },
            "downloads": -1,
            "filename": "pytorch_frame-0.2.3.tar.gz",
            "has_sig": false,
            "md5_digest": "5d0e7fb9b9c10d2404c0b9229c33f158",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": ">=3.8",
            "size": 102775,
            "upload_time": "2024-07-08T22:04:33",
            "upload_time_iso_8601": "2024-07-08T22:04:33.163977Z",
            "url": "https://files.pythonhosted.org/packages/7f/5c/3c2225a2391665d0d91a445c1f34a6cce61c4667e35d73a926141cc033f9/pytorch_frame-0.2.3.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2024-07-08 22:04:33",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "pyg-team",
    "github_project": "pytorch-frame",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "lcname": "pytorch-frame"
}
        
Elapsed time: 0.28794s