ViT#
- class deeplay.models.visiontransformer.ViT(*args, **kwargs)#
Bases:
DeeplayModuleVision Transformer (ViT) model.
Parameters#
- image_sizeint
Size of the input image. The image is assumed to be square.
- patch_sizeint
Size of the patch. The image is divided into patches of size patch_size x patch_size pixels.
- in_channelsint or None
Number of input channels. If None, the input shape is inferred from the first forward pass.
- hidden_featuresSequence[int]
Number of hidden features for each layer of the transformer encoder.
- out_featuresint
Number of output features.
- num_headsint
Number of attention heads in multihead attention layers of the transformer encoder.
- out_activation: template-like or None
Specification for the output activation of the model (Default: nn.Identity).
Configurables#
in_channels (int): Number of input features. If None, the input shape is inferred from the first forward pass.
hidden_features (list[int]): Number of hidden units in each transformer layer.
out_features (int): Number of output features.
num_heads (int): Number of attention heads in multihead attention layers.
patch_embedder (template-like): Specification for the patch embedder (Default: dl.Patchify).
positional_embedder (template-like): Specification for the positional embedder (Default: dl.PositionalEmbedding).
transformer_encoder (template-like): Specification for the transformer encoder layer (Default: dl.TransformerEncoderLayer).
dense_top (template-like): Specification for the dense top layer (Default: dl.MultiLayerPerceptron).
Constraints#
input_shape: (batch_size, in_channels, image_size, image_size)
output_shape: (batch_size, out_features)
Examples#
>>> vit = ViT( >>> image_size=32, >>> patch_size=4, >>> hidden_features=[384,] * 7, >>> out_channels=10, >>> num_heads=12, >>> ).create() >>> # Testing on a batch of 2 >>> x = torch.randn(2, 3, 32, 32) >>> vit(x).shape torch.Size([2, 10])
Return Values#
The forward method returns the processed tensor.
Attributes Summary
Return the hidden layers of the network.
Return the input layer of the network.
Return the last layer of the network.
Methods Summary
forward(x)Define the computation performed at every call.
Attributes Documentation
Return the hidden layers of the network. Equivalent to .transformer_encoder.
- input#
Return the input layer of the network. Equivalent to .patch_embedder.
- output#
Return the last layer of the network. Equivalent to .dense_top.
Methods Documentation
- forward(x)#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.