Source code for fkat.data.datasets.map
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import TypeVar
from collections.abc import Callable, Iterator
from fkat.data.datasets import SizedDataset
from torch.utils.data import IterableDataset
T_in = TypeVar("T_in", contravariant=True)
T_from = TypeVar("T_from", covariant=True)
T_to = TypeVar("T_to", covariant=True)
[docs]class MapDataset(SizedDataset[T_in, T_to]):
"""A :class:`Dataset` that transforms the samples from another :class:`Dataset` using a function."""
def __init__(
self,
dataset: SizedDataset[T_in, T_from],
fn: Callable[[T_from], T_to],
) -> None:
"""Create a :class:`Dataset` that maps samples of another :class:`Dataset` using a function.
Args:
dataset (SizedDataset): Source :class:`Dataset`.
fn (Callable[[T_from], T_to]): Sample transformation function.
Returns:
None
"""
self.dataset = dataset
self.fn = fn
def __len__(self) -> int:
"""Get :class:`Dataset` size.
Returns:
int: :class:`Dataset` size.
"""
return len(self.dataset)
def __getitems__(self, idxs: list[T_in]) -> list[T_to]:
"""Get a batch of samples at the specified indices.
Args:
idxs (List[T_in]): Samples' indices.
Returns:
List[T_to]: A batch of samples.
"""
if getitems := getattr(self.dataset, "__getitems__", None):
batch = getitems(idxs)
else:
batch = [self.dataset[idx] for idx in idxs]
samples = [self.fn(sample) for sample in batch]
return samples
def __getitem__(self, idx: T_in) -> T_to:
"""Get a sample at the specified index.
Args:
idx (T_in): Sample index.
Returns:
T_to: A sample.
"""
sample = self.fn(self.dataset[idx])
return sample
[docs]class IterableMapDataset(IterableDataset[T_to]):
"""An :class:`IterableDataset` that transforms the samples from another
:class:`IterableDataset` using a function."""
def __init__(
self,
dataset: IterableDataset[T_from],
fn: Callable[[T_from], T_to],
) -> None:
self.dataset = dataset
self.fn = fn
def __iter__(self) -> Iterator[T_to]:
"""Get :class:`IterableDataset` iterator.
Yields:
T_to: A sample.
"""
for sample in iter(self.dataset):
yield self.fn(sample)