Source code for fkat.data.datasets.dict
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Any
from typing_extensions import override
from fkat.data.datasets import SizedDataset
from fkat.utils.config import to_primitive_container
[docs]class DictDataset(SizedDataset[tuple[str, Any], dict[str, Any]]):
""":class:`Dataset` that can get samples from one of the :class:`Dataset` using a mapping."""
def __init__(
self,
datasets: dict[str, SizedDataset[Any, dict[str, Any]]],
key: str = "dataset",
) -> None:
"""Create a :class:`Dataset` that can get samples from one of the :class:`Dataset` using a mapping.
Args:
datasets (Dict[str, SizedDataset[Any, Dict[str, Any]]]): A mapping from labels to :class:`Dataset`\\s.
key (str): The name of the field to reflect the :class:`Dataset` the samples were provided from.
Defaults to "dataset".
Returns:
None
"""
self.datasets = to_primitive_container(datasets)
self.len = sum(len(dataset) for dataset in datasets.values())
self.key = key
@override
def __len__(self) -> int:
"""Get :class:`Dataset` size.
Returns:
int: :class:`Dataset` size.
"""
return self.len
def _wrap(self, name: str, item: dict[str, Any]) -> dict[str, Any]:
if not isinstance(item, dict) or self.key in item:
raise RuntimeError(f"Datasets must return a dict without {self.key} key")
item[self.key] = name
return item
def __getitems__(self, name_and_idxs: tuple[str, list[Any]]) -> list[dict[str, Any]]:
"""Get a batch of samples from the target :class:`Dataset` at the specified indices.
Args:
name_and_idxs (Tuple[str, List[Any]]): Samples' :class:`Dataset` and indices.
Returns:
List[Dict[str, Any]]: A batch of samples.
"""
name, idxs = name_and_idxs
if getitems := getattr(self.datasets[name], "__getitems__", None):
batch = getitems(idxs)
else:
batch = [self.datasets[name][idx] for idx in idxs]
for b in batch:
self._wrap(name, b)
return batch
@override
def __getitem__(self, idx: tuple[str, Any]) -> dict[str, Any]:
"""Get a sample from the target :class:`Dataset` at the specified index.
Args:
idx (Tuple[str, Any]): Sample :class:`Dataset` and index.
Returns:
Dict[str, Any]: A sample.
"""
name, idx_ = idx
sample = self.datasets[name][idx_]
return self._wrap(name, sample)