collatable


Namecollatable JSON
Version 0.6.0 PyPI version JSON
download
home_pageNone
SummaryConstructing batched tensors for any machine learning tasks
upload_time2025-07-19 09:16:54
maintainerNone
docs_urlNone
authoraltescy
requires_python<4.0,>=3.8
licenseNone
keywords python machine learning
VCS
bugtrack_url
requirements No requirements were recorded.
Travis-CI No Travis.
coveralls test coverage No coveralls.
            # Collatable

[![Actions Status](https://github.com/altescy/collatable/workflows/CI/badge.svg)](https://github.com/altescy/collatable/actions/workflows/ci.yml)
[![License](https://img.shields.io/github/license/altescy/collatable)](https://github.com/altescy/collatable/blob/main/LICENSE)
[![Python version](https://img.shields.io/pypi/pyversions/collatable)](https://github.com/altescy/collatable)
[![pypi version](https://img.shields.io/pypi/v/collatable)](https://pypi.org/project/collatable/)

Constructing batched tensors for any machine learning tasks

## Installation

```bash
pip install collatable
```

## Examples

The following scripts show how to tokenize/index/collate your dataset with `collatable`:

### Text Classification

```python
import collatable
from collatable import LabelField, MetadataField, TextField
from collatable.extras.indexer import LabelIndexer, TokenIndexer

dataset = [
    ("this is awesome", "positive"),
    ("this is a bad movie", "negative"),
    ("this movie is an awesome movie", "positive"),
    ("this movie is too bad to watch", "negative"),
]

# Set up indexers for tokens and labels
PAD_TOKEN = "<PAD>"
UNK_TOKEN = "<UNK>"
token_indexer = TokenIndexer[str](specials=[PAD_TOKEN, UNK_TOKEN], default=UNK_TOKEN)
label_indexer = LabelIndexer[str]()

# Load training dataset
instances = []
with token_indexer.context(train=True), label_indexer.context(train=True):
    for id_, (text, label) in enumerate(dataset):
        # Prepare each field with the corresponding field class
        text_field = TextField(
            text.split(),
            indexer=token_indexer,
            padding_value=token_indexer[PAD_TOKEN],
        )
        label_field = LabelField(
            label,
            indexer=label_indexer,
        )
        metadata_field = MetadataField({"id": id_})
        # Combine these fields into instance
        instance = dict(
            text=text_field,
            label=label_field,
            metadata=metadata_field,
        )
        instances.append(instance)

# Collate instances and build batch
output = collatable.collate(instances)
print(output)
```

Execution result:

```text
{'metadata': [{'id': 0}, {'id': 1}, {'id': 2}, {'id': 3}],
 'text': {
    'token_ids': array([[ 2,  3,  4,  0,  0,  0,  0],
                        [ 2,  3,  5,  6,  7,  0,  0],
                        [ 2,  7,  3,  8,  4,  7,  0],
                        [ 2,  7,  3,  9,  6, 10, 11]]),
    'mask': array([[ True,  True,  True, False, False, False, False],
                   [ True,  True,  True,  True,  True, False, False],
                   [ True,  True,  True,  True,  True,  True, False],
                   [ True,  True,  True,  True,  True,  True,  True]])},
 'label': array([0, 1, 0, 1], dtype=int32)}
```

### Sequence Labeling

```python
import collatable
from collatable import SequenceLabelField, TextField
from collatable.extras.indexer import LabelIndexer, TokenIndexer

dataset = [
    (["my", "name", "is", "john", "smith"], ["O", "O", "O", "B", "I"]),
    (["i", "lived", "in", "japan", "three", "years", "ago"], ["O", "O", "O", "U", "O", "O", "O"]),
]

# Set up indexers for tokens and labels
PAD_TOKEN = "<PAD>"
token_indexer = TokenIndexer[str](specials=(PAD_TOKEN,))
label_indexer = LabelIndexer[str]()

# Load training dataset
instances = []
with token_indexer.context(train=True), label_indexer.context(train=True):
    for tokens, labels in dataset:
        text_field = TextField(tokens, indexer=token_indexer, padding_value=token_indexer[PAD_TOKEN])
        label_field = SequenceLabelField(labels, text_field, indexer=label_indexer)
        instance = dict(text=text_field, label=label_field)
        instances.append(instance)

output = collatable.collate(instances)
print(output)
```

Execution result:

```text
{'label': array([[0, 0, 0, 1, 2, 0, 0],
                 [0, 0, 0, 3, 0, 0, 0]]),
 'text': {
    'token_ids': array([[ 1,  2,  3,  4,  5,  0,  0],
                        [ 6,  7,  8,  9, 10, 11, 12]]),
    'mask': array([[ True,  True,  True,  True,  True, False, False],
                   [ True,  True,  True,  True,  True,  True,  True]])}}
```

### Relation Extraction

```python
import collatable
from collatable.extras.indexer import LabelIndexer, TokenIndexer
from collatable import AdjacencyField, ListField, SpanField, TextField

PAD_TOKEN = "<PAD>"
token_indexer = TokenIndexer[str](specials=(PAD_TOKEN,))
label_indexer = LabelIndexer[str]()

instances = []
with token_indexer.context(train=True), label_indexer.context(train=True):
    text = TextField(
        ["john", "smith", "was", "born", "in", "new", "york", "and", "now", "lives", "in", "tokyo"],
        indexer=token_indexer,
        padding_value=token_indexer[PAD_TOKEN],
    )
    spans = ListField([SpanField(0, 2, text), SpanField(5, 7, text), SpanField(11, 12, text)])
    relations = AdjacencyField([(0, 1), (0, 2)], spans, labels=["born-in", "lives-in"], indexer=label_indexer)
    instance = dict(text=text, spans=spans, relations=relations)
    instances.append(instance)

    text = TextField(
        ["tokyo", "is", "the", "capital", "of", "japan"],
        indexer=token_indexer,
        padding_value=token_indexer[PAD_TOKEN],
    )
    spans = ListField([SpanField(0, 1, text), SpanField(5, 6, text)])
    relations = AdjacencyField([(0, 1)], spans, labels=["capital-of"], indexer=label_indexer)
    instance = dict(text=text, spans=spans, relations=relations)
    instances.append(instance)

output = collatable.collate(instances)
print(output)
```

Execution result:

```text
{'text': {
    'token_ids': array([[ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10,  5, 11],
                        [11, 12, 13, 14, 15, 16,  0,  0,  0,  0,  0,  0]]),
    'mask': array([[ True,  True,  True,  True,  True,  True,  True,  True,  True, True,  True,  True],
                   [ True,  True,  True,  True,  True,  True, False, False, False, False, False, False]])},
 'spans': array([[[ 0,  2],
                  [ 5,  7],
                  [11, 12]],
                 [[ 0,  1],
                  [ 5,  6],
                  [-1, -1]]]),
 'relations': array([[[-1,  0,  1],
                      [-1, -1, -1],
                      [-1, -1, -1]],
                     [[-1,  2, -1],
                      [-1, -1, -1],
                      [-1, -1, -1]]], dtype=int32)}
```


### Rererence Implementation

`extra` module provides a reference implementation to use `collatable` effectively.
Here is an example of text-to-text task that encodes raw texts/labels into token
ids and decodes them back to raw texts/labels:

```python
from dataclasses import dataclass
from typing import Mapping, Sequence, Union

from collatable.extras import DataLoader, Dataset, DefaultBatchSampler, LabelIndexer, TokenIndexer
from collatable.extras.datamodule import DataModule, LabelFieldTransform, TextFieldTransform
from collatable.utils import debatched


@dataclass
class Text2TextExample:
    source: Union[str, Sequence[str]]
    target: Union[str, Sequence[str]]
    language: str


text2text_dataset = [
    Text2TextExample(source="how are you?", target="I am fine.", language="en"),
    Text2TextExample(source="what is your name?", target="My name is John.", language="en"),
    Text2TextExample(source="where are you?", target="I am in New-York.", language="en"),
    Text2TextExample(source="what is the time?", target="It is 10:00 AM.", language="en"),
    Text2TextExample(source="comment ça va?", target="Je vais bien.", language="fr"),
]

shared_token_indexer = TokenIndexer(default="<unk>", specials=["<pad>", "<unk>"])
language_indexer = LabelIndexer[str]()

text2text_datamodule = DataModule[Text2TextExample](
    fields={
        "source": TextFieldTransform(indexer=shared_token_indexer, pad_token="<pad>"),
        "target": TextFieldTransform(indexer=shared_token_indexer, pad_token="<pad>"),
        "language": LabelFieldTransform(indexer=language_indexer),
    }
)

with shared_token_indexer.context(train=True), language_indexer.context(train=True):
    text2text_datamodule.build(text2text_dataset)

dataloader = DataLoader(DefaultBatchSampler(batch_size=2))

text2text_instances = Dataset.from_iterable(text2text_datamodule(text2text_dataset))

for batch in dataloader(text2text_instances):
    print("Batch:")
    print(batch)
    print("Reconstruction:")
    for item in debatched(batch):
        print(text2text_datamodule.reconstruct(item))
    print()
```

Execution result:

```text
Batch:
{'target': {
    'token_ids': array([[16, 17, 18, 19,  0],
                        [20,  9,  7, 21, 19]]),
    'mask': array([[ True,  True,  True,  True, False],
                   [ True,  True,  True,  True,  True]])},
    'language': array([0, 0], dtype=int32),
 'source': {
    'token_ids': array([[2, 3, 4, 5, 0],
                        [6, 7, 8, 9, 5]]),
    'mask': array([[ True,  True,  True,  True, False],
                   [ True,  True,  True,  True,  True]])}}
Reconstruction:
{'source': ['how', 'are', 'you', '?'], 'target': ['I', 'am', 'fine', '.'], 'language': 'en'}
{'source': ['what', 'is', 'your', 'name', '?'], 'target': ['My', 'name', 'is', 'John', '.'], 'language': 'en'}

...
```

            

Raw data

            {
    "_id": null,
    "home_page": null,
    "name": "collatable",
    "maintainer": null,
    "docs_url": null,
    "requires_python": "<4.0,>=3.8",
    "maintainer_email": null,
    "keywords": "python, machine learning",
    "author": "altescy",
    "author_email": "altescy <me@altescy.jp>",
    "download_url": "https://files.pythonhosted.org/packages/ef/d7/c752b1dcdf96759d86fd881c469d4d6c682788bf66e1bdd43095bdc26cfa/collatable-0.6.0.tar.gz",
    "platform": null,
    "description": "# Collatable\n\n[![Actions Status](https://github.com/altescy/collatable/workflows/CI/badge.svg)](https://github.com/altescy/collatable/actions/workflows/ci.yml)\n[![License](https://img.shields.io/github/license/altescy/collatable)](https://github.com/altescy/collatable/blob/main/LICENSE)\n[![Python version](https://img.shields.io/pypi/pyversions/collatable)](https://github.com/altescy/collatable)\n[![pypi version](https://img.shields.io/pypi/v/collatable)](https://pypi.org/project/collatable/)\n\nConstructing batched tensors for any machine learning tasks\n\n## Installation\n\n```bash\npip install collatable\n```\n\n## Examples\n\nThe following scripts show how to tokenize/index/collate your dataset with `collatable`:\n\n### Text Classification\n\n```python\nimport collatable\nfrom collatable import LabelField, MetadataField, TextField\nfrom collatable.extras.indexer import LabelIndexer, TokenIndexer\n\ndataset = [\n    (\"this is awesome\", \"positive\"),\n    (\"this is a bad movie\", \"negative\"),\n    (\"this movie is an awesome movie\", \"positive\"),\n    (\"this movie is too bad to watch\", \"negative\"),\n]\n\n# Set up indexers for tokens and labels\nPAD_TOKEN = \"<PAD>\"\nUNK_TOKEN = \"<UNK>\"\ntoken_indexer = TokenIndexer[str](specials=[PAD_TOKEN, UNK_TOKEN], default=UNK_TOKEN)\nlabel_indexer = LabelIndexer[str]()\n\n# Load training dataset\ninstances = []\nwith token_indexer.context(train=True), label_indexer.context(train=True):\n    for id_, (text, label) in enumerate(dataset):\n        # Prepare each field with the corresponding field class\n        text_field = TextField(\n            text.split(),\n            indexer=token_indexer,\n            padding_value=token_indexer[PAD_TOKEN],\n        )\n        label_field = LabelField(\n            label,\n            indexer=label_indexer,\n        )\n        metadata_field = MetadataField({\"id\": id_})\n        # Combine these fields into instance\n        instance = dict(\n            text=text_field,\n            label=label_field,\n            metadata=metadata_field,\n        )\n        instances.append(instance)\n\n# Collate instances and build batch\noutput = collatable.collate(instances)\nprint(output)\n```\n\nExecution result:\n\n```text\n{'metadata': [{'id': 0}, {'id': 1}, {'id': 2}, {'id': 3}],\n 'text': {\n    'token_ids': array([[ 2,  3,  4,  0,  0,  0,  0],\n                        [ 2,  3,  5,  6,  7,  0,  0],\n                        [ 2,  7,  3,  8,  4,  7,  0],\n                        [ 2,  7,  3,  9,  6, 10, 11]]),\n    'mask': array([[ True,  True,  True, False, False, False, False],\n                   [ True,  True,  True,  True,  True, False, False],\n                   [ True,  True,  True,  True,  True,  True, False],\n                   [ True,  True,  True,  True,  True,  True,  True]])},\n 'label': array([0, 1, 0, 1], dtype=int32)}\n```\n\n### Sequence Labeling\n\n```python\nimport collatable\nfrom collatable import SequenceLabelField, TextField\nfrom collatable.extras.indexer import LabelIndexer, TokenIndexer\n\ndataset = [\n    ([\"my\", \"name\", \"is\", \"john\", \"smith\"], [\"O\", \"O\", \"O\", \"B\", \"I\"]),\n    ([\"i\", \"lived\", \"in\", \"japan\", \"three\", \"years\", \"ago\"], [\"O\", \"O\", \"O\", \"U\", \"O\", \"O\", \"O\"]),\n]\n\n# Set up indexers for tokens and labels\nPAD_TOKEN = \"<PAD>\"\ntoken_indexer = TokenIndexer[str](specials=(PAD_TOKEN,))\nlabel_indexer = LabelIndexer[str]()\n\n# Load training dataset\ninstances = []\nwith token_indexer.context(train=True), label_indexer.context(train=True):\n    for tokens, labels in dataset:\n        text_field = TextField(tokens, indexer=token_indexer, padding_value=token_indexer[PAD_TOKEN])\n        label_field = SequenceLabelField(labels, text_field, indexer=label_indexer)\n        instance = dict(text=text_field, label=label_field)\n        instances.append(instance)\n\noutput = collatable.collate(instances)\nprint(output)\n```\n\nExecution result:\n\n```text\n{'label': array([[0, 0, 0, 1, 2, 0, 0],\n                 [0, 0, 0, 3, 0, 0, 0]]),\n 'text': {\n    'token_ids': array([[ 1,  2,  3,  4,  5,  0,  0],\n                        [ 6,  7,  8,  9, 10, 11, 12]]),\n    'mask': array([[ True,  True,  True,  True,  True, False, False],\n                   [ True,  True,  True,  True,  True,  True,  True]])}}\n```\n\n### Relation Extraction\n\n```python\nimport collatable\nfrom collatable.extras.indexer import LabelIndexer, TokenIndexer\nfrom collatable import AdjacencyField, ListField, SpanField, TextField\n\nPAD_TOKEN = \"<PAD>\"\ntoken_indexer = TokenIndexer[str](specials=(PAD_TOKEN,))\nlabel_indexer = LabelIndexer[str]()\n\ninstances = []\nwith token_indexer.context(train=True), label_indexer.context(train=True):\n    text = TextField(\n        [\"john\", \"smith\", \"was\", \"born\", \"in\", \"new\", \"york\", \"and\", \"now\", \"lives\", \"in\", \"tokyo\"],\n        indexer=token_indexer,\n        padding_value=token_indexer[PAD_TOKEN],\n    )\n    spans = ListField([SpanField(0, 2, text), SpanField(5, 7, text), SpanField(11, 12, text)])\n    relations = AdjacencyField([(0, 1), (0, 2)], spans, labels=[\"born-in\", \"lives-in\"], indexer=label_indexer)\n    instance = dict(text=text, spans=spans, relations=relations)\n    instances.append(instance)\n\n    text = TextField(\n        [\"tokyo\", \"is\", \"the\", \"capital\", \"of\", \"japan\"],\n        indexer=token_indexer,\n        padding_value=token_indexer[PAD_TOKEN],\n    )\n    spans = ListField([SpanField(0, 1, text), SpanField(5, 6, text)])\n    relations = AdjacencyField([(0, 1)], spans, labels=[\"capital-of\"], indexer=label_indexer)\n    instance = dict(text=text, spans=spans, relations=relations)\n    instances.append(instance)\n\noutput = collatable.collate(instances)\nprint(output)\n```\n\nExecution result:\n\n```text\n{'text': {\n    'token_ids': array([[ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10,  5, 11],\n                        [11, 12, 13, 14, 15, 16,  0,  0,  0,  0,  0,  0]]),\n    'mask': array([[ True,  True,  True,  True,  True,  True,  True,  True,  True, True,  True,  True],\n                   [ True,  True,  True,  True,  True,  True, False, False, False, False, False, False]])},\n 'spans': array([[[ 0,  2],\n                  [ 5,  7],\n                  [11, 12]],\n                 [[ 0,  1],\n                  [ 5,  6],\n                  [-1, -1]]]),\n 'relations': array([[[-1,  0,  1],\n                      [-1, -1, -1],\n                      [-1, -1, -1]],\n                     [[-1,  2, -1],\n                      [-1, -1, -1],\n                      [-1, -1, -1]]], dtype=int32)}\n```\n\n\n### Rererence Implementation\n\n`extra` module provides a reference implementation to use `collatable` effectively.\nHere is an example of text-to-text task that encodes raw texts/labels into token\nids and decodes them back to raw texts/labels:\n\n```python\nfrom dataclasses import dataclass\nfrom typing import Mapping, Sequence, Union\n\nfrom collatable.extras import DataLoader, Dataset, DefaultBatchSampler, LabelIndexer, TokenIndexer\nfrom collatable.extras.datamodule import DataModule, LabelFieldTransform, TextFieldTransform\nfrom collatable.utils import debatched\n\n\n@dataclass\nclass Text2TextExample:\n    source: Union[str, Sequence[str]]\n    target: Union[str, Sequence[str]]\n    language: str\n\n\ntext2text_dataset = [\n    Text2TextExample(source=\"how are you?\", target=\"I am fine.\", language=\"en\"),\n    Text2TextExample(source=\"what is your name?\", target=\"My name is John.\", language=\"en\"),\n    Text2TextExample(source=\"where are you?\", target=\"I am in New-York.\", language=\"en\"),\n    Text2TextExample(source=\"what is the time?\", target=\"It is 10:00 AM.\", language=\"en\"),\n    Text2TextExample(source=\"comment \u00e7a va?\", target=\"Je vais bien.\", language=\"fr\"),\n]\n\nshared_token_indexer = TokenIndexer(default=\"<unk>\", specials=[\"<pad>\", \"<unk>\"])\nlanguage_indexer = LabelIndexer[str]()\n\ntext2text_datamodule = DataModule[Text2TextExample](\n    fields={\n        \"source\": TextFieldTransform(indexer=shared_token_indexer, pad_token=\"<pad>\"),\n        \"target\": TextFieldTransform(indexer=shared_token_indexer, pad_token=\"<pad>\"),\n        \"language\": LabelFieldTransform(indexer=language_indexer),\n    }\n)\n\nwith shared_token_indexer.context(train=True), language_indexer.context(train=True):\n    text2text_datamodule.build(text2text_dataset)\n\ndataloader = DataLoader(DefaultBatchSampler(batch_size=2))\n\ntext2text_instances = Dataset.from_iterable(text2text_datamodule(text2text_dataset))\n\nfor batch in dataloader(text2text_instances):\n    print(\"Batch:\")\n    print(batch)\n    print(\"Reconstruction:\")\n    for item in debatched(batch):\n        print(text2text_datamodule.reconstruct(item))\n    print()\n```\n\nExecution result:\n\n```text\nBatch:\n{'target': {\n    'token_ids': array([[16, 17, 18, 19,  0],\n                        [20,  9,  7, 21, 19]]),\n    'mask': array([[ True,  True,  True,  True, False],\n                   [ True,  True,  True,  True,  True]])},\n    'language': array([0, 0], dtype=int32),\n 'source': {\n    'token_ids': array([[2, 3, 4, 5, 0],\n                        [6, 7, 8, 9, 5]]),\n    'mask': array([[ True,  True,  True,  True, False],\n                   [ True,  True,  True,  True,  True]])}}\nReconstruction:\n{'source': ['how', 'are', 'you', '?'], 'target': ['I', 'am', 'fine', '.'], 'language': 'en'}\n{'source': ['what', 'is', 'your', 'name', '?'], 'target': ['My', 'name', 'is', 'John', '.'], 'language': 'en'}\n\n...\n```\n",
    "bugtrack_url": null,
    "license": null,
    "summary": "Constructing batched tensors for any machine learning tasks",
    "version": "0.6.0",
    "project_urls": {
        "Homepage": "https://github.com/altescy/collatable",
        "Issues": "https://github.com/altescy/collatable/issues"
    },
    "split_keywords": [
        "python",
        " machine learning"
    ],
    "urls": [
        {
            "comment_text": null,
            "digests": {
                "blake2b_256": "1f2c66356ed3957b5ac3ba35949b38fd06878742a8fadd5cc7777e494dc30a99",
                "md5": "7c2fedfe77a3b059c5c02a0cf614e689",
                "sha256": "23d11fe750f77657c0987274f2a7f34356a4f703650d0b13fdf338914d6fc21e"
            },
            "downloads": -1,
            "filename": "collatable-0.6.0-py3-none-any.whl",
            "has_sig": false,
            "md5_digest": "7c2fedfe77a3b059c5c02a0cf614e689",
            "packagetype": "bdist_wheel",
            "python_version": "py3",
            "requires_python": "<4.0,>=3.8",
            "size": 25175,
            "upload_time": "2025-07-19T09:16:53",
            "upload_time_iso_8601": "2025-07-19T09:16:53.365660Z",
            "url": "https://files.pythonhosted.org/packages/1f/2c/66356ed3957b5ac3ba35949b38fd06878742a8fadd5cc7777e494dc30a99/collatable-0.6.0-py3-none-any.whl",
            "yanked": false,
            "yanked_reason": null
        },
        {
            "comment_text": null,
            "digests": {
                "blake2b_256": "efd7c752b1dcdf96759d86fd881c469d4d6c682788bf66e1bdd43095bdc26cfa",
                "md5": "00ed8ead8b2a434253273e0c2fd7862f",
                "sha256": "9b502dab4477083c668a96429c8fcc2ad76c4c1dacd5bf37de81d155f864723f"
            },
            "downloads": -1,
            "filename": "collatable-0.6.0.tar.gz",
            "has_sig": false,
            "md5_digest": "00ed8ead8b2a434253273e0c2fd7862f",
            "packagetype": "sdist",
            "python_version": "source",
            "requires_python": "<4.0,>=3.8",
            "size": 14590,
            "upload_time": "2025-07-19T09:16:54",
            "upload_time_iso_8601": "2025-07-19T09:16:54.215984Z",
            "url": "https://files.pythonhosted.org/packages/ef/d7/c752b1dcdf96759d86fd881c469d4d6c682788bf66e1bdd43095bdc26cfa/collatable-0.6.0.tar.gz",
            "yanked": false,
            "yanked_reason": null
        }
    ],
    "upload_time": "2025-07-19 09:16:54",
    "github": true,
    "gitlab": false,
    "bitbucket": false,
    "codeberg": false,
    "github_user": "altescy",
    "github_project": "collatable",
    "travis_ci": false,
    "coveralls": false,
    "github_actions": true,
    "lcname": "collatable"
}
        
Elapsed time: 1.92746s