ViT#

class deeplay.models.visiontransformer.ViT(*args, **kwargs)#

Bases: DeeplayModule

Vision Transformer (ViT) model.

Parameters#

image_sizeint

Size of the input image. The image is assumed to be square.

patch_sizeint

Size of the patch. The image is divided into patches of size patch_size x patch_size pixels.

in_channelsint or None

Number of input channels. If None, the input shape is inferred from the first forward pass.

hidden_featuresSequence[int]

Number of hidden features for each layer of the transformer encoder.

out_featuresint

Number of output features.

num_headsint

Number of attention heads in multihead attention layers of the transformer encoder.

out_activation: template-like or None

Specification for the output activation of the model (Default: nn.Identity).

Configurables#

  • in_channels (int): Number of input features. If None, the input shape is inferred from the first forward pass.

  • hidden_features (list[int]): Number of hidden units in each transformer layer.

  • out_features (int): Number of output features.

  • num_heads (int): Number of attention heads in multihead attention layers.

  • patch_embedder (template-like): Specification for the patch embedder (Default: dl.Patchify).

  • positional_embedder (template-like): Specification for the positional embedder (Default: dl.PositionalEmbedding).

  • transformer_encoder (template-like): Specification for the transformer encoder layer (Default: dl.TransformerEncoderLayer).

  • dense_top (template-like): Specification for the dense top layer (Default: dl.MultiLayerPerceptron).

Constraints#

  • input_shape: (batch_size, in_channels, image_size, image_size)

  • output_shape: (batch_size, out_features)

Examples#

>>> vit = ViT(
>>>       image_size=32,
>>>       patch_size=4,
>>>       hidden_features=[384,] * 7,
>>>       out_channels=10,
>>>       num_heads=12,
>>> ).create()
>>> # Testing on a batch of 2
>>> x = torch.randn(2, 3, 32, 32)
>>> vit(x).shape
torch.Size([2, 10])

Return Values#

The forward method returns the processed tensor.

Attributes Summary

hidden

Return the hidden layers of the network.

input

Return the input layer of the network.

output

Return the last layer of the network.

Methods Summary

forward(x)

Define the computation performed at every call.

Attributes Documentation

hidden#

Return the hidden layers of the network. Equivalent to .transformer_encoder.

input#

Return the input layer of the network. Equivalent to .patch_embedder.

output#

Return the last layer of the network. Equivalent to .dense_top.

Methods Documentation

forward(x)#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.