"""
Data loader with simple API
---------------------------
.. autosummary::
:toctree: toctree/utils/
DataLoader
"""
import math
import torch
import torch.utils.data
from torchvision.transforms import ToTensor
from tqdm import tqdm
from mighty.utils.constants import DATA_DIR, BATCH_SIZE
from mighty.utils.data.normalize import get_normalize_inverse, get_normalize
[docs]
class DataLoader:
"""
Data loader with simple API.
Parameters
----------
dataset_cls : type
Dataset class.
transform : object, optional
Torchvision transform object that implements ``__call__`` method.
Default: ToTensor()
loader_cls : type, optional
A batches loader class.
Default: torch.utils.data.DataLoader
batch_size : int, optional
Batch size. Default: 256
eval_size : int or None, optional
Evaluation size in the minimum number of samples. If None, the length
of the dataset is used.
Default: None
num_workers : int, optional
The number of workers passed to `loader_cls`.
Default: 0
"""
def __init__(self, dataset_cls, transform=ToTensor(),
loader_cls=torch.utils.data.DataLoader,
batch_size=BATCH_SIZE, eval_size=None, num_workers=0):
self.dataset_cls = dataset_cls
self.loader_cls = loader_cls
self.transform = transform
if eval_size is None:
eval_size = float('inf')
dataset = self.dataset_cls(DATA_DIR, train=True, download=True)
self.eval_size = min(eval_size, len(dataset))
self.batch_size = min(batch_size, len(dataset))
self.num_workers = num_workers
self.normalize_inverse = get_normalize_inverse(self.transform)
# hack to check if the dataset is a set of (signal, label) pairs
self.has_labels = False
sample = self.sample()
if isinstance(sample, (tuple, list)) and len(sample) > 1:
labels = sample[1]
self.has_labels = isinstance(labels, torch.Tensor) \
and labels.dtype is torch.long
[docs]
def get(self, train=True):
"""
Returns a train or test loader.
Parameters
----------
train : bool, optional
Train (True) or test (False) fold.
Default: true
Returns
-------
loader : torch.utils.data.DataLoader
A data loader with batches.
"""
dataset = self.dataset_cls(DATA_DIR, train=train, download=True,
transform=self.transform)
shuffle = train
if isinstance(dataset, torch.utils.data.IterableDataset):
shuffle = False
loader = self.loader_cls(dataset,
batch_size=self.batch_size,
shuffle=shuffle,
num_workers=self.num_workers)
return loader
[docs]
def eval(self, description=None):
"""
Returns a generator over train samples with no shuffling.
The generator exits after producing at least :attr:`eval_size` samples.
Parameters
----------
description : str or None, optional
Message description.
Default: None
Yields
------
batch : torch.Tensor
Eval batch, same as in train.
"""
dataset = self.dataset_cls(DATA_DIR, train=True, download=True,
transform=self.transform)
eval_loader = self.loader_cls(dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers)
n_batches = math.ceil(self.eval_size / self.batch_size)
for batch_id, batch in tqdm(
enumerate(iter(eval_loader)),
desc=description,
total=n_batches,
disable=not description,
leave=False):
if batch_id >= n_batches:
break
yield batch
[docs]
def sample(self):
"""
Returns the first batch from :meth:`DataLoader.eval`.
No shuffling/sampling is performed.
Returns
-------
torch.Tensor or tuple of torch.Tensor
A tensor or a batch of tensors.
"""
return next(iter(self.eval()))
@staticmethod
def _shorten(str_repr):
str_repr = str(str_repr)
return str_repr if len(str_repr) < 50 else f"{str_repr[:50]}..."
def __repr__(self):
return f"{self.__class__.__name__}({self.dataset_cls.__name__}, " \
f"has_labels={self.has_labels}, " \
f"transform={self._shorten(self.transform)}, batch_size={self.batch_size}, " \
f"eval_size={self.eval_size}, " \
f"num_workers={self.num_workers}), normalize_inverse=" \
f"{self._shorten(self.normalize_inverse)})"
def state_dict(self):
normalize = get_normalize(self.transform)
if normalize is None:
return None
return {
"mean": normalize.mean,
"std": normalize.std
}