deeplay.trainer Module#
Extension of the Lightining Trainer.
This module extends the PyTorch Lightning Trainer class to include additional functionality for managing callbacks and progress bars. It provides a custom _DeeplayCallbackConnector to configure default callbacks and a Trainer subclass that offers methods to enable or disable specific callbacks like progress bars and logging.
Key Features#
Enhanced Callback Connector
The _DeeplayCallbackConnector class extends the Lightning _CallbackConnector to add default callbacks such as TQDMProgressBar and LogHistory if they are not explicitly provided.
Custom Trainer Class
The Trainer class extends the Lightning Trainer to provide convenience methods for enabling and disabling progress bars (tqdm or rich) and log history callbacks.
Module Structure#
Classes:
_DeeplayCallbackConnector: Connector to configure default callbacks.
Trainer: Extended trainer with additional methods for managing callbacks.
Examples#
This shows how to use the extended trainer to enable different progress bars:
```python import deeplay as dl import torch
# Create training dataset. num_samples = 10 ** 2 data = torch.randn(num_samples, 2) labels = (data.sum(dim=1) > 0).long()
dataset = torch.utils.data.TensorDataset(data, labels) dataloader = dl.DataLoader(dataset, batch_size=16, shuffle=True)
# Create neural network and classifier application. mlp = dl.MediumMLP(in_features=2, out_features=2) classifier = dl.Classifier(mlp, optimizer=dl.Adam(), num_classes=2).build()
# Train neural network with progress bar disabled. trainer = dl.Trainer(max_epochs=100) trainer.disable_progress_bar() trainer.fit(classifier, dataloader)
# Return and plot training history. history = trainer.history history.plot()
# Retrain with TQDM progress bar enabled. trainer.tqdm_progress_bar() trainer.fit(classifier, dataloader)
# Retrain with rich progress bar enabled. trainer.rich_progress_bar() trainer.fit(classifier, dataloader)
Classes#
|
A keras-like history callback for lightning. |
|
The base class for progress bars in Lightning. |
|
A progress bar for displaying training progress with Rich. |
|
A progress bar for displaying training progress with TQDM. |
|
|
|
alias of |