deeplay.trainer Module#

Extension of the Lightining Trainer.

This module extends the PyTorch Lightning Trainer class to include additional functionality for managing callbacks and progress bars. It provides a custom _DeeplayCallbackConnector to configure default callbacks and a Trainer subclass that offers methods to enable or disable specific callbacks like progress bars and logging.

Key Features#

  • Enhanced Callback Connector

    The _DeeplayCallbackConnector class extends the Lightning _CallbackConnector to add default callbacks such as TQDMProgressBar and LogHistory if they are not explicitly provided.

  • Custom Trainer Class

    The Trainer class extends the Lightning Trainer to provide convenience methods for enabling and disabling progress bars (tqdm or rich) and log history callbacks.

Module Structure#

Classes:

  • _DeeplayCallbackConnector: Connector to configure default callbacks.

  • Trainer: Extended trainer with additional methods for managing callbacks.

Examples#

This shows how to use the extended trainer to enable different progress bars:

```python import deeplay as dl import torch

# Create training dataset. num_samples = 10 ** 2 data = torch.randn(num_samples, 2) labels = (data.sum(dim=1) > 0).long()

dataset = torch.utils.data.TensorDataset(data, labels) dataloader = dl.DataLoader(dataset, batch_size=16, shuffle=True)

# Create neural network and classifier application. mlp = dl.MediumMLP(in_features=2, out_features=2) classifier = dl.Classifier(mlp, optimizer=dl.Adam(), num_classes=2).build()

# Train neural network with progress bar disabled. trainer = dl.Trainer(max_epochs=100) trainer.disable_progress_bar() trainer.fit(classifier, dataloader)

# Return and plot training history. history = trainer.history history.plot()

# Retrain with TQDM progress bar enabled. trainer.tqdm_progress_bar() trainer.fit(classifier, dataloader)

# Retrain with rich progress bar enabled. trainer.rich_progress_bar() trainer.fit(classifier, dataloader)

```

Classes#

LogHistory()

A keras-like history callback for lightning.

ProgressBar()

The base class for progress bars in Lightning.

RichProgressBar([refresh_rate, leave, ...])

A progress bar for displaying training progress with Rich.

TQDMProgressBar([refresh_rate])

A progress bar for displaying training progress with TQDM.

Trainer(*[, accelerator, strategy, devices, ...])

pl_Trainer

alias of Trainer