# SeqIO
*Task-based datasets, preprocessing, and evaluation for sequence models*
*Go to [SeqIO ReadTheDocs Documentation Page](https://seqio.readthedocs.io/).*
## Overview
**SeqIO** is a library for processing sequential data to be fed into downstream
sequence models. It uses
[`tf.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset)
to create scalable data pipelines but requires minimal use of TensorFlow. In
particular, with one line of code, the returned dataset can be transformed to a
numpy iterator and hence it is fully compatible with other frameworks such as
[JAX](https://github.com/google/jax) or
[PyTorch](https://pytorch.org/).
SeqIO assumes that the dataset is a sequence. Modalities such as text or audio
are naturally supported. Images are supported as long as they are represented as
sequences (e.g., [Image GPT](http://proceedings.mlr.press/v119/chen20s.html)).
SeqIO is a refactor of the
[`t5.data`](https://github.com/google-research/text-to-text-transfer-transformer/)
library used (in conjunction with the
[Mesh Tensorflow](https://github.com/tensorflow/mesh) Transformer
implementation) to train the T5 models introduced in [*Exploring the Limits of
Transfer Learning with a Unified Text-to-Text
Transformer*](https://arxiv.org/abs/1910.10683).
If you have used `t5.data` in the past and want to know how SeqIO differs,
please read [this section](#differences-from-t5data).
## Installation
### From Pypi
```sh
pip install seqio
```
### From Source
```sh
git clone https://github.com/google/seqio.git
cd seqio
pip install -e .
```
## Usage Tutorial
At a high level, we use SeqIO with the following steps.
1. Define a `Task` (and optionally a `Mixture`).
1. Define (or use an existing) a `FeatureConverter` based on the model
architecture.
1. Use the top-level function `seqio.get_dataset` to obtain the
`tf.data.Dataset` instance.
We will look at each of these steps in detail.
### Defining a `Task`
The most important class in SeqIO is the `Task`. It is an abstraction that
combines:
* a raw *data source*
* one or more *preprocessing* steps
* a *vocabulary* to tokenize/detokenize each preprocessed feature for the
model
* a *postprocessor* to convert detokenized model outputs into a format for
evaluation
* one or more *metrics* to evaluate with
Oftentimes a `Task` lines up with a common benchmark. In this tutorial, we use
[WMT 19 English-German](http://www.statmt.org/wmt19/translation-task.html)
machine translation task. In the end, our `Task` will look like this:
```py
seqio.TaskRegistry.add(
"wmt19_ende",
seqio.TfdsDataSource(tfds_name="wmt19_translate/de-en:1.0.0"),
preprocessors=[
functools.partial(
translate, source_language='en', target_language='de'),
seqio.preprocessors.tokenize, seqio.preprocessors.append_eos
],
output_features={
'inputs':
seqio.Feature(
seqio.SentencePieceVocabulary('/path/to/inputs/vocab'),
add_eos=False,
dtype=tf.int32),
'targets':
seqio.Feature(
seqio.SentencePieceVocabulary('/path/to/targets/vocab'),
add_eos=True,
dtype=tf.int32),
},
metric_fns=[bleu])
```
We typically add the `Task` to the global registry when we define it (as shown
above) to make it easier to use with model configs and flags. Thus, it must
have a unique string name (`"wmt19_ende"` in this case). Note, however, that
you may also instantiate a `seqio.Task` directly without adding it to the
registry, if desired.
We'll now break down each part of the task definition.
#### Data Source
Data sources are the first step in your pipeline, providing a way to load raw
data in many formats as a `tf.data.Dataset`.
All data sources are subclasses of the `DataSource` base class and are defined
in
[dataset_providers](https://github.com/google/seqio/tree/main/seqio/dataset_providers.py).
Existing implementations include:
* `TfdsDataSource` for loading examples from
[TensorFlow Datasets](https://www.tensorflow.org/datasets).
* `TextLineDataSource` for loading examples from text files (e.g., tsv).
* `TFExampleDataSource` for loading
[`tf.train.Example`](https://www.tensorflow.org/tutorials/load_data/tfrecord)
protos from a file (e.g. a `TFRecord` file.)
* `FunctionDataSource` for providing an custom function that returns a
`tf.data.Dataset`.
In our example, we are using the `TfdsDataSource`. We specify the name of the
WMT dataset in TFDS
([`"wmt19_translate"`](https://www.tensorflow.org/datasets/catalog/wmt19_translate)),
the specific config for the language pair that excludes the context for the open
domain setting (`"de-en"`), and the version number (`"1.0.0"`).
#### Output Features
The `output_features` field expects a dictionary that maps string feature names
to `seqio.Feature` objects. This defines what the `Task` is expected to produce
in its output examples. The output examples *may* contain additional fields, but
they *must* contain these fields in the specified format or exceptions will be
raised.
Each `Feature` includes:
* A `vocabulary`, which must subclass
[`seqio.Vocabulary`](https://github.com/google/seqio/tree/main/seqio/vocabularies.py),
to specify how the feature can be tokenized and detokenized. You may use
`seqio.PassThroughVocabulary` if tokenization is not necessary.
* `add_eos`, which specifies whether the feature should end with the
vocabulary's EOS token.
* The output `dtype` which must be a `tf.dtypes.DType`.
**Note:** specifying these options on `Feature` does not by itself ensure the
proper transformations are applied -- you must also include the necessary
preprocessors.
The [tasks used in T5](TODO) all produce "inputs" and "targets" features to be
consumed by the text-to-text model. For a decoder-only language model, only a
single feature (e.g., "targets") would be necessary. Nevertheless, SeqIO is
flexible enough to generate arbitrary output features what will be converted
into model features by the [`FeatureConverter`](#featureconverter) later in the
pipeline.
#### Preprocessors
Preprocessors are functions that transform one `tf.data.Dataset` into a new
`tf.data.Dataset`. Typically this involves executing a `map` over the given
dataset. The preprocessors provided to the `Task` will be executed sequentially.
As an example, let's look at the previously undefined `translate` from the
"wmt19_ende" example above.
```py
def translate(dataset: tf.data.Dataset,
source_language: str,
target_language: str) -> tf.data.Dataset:
def _translate(ex: Mapping[str, tf.Tensor]) -> Mapping[str, tf.Tensor]:
"""Convert a translation example to a text2text pair.
For example, say the dataset returns examples of this format:
{'de': 'Das ist gut.', 'en': 'That is good.'}
If source_language = 'de', target_language = 'en', then the outputs will
have the format:
{'inputs': 'translate de to en: Das ist gut.',
'targets': 'That is good.'}
Args:
ex: an example to process.
source_language: source language code (e.g. 'en') to translate from.
target_language: target language code (e.g. 'de') to translate to.
Returns:
A preprocessed example with the format listed above.
"""
src_str = f'translate {source_language}'
tgt_str = f' to {target_language}: '
return {
'inputs': tf.strings.join([src_str, tgt_str, ex[source_language]]),
'targets': ex[target_language],
}
return dataset.map(_translate,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
```
The TFDS dataset provides the dataset where each example has the form: `{'de':
'Das ist gut.', 'en': 'That is good.'}`. We convert this to "inputs" and
"targets" with the appropriate prompt to inform the model of the task.
A few **important** notes:
1. When instantiating a `Task`, the preprocessor functions can have the
following arguments: `dataset`, `output_features`, and `sequence_length`.
The first (positional) dataset argument is always required. If an argument
named `output_features` is provided, the
[output feature mapping](#output-features) will be passed to the
preprocessor. If `sequence_length` is provided, a mapping from feature name
to its *maximum* final sequence length
([provided by the caller](#getting-a-preprocessed-dataset)) will be
passed -- any sequences that are too long after preprocessing will be
automatically truncated. If a preprocessor function does have other
arguments, they must have default values or be bound (e.g., with
`functools.partial` as used in `translate`) before instantiating the `Task`.
1. Mapping functions operate on and return `tf.Tensor`s using TensorFlow
operations. This is more flexible than it may sound:
* Automatic
[AutoGraph](https://www.tensorflow.org/guide/function#autograph_transformations)
conversion allow you to write python control flow in your
transformations.
* [tf.experimental.numpy](https://www.tensorflow.org/guide/tf_numpy)
provides a numpy interface.
* [`tf.py_function`](https://www.tensorflow.org/api_docs/python/tf/py_function)
allows you to wrap arbitrary Python code. Note: `tf.data` pipelines
using this function can only be run in the python process where they
were defined, and performance is limited by the python GIL.
See `tf.data.Dataset`
[documentation](https://www.tensorflow.org/api_docs/python/tf/data/Dataset)
for more details.
1. When calling `map`, it is important to **always** set
`num_parallel_calls=tf.data.experimental.AUTOTUNE` to avoid creating a
bottleneck. The `seqio.map_over_dataset` decorator helps enforce this as
follows.
```py
@seqio.map_over_dataset
def translate(ex: Mapping[str, tf.Tensor],
source_language: str,
target_language: str) -> Mapping[str, tf.Tensor]:
"""Convert a translation dataset to a text2text pair.
For example, say the dataset returns examples of this format:
{'de': 'Das ist gut.', 'en': 'That is good.'}
If source_language = 'de', target_language = 'en', then the outputs will
have the format:
{'inputs': 'translate German to English: Das ist gut.',
'targets': 'That is good.'}
Args:
ex: an example to process.
source_language: source language code (e.g. 'en') to translate from.
target_language: target language code (e.g. 'de') to translate to.
Returns:
A preprocessed example with the format listed above.
"""
src_str = f'translate {source_language}'
tgt_str = f' to {target_language}: '
return {
'inputs': tf.strings.join([src_str, tgt_str, ex[source_language]]),
'targets': ex[target_language],
}
```
Note that `translate` takes as input an individual example. Then
`seqio.map_over_dataset` decorates it to a function that takes in a
`tf.data.Dataset` instance.
1. Stochastic operations must be
[stateless](https://www.tensorflow.org/guide/random_numbers#stateless_rngs)
if deterministic pipelines are needed. To get (optionally deterministic)
seeds for these operations, use the `seqio.map_over_dataset(num_seeds=n)`
decorator. For example:
```py
def random_chunk(
dataset: tf.data.Dataset,
sequence_length: Mapping[str, int]
) -> tf.data.Dataset:
"""Takes a random chunk out of each feature with size `sequence_length`."""
@seqio.map_over_dataset(num_seeds=1)
def take_chunk(
ex: Mapping[str, tf.Tensor],
seed
) -> Mapping[str, tf.Tensor]:
new_ex = {}
for k, v in ex.items():
if k in sequence_length:
length = sequence_length[k]
start_idx = tf.random.stateless_uniform(
(), seed, 0, tf.size(v) - (length + 1))
new_ex[k] = v[start_idx:start_idx+length]
else:
new_ex[k] = v
return new_ex
return take_chunk(dataset)
```
If `num_seeds > 1`, the arg will instead be called `seeds` and will contain
a sequence of seeds.
In our "wmt_19_ende" task, we also use the predefined preprocessors
`seqio.preprocessors.tokenize` and `seqio.preprocessors.append_eos`. The former
uses each `Feature.vocabulary` to tokenize it, and the the latter appends
`Feature.vocabulary.eos_id` to the feature if the `Feature.add_eos` is True. See
[preprocessors.py](https://github.com/google/seqio/tree/main/seqio/preprocessors.py) for
their implementations and other useful preprocessors.
#### Postprocessor
During evaluation, the model outputs are first detokenized using the output
feature vocabulary. Before passing these predictions to the metric functions,
they can be run through a Python postprocessing function, alongside the full
input example. Similarly, the raw targets are run through this function before
being passed to the metrics. Since the postprocess function is used on both the
model output and the targets, it is passed an `is_target` boolean in case the
behavior should be different. It is also passed the fully preprocessed example,
including fields that were excluded from `output_features`.
For the "wmt19_ende", we don't need any postprocessors. See "trivia_qa_open"
task in the [Advanced Postprocessing `Task`](#advanced-postprocessing-task) for
an example postprocessor.
#### Metrics
Metrics are functions that are passed (by the [Evaluator](#evaluator)) the
fully-materialized list of postprocessed model outputs (or scores) and targets
and return a mapping from string names to `MetricValue` objects containing their
values. These are most commonly floating-point scalars, but may also be text,
images, audio, histograms, etc (see
[metrics.py](https://github.com/google/seqio/tree/main/seqio/metrics.py) for the full list).
The first argument of a metric function must always be called `targets`. If the
second argument of a metric function is called `predictions`, it will be passed
the decoded and detokenized model prediction. If it is called `scores`, it will
be passed a list of log-likelihood scores for each example.
If multiple metric functions are provided, they will all be used and their
returned mappings merged.
##### Prediction Metrics
Prediction metrics are computed using the postprocessed targets and model
outputs (predictions). The args must be named `targets` and `predictions`.
Let's look at the metric function used for "wmt19_ende" task. A standard metric
for the translation task is BLEU and we use `sacrebleu` implementation.
```py
def bleu(targets: Sequence[str], predictions: Sequence[str]):
"""Computes BLEU score.
Args:
targets: list of strings or list of list of strings if multiple references
are present.
predictions: list of strings
Returns:
bleu_score across all targets and predictions
"""
if isinstance(targets[0], list):
targets = [[x for x in target] for target in targets]
else:
# Need to wrap targets in another list for corpus_bleu.
targets = [targets]
bleu_score = sacrebleu.corpus_bleu(predictions, targets,
smooth_method="exp",
smooth_value=0.0,
force=False,
lowercase=False,
tokenize="intl",
use_effective_order=False)
return {"bleu": bleu_score.score}
```
##### Score Metrics
Score metrics are computed using the postprocessed targets and their
log-likelihood scores according to the model. The args must be named `targets`
and `scores`.
```py
def perplexity(targets: Sequence[str], scores: Sequence[float]):
return {
"perplexity": seqio.metrics.Scalar(np.exp(np.mean(scores)))
}
```
### Defining a `Mixture`
Once you have multiple `Task`s added to the `TaskRegistry`, you can define
`Mixture`s that will combine the examples from them according to some specified
rate. Examples will then be sampled from each task in proportion to its rate.
As an example, [Multilingual T5](http://goo.gle/mt5) uses a `Mixture` of
per-language `Task`s with tail languages up-weighted in the mixture.
There are 3 ways to specify the tasks and their rates:
1. Provide a rate along with each task's name (rates are normalized before
sampling). In this example, the rates provided are units of the final
mixture that come from the component tasks. Here, 1/(1+7) of the final
mixture will come from "task1".
```py
seqio.MixtureRegistry.add(
"mix1",
[("task1", 1), ("task2", 7)]
)
```
1. Provide a constant default rate for some or all tasks, which will be used
when only the name is provided. The example below will produce identical
mixing rates as the previous one.
```py
seqio.MixtureRegistry.add(
"mix1",
[("task1", 0.5), "task2"],
default_rate=3.5
)
```
1. Provide a function that generates the rate for each task at runtime. The
example below uses the provided
[`seqio.mixing_rate_num_examples`](https://github.com/google/seqio/tree/main/seqio/utils.py),
which uses the number of examples (computed during
[offline caching](#optional-offline-caching)) as the rate for each task.
```py
seqio.MixtureRegistry.add(
"mix2",
["task1", "task2"],
default_rate=seqio.mixing_rate_num_examples
)
```
You can also include `Mixture`s in your `Mixture`! For example, the following
task would contain 1/24 (from "mix1") + 1/3 "task1", 7/24 (from "mix1") of
"task2", and 1/3 "task3".
```py
seqio.MixtureRegistry.add(
"mix3",
["mix1", "task1", "task3"],
default_rate=1
)
```
If sampling without replacement is important for your task, you can achieve that
by using either deterministic tasks or using dataset checkpointing (and not
running more than an epoch) for a non-deterministic task. Otherwise, the mixture
may sample with replacement.
### Getting a Preprocessed Dataset
Now that your `Task` (and/or `Mixture`) is defined, its primary functionality is
to use it to generate a dataset.
You may first need to use `seqio.get_mixture_or_task(mixture_or_task_name)` to
access your dataset provider from the registry.
After that, you can call `get_dataset` to build the `tf.data.Dataset`. For
example:
```py
dataset = seqio.get_mixture_or_task("mix1").get_dataset(
sequence_length={"inputs": 256, "targets": 128},
split="train",
shuffle=True,
num_epochs=1,
shard_info=seqio.ShardInfo(index=0, num_shards=10),
use_cached=False,
seed=42
)
# Print the first 5 examples.
for _, ex in zip(range(5), dataset.as_numpy_iterator()):
print(ex)
```
Some notes on a few of the arguments:
* `sequence_length`: An *optional* mapping from feature name to *maximum*
length. Will be passed to the preprocessors with a `sequence_length`
argument. If not `None`, the final example features will be truncated if
they exceed the specified length. Note that this value may be required to be
set if any of the preprocessors use the `sequence_length` argument and do
not handle the `None` case.
* `num_epochs`: The number of times to repeat the source dataset.
Preprocessing will be re-applied with new seeds to enable new samples from
stochastic steps. Note that if the `CacheDatasetPlaceholder` is included
(see below) preprocessing is only re-applied after that step.
* `shard_info`: An optional sharding specification for loading a deterministic
subset of the dataset. Loading will be most efficient if the number of
shards evenly divides the number of shards in the raw data source.
* `use_cached`: Specifies whether to load from a pre-cached task for increased
performance or to do the preprocessing on-the-fly. See the
[following section](#optional-offline-caching) for details on how to cache
your task, which must be done before this can be set to `True`.
* `seed`: An optional seed to use for deterministic shuffling and (stateless)
stochastic ops. These operations will still be pseudorandom but will be
reproducible with the same seed. Set to `None` if determinism is not
desired.
### (Optional) Offline Caching
For improved performance at load time and to avoid redundant computations for
commonly used tasks, you can pre-cache your `Task` with all or part of the
preprocessing done in advance of training; this partial preprocessing is
especially useful if the Task is stochastic and one wishes to cache the
deterministic operations while running the stochastic ones on the fly. Caching
stochastic SeqIO Mixtures in this way is not supported.
The first step to doing so is to add a
`seqio.CacheDatasetPlaceholder(required=False)` as one of the steps in your
preprocessing pipeline. All steps before the placeholder will be cached offline
and all steps after will be executed on the fly at load time. You may set
`required=True` if you want `get_dataset` to fail unless `use_cached=True`.
Caveats:
* Any stochastic operations that you wish to be re-run when `num_epochs > 1`
or with a different `seed` *should* go after the placeholder since only a
single sample will be cached.
* Any preprocessing steps that use the `sequence_length` argument *must* come
after the `seqio.CacheDatasetPlaceholder` preprocessor since this is only
known at runtime, or an exception will be raised. If you wish to cache for a
specific sequence length, you can use
[`seqio.experimental.add_fully_cached_task`](https://github.com/google/seqio/tree/main/seqio/experimental.py).
Once your `Task` is registered, you can run
[`cache_tasks_main`](https://github.com/google/seqio/tree/main/seqio/scripts/cache_tasks_main.py)
to execute the offline preprocessing, providing it with the module containing
your task definitions via the `--module_import` flag. For very large datasets,
it's recommended you run this [Apache Beam](https://beam.apache.org/) script on
a distributed framework like
[Google Cloud DataFlow](https://beam.apache.org/documentation/runners/dataflow/).
Finally, you are ready to load the cached version of your `Task` (or `Mixture`)
containing it. You will need to add the path to the directory you passed to
`--output_cache_dir` via `seqio.add_global_cache_dirs(["/my/cache/dir"])`. Now
when you call `task_or_mixture.get_dataset(..., use_cached=True)`, the data will
be loaded from the cache directory instead of the raw data source.
### Feature Converters
The role of `Task` is to provide the dataset object with as little
model-specific features (e.g., generic "inputs" and "targets") while the Feature
Converters transform the model-agnostic features to model-specific features
(e.g., "encoder_input_tokens"). We refer to the former as "task features" and
the latter as "model features".
Let's use machine translation (English to German) as a running example.
The raw data consists of sentence pairs such as
```
"That is good\tDas ist gut."
```
A task registered to `Task` (e.g.,
[wmt_t2t_ende_v003](t5/data/tasks.py?l=156&rcl=337594707))
reads these sentence pairs from the data source and applies a series of
[preprocessors](t5/data/preprocessors.py?rcl=343354647).
One of the internal representations looks like
```python
{"inputs": "translate English to German: That is good.",
"targets": "Das ist gut."}
```
The final output from the `Task` is a tokenized version of the parallel
sentences. In the following toy example (the token ids do not correspond to the
above string example), the dataset consists of 2 examples.
```python
dataset = [{"inputs": [7, 8, 5], "targets": [3, 9]},
{"inputs": [8, 4, 9, 3], "targets": [4]}]
```
The format is in the `tf.data.Dataset` (i.e., each example is a dictionary with
"inputs" and "targets" fields.
The `FeatureConverter` then takes this as an input and converts to the
model-specific features. In addition, the feature converter performs padding and
optionally packing (for model implementations that support it) for efficiency.
For example, let's assume that we are using the standard Transformer
architecture with an encoder and a decoder. The output of the feature converter
is
```python
converted_dataset = [{
"encoder_input_tokens": [7, 8, 5, 1, 8, 4, 9, 3, 1, 0],
"encoder_segment_ids": [1, 1, 1, 1, 2, 2, 2, 2, 2, 0],
"encoder_positions": [0, 1, 2, 3, 0, 1, 2, 3, 4, 0],
"decoder_target_tokens": [3, 9, 1, 4, 1, 0, 0],
"decoder_input_tokens": [0, 3, 9, 0, 4, 0, 0],
"decoder_loss_weights": [1, 1, 1, 1, 1, 0, 0],
"decoder_positions": [0, 1, 2, 0, 1, 0, 0],
"decoder_segment_ids": [1, 1, 1, 2, 2, 0, 0],
}]
```
In this case, two task examples are packed into one. `*_segment_id` and
`*_position` are the fields used to denote the membership and position of packed
token in the original sequence. The EOS ids (i.e., 1) are appended. In addition,
each fields is padded to the specified length.
We will look at the details of this example in Encoder-decoder architecture:
`seqio.EncDecFeatureConverter` section.
#### Feature converters provided out of the box
We provide feature converters for three common architectures: encoder-decoder,
decoder-only and encoder-only. Here we describe how users can use the feature
converters for each of these architectures out of the box as a part of the SeqIO
library.
In the SeqIO library, each architecture has a class defining how the task
features are converted to model features. Since these feature converters are
already implemented, it is straightforward to use them by providing the class as
a `feature_converter` argument of the `seqio.get_dataset` function. The
following sections show example usage of `seqio.get_dataset`.
##### Encoder-decoder architecture: `seqio.EncDecFeatureConverter`
This is the architecture of the original Transformer paper. For the
English-to-German translation task, the following function call retrieves the
`tf.data.Dataset` object with the model features.
```python
dataset: tf.data.Dataset = seqio.get_dataset(
mixture_or_task_name="wmt_t2t_ende_v003",
task_feature_lengths={"inputs": 32, "targets": 32},
dataset_split="train",
shuffle=True,
feature_converter=seqio.EncDecFeatureConverter(pack=True)
)
```
The resulting dataset object has the following 7 fields
|Feature name | Explanation |
|----------------------|---------------------------|
|`encoder_input_tokens` | Input tokens to the encoder. |
|`encoder_positions` | Position index in the sequence before packing.|
|`encoder_segment_ids` | Sequence membership before packing. Two positions with
the same positive integer mean that they belong to the same sequence before
packing. |
|`decoder_input_tokens` | Input tokens to the decoder. |
|`decoder_target_tokens`| Output tokens from the decoder. |
|`decoder_loss_weights` | A weight on each position that can be used as a mask. |
|`decoder_positions` | Position index in the sequence before packing. |
|`decoder_segment_ids` | Same as `encoder_segment_ids` but for decoder.|
##### Decoder-only architecture
This architecture consists of a single autoregressive stack, which we denote as
a "decoder".
A decoder autoregressively produces an output sequence.
Therefore, it can be used as a standard language model if the task dataset has
only "targets" features, i.e., self-supervised. If the task dataset also has an
"inputs" field, e.g., supervised machine translation, the decoder can still be
used by concatenating the inputs and targets fields. See [Raffel et al.
(2020)](https://arxiv.org/abs/1910.10683), Section 3.2.1 for more detailed take
on this topic.
We support both uses cases and refer to the former as *standard language model*
and the latter as *prefix language model*. Each of these models is described
separately below.
Note that we do not provide special features to denote how the dataset should be
consumed. For example, a Transformer-based fully autoregressive decoder has a
fully-causal self-attention layer. Since there are many ways of implementing the
masking pattern for such attention layer and, more importantly, SeqIO is not
limited to attention-based models, we leave it up to the model implementations
to apply the masking pattern. There is one exception, and we cover this in
the Prefix LM section below.
A common use pattern is to pretrain a decoder model with the left-to-right
language modeling objective (unsupervised) using `seqio.LMFeatureConverter` and
then fine-tune (supervised) using `seqio.PrefixLMFeatureConverter`.
###### Standard LM
For the standard language model, the task dataset only has "targets" field.
Therefore, the sequence length specification only needs to specify targets.
```python
dataset: tf.data.Dataset = seqio.get_dataset(
mixture_or_task_name="standard_lm",
task_feature_lengths={"targets": 32},
dataset_split="train",
shuffle=True,
feature_converter=seqio.LMFeatureConverter(pack=True)
)
```
Note that "standard_lm" is not a registered task in the codebase. It is the
left-to-right language modeling task, i.e., predict the next token given the
previous tokens on some language corpus (e.g.,
[C4](https://www.tensorflow.org/datasets/catalog/c4)).
The output dataset has the following model features.
|Feature name | Explanation |
|----------------------|---------------------------|
|`decoder_target_tokens`| Output tokens from the decoder |
|`decoder_input_tokens` | Input tokens to the decoder |
|`decoder_loss_weights` | Binary mask to indicate where the loss should be taken |
|`decoder_positions` | Position index in the sequence before packing|
|`decoder_segment_ids` | Sequence membership before packing. Two positions with
the same positive integer mean that they belong to the same sequence before
packing. |
The `decoder_target_tokens` is a shifted version of `decoder_input_tokens` for
the standard teacher-forced autoregressive training.
###### Prefix LM: `seqio.PrefixLMFeatureConverter`
If the input dataset has a notion of "inputs" and "targets", we can concatenate
them so that we can still use a single stack decoder. Therefore, the output only
contains "targets" just like standard LM case.
We use the same toy example for English-to-German translation task as a running
example:
```
{"inputs": "translate English to German: That is good.",
"targets": "Das ist gut."}
```
To be consumed by the decoder-only stack, `seqio.PrefixLMFeatureConverter`
concatenates them form the new "targets". Consider 2-layer decoder architecture
whose activations are shown below
```
That is good <EOS> Das ist gut <EOS>
| | | | | | | |
u1 u2 u3 u4 u5 u6 u7 u8
| | | | | | | |
v1 v2 v3 v4 v5 v6 v7 v8
| | | | | | | |
<BOS> That is good <EOS> Das ist gut
```
Let us denote the first layer's activation in the `i`th position as `vi`.
Similarly, let `ui` denote the activation of the second layer in the `i`th
position.
For attention-based sequence models such as Transformer decoders, the
self-attention layer is used to encode contextualized representation of the
sequence. At a given layer, each position's representation is computed as a
function of the representations of the tokens *before* its position in the
previous layer.
Referring to the toy example, when computing `u2` with fully-causal masking, we
do not use `v3`. This results in a representation `u2` of the word "is" that
does not take into account the word "good", which is unnecessarily limiting.
For Prefix LM, this issue is resolved by having the fully visible masking
pattern for the inputs portion only. For example, when computing `u2`, `v1`,
`v2`, `v3`, `v4` and `v5` are all visible and taken into account. For the tokens
in the "targets" of the `Task` dataset, we use the causal masking. For example,
when computing `u6`, all `vi` for `i <= 6` are taken into account but not `v7`.
<details>
<summary>Why is `v5` included in the inputs attention pattern?</summary>
In the same translation example, we note that when computing `u2`, the
activation corresponding to the position where \<EOS\> token was input (i.e.,
`v5`) was visible. This doesn't count as "cheating" because the model doesn't
see the next word "Das". This can provide additional context in building the
representation for "good". In this case, `u4` has the context that "good" is
the last word in the sentence.
</details>
`seqio.PrefixLMFeatureConverter` provides a feature `decoder_causal_attention`
to encode this information. For the above example, we have
```
decoder_causal_attention = [1, 1, 1, 1, 1, 0, 0, 0]
```
indicating that the non-causal attention can be applied to the first five
positions. Note that this feature seems trivial, but for a packed dataset
the inputs and targets boundary are more nuanced.
A final consideration for the prefix LM is that because we concatenate "inputs"
and "targets", which tokens are used for the loss computation is a modeling
decision. For example, we can penalize the models only for the "targets" tokens
or we may choose to penalize building the representation for "inputs" tokens.
This is controlled by `loss_on_targets_only` argument (defaults to `True`) to
`seqio.PrefixLMFeatureConverter` constructor. In the above example, we would get
```
decoder_loss_weights = [0, 0, 0, 0, 1, 1, 1, 1]
```
This indicates that the last 4 positions are used for the loss computation.
To get the dataset with prefix LM features, we can use
```python
dataset: tf.data.Dataset = seqio.get_dataset(
mixture_or_task_name="wmt_t2t_ende_v003",
task_feature_lengths={"inputs": 32, "targets": 32},
dataset_split="train",
shuffle=True,
feature_converter=seqio.PrefixLMFeatureConverter(
pack=True,
loss_on_targets_only=True)
)
```
The resulting features have length 64 because it concatenates inputs and targets
each with length 32.
The output dataset has the following model features. Note that the only
additional feature is `decoder_causal_attention`.
|Feature name | Explanation |
|----------------------|---------------------------|
|`decoder_target_tokens`| Output tokens from the decoder |
|`decoder_input_tokens` | Input tokens to the decoder |
|`decoder_loss_weights` | Binary mask to indicate where the loss should be
taken |
|`decoder_positions` | Position index in the sequence before packing|
|`decoder_segment_ids` | Sequence membership before packing. Two positions with
the ` same positive integer mean that they belong to the same sequence before
packing. |
|`decoder_causal_attention`| Binary mask denoting which tokens are in the
non-causal masking region.|
###### Encoder-only architecture
Like decoder-only architecture, this one is a single stack, but not
autoregressive.
One notable assumption is that the inputs and targets are *aligned*, i.e., they
have the same sequence length and `i`th position in the targets correspond to
the output representation of the `i`th token in the inputs.
A common model using encoder-only architecture is
[BERT](https://arxiv.org/abs/1810.04805). We provide `Encoder` feature converter
class to support the Masked Language Modeling (MLM) objective from BERT.
We assume that a unique sentinel such as `[MASK]` token is used to mask some
fraction of the input text and the task is to recover the original text.
Therefore, the "targets" is naturally defined as the original text whereas
"inputs" are the masked text.
Encoder-only models are often used for classification tasks. In BERT, a special
token `[CLS]` is prepended to the input sequence. The last layer's activation
corresponding to this sentinel token is the contextualized representation of the
sequence. We assume that such "classification" sentinel is prepended.
Consider the following example for the MLM task. The input dataset has two
examples, which is packed to one example. We assume that `mask_id = 9` and the
`[CLS]` token has id of 8.
```py
dataset = [{"inputs": [8, 9, 9, 3, 4], "targets": [8, 7, 4, 3, 4]},
{"inputs": [8, 3, 9], "targets": [8, 3, 6]}]
converted_dataset = {
"encoder_input_tokens": [8, 9, 9, 3, 4, 1, 8, 3, 9, 1, 0],
"encoder_target_tokens": [8, 7, 4, 3, 4, 1, 8, 3, 6, 1, 0],
"encoder_segment_ids": [1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 0],
"encoder_positions": [0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 0],
"encoder_loss_weights": [0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0],
}
```
Note that the packed sequence has `[CLS]` token at the beginning of each
sequences. Also note that the loss is taken only on the masked position.
To use the pre-defined `EncoderFeatureConverter`, provide `mask_id` as an
argument.
```py
dataset: tf.data.Dataset = seqio.get_dataset(
mixture_or_task_name="some mlm task",
task_feature_lengths={"inputs": 32, "targets": 32},
dataset_split="train",
shuffle=True,
feature_converter=seqio.EncoderFeatureConverter(
pack=True,
mask_id=9)
)
```
The resulting dataset object has the following 5 fields
|Feature name | Explanation |
|----------------------|---------------------------|
|`encoder_input_tokens` | Input tokens to the encoder |
|`encoder_positions` | Position index in the sequence before packing|
|`encoder_segment_ids` | Sequence membership before packing. Two positions with
the ` same positive integer mean that they belong to the same sequence before
packing. |
|`encoder_target_tokens`| Output tokens from the encoder |
|`encoder_loss_weights` | Binary mask to indicate where the loss should be taken | :
###### Custom architectures
For a custom model architecture, you need to create a subclass of
`FeatureConverter` and override two methods `_convert_features` and
`get_model_feature_lengths` to define how task features are mapped to the model
features, including the length relationships. The existing feature converters
(e.g., `seqio.EncDecFeatureConverter`) follow the same pattern, which can be a
useful starting point.
### Evaluation
The SeqIO `Evaluator` class provides a way to evaluate models on SeqIO Tasks
and Mixtures. For an interactive walkthrough of SeqIO evaluation, see the
[Evaluation Notebook](https://github.com/google/seqio/blob/main/seqio/docs/tutorials.md).
The following is a deep-dive into the `Evaluator` class.
An Evaluator instance can be created by passing a SeqIO Task or
Mixture, and additional eval params like feature converter, split, sequence
lengths, seed, etc. The Evaluator init calls `get_dataset` for each Task to be
evaluated with the appropriate params, creating the `task_dataset`, and invokes
the model-specific feature converter on the `task_dataset` to create features
that can be passed to a model, called `model_dataset`. Both `task_dataset` and
`model_dataset` are stored in-memory so that the dataset can be reused across
multiple evaluations (e.g. on checkpoints from a training run). Both datasets
are enumerated so that even if the order of examples is changed during model
inference, the enumeration can be used to match model outputs to examples from
the `task_dataset`.
For Mixtures, each sub-Task is evaluated separately, regardless of mixing
rates, because in the context of eval benchmarks, Mixtures commonly refer to a
collection of Tasks belonging to that benchmark, each of which is evaluated
separately, e.g. SuperGLUE mixture.
Once an `Evaluator` instance is created with a SeqIO Task or Mixture, a model
can be evaluated by calling `evaluator.evaluate(...)` and passing a `predict_fn`
and/or a `predict_with_aux_fn` and/or a `score_fn` to interact with the model.
`predict_fn` takes the `model_dataset` as input and outputs a `Sequence[(index,
token_ids)]` where `token_ids` is the sequence of token ids generated by the
model for the input example whose index matches `index`. Therefore, even if
`predict_fn` mixes the order of the examples during prediction, the order can be
corrected as long as the correct index for each example is maintained. A common
example is the multi-host setup where the evaluation dataset is split amongst
multiple hosts that independently make predictions and combine the results
during which the ordering can be mixed. `predict_with_aux_fn` is similar to
`predict_fn`, except that it can also return a dictionary of auxiliary values
along with each sequence of `token_ids`, e.g. scores from the generated tokens.
The `score_fn` takes the `model_dataset` as input and returns a
`Sequence[(index, score)]` where `score` is the sequence of log likelihood
scores for the targets in the dataset. This simple interface allows users to
easily integrate the SeqIO evaluation flow with popular training frameworks in
TF and Jax.
Corresponding to the model fns, users can configure three kinds of metric fns in
their Tasks, which are differentiated by their function signature. Metrics
computed on the outputs of `predict_fn` (and `predict_with_aux_fn`) have the
signature `targets` and `predictions` (and optionally `aux_values`), while
metrics computed on the outputs of `score_fn` have the signature `targets` and
`scores`. The `Evaluator` takes care of calling the correct model fns and
metric fns during evaluation. Here is an example of a metric of each type.
```
def sequence_accuracy(targets, predictions):
seq_acc = 100 * np.mean([p == t for p, t in zip(predictions, targets)])
return {"sequence_accuracy": seq_acc}
def log_likelihood(targets, scores):
log_likelihood = np.mean([scipy.special.logsumexp(el) for el in scores])
return {"log_likelihood": log_likelihood}
```
There are 4 steps involved in the evaluation using predicted tokens:
+ the `predict_fn` or `predict_with_aux_fn` returns indices and output_tokens:
`Sequence[Tuple[int, Sequence[int]]]`, potentially with some auxiliary
values.
+ output tokens are decoded by `vocab.decode`
+ postprocessors configured in Tasks are applied to the decoded output. These
are denoted as predictions.
+ metric fns configured in Tasks are applied to the predictions and the cached
targets.
There are 2 steps involved in the evaluation using scores:
+ the `score_fn` returns indices and scores: `Sequence[Tuple[int,
Sequence[float]]]`
+ metric fns configured in Tasks is applied to the scores and the cached
targets.
Training codebases like T5X provide integration with SeqIO evaluation to allow
evaluating checkpoints on SeqIO Tasks and Mixtures. See
[T5X Eval](https://github.com/google-research/t5x/blob/main/docs/usage/eval.md)
for instructions.
## Differences from `t5.data`
The original `t5` library introduced and implemented the `t5.data.Task`
abstraction for specifying preprocessing and evaluation metrics for text-to-text
tasks. When creating a task, users specify a source dataset of raw text, some
preprocessing steps, a vocabulary for tokenization, and evaluation metrics. The
fully-specified Task can then be used to pre-train or fine-tune a
encoder-decoder transformer model. However, the design included many baked-in
assumptions about the types of tasks users could specify.
SeqIO removes some of the constraints of this abstraction:
* Inputs and outputs are no longer required to be strings (e.g., it may be
images or audio).
* Architectures other than the original encoder-decoder are supported (e.g.,
decoder-only language models like GPT or encoder-only models like BERT).
* Users can control at which stage of the pipeline offline caching occurs.
* Users can control when and where EOS tokens are added.
Furthermore, SeqIO has been made more modular with respect to the Mesh
TensorFlow Transformer. This allows it to be used with other model
implementations with more consistency and much less code duplication.
## Advanced Postprocessing `Task`
### TriviaQA (Closed-book, open-domain version)
This version of TriviaQA was introduced in [Roberts et al.
2020](https://arxiv.org/abs/2002.08910).
```py
seqio.TaskRegistry.add(
"trivia_qa_open",
source=seqio.TfdsDataSource(
tfds_name="trivia_qa/unfiltered.nocontext:1.1.0",
splits={
"train": "train[:90%]",
"validation": "train[90%:]",
"test": "validation"
}),
preprocessors=[
tqa_open_preprocessor,
seqio.preprocessors.tokenize,
seqio.preprocessors.append_eos,
],
output_features={
"inputs": seqio.Feature(
seqio.SentencePieceVocabulary("/path/to/inputs/vocab"),
add_eos=False, dtype=tf.int32
),
"targets": seqio.Feature(
seqio.SentencePieceVocabulary("/path/to/targets/vocab"),
add_eos=True, dtype=tf.int32
),
},
postprocess_fn=tqa_open_postprocessor,
metric_fns=[tqa_metric])
```
In this example, we are using the `TfdsDataSource`. We specify the name of the
TriviaQA dataset in TFDS
([`"trivia_qa"`](https://www.tensorflow.org/datasets/catalog/trivia_qa)), the
specific config that excludes the context for the open domain setting
(`"unfiltered.nocontext"`), and the version number (`"1.1.0"`). We also override
the default splits to match what is commonly used for the open domain setting.
Specifically, we set our "test" split to be the TFDS "validation" split, and
create a small pseudo-"validation" set by taking examples out of the TFDS
"train" split.
The preprocessor `tqa_open_preprocessor` is defined as follows.
```py
def tqa_open_preprocessor(
dataset: tf.data.Dataset,
prefix:str = "trivia_qa question: "
) -> tf.data.Dataset:
"""Convert TriviaQA dataset to open domain qa examples.
The function takes the trivia_qa TFDS dataset and emits examples of the
form:
{
"inputs": "trivia_qa question: What are the names of the Olsen Twins?"
"targets": "Mary-Kate and Ashley",
"answers": ["Mary-Kate and Ashley", "Ashley and Mary-Kate"]
}
Args:
dataset: a tf.data.Dataset to process.
prefix: str, prefix to prepend to the inputs.
Returns:
a tf.data.Dataset
"""
def tqa_map(ex):
"""Map TriviaQA example to text-to-text example."""
return {
"inputs": prefix + ex["question"],
"targets": ex["answer"]["value"],
"answers": ex["answer"]["aliases"],
}
return dataset.map(tqa_map, num_parallel_calls=tf.data.experimental.AUTOTUNE)
```
Or with the `seqio.map_overdataset` decorator, we have
```py
def tqa_open_preprocessor(
dataset: tf.data.Dataset,
prefix: str = "trivia_qa question: "
) -> tf.data.Dataset:
@seqio.map_over_dataset
def tqa_map(ex: Mapping[str, tf.Tensor]) -> Mapping[str, tf.Tensor]:
"""Map TriviaQA example to text-to-text example."""
return {
"inputs": prefix + ex["question"],
"targets": ex["answer"]["value"],
"answers": ex["answer"]["aliases"],
}
return tqa_map(dataset)
```
Here we made a thin wrapper to emphasize that the function decorated by
`seqio.map_over_dataset` takes in an instance of `tf.data.Dataset`. In practice,
this wrapper is not necessary.
The postprocessor for this example is `tqa_open_postprocessor`, which is defined
as follows:
```py
def tqa_open_postprocessor(output_or_target, example=None, is_target=False):
"""Returns output as answer, or all answers if the full example is provided."""
if is_target:
return [a.decode("utf-8") for a in example["answers"]]
else:
return output_or_target.decode("utf-8")
```
When processing the target, we ignore `output_or_target` (equivalent to
`example["targets"]`) since it is just selecting a single answer in
`trivia_qa_open`. Instead, we extract the full list of answers from the example
and convert them from bytes to text. When handling the model output, we simply
convert it to text from detokenized bytes.
The metric function `tqa_metric` is defined as:
```py
def tqa_metric(
targets: Sequence[Sequence[str]],
predictions: Sequence[str]
) -> Mapping[str, seqio.metrics.MetricValueValue]:
"""Computes official TriviaQA metrics.
Args:
targets: list of lists of strings
predictions: list of strings
Returns:
dict with score_key: squad score across all targets and predictions
"""
if len(targets) != len(predictions):
raise ValueError("Number of targets and predictions must match.")
def _normalize_answer(text):
"""Lower text and remove punctuation, articles and extra whitespace."""
# Remove articles.
text = re.sub(r"\b(a|an|the)\b", " ", s)
# Remove punctuation.
for punc in string.punctuation:
text = text.replace(punc, '')
# Normalize white space
text = " ".join(s.split())
return text
# Normalize answers before comparing.
targets = [[_normalize_answer(t) for t in u] for u in targets]
predictions = [_normalize_answer(p) for p in predictions]
em = np.mean([
max(pred == gt for gt in ground_truths)
for pred, ground_truths in zip(predictions, targets)
])
return {
"exact_match": seqio.metrics.Scalar(em),
}
```
## Citing SeqIO
Please use the following bibtex entry to cite SeqIO.
```
@article{roberts2022t5x,
url = {https://arxiv.org/abs/2203.17189},
author = {Roberts, Adam and Chung, Hyung Won and Levskaya, Anselm and Mishra,
Gaurav and Bradbury, James and Andor, Daniel and Narang, Sharan and Lester,
Brian and Gaffney, Colin and Mohiuddin, Afroz and Hawthorne, Curtis and
Lewkowycz, Aitor and Salcianu, Alex and van Zee, Marc and Austin, Jacob and
Goodman, Sebastian and Soares, Livio Baldini and Hu, Haitang and
Tsvyashchenko, Sasha and Chowdhery, Aakanksha and Bastings, Jasmijn and
Bulian, Jannis and Garcia, Xavier and Ni, Jianmo and Chen, Andrew and Kenealy,
Kathleen and Clark, Jonathan H. and Lee, Stephan and Garrette, Dan and
Lee-Thorp, James and Raffel, Colin and Shazeer, Noam and Ritter, Marvin and
Bosma, Maarten and Passos, Alexandre and Maitin-Shepard, Jeremy and Fiedel,
Noah and Omernick, Mark and Saeta, Brennan and Sepassi, Ryan and Spiridonov,
Alexander and Newlan, Joshua and Gesmundo, Andrea},
title = {Scaling Up Models and Data with $\texttt{t5x}$ and $\texttt{seqio}$},
journal={arXiv preprint arXiv:2203.17189},
year = {2022},
}
```
Raw data
{
"_id": null,
"home_page": "https://github.com/google/seqio/tree/nightly",
"name": "seqio-nightly",
"maintainer": null,
"docs_url": null,
"requires_python": null,
"maintainer_email": null,
"keywords": "sequence preprocessing nlp machinelearning",
"author": "Google Inc.",
"author_email": "no-reply@google.com",
"download_url": "https://files.pythonhosted.org/packages/6d/1f/967e04296da645633e33b0098d454d8b026745850438624e0623a5ecf529/seqio_nightly-0.0.18.dev20241219.tar.gz",
"platform": null,
"description": "# SeqIO\n\n*Task-based datasets, preprocessing, and evaluation for sequence models*\n\n*Go to [SeqIO ReadTheDocs Documentation Page](https://seqio.readthedocs.io/).*\n\n\n## Overview\n\n**SeqIO** is a library for processing sequential data to be fed into downstream\nsequence models. It uses\n[`tf.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset)\nto create scalable data pipelines but requires minimal use of TensorFlow. In\nparticular, with one line of code, the returned dataset can be transformed to a\nnumpy iterator and hence it is fully compatible with other frameworks such as\n[JAX](https://github.com/google/jax) or\n[PyTorch](https://pytorch.org/).\n\nSeqIO assumes that the dataset is a sequence. Modalities such as text or audio\nare naturally supported. Images are supported as long as they are represented as\nsequences (e.g., [Image GPT](http://proceedings.mlr.press/v119/chen20s.html)).\n\nSeqIO is a refactor of the\n[`t5.data`](https://github.com/google-research/text-to-text-transfer-transformer/)\nlibrary used (in conjunction with the\n[Mesh Tensorflow](https://github.com/tensorflow/mesh) Transformer\nimplementation) to train the T5 models introduced in [*Exploring the Limits of\nTransfer Learning with a Unified Text-to-Text\nTransformer*](https://arxiv.org/abs/1910.10683).\n\nIf you have used `t5.data` in the past and want to know how SeqIO differs,\nplease read [this section](#differences-from-t5data).\n\n## Installation\n\n### From Pypi\n\n```sh\npip install seqio\n```\n\n### From Source\n\n```sh\ngit clone https://github.com/google/seqio.git\ncd seqio\npip install -e .\n```\n\n## Usage Tutorial\n\nAt a high level, we use SeqIO with the following steps.\n\n1. Define a `Task` (and optionally a `Mixture`).\n\n1. Define (or use an existing) a `FeatureConverter` based on the model\n architecture.\n\n1. Use the top-level function `seqio.get_dataset` to obtain the\n `tf.data.Dataset` instance.\n\nWe will look at each of these steps in detail.\n\n\n### Defining a `Task`\n\nThe most important class in SeqIO is the `Task`. It is an abstraction that\ncombines:\n\n * a raw *data source*\n * one or more *preprocessing* steps\n * a *vocabulary* to tokenize/detokenize each preprocessed feature for the\n model\n * a *postprocessor* to convert detokenized model outputs into a format for\n evaluation\n * one or more *metrics* to evaluate with\n\nOftentimes a `Task` lines up with a common benchmark. In this tutorial, we use\n[WMT 19 English-German](http://www.statmt.org/wmt19/translation-task.html)\nmachine translation task. In the end, our `Task` will look like this:\n\n\n```py\nseqio.TaskRegistry.add(\n \"wmt19_ende\",\n seqio.TfdsDataSource(tfds_name=\"wmt19_translate/de-en:1.0.0\"),\n preprocessors=[\n functools.partial(\n translate, source_language='en', target_language='de'),\n seqio.preprocessors.tokenize, seqio.preprocessors.append_eos\n ],\n output_features={\n 'inputs':\n seqio.Feature(\n seqio.SentencePieceVocabulary('/path/to/inputs/vocab'),\n add_eos=False,\n dtype=tf.int32),\n 'targets':\n seqio.Feature(\n seqio.SentencePieceVocabulary('/path/to/targets/vocab'),\n add_eos=True,\n dtype=tf.int32),\n },\n metric_fns=[bleu])\n```\n\nWe typically add the `Task` to the global registry when we define it (as shown\nabove) to make it easier to use with model configs and flags. Thus, it must\nhave a unique string name (`\"wmt19_ende\"` in this case). Note, however, that\nyou may also instantiate a `seqio.Task` directly without adding it to the\nregistry, if desired.\n\nWe'll now break down each part of the task definition.\n\n#### Data Source\n\nData sources are the first step in your pipeline, providing a way to load raw\ndata in many formats as a `tf.data.Dataset`.\nAll data sources are subclasses of the `DataSource` base class and are defined\nin\n[dataset_providers](https://github.com/google/seqio/tree/main/seqio/dataset_providers.py).\n\nExisting implementations include:\n\n * `TfdsDataSource` for loading examples from\n [TensorFlow Datasets](https://www.tensorflow.org/datasets).\n * `TextLineDataSource` for loading examples from text files (e.g., tsv).\n * `TFExampleDataSource` for loading\n [`tf.train.Example`](https://www.tensorflow.org/tutorials/load_data/tfrecord)\n protos from a file (e.g. a `TFRecord` file.)\n * `FunctionDataSource` for providing an custom function that returns a\n `tf.data.Dataset`.\n\nIn our example, we are using the `TfdsDataSource`. We specify the name of the\nWMT dataset in TFDS\n([`\"wmt19_translate\"`](https://www.tensorflow.org/datasets/catalog/wmt19_translate)),\nthe specific config for the language pair that excludes the context for the open\ndomain setting (`\"de-en\"`), and the version number (`\"1.0.0\"`).\n\n#### Output Features\n\nThe `output_features` field expects a dictionary that maps string feature names\nto `seqio.Feature` objects. This defines what the `Task` is expected to produce\nin its output examples. The output examples *may* contain additional fields, but\nthey *must* contain these fields in the specified format or exceptions will be\nraised.\n\nEach `Feature` includes:\n\n* A `vocabulary`, which must subclass\n [`seqio.Vocabulary`](https://github.com/google/seqio/tree/main/seqio/vocabularies.py),\n to specify how the feature can be tokenized and detokenized. You may use\n `seqio.PassThroughVocabulary` if tokenization is not necessary.\n* `add_eos`, which specifies whether the feature should end with the\n vocabulary's EOS token.\n* The output `dtype` which must be a `tf.dtypes.DType`.\n\n**Note:** specifying these options on `Feature` does not by itself ensure the\nproper transformations are applied -- you must also include the necessary\npreprocessors.\n\nThe [tasks used in T5](TODO) all produce \"inputs\" and \"targets\" features to be\nconsumed by the text-to-text model. For a decoder-only language model, only a\nsingle feature (e.g., \"targets\") would be necessary. Nevertheless, SeqIO is\nflexible enough to generate arbitrary output features what will be converted\ninto model features by the [`FeatureConverter`](#featureconverter) later in the\npipeline.\n\n#### Preprocessors\n\nPreprocessors are functions that transform one `tf.data.Dataset` into a new\n`tf.data.Dataset`. Typically this involves executing a `map` over the given\ndataset. The preprocessors provided to the `Task` will be executed sequentially.\n\nAs an example, let's look at the previously undefined `translate` from the\n\"wmt19_ende\" example above.\n\n```py\ndef translate(dataset: tf.data.Dataset,\n source_language: str,\n target_language: str) -> tf.data.Dataset:\n def _translate(ex: Mapping[str, tf.Tensor]) -> Mapping[str, tf.Tensor]:\n \"\"\"Convert a translation example to a text2text pair.\n\n For example, say the dataset returns examples of this format:\n {'de': 'Das ist gut.', 'en': 'That is good.'}\n If source_language = 'de', target_language = 'en', then the outputs will\n have the format:\n {'inputs': 'translate de to en: Das ist gut.',\n 'targets': 'That is good.'}\n\n Args:\n ex: an example to process.\n source_language: source language code (e.g. 'en') to translate from.\n target_language: target language code (e.g. 'de') to translate to.\n\n Returns:\n A preprocessed example with the format listed above.\n \"\"\"\n src_str = f'translate {source_language}'\n tgt_str = f' to {target_language}: '\n return {\n 'inputs': tf.strings.join([src_str, tgt_str, ex[source_language]]),\n 'targets': ex[target_language],\n }\n\n return dataset.map(_translate,\n num_parallel_calls=tf.data.experimental.AUTOTUNE)\n```\n\nThe TFDS dataset provides the dataset where each example has the form: `{'de':\n'Das ist gut.', 'en': 'That is good.'}`. We convert this to \"inputs\" and\n\"targets\" with the appropriate prompt to inform the model of the task.\n\n\nA few **important** notes:\n\n1. When instantiating a `Task`, the preprocessor functions can have the\n following arguments: `dataset`, `output_features`, and `sequence_length`.\n The first (positional) dataset argument is always required. If an argument\n named `output_features` is provided, the\n [output feature mapping](#output-features) will be passed to the\n preprocessor. If `sequence_length` is provided, a mapping from feature name\n to its *maximum* final sequence length\n ([provided by the caller](#getting-a-preprocessed-dataset)) will be\n passed -- any sequences that are too long after preprocessing will be\n automatically truncated. If a preprocessor function does have other\n arguments, they must have default values or be bound (e.g., with\n `functools.partial` as used in `translate`) before instantiating the `Task`.\n\n1. Mapping functions operate on and return `tf.Tensor`s using TensorFlow\n operations. This is more flexible than it may sound:\n\n * Automatic\n [AutoGraph](https://www.tensorflow.org/guide/function#autograph_transformations)\n conversion allow you to write python control flow in your\n transformations.\n * [tf.experimental.numpy](https://www.tensorflow.org/guide/tf_numpy)\n provides a numpy interface.\n * [`tf.py_function`](https://www.tensorflow.org/api_docs/python/tf/py_function)\n allows you to wrap arbitrary Python code. Note: `tf.data` pipelines\n using this function can only be run in the python process where they\n were defined, and performance is limited by the python GIL.\n\n See `tf.data.Dataset`\n [documentation](https://www.tensorflow.org/api_docs/python/tf/data/Dataset)\n for more details.\n\n1. When calling `map`, it is important to **always** set\n `num_parallel_calls=tf.data.experimental.AUTOTUNE` to avoid creating a\n bottleneck. The `seqio.map_over_dataset` decorator helps enforce this as\n follows.\n\n ```py\n @seqio.map_over_dataset\n def translate(ex: Mapping[str, tf.Tensor],\n source_language: str,\n target_language: str) -> Mapping[str, tf.Tensor]:\n \"\"\"Convert a translation dataset to a text2text pair.\n\n For example, say the dataset returns examples of this format:\n {'de': 'Das ist gut.', 'en': 'That is good.'}\n If source_language = 'de', target_language = 'en', then the outputs will\n have the format:\n {'inputs': 'translate German to English: Das ist gut.',\n 'targets': 'That is good.'}\n\n Args:\n ex: an example to process.\n source_language: source language code (e.g. 'en') to translate from.\n target_language: target language code (e.g. 'de') to translate to.\n\n Returns:\n A preprocessed example with the format listed above.\n \"\"\"\n src_str = f'translate {source_language}'\n tgt_str = f' to {target_language}: '\n return {\n 'inputs': tf.strings.join([src_str, tgt_str, ex[source_language]]),\n 'targets': ex[target_language],\n }\n ```\n\n Note that `translate` takes as input an individual example. Then\n `seqio.map_over_dataset` decorates it to a function that takes in a\n `tf.data.Dataset` instance.\n\n1. Stochastic operations must be\n [stateless](https://www.tensorflow.org/guide/random_numbers#stateless_rngs)\n if deterministic pipelines are needed. To get (optionally deterministic)\n seeds for these operations, use the `seqio.map_over_dataset(num_seeds=n)`\n decorator. For example:\n\n ```py\n def random_chunk(\n dataset: tf.data.Dataset,\n sequence_length: Mapping[str, int]\n ) -> tf.data.Dataset:\n \"\"\"Takes a random chunk out of each feature with size `sequence_length`.\"\"\"\n\n @seqio.map_over_dataset(num_seeds=1)\n def take_chunk(\n ex: Mapping[str, tf.Tensor],\n seed\n ) -> Mapping[str, tf.Tensor]:\n new_ex = {}\n for k, v in ex.items():\n if k in sequence_length:\n length = sequence_length[k]\n start_idx = tf.random.stateless_uniform(\n (), seed, 0, tf.size(v) - (length + 1))\n new_ex[k] = v[start_idx:start_idx+length]\n else:\n new_ex[k] = v\n return new_ex\n\n return take_chunk(dataset)\n ```\n\n If `num_seeds > 1`, the arg will instead be called `seeds` and will contain\n a sequence of seeds.\n\nIn our \"wmt_19_ende\" task, we also use the predefined preprocessors\n`seqio.preprocessors.tokenize` and `seqio.preprocessors.append_eos`. The former\nuses each `Feature.vocabulary` to tokenize it, and the the latter appends\n`Feature.vocabulary.eos_id` to the feature if the `Feature.add_eos` is True. See\n[preprocessors.py](https://github.com/google/seqio/tree/main/seqio/preprocessors.py) for\ntheir implementations and other useful preprocessors.\n\n#### Postprocessor\n\nDuring evaluation, the model outputs are first detokenized using the output\nfeature vocabulary. Before passing these predictions to the metric functions,\nthey can be run through a Python postprocessing function, alongside the full\ninput example. Similarly, the raw targets are run through this function before\nbeing passed to the metrics. Since the postprocess function is used on both the\nmodel output and the targets, it is passed an `is_target` boolean in case the\nbehavior should be different. It is also passed the fully preprocessed example,\nincluding fields that were excluded from `output_features`.\n\nFor the \"wmt19_ende\", we don't need any postprocessors. See \"trivia_qa_open\"\ntask in the [Advanced Postprocessing `Task`](#advanced-postprocessing-task) for\nan example postprocessor.\n\n#### Metrics\n\nMetrics are functions that are passed (by the [Evaluator](#evaluator)) the\nfully-materialized list of postprocessed model outputs (or scores) and targets\nand return a mapping from string names to `MetricValue` objects containing their\nvalues. These are most commonly floating-point scalars, but may also be text,\nimages, audio, histograms, etc (see\n[metrics.py](https://github.com/google/seqio/tree/main/seqio/metrics.py) for the full list).\n\nThe first argument of a metric function must always be called `targets`. If the\nsecond argument of a metric function is called `predictions`, it will be passed\nthe decoded and detokenized model prediction. If it is called `scores`, it will\nbe passed a list of log-likelihood scores for each example.\n\nIf multiple metric functions are provided, they will all be used and their\nreturned mappings merged.\n\n##### Prediction Metrics\n\nPrediction metrics are computed using the postprocessed targets and model\noutputs (predictions). The args must be named `targets` and `predictions`.\n\nLet's look at the metric function used for \"wmt19_ende\" task. A standard metric\nfor the translation task is BLEU and we use `sacrebleu` implementation.\n\n```py\ndef bleu(targets: Sequence[str], predictions: Sequence[str]):\n \"\"\"Computes BLEU score.\n\n Args:\n targets: list of strings or list of list of strings if multiple references\n are present.\n predictions: list of strings\n\n Returns:\n bleu_score across all targets and predictions\n \"\"\"\n if isinstance(targets[0], list):\n targets = [[x for x in target] for target in targets]\n else:\n # Need to wrap targets in another list for corpus_bleu.\n targets = [targets]\n\n bleu_score = sacrebleu.corpus_bleu(predictions, targets,\n smooth_method=\"exp\",\n smooth_value=0.0,\n force=False,\n lowercase=False,\n tokenize=\"intl\",\n use_effective_order=False)\n return {\"bleu\": bleu_score.score}\n```\n\n\n##### Score Metrics\n\nScore metrics are computed using the postprocessed targets and their\nlog-likelihood scores according to the model. The args must be named `targets`\nand `scores`.\n\n```py\ndef perplexity(targets: Sequence[str], scores: Sequence[float]):\n return {\n \"perplexity\": seqio.metrics.Scalar(np.exp(np.mean(scores)))\n }\n```\n\n### Defining a `Mixture`\n\nOnce you have multiple `Task`s added to the `TaskRegistry`, you can define\n`Mixture`s that will combine the examples from them according to some specified\nrate. Examples will then be sampled from each task in proportion to its rate.\n\nAs an example, [Multilingual T5](http://goo.gle/mt5) uses a `Mixture` of\nper-language `Task`s with tail languages up-weighted in the mixture.\n\nThere are 3 ways to specify the tasks and their rates:\n\n1. Provide a rate along with each task's name (rates are normalized before\n sampling). In this example, the rates provided are units of the final\n mixture that come from the component tasks. Here, 1/(1+7) of the final\n mixture will come from \"task1\".\n\n ```py\n seqio.MixtureRegistry.add(\n \"mix1\",\n [(\"task1\", 1), (\"task2\", 7)]\n )\n ```\n\n1. Provide a constant default rate for some or all tasks, which will be used\n when only the name is provided. The example below will produce identical\n mixing rates as the previous one.\n\n ```py\n seqio.MixtureRegistry.add(\n \"mix1\",\n [(\"task1\", 0.5), \"task2\"],\n default_rate=3.5\n )\n ```\n\n1. Provide a function that generates the rate for each task at runtime. The\n example below uses the provided\n [`seqio.mixing_rate_num_examples`](https://github.com/google/seqio/tree/main/seqio/utils.py),\n which uses the number of examples (computed during\n [offline caching](#optional-offline-caching)) as the rate for each task.\n\n ```py\n seqio.MixtureRegistry.add(\n \"mix2\",\n [\"task1\", \"task2\"],\n default_rate=seqio.mixing_rate_num_examples\n )\n ```\n\nYou can also include `Mixture`s in your `Mixture`! For example, the following\ntask would contain 1/24 (from \"mix1\") + 1/3 \"task1\", 7/24 (from \"mix1\") of\n\"task2\", and 1/3 \"task3\".\n\n```py\nseqio.MixtureRegistry.add(\n \"mix3\",\n [\"mix1\", \"task1\", \"task3\"],\n default_rate=1\n)\n```\n\nIf sampling without replacement is important for your task, you can achieve that\nby using either deterministic tasks or using dataset checkpointing (and not\nrunning more than an epoch) for a non-deterministic task. Otherwise, the mixture\nmay sample with replacement.\n\n### Getting a Preprocessed Dataset\n\nNow that your `Task` (and/or `Mixture`) is defined, its primary functionality is\nto use it to generate a dataset.\n\nYou may first need to use `seqio.get_mixture_or_task(mixture_or_task_name)` to\naccess your dataset provider from the registry.\n\nAfter that, you can call `get_dataset` to build the `tf.data.Dataset`. For\nexample:\n\n```py\ndataset = seqio.get_mixture_or_task(\"mix1\").get_dataset(\n sequence_length={\"inputs\": 256, \"targets\": 128},\n split=\"train\",\n shuffle=True,\n num_epochs=1,\n shard_info=seqio.ShardInfo(index=0, num_shards=10),\n use_cached=False,\n seed=42\n)\n\n# Print the first 5 examples.\nfor _, ex in zip(range(5), dataset.as_numpy_iterator()):\n print(ex)\n```\n\nSome notes on a few of the arguments:\n\n* `sequence_length`: An *optional* mapping from feature name to *maximum*\n length. Will be passed to the preprocessors with a `sequence_length`\n argument. If not `None`, the final example features will be truncated if\n they exceed the specified length. Note that this value may be required to be\n set if any of the preprocessors use the `sequence_length` argument and do\n not handle the `None` case.\n* `num_epochs`: The number of times to repeat the source dataset.\n Preprocessing will be re-applied with new seeds to enable new samples from\n stochastic steps. Note that if the `CacheDatasetPlaceholder` is included\n (see below) preprocessing is only re-applied after that step.\n* `shard_info`: An optional sharding specification for loading a deterministic\n subset of the dataset. Loading will be most efficient if the number of\n shards evenly divides the number of shards in the raw data source.\n* `use_cached`: Specifies whether to load from a pre-cached task for increased\n performance or to do the preprocessing on-the-fly. See the\n [following section](#optional-offline-caching) for details on how to cache\n your task, which must be done before this can be set to `True`.\n* `seed`: An optional seed to use for deterministic shuffling and (stateless)\n stochastic ops. These operations will still be pseudorandom but will be\n reproducible with the same seed. Set to `None` if determinism is not\n desired.\n\n### (Optional) Offline Caching\n\nFor improved performance at load time and to avoid redundant computations for\ncommonly used tasks, you can pre-cache your `Task` with all or part of the\npreprocessing done in advance of training; this partial preprocessing is\nespecially useful if the Task is stochastic and one wishes to cache the\ndeterministic operations while running the stochastic ones on the fly. Caching\nstochastic SeqIO Mixtures in this way is not supported.\n\nThe first step to doing so is to add a\n`seqio.CacheDatasetPlaceholder(required=False)` as one of the steps in your\npreprocessing pipeline. All steps before the placeholder will be cached offline\nand all steps after will be executed on the fly at load time. You may set\n`required=True` if you want `get_dataset` to fail unless `use_cached=True`.\n\nCaveats:\n\n* Any stochastic operations that you wish to be re-run when `num_epochs > 1`\n or with a different `seed` *should* go after the placeholder since only a\n single sample will be cached.\n* Any preprocessing steps that use the `sequence_length` argument *must* come\n after the `seqio.CacheDatasetPlaceholder` preprocessor since this is only\n known at runtime, or an exception will be raised. If you wish to cache for a\n specific sequence length, you can use\n [`seqio.experimental.add_fully_cached_task`](https://github.com/google/seqio/tree/main/seqio/experimental.py).\n\nOnce your `Task` is registered, you can run\n[`cache_tasks_main`](https://github.com/google/seqio/tree/main/seqio/scripts/cache_tasks_main.py)\nto execute the offline preprocessing, providing it with the module containing\nyour task definitions via the `--module_import` flag. For very large datasets,\nit's recommended you run this [Apache Beam](https://beam.apache.org/) script on\na distributed framework like\n[Google Cloud DataFlow](https://beam.apache.org/documentation/runners/dataflow/).\n\nFinally, you are ready to load the cached version of your `Task` (or `Mixture`)\ncontaining it. You will need to add the path to the directory you passed to\n`--output_cache_dir` via `seqio.add_global_cache_dirs([\"/my/cache/dir\"])`. Now\nwhen you call `task_or_mixture.get_dataset(..., use_cached=True)`, the data will\nbe loaded from the cache directory instead of the raw data source.\n\n### Feature Converters\n\nThe role of `Task` is to provide the dataset object with as little\nmodel-specific features (e.g., generic \"inputs\" and \"targets\") while the Feature\nConverters transform the model-agnostic features to model-specific features\n(e.g., \"encoder_input_tokens\"). We refer to the former as \"task features\" and\nthe latter as \"model features\".\n\nLet's use machine translation (English to German) as a running example.\n\nThe raw data consists of sentence pairs such as\n\n```\n\"That is good\\tDas ist gut.\"\n```\n\nA task registered to `Task` (e.g.,\n[wmt_t2t_ende_v003](t5/data/tasks.py?l=156&rcl=337594707))\nreads these sentence pairs from the data source and applies a series of\n[preprocessors](t5/data/preprocessors.py?rcl=343354647).\nOne of the internal representations looks like\n\n```python\n{\"inputs\": \"translate English to German: That is good.\",\n \"targets\": \"Das ist gut.\"}\n```\n\nThe final output from the `Task` is a tokenized version of the parallel\nsentences. In the following toy example (the token ids do not correspond to the\nabove string example), the dataset consists of 2 examples.\n\n```python\ndataset = [{\"inputs\": [7, 8, 5], \"targets\": [3, 9]},\n {\"inputs\": [8, 4, 9, 3], \"targets\": [4]}]\n```\n\nThe format is in the `tf.data.Dataset` (i.e., each example is a dictionary with\n\"inputs\" and \"targets\" fields.\n\nThe `FeatureConverter` then takes this as an input and converts to the\nmodel-specific features. In addition, the feature converter performs padding and\noptionally packing (for model implementations that support it) for efficiency.\nFor example, let's assume that we are using the standard Transformer\narchitecture with an encoder and a decoder. The output of the feature converter\nis\n\n```python\nconverted_dataset = [{\n \"encoder_input_tokens\": [7, 8, 5, 1, 8, 4, 9, 3, 1, 0],\n \"encoder_segment_ids\": [1, 1, 1, 1, 2, 2, 2, 2, 2, 0],\n \"encoder_positions\": [0, 1, 2, 3, 0, 1, 2, 3, 4, 0],\n \"decoder_target_tokens\": [3, 9, 1, 4, 1, 0, 0],\n \"decoder_input_tokens\": [0, 3, 9, 0, 4, 0, 0],\n \"decoder_loss_weights\": [1, 1, 1, 1, 1, 0, 0],\n \"decoder_positions\": [0, 1, 2, 0, 1, 0, 0],\n \"decoder_segment_ids\": [1, 1, 1, 2, 2, 0, 0],\n}]\n```\n\nIn this case, two task examples are packed into one. `*_segment_id` and\n`*_position` are the fields used to denote the membership and position of packed\ntoken in the original sequence. The EOS ids (i.e., 1) are appended. In addition,\neach fields is padded to the specified length.\n\nWe will look at the details of this example in Encoder-decoder architecture:\n`seqio.EncDecFeatureConverter` section.\n\n\n#### Feature converters provided out of the box\n\nWe provide feature converters for three common architectures: encoder-decoder,\ndecoder-only and encoder-only. Here we describe how users can use the feature\nconverters for each of these architectures out of the box as a part of the SeqIO\nlibrary.\n\nIn the SeqIO library, each architecture has a class defining how the task\nfeatures are converted to model features. Since these feature converters are\nalready implemented, it is straightforward to use them by providing the class as\na `feature_converter` argument of the `seqio.get_dataset` function. The\nfollowing sections show example usage of `seqio.get_dataset`.\n\n##### Encoder-decoder architecture: `seqio.EncDecFeatureConverter`\nThis is the architecture of the original Transformer paper. For the\nEnglish-to-German translation task, the following function call retrieves the\n`tf.data.Dataset` object with the model features.\n\n```python\ndataset: tf.data.Dataset = seqio.get_dataset(\n mixture_or_task_name=\"wmt_t2t_ende_v003\",\n task_feature_lengths={\"inputs\": 32, \"targets\": 32},\n dataset_split=\"train\",\n shuffle=True,\n feature_converter=seqio.EncDecFeatureConverter(pack=True)\n)\n```\n\nThe resulting dataset object has the following 7 fields\n\n|Feature name | Explanation |\n|----------------------|---------------------------|\n|`encoder_input_tokens` | Input tokens to the encoder. |\n|`encoder_positions` | Position index in the sequence before packing.|\n|`encoder_segment_ids` | Sequence membership before packing. Two positions with\nthe same positive integer mean that they belong to the same sequence before\npacking. |\n|`decoder_input_tokens` | Input tokens to the decoder. |\n|`decoder_target_tokens`| Output tokens from the decoder. |\n|`decoder_loss_weights` | A weight on each position that can be used as a mask. |\n|`decoder_positions` | Position index in the sequence before packing. |\n|`decoder_segment_ids` | Same as `encoder_segment_ids` but for decoder.|\n\n##### Decoder-only architecture\n\nThis architecture consists of a single autoregressive stack, which we denote as\na \"decoder\".\n\nA decoder autoregressively produces an output sequence.\nTherefore, it can be used as a standard language model if the task dataset has\nonly \"targets\" features, i.e., self-supervised. If the task dataset also has an\n\"inputs\" field, e.g., supervised machine translation, the decoder can still be\nused by concatenating the inputs and targets fields. See [Raffel et al.\n(2020)](https://arxiv.org/abs/1910.10683), Section 3.2.1 for more detailed take\non this topic.\n\nWe support both uses cases and refer to the former as *standard language model*\nand the latter as *prefix language model*. Each of these models is described\nseparately below.\n\nNote that we do not provide special features to denote how the dataset should be\nconsumed. For example, a Transformer-based fully autoregressive decoder has a\nfully-causal self-attention layer. Since there are many ways of implementing the\nmasking pattern for such attention layer and, more importantly, SeqIO is not\nlimited to attention-based models, we leave it up to the model implementations\nto apply the masking pattern. There is one exception, and we cover this in\nthe Prefix LM section below.\n\nA common use pattern is to pretrain a decoder model with the left-to-right\nlanguage modeling objective (unsupervised) using `seqio.LMFeatureConverter` and\nthen fine-tune (supervised) using `seqio.PrefixLMFeatureConverter`.\n\n\n###### Standard LM\n\nFor the standard language model, the task dataset only has \"targets\" field.\nTherefore, the sequence length specification only needs to specify targets.\n\n```python\ndataset: tf.data.Dataset = seqio.get_dataset(\n mixture_or_task_name=\"standard_lm\",\n task_feature_lengths={\"targets\": 32},\n dataset_split=\"train\",\n shuffle=True,\n feature_converter=seqio.LMFeatureConverter(pack=True)\n)\n```\n\nNote that \"standard_lm\" is not a registered task in the codebase. It is the\nleft-to-right language modeling task, i.e., predict the next token given the\nprevious tokens on some language corpus (e.g.,\n[C4](https://www.tensorflow.org/datasets/catalog/c4)).\n\nThe output dataset has the following model features.\n\n|Feature name | Explanation |\n|----------------------|---------------------------|\n|`decoder_target_tokens`| Output tokens from the decoder |\n|`decoder_input_tokens` | Input tokens to the decoder |\n|`decoder_loss_weights` | Binary mask to indicate where the loss should be taken |\n|`decoder_positions` | Position index in the sequence before packing|\n|`decoder_segment_ids` | Sequence membership before packing. Two positions with\nthe same positive integer mean that they belong to the same sequence before\npacking. |\n\nThe `decoder_target_tokens` is a shifted version of `decoder_input_tokens` for\nthe standard teacher-forced autoregressive training.\n\n\n\n###### Prefix LM: `seqio.PrefixLMFeatureConverter`\n\nIf the input dataset has a notion of \"inputs\" and \"targets\", we can concatenate\nthem so that we can still use a single stack decoder. Therefore, the output only\ncontains \"targets\" just like standard LM case.\n\nWe use the same toy example for English-to-German translation task as a running\nexample:\n\n```\n{\"inputs\": \"translate English to German: That is good.\",\n \"targets\": \"Das ist gut.\"}\n```\n\nTo be consumed by the decoder-only stack, `seqio.PrefixLMFeatureConverter`\nconcatenates them form the new \"targets\". Consider 2-layer decoder architecture\nwhose activations are shown below\n\n```\n\nThat is good <EOS> Das ist gut <EOS>\n | | | | | | | |\n u1 u2 u3 u4 u5 u6 u7 u8\n | | | | | | | |\n v1 v2 v3 v4 v5 v6 v7 v8\n | | | | | | | |\n<BOS> That is good <EOS> Das ist gut\n\n```\n\nLet us denote the first layer's activation in the `i`th position as `vi`.\nSimilarly, let `ui` denote the activation of the second layer in the `i`th\nposition.\n\n\nFor attention-based sequence models such as Transformer decoders, the\nself-attention layer is used to encode contextualized representation of the\nsequence. At a given layer, each position's representation is computed as a\nfunction of the representations of the tokens *before* its position in the\nprevious layer.\n\nReferring to the toy example, when computing `u2` with fully-causal masking, we\ndo not use `v3`. This results in a representation `u2` of the word \"is\" that\ndoes not take into account the word \"good\", which is unnecessarily limiting.\n\nFor Prefix LM, this issue is resolved by having the fully visible masking\npattern for the inputs portion only. For example, when computing `u2`, `v1`,\n`v2`, `v3`, `v4` and `v5` are all visible and taken into account. For the tokens\nin the \"targets\" of the `Task` dataset, we use the causal masking. For example,\nwhen computing `u6`, all `vi` for `i <= 6` are taken into account but not `v7`.\n\n<details>\n <summary>Why is `v5` included in the inputs attention pattern?</summary>\n In the same translation example, we note that when computing `u2`, the\n activation corresponding to the position where \\<EOS\\> token was input (i.e.,\n `v5`) was visible. This doesn't count as \"cheating\" because the model doesn't\n see the next word \"Das\". This can provide additional context in building the\n representation for \"good\". In this case, `u4` has the context that \"good\" is\n the last word in the sentence.\n</details>\n\n`seqio.PrefixLMFeatureConverter` provides a feature `decoder_causal_attention`\nto encode this information. For the above example, we have\n\n\n```\ndecoder_causal_attention = [1, 1, 1, 1, 1, 0, 0, 0]\n```\n\nindicating that the non-causal attention can be applied to the first five\npositions. Note that this feature seems trivial, but for a packed dataset\nthe inputs and targets boundary are more nuanced.\n\n\nA final consideration for the prefix LM is that because we concatenate \"inputs\"\nand \"targets\", which tokens are used for the loss computation is a modeling\ndecision. For example, we can penalize the models only for the \"targets\" tokens\nor we may choose to penalize building the representation for \"inputs\" tokens.\nThis is controlled by `loss_on_targets_only` argument (defaults to `True`) to\n`seqio.PrefixLMFeatureConverter` constructor. In the above example, we would get\n\n```\ndecoder_loss_weights = [0, 0, 0, 0, 1, 1, 1, 1]\n```\n\nThis indicates that the last 4 positions are used for the loss computation.\n\nTo get the dataset with prefix LM features, we can use\n\n```python\ndataset: tf.data.Dataset = seqio.get_dataset(\n mixture_or_task_name=\"wmt_t2t_ende_v003\",\n task_feature_lengths={\"inputs\": 32, \"targets\": 32},\n dataset_split=\"train\",\n shuffle=True,\n feature_converter=seqio.PrefixLMFeatureConverter(\n pack=True,\n loss_on_targets_only=True)\n)\n```\n\nThe resulting features have length 64 because it concatenates inputs and targets\neach with length 32.\n\nThe output dataset has the following model features. Note that the only\nadditional feature is `decoder_causal_attention`.\n\n|Feature name | Explanation |\n|----------------------|---------------------------|\n|`decoder_target_tokens`| Output tokens from the decoder |\n|`decoder_input_tokens` | Input tokens to the decoder |\n|`decoder_loss_weights` | Binary mask to indicate where the loss should be\ntaken |\n|`decoder_positions` | Position index in the sequence before packing|\n|`decoder_segment_ids` | Sequence membership before packing. Two positions with\nthe ` same positive integer mean that they belong to the same sequence before\npacking. |\n|`decoder_causal_attention`| Binary mask denoting which tokens are in the\nnon-causal masking region.|\n\n###### Encoder-only architecture\nLike decoder-only architecture, this one is a single stack, but not\nautoregressive.\n\nOne notable assumption is that the inputs and targets are *aligned*, i.e., they\nhave the same sequence length and `i`th position in the targets correspond to\nthe output representation of the `i`th token in the inputs.\n\nA common model using encoder-only architecture is\n[BERT](https://arxiv.org/abs/1810.04805). We provide `Encoder` feature converter\nclass to support the Masked Language Modeling (MLM) objective from BERT.\n\nWe assume that a unique sentinel such as `[MASK]` token is used to mask some\nfraction of the input text and the task is to recover the original text.\nTherefore, the \"targets\" is naturally defined as the original text whereas\n\"inputs\" are the masked text.\n\nEncoder-only models are often used for classification tasks. In BERT, a special\ntoken `[CLS]` is prepended to the input sequence. The last layer's activation\ncorresponding to this sentinel token is the contextualized representation of the\nsequence. We assume that such \"classification\" sentinel is prepended.\n\nConsider the following example for the MLM task. The input dataset has two\nexamples, which is packed to one example. We assume that `mask_id = 9` and the\n`[CLS]` token has id of 8.\n\n```py\ndataset = [{\"inputs\": [8, 9, 9, 3, 4], \"targets\": [8, 7, 4, 3, 4]},\n {\"inputs\": [8, 3, 9], \"targets\": [8, 3, 6]}]\n\nconverted_dataset = {\n \"encoder_input_tokens\": [8, 9, 9, 3, 4, 1, 8, 3, 9, 1, 0],\n \"encoder_target_tokens\": [8, 7, 4, 3, 4, 1, 8, 3, 6, 1, 0],\n \"encoder_segment_ids\": [1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 0],\n \"encoder_positions\": [0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 0],\n \"encoder_loss_weights\": [0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0],\n}\n```\n\nNote that the packed sequence has `[CLS]` token at the beginning of each\nsequences. Also note that the loss is taken only on the masked position.\n\nTo use the pre-defined `EncoderFeatureConverter`, provide `mask_id` as an\nargument.\n\n```py\ndataset: tf.data.Dataset = seqio.get_dataset(\n mixture_or_task_name=\"some mlm task\",\n task_feature_lengths={\"inputs\": 32, \"targets\": 32},\n dataset_split=\"train\",\n shuffle=True,\n feature_converter=seqio.EncoderFeatureConverter(\n pack=True,\n mask_id=9)\n)\n```\n\nThe resulting dataset object has the following 5 fields\n\n|Feature name | Explanation |\n|----------------------|---------------------------|\n|`encoder_input_tokens` | Input tokens to the encoder |\n|`encoder_positions` | Position index in the sequence before packing|\n|`encoder_segment_ids` | Sequence membership before packing. Two positions with\nthe ` same positive integer mean that they belong to the same sequence before\npacking. |\n|`encoder_target_tokens`| Output tokens from the encoder |\n|`encoder_loss_weights` | Binary mask to indicate where the loss should be taken | :\n\n###### Custom architectures\nFor a custom model architecture, you need to create a subclass of\n`FeatureConverter` and override two methods `_convert_features` and\n`get_model_feature_lengths` to define how task features are mapped to the model\nfeatures, including the length relationships. The existing feature converters\n(e.g., `seqio.EncDecFeatureConverter`) follow the same pattern, which can be a\nuseful starting point.\n\n### Evaluation\n\nThe SeqIO `Evaluator` class provides a way to evaluate models on SeqIO Tasks\nand Mixtures. For an interactive walkthrough of SeqIO evaluation, see the\n[Evaluation Notebook](https://github.com/google/seqio/blob/main/seqio/docs/tutorials.md).\nThe following is a deep-dive into the `Evaluator` class.\n\nAn Evaluator instance can be created by passing a SeqIO Task or\nMixture, and additional eval params like feature converter, split, sequence\nlengths, seed, etc. The Evaluator init calls `get_dataset` for each Task to be\nevaluated with the appropriate params, creating the `task_dataset`, and invokes\nthe model-specific feature converter on the `task_dataset` to create features\nthat can be passed to a model, called `model_dataset`. Both `task_dataset` and\n`model_dataset` are stored in-memory so that the dataset can be reused across\nmultiple evaluations (e.g. on checkpoints from a training run). Both datasets\nare enumerated so that even if the order of examples is changed during model\ninference, the enumeration can be used to match model outputs to examples from\nthe `task_dataset`.\n\nFor Mixtures, each sub-Task is evaluated separately, regardless of mixing\nrates, because in the context of eval benchmarks, Mixtures commonly refer to a\ncollection of Tasks belonging to that benchmark, each of which is evaluated\nseparately, e.g. SuperGLUE mixture.\n\nOnce an `Evaluator` instance is created with a SeqIO Task or Mixture, a model\ncan be evaluated by calling `evaluator.evaluate(...)` and passing a `predict_fn`\nand/or a `predict_with_aux_fn` and/or a `score_fn` to interact with the model.\n`predict_fn` takes the `model_dataset` as input and outputs a `Sequence[(index,\ntoken_ids)]` where `token_ids` is the sequence of token ids generated by the\nmodel for the input example whose index matches `index`. Therefore, even if\n`predict_fn` mixes the order of the examples during prediction, the order can be\ncorrected as long as the correct index for each example is maintained. A common\nexample is the multi-host setup where the evaluation dataset is split amongst\nmultiple hosts that independently make predictions and combine the results\nduring which the ordering can be mixed. `predict_with_aux_fn` is similar to\n`predict_fn`, except that it can also return a dictionary of auxiliary values\nalong with each sequence of `token_ids`, e.g. scores from the generated tokens.\nThe `score_fn` takes the `model_dataset` as input and returns a\n`Sequence[(index, score)]` where `score` is the sequence of log likelihood\nscores for the targets in the dataset. This simple interface allows users to\neasily integrate the SeqIO evaluation flow with popular training frameworks in\nTF and Jax.\n\nCorresponding to the model fns, users can configure three kinds of metric fns in\ntheir Tasks, which are differentiated by their function signature. Metrics\ncomputed on the outputs of `predict_fn` (and `predict_with_aux_fn`) have the\nsignature `targets` and `predictions` (and optionally `aux_values`), while\nmetrics computed on the outputs of `score_fn` have the signature `targets` and\n`scores`. The `Evaluator` takes care of calling the correct model fns and\nmetric fns during evaluation. Here is an example of a metric of each type.\n\n```\ndef sequence_accuracy(targets, predictions):\n seq_acc = 100 * np.mean([p == t for p, t in zip(predictions, targets)])\n return {\"sequence_accuracy\": seq_acc}\n\ndef log_likelihood(targets, scores):\n log_likelihood = np.mean([scipy.special.logsumexp(el) for el in scores])\n return {\"log_likelihood\": log_likelihood}\n```\n\nThere are 4 steps involved in the evaluation using predicted tokens:\n\n+ the `predict_fn` or `predict_with_aux_fn` returns indices and output_tokens:\n `Sequence[Tuple[int, Sequence[int]]]`, potentially with some auxiliary\n values.\n+ output tokens are decoded by `vocab.decode`\n+ postprocessors configured in Tasks are applied to the decoded output. These\n are denoted as predictions.\n+ metric fns configured in Tasks are applied to the predictions and the cached\n targets.\n\nThere are 2 steps involved in the evaluation using scores:\n\n+ the `score_fn` returns indices and scores: `Sequence[Tuple[int,\n Sequence[float]]]`\n+ metric fns configured in Tasks is applied to the scores and the cached\n targets.\n\nTraining codebases like T5X provide integration with SeqIO evaluation to allow \nevaluating checkpoints on SeqIO Tasks and Mixtures. See\n[T5X Eval](https://github.com/google-research/t5x/blob/main/docs/usage/eval.md)\nfor instructions.\n\n## Differences from `t5.data`\n\nThe original `t5` library introduced and implemented the `t5.data.Task`\nabstraction for specifying preprocessing and evaluation metrics for text-to-text\ntasks. When creating a task, users specify a source dataset of raw text, some\npreprocessing steps, a vocabulary for tokenization, and evaluation metrics. The\nfully-specified Task can then be used to pre-train or fine-tune a\nencoder-decoder transformer model. However, the design included many baked-in\nassumptions about the types of tasks users could specify.\n\nSeqIO removes some of the constraints of this abstraction:\n\n* Inputs and outputs are no longer required to be strings (e.g., it may be\n images or audio).\n* Architectures other than the original encoder-decoder are supported (e.g.,\n decoder-only language models like GPT or encoder-only models like BERT).\n* Users can control at which stage of the pipeline offline caching occurs.\n* Users can control when and where EOS tokens are added.\n\nFurthermore, SeqIO has been made more modular with respect to the Mesh\nTensorFlow Transformer. This allows it to be used with other model\nimplementations with more consistency and much less code duplication.\n\n## Advanced Postprocessing `Task`\n\n### TriviaQA (Closed-book, open-domain version)\nThis version of TriviaQA was introduced in [Roberts et al.\n2020](https://arxiv.org/abs/2002.08910).\n\n\n```py\nseqio.TaskRegistry.add(\n \"trivia_qa_open\",\n source=seqio.TfdsDataSource(\n tfds_name=\"trivia_qa/unfiltered.nocontext:1.1.0\",\n splits={\n \"train\": \"train[:90%]\",\n \"validation\": \"train[90%:]\",\n \"test\": \"validation\"\n }),\n preprocessors=[\n tqa_open_preprocessor,\n seqio.preprocessors.tokenize,\n seqio.preprocessors.append_eos,\n ],\n output_features={\n \"inputs\": seqio.Feature(\n seqio.SentencePieceVocabulary(\"/path/to/inputs/vocab\"),\n add_eos=False, dtype=tf.int32\n ),\n \"targets\": seqio.Feature(\n seqio.SentencePieceVocabulary(\"/path/to/targets/vocab\"),\n add_eos=True, dtype=tf.int32\n ),\n },\n postprocess_fn=tqa_open_postprocessor,\n metric_fns=[tqa_metric])\n```\n\nIn this example, we are using the `TfdsDataSource`. We specify the name of the\nTriviaQA dataset in TFDS\n([`\"trivia_qa\"`](https://www.tensorflow.org/datasets/catalog/trivia_qa)), the\nspecific config that excludes the context for the open domain setting\n(`\"unfiltered.nocontext\"`), and the version number (`\"1.1.0\"`). We also override\nthe default splits to match what is commonly used for the open domain setting.\nSpecifically, we set our \"test\" split to be the TFDS \"validation\" split, and\ncreate a small pseudo-\"validation\" set by taking examples out of the TFDS\n\"train\" split.\n\nThe preprocessor `tqa_open_preprocessor` is defined as follows.\n\n```py\ndef tqa_open_preprocessor(\n dataset: tf.data.Dataset,\n prefix:str = \"trivia_qa question: \"\n ) -> tf.data.Dataset:\n \"\"\"Convert TriviaQA dataset to open domain qa examples.\n\n The function takes the trivia_qa TFDS dataset and emits examples of the\n form:\n {\n \"inputs\": \"trivia_qa question: What are the names of the Olsen Twins?\"\n \"targets\": \"Mary-Kate and Ashley\",\n \"answers\": [\"Mary-Kate and Ashley\", \"Ashley and Mary-Kate\"]\n }\n\n Args:\n dataset: a tf.data.Dataset to process.\n prefix: str, prefix to prepend to the inputs.\n\n Returns:\n a tf.data.Dataset\n \"\"\"\n def tqa_map(ex):\n \"\"\"Map TriviaQA example to text-to-text example.\"\"\"\n return {\n \"inputs\": prefix + ex[\"question\"],\n \"targets\": ex[\"answer\"][\"value\"],\n \"answers\": ex[\"answer\"][\"aliases\"],\n }\n\n return dataset.map(tqa_map, num_parallel_calls=tf.data.experimental.AUTOTUNE)\n```\n\nOr with the `seqio.map_overdataset` decorator, we have\n\n```py\ndef tqa_open_preprocessor(\n dataset: tf.data.Dataset,\n prefix: str = \"trivia_qa question: \"\n) -> tf.data.Dataset:\n\n @seqio.map_over_dataset\n def tqa_map(ex: Mapping[str, tf.Tensor]) -> Mapping[str, tf.Tensor]:\n \"\"\"Map TriviaQA example to text-to-text example.\"\"\"\n return {\n \"inputs\": prefix + ex[\"question\"],\n \"targets\": ex[\"answer\"][\"value\"],\n \"answers\": ex[\"answer\"][\"aliases\"],\n }\n\nreturn tqa_map(dataset)\n```\n\nHere we made a thin wrapper to emphasize that the function decorated by\n`seqio.map_over_dataset` takes in an instance of `tf.data.Dataset`. In practice,\nthis wrapper is not necessary.\n\n\nThe postprocessor for this example is `tqa_open_postprocessor`, which is defined\nas follows:\n\n```py\ndef tqa_open_postprocessor(output_or_target, example=None, is_target=False):\n \"\"\"Returns output as answer, or all answers if the full example is provided.\"\"\"\n if is_target:\n return [a.decode(\"utf-8\") for a in example[\"answers\"]]\n else:\n return output_or_target.decode(\"utf-8\")\n```\n\nWhen processing the target, we ignore `output_or_target` (equivalent to\n`example[\"targets\"]`) since it is just selecting a single answer in\n`trivia_qa_open`. Instead, we extract the full list of answers from the example\nand convert them from bytes to text. When handling the model output, we simply\nconvert it to text from detokenized bytes.\n\nThe metric function `tqa_metric` is defined as:\n\n```py\ndef tqa_metric(\n targets: Sequence[Sequence[str]],\n predictions: Sequence[str]\n) -> Mapping[str, seqio.metrics.MetricValueValue]:\n \"\"\"Computes official TriviaQA metrics.\n\n Args:\n targets: list of lists of strings\n predictions: list of strings\n\n Returns:\n dict with score_key: squad score across all targets and predictions\n \"\"\"\n\n if len(targets) != len(predictions):\n raise ValueError(\"Number of targets and predictions must match.\")\n\n def _normalize_answer(text):\n \"\"\"Lower text and remove punctuation, articles and extra whitespace.\"\"\"\n # Remove articles.\n text = re.sub(r\"\\b(a|an|the)\\b\", \" \", s)\n # Remove punctuation.\n for punc in string.punctuation:\n text = text.replace(punc, '')\n # Normalize white space\n text = \" \".join(s.split())\n return text\n\n # Normalize answers before comparing.\n targets = [[_normalize_answer(t) for t in u] for u in targets]\n predictions = [_normalize_answer(p) for p in predictions]\n\n em = np.mean([\n max(pred == gt for gt in ground_truths)\n for pred, ground_truths in zip(predictions, targets)\n ])\n return {\n \"exact_match\": seqio.metrics.Scalar(em),\n }\n```\n\n## Citing SeqIO\nPlease use the following bibtex entry to cite SeqIO.\n\n```\n@article{roberts2022t5x,\n url = {https://arxiv.org/abs/2203.17189},\n author = {Roberts, Adam and Chung, Hyung Won and Levskaya, Anselm and Mishra,\n Gaurav and Bradbury, James and Andor, Daniel and Narang, Sharan and Lester,\n Brian and Gaffney, Colin and Mohiuddin, Afroz and Hawthorne, Curtis and\n Lewkowycz, Aitor and Salcianu, Alex and van Zee, Marc and Austin, Jacob and\n Goodman, Sebastian and Soares, Livio Baldini and Hu, Haitang and\n Tsvyashchenko, Sasha and Chowdhery, Aakanksha and Bastings, Jasmijn and\n Bulian, Jannis and Garcia, Xavier and Ni, Jianmo and Chen, Andrew and Kenealy,\n Kathleen and Clark, Jonathan H. and Lee, Stephan and Garrette, Dan and\n Lee-Thorp, James and Raffel, Colin and Shazeer, Noam and Ritter, Marvin and\n Bosma, Maarten and Passos, Alexandre and Maitin-Shepard, Jeremy and Fiedel,\n Noah and Omernick, Mark and Saeta, Brennan and Sepassi, Ryan and Spiridonov,\n Alexander and Newlan, Joshua and Gesmundo, Andrea},\n title = {Scaling Up Models and Data with $\\texttt{t5x}$ and $\\texttt{seqio}$},\n journal={arXiv preprint arXiv:2203.17189},\n year = {2022},\n}\n```\n\n",
"bugtrack_url": null,
"license": "Apache 2.0",
"summary": "SeqIO: Task-based datasets, preprocessing, and evaluation for sequence models.",
"version": "0.0.18.dev20241219",
"project_urls": {
"Homepage": "https://github.com/google/seqio/tree/nightly"
},
"split_keywords": [
"sequence",
"preprocessing",
"nlp",
"machinelearning"
],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "1f9a849b192bf3feaa062764160a231ab20ac8717f5abd53b9db0f2b32eb6fdc",
"md5": "fb098553665747e4e64383e8c470d773",
"sha256": "8c87a9d312008da077d0243c0ef2d31b67e1e6fe233fcc0494d6963e56e600f4"
},
"downloads": -1,
"filename": "seqio_nightly-0.0.18.dev20241219-py3-none-any.whl",
"has_sig": false,
"md5_digest": "fb098553665747e4e64383e8c470d773",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": null,
"size": 357447,
"upload_time": "2024-12-19T07:01:56",
"upload_time_iso_8601": "2024-12-19T07:01:56.736810Z",
"url": "https://files.pythonhosted.org/packages/1f/9a/849b192bf3feaa062764160a231ab20ac8717f5abd53b9db0f2b32eb6fdc/seqio_nightly-0.0.18.dev20241219-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": "",
"digests": {
"blake2b_256": "6d1f967e04296da645633e33b0098d454d8b026745850438624e0623a5ecf529",
"md5": "b7eb1ac8f1ffe23a91598ee519989921",
"sha256": "202950bc5b1112156923586b978f3f29bf26b2cc0ce1ccc39926033081da2ee5"
},
"downloads": -1,
"filename": "seqio_nightly-0.0.18.dev20241219.tar.gz",
"has_sig": false,
"md5_digest": "b7eb1ac8f1ffe23a91598ee519989921",
"packagetype": "sdist",
"python_version": "source",
"requires_python": null,
"size": 363904,
"upload_time": "2024-12-19T07:01:59",
"upload_time_iso_8601": "2024-12-19T07:01:59.548208Z",
"url": "https://files.pythonhosted.org/packages/6d/1f/967e04296da645633e33b0098d454d8b026745850438624e0623a5ecf529/seqio_nightly-0.0.18.dev20241219.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2024-12-19 07:01:59",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "google",
"github_project": "seqio",
"travis_ci": false,
"coveralls": false,
"github_actions": true,
"lcname": "seqio-nightly"
}