LogHistory#

class deeplay.callbacks.history.LogHistory#

Bases: Callback

A keras-like history callback for lightning. Keeps track of metrics and losses during training and validation.

Example: >>> history = LogHistory() >>> trainer = dl.Trainer(callbacks=[history]) >>> trainer.fit(model, train_dataloader, val_dataloader) >>> history.history {“train_loss_epoch”: {“value”: [0.1, 0.2, 0.3], “epoch”: [0, 1, 2], “step”: [0, 100, 200]}}

Attributes Summary

Methods Summary

on_train_batch_end(trainer, *args, **kwargs)

Called when the train batch ends.

on_train_epoch_end(trainer, *args, **kwargs)

Called when the train epoch ends.

on_validation_epoch_end(trainer, *args, **kwargs)

Called when the val epoch ends.

plot(*args[, yscale])

Plot the history of the metrics and losses.

Attributes Documentation

history#
step_history#

Methods Documentation

on_train_batch_end(trainer, *args, **kwargs) None#

Called when the train batch ends.

Note:

The value outputs["loss"] here will be the normalized value w.r.t accumulate_grad_batches of the loss returned from training_step.

on_train_epoch_end(trainer, *args, **kwargs) None#

Called when the train epoch ends.

To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the lightning.pytorch.core.LightningModule and access them in this hook:

class MyLightningModule(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.training_step_outputs = []

    def training_step(self):
        loss = ...
        self.training_step_outputs.append(loss)
        return loss


class MyCallback(L.Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        # do something with all training_step outputs, for example:
        epoch_mean = torch.stack(pl_module.training_step_outputs).mean()
        pl_module.log("training_epoch_mean", epoch_mean)
        # free up the memory
        pl_module.training_step_outputs.clear()
on_validation_epoch_end(trainer, *args, **kwargs) None#

Called when the val epoch ends.

plot(*args, yscale='log', **kwargs)#

Plot the history of the metrics and losses.