Source code for mighty.monitor.batch_timer

"""
Timer and schedulers
--------------------

.. autosummary::
    :toctree: toctree/monitor

    BatchTimer
    ScheduleStep
    ScheduleExp
"""


import math
from abc import ABC, abstractmethod
from functools import wraps
from typing import Callable


__all__ = [
    "BatchTimer",
    "timer",
    "Schedule",
    "ScheduleStep",
    "ScheduleExp"
]


[docs] class BatchTimer: """ A global batch timer. """ def __init__(self): self.batch_id = 0 self.batches_in_epoch = 1 # will be set later on self.n_epochs = None # user-defined no. of epochs to run
[docs] def init(self, batches_in_epoch): """ Initialize the timer by providing the epoch length. Parameters ---------- batches_in_epoch : int The number of batches in an epoch. """ self.batches_in_epoch = batches_in_epoch
@property def epoch(self): """ Returns ------- int Epoch id. """ return int(self.epoch_progress())
[docs] def epoch_progress(self): """ Returns ------- float Epoch progress. """ return self.batch_id / self.batches_in_epoch
[docs] def is_epoch_finished(self): """ Returns ------- bool Whether it's the end of an epoch (True) or in the middle of training (False). """ return self.batch_id > 0 and self.batch_id % self.batches_in_epoch == 0
[docs] def tick(self): """ Increments the number of elapsed batches by 1. """ self.batch_id += 1
[docs] def set_epoch(self, epoch): """ Manually set the epoch. Parameters ---------- epoch : int A new epoch. """ self.batch_id = self.batches_in_epoch * epoch
timer = BatchTimer() class Schedule(ABC): """ Schedule the next update program. """ def __init__(self): self.last_batch_update = -1 @abstractmethod def next_batch_update(self): """ Returns ------- int The next batch id when an update is needed. """ return 0 def __call__(self, func: Callable): @wraps(func) def wrapped(*args, **kwargs): if self.last_batch_update == -1: # restore the last trained batch self.last_batch_update = timer.batch_id - 1 if timer.batch_id >= self.next_batch_update(): self.last_batch_update = timer.batch_id func(*args, **kwargs) return wrapped
[docs] class ScheduleStep(Schedule): """ Performs an update each ``epoch_step * timer.epoch_size + batch_step`` batches. Parameters ---------- epoch_step : int, optional Each epoch step. Default: 1 batch_step : int, optional Each batch step. Default: 0 """ def __init__(self, epoch_step=1, batch_step=0): super().__init__() self.epoch_step = epoch_step self.batch_step = batch_step
[docs] def next_batch_update(self): # timer.batches_in_epoch is updated in run-time dt = timer.batches_in_epoch * self.epoch_step + self.batch_step return self.last_batch_update + dt
[docs] class ScheduleExp(Schedule): """ Schedule updates at batches that are powers of two: 1, 2, 4, 8, 16, ... Handy for the first epoch. """
[docs] def next_batch_update(self): if self.last_batch_update > 0: next_power = math.floor(math.log2(self.last_batch_update)) + 1 else: next_power = 0 return 2 ** next_power