mighty.utils.data.loader.DataLoader

class mighty.utils.data.loader.DataLoader(dataset_cls, transform=ToTensor(), loader_cls=<class 'torch.utils.data.dataloader.DataLoader'>, batch_size=256, eval_size=None, num_workers=0)[source]

Data loader with simple API.

Parameters:
dataset_clstype

Dataset class.

transformobject, optional

Torchvision transform object that implements __call__ method. Default: ToTensor()

loader_clstype, optional

A batches loader class. Default: torch.utils.data.DataLoader

batch_sizeint, optional

Batch size. Default: 256

eval_sizeint or None, optional

Evaluation size in the minimum number of samples. If None, the length of the dataset is used. Default: None

num_workersint, optional

The number of workers passed to loader_cls. Default: 0

Methods

__init__(dataset_cls[, transform, ...])

eval([description])

Returns a generator over train samples with no shuffling.

get([train])

Returns a train or test loader.

sample()

Returns the first batch from DataLoader.eval().

state_dict()

eval(description=None)[source]

Returns a generator over train samples with no shuffling.

The generator exits after producing at least eval_size samples.

Parameters:
descriptionstr or None, optional

Message description. Default: None

Yields:
batchtorch.Tensor

Eval batch, same as in train.

get(train=True)[source]

Returns a train or test loader.

Parameters:
trainbool, optional

Train (True) or test (False) fold. Default: true

Returns:
loadertorch.utils.data.DataLoader

A data loader with batches.

sample()[source]

Returns the first batch from DataLoader.eval().

No shuffling/sampling is performed.

Returns:
torch.Tensor or tuple of torch.Tensor

A tensor or a batch of tensors.