DeiT

class cv.backbones.DeiT.model.DeiT(**kwargs)[source]

Bases: ViT

Data-efficient Image Transformer (DeiT) model class from paper.

This class implements the DeiT architecture, which is designed to efficiently train Vision Transformers with the additional feature of knowledge distillation. DeiT uses a teacher-student framework where the student model (DeiT) is trained to mimic the output of a pre-trained teacher model.

Parameters:
  • num_classes (-) – Number of output classes for classification.

  • d_model (-) – Dimensionality of the model’s hidden representations.

  • image_size (-) – Size of the input image (should be divisible by patch_size).

  • patch_size (-) – Size of the patches used in the image processing.

  • classifier_mlp_d (-) – Dimensionality of the intermediate MLP in the classification head.

  • encoder_mlp_d (-) – Dimension of the feed-forward layers in the Transformer encoder.

  • encoder_num_heads (-) – Number of attention heads in the Transformer encoder.

  • num_encoder_blocks (-) – Number of blocks in the Transformer encoder.

  • dropout (-) – Dropout probability applied after the linear projection and within the MLP layers.

  • encoder_dropout (-) – Dropout rate applied within the transformer encoder blocks.

  • encoder_attention_dropout (-) – Dropout probability applied to the attention layers.

  • patchify_technique (-) – Method used to divide the input image into patches. Options are “linear” for unfolding and “convolutional” for using convolution.

  • stochastic_depth (-) – Whether to use stochastic depth regularization.

  • stochastic_depth_mp (-) – Maximum probability for stochastic depth, controlling the likelihood of dropping layers during training.

  • layer_scale (-) – Scaling factor for LayerScale initialization. If None, LayerScale is disabled.

  • return_logits_type (-) – Type of logits to return. Options are “classification”, “distillation”, or “fusion”.

  • teacher_model_name (-) – Name of the pre-trained teacher model to use for distillation.

  • in_channels (-) – Number of input channels in the image, typically 3 for RGB.

Note

Make sure the weights for the teacher model exists.

Example

>>> model = DeiT(num_classes=1000, d_model=768, image_size=224, patch_size=16, classifier_mlp_d=2048, encoder_mlp_d=3072, encoder_num_heads=12, num_encoder_blocks=12, dropout=0.1, encoder_dropout=0.0, encoder_attention_dropout=0.0, patchify_technique="linear", stochastic_depth=False, stochastic_depth_mp=0.0, layer_scale=None, return_logits_type="fusion", teacher_model_name="ConvNeXt", in_channels=3)
loadTeacherModel()[source]

Loads and prepares the pre-trained teacher model for distillation.

This method loads the teacher model from checkpoints, sets it to evaluation mode, and disables gradient computation.

student_model(x)[source]

Performs a forward pass of the student model, including the handling of the distillation token.

This method processes the input through the student model, computes logits for the class token and optionally for the distillation token based on the return_logits_type setting.

Parameters:

x (Tensor) – Input tensor of shape (batch_size, channels, height, width).

Returns:

The output logits from the student model. The type of logits returned is determined by:
  • If return_logits_type is “classification”: Returns logits from the classification head.

  • If return_logits_type is “distillation”: Returns logits from the distillation token classifier.

  • If return_logits_type is “fusion”: Returns the average of logits from both classifiers.

Return type:

Tensor

forward(x)[source]

Performs a forward pass of the DeiT model.

This method computes the outputs of both the student and teacher models. It returns a tuple containing the student model’s output and the teacher model’s output.

Parameters:

x (Tensor) – Input tensor of shape (batch_size, channels, height, width).

Returns:

A tuple with two elements:
  • Tensor: Output from the student model (student_model method).

  • Tensor: Output from the teacher model (computed without gradients).

Return type:

tuple

Example

>>> output = model(torch.randn(1, 3, 224, 224))  # Example input tensor of shape (batch_size, channels, height, width)