o
    hJ                     @   sH  d Z ddlZddlZddlmZ ddlmZmZm	Z	m
Z
 ddlZddlZddlmZ ddlmZmZmZ ddlmZ dd	lmZmZmZmZ dd
lmZ ddlmZmZ ddlmZm Z m!Z!m"Z"m#Z#m$Z$ ddl%m&Z& e#'e(Z)dZ*dZ+g dZ,dZ-dZ.dgZ/G dd dej0Z1G dd dej0Z2G dd dej0Z3G dd dej0Z4G dd dej0Z5G dd dej0Z6G dd  d ej0Z7G d!d" d"ej0Z8G d#d$ d$ej0Z9G d%d& d&eZ:d'Z;d(Z<e!d)e;G d*d+ d+e:Z=G d,d- d-ej0Z>e!d.e;G d/d0 d0e:Z?e!d1e;G d2d3 d3e:Z@eG d4d5 d5eZAe!d6e;G d7d8 d8e:ZBdS )9z PyTorch DeiT model.    N)	dataclass)OptionalSetTupleUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FN)BaseModelOutputBaseModelOutputWithPoolingImageClassifierOutputMaskedImageModelingOutput)PreTrainedModel) find_pruneable_heads_and_indicesprune_linear_layer)ModelOutputadd_code_sample_docstringsadd_start_docstrings%add_start_docstrings_to_model_forwardloggingreplace_return_docstrings   )
DeiTConfigr   z(facebook/deit-base-distilled-patch16-224)r      i   ztabby, tabby catc                       sR   e Zd ZdZddededdf fddZdd	ejd
e	ej
 dejfddZ  ZS )DeiTEmbeddingszv
    Construct the CLS token, distillation token, position and patch embeddings. Optionally, also the mask token.
    Fconfiguse_mask_tokenreturnNc                    s   t    ttdd|j| _ttdd|j| _|r*ttdd|jnd | _	t
|| _| jj}ttd|d |j| _t|j| _d S )Nr      )super__init__r   	Parametertorchzeroshidden_size	cls_tokendistillation_token
mask_tokenDeiTPatchEmbeddingspatch_embeddingsnum_patchesposition_embeddingsDropouthidden_dropout_probdropout)selfr   r   r-   	__class__ \/var/www/html/ai/venv/lib/python3.10/site-packages/transformers/models/deit/modeling_deit.pyr#   I   s   
 
zDeiTEmbeddings.__init__pixel_valuesbool_masked_posc                 C   s   |  |}| \}}}|d ur*| j||d}|d|}|d|  ||  }| j|dd}	| j|dd}
tj	|	|
|fdd}|| j
 }| |}|S )N      ?r   dim)r,   sizer*   expand	unsqueezetype_asr(   r)   r%   catr.   r1   )r2   r7   r8   
embeddings
batch_size
seq_length_mask_tokensmask
cls_tokensdistillation_tokensr5   r5   r6   forwardT   s   


zDeiTEmbeddings.forward)FN)__name__
__module____qualname____doc__r   boolr#   r%   Tensorr   
BoolTensorrJ   __classcell__r5   r5   r3   r6   r   D   s    *r   c                       s6   e Zd ZdZ fddZdejdejfddZ  ZS )r+   z
    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
    Transformer.
    c                    s   t    |j|j}}|j|j}}t|tjj	r|n||f}t|tjj	r)|n||f}|d |d  |d |d   }|| _|| _|| _|| _
tj||||d| _d S )Nr   r   )kernel_sizestride)r"   r#   
image_size
patch_sizenum_channelsr'   
isinstancecollectionsabcIterabler-   r   Conv2d
projection)r2   r   rV   rW   rX   r'   r-   r3   r5   r6   r#   m   s   
 zDeiTPatchEmbeddings.__init__r7   r    c              
   C   s   |j \}}}}|| jkrtd|| jd ks|| jd kr5td| d| d| jd  d| jd  d	| |ddd}|S )	NzeMake sure that the channel dimension of the pixel values match with the one set in the configuration.r   r   zInput image size (*z) doesn't match model (z).r!   )shaperX   
ValueErrorrV   r^   flatten	transpose)r2   r7   rC   rX   heightwidthxr5   r5   r6   rJ   |   s   
(zDeiTPatchEmbeddings.forward)	rL   rM   rN   rO   r#   r%   rQ   rJ   rS   r5   r5   r3   r6   r+   f   s    r+   c                
       sv   e Zd Zdeddf fddZdejdejfddZ		dd
eej de	de
eejejf eej f fddZ  ZS )DeiTSelfAttentionr   r    Nc                    s   t    |j|j dkr t|ds td|jf d|j d|j| _t|j|j | _| j| j | _t	j
|j| j|jd| _t	j
|j| j|jd| _t	j
|j| j|jd| _t	|j| _d S )Nr   embedding_sizezThe hidden size z4 is not a multiple of the number of attention heads .)bias)r"   r#   r'   num_attention_headshasattrra   intattention_head_sizeall_head_sizer   Linearqkv_biasquerykeyvaluer/   attention_probs_dropout_probr1   r2   r   r3   r5   r6   r#      s   
zDeiTSelfAttention.__init__rf   c                 C   s6   |  d d | j| jf }||}|ddddS )Nr9   r   r!   r   r   )r=   rk   rn   viewpermute)r2   rf   new_x_shaper5   r5   r6   transpose_for_scores   s   
z&DeiTSelfAttention.transpose_for_scoresF	head_maskoutput_attentionsc                 C   s   |  |}| | |}| | |}| |}t||dd}|t| j	 }t
jj|dd}	| |	}	|d urA|	| }	t|	|}
|
dddd }
|
 d d | jf }|
|}
|rj|
|	f}|S |
f}|S )Nr9   r;   r   r!   r   r   )rr   rz   rs   rt   r%   matmulrc   mathsqrtrn   r   
functionalsoftmaxr1   rx   
contiguousr=   ro   rw   )r2   hidden_statesr{   r|   mixed_query_layer	key_layervalue_layerquery_layerattention_scoresattention_probscontext_layernew_context_layer_shapeoutputsr5   r5   r6   rJ      s$   



zDeiTSelfAttention.forwardNF)rL   rM   rN   r   r#   r%   rQ   rz   r   rP   r   r   rJ   rS   r5   r5   r3   r6   rg      s    rg   c                       sF   e Zd ZdZdeddf fddZdejdejdejfd	d
Z  Z	S )DeiTSelfOutputz
    The residual connection is defined in DeiTLayer instead of here (as is the case with other models), due to the
    layernorm applied before each block.
    r   r    Nc                    s.   t    t|j|j| _t|j| _d S rK   )	r"   r#   r   rp   r'   denser/   r0   r1   rv   r3   r5   r6   r#         
zDeiTSelfOutput.__init__r   input_tensorc                 C      |  |}| |}|S rK   r   r1   r2   r   r   r5   r5   r6   rJ         

zDeiTSelfOutput.forward)
rL   rM   rN   rO   r   r#   r%   rQ   rJ   rS   r5   r5   r3   r6   r      s    $r   c                       s~   e Zd Zdeddf fddZdee ddfddZ			dd
ej	de
ej	 dedeeej	ej	f eej	 f fddZ  ZS )DeiTAttentionr   r    Nc                    s*   t    t|| _t|| _t | _d S rK   )r"   r#   rg   	attentionr   outputsetpruned_headsrv   r3   r5   r6   r#      s   


zDeiTAttention.__init__headsc                 C   s   t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S )Nr   r   r;   )lenr   r   rk   rn   r   r   rr   rs   rt   r   r   ro   union)r2   r   indexr5   r5   r6   prune_heads   s   zDeiTAttention.prune_headsFr   r{   r|   c                 C   s4   |  |||}| |d |}|f|dd   }|S )Nr   r   )r   r   )r2   r   r{   r|   self_outputsattention_outputr   r5   r5   r6   rJ      s   zDeiTAttention.forwardr   )rL   rM   rN   r   r#   r   rm   r   r%   rQ   r   rP   r   r   rJ   rS   r5   r5   r3   r6   r      s    r   c                       s<   e Zd Zdeddf fddZdejdejfddZ  ZS )	DeiTIntermediater   r    Nc                    sD   t    t|j|j| _t|jt	rt
|j | _d S |j| _d S rK   )r"   r#   r   rp   r'   intermediate_sizer   rY   
hidden_actstrr   intermediate_act_fnrv   r3   r5   r6   r#     s
   
zDeiTIntermediate.__init__r   c                 C   r   rK   )r   r   )r2   r   r5   r5   r6   rJ     r   zDeiTIntermediate.forward	rL   rM   rN   r   r#   r%   rQ   rJ   rS   r5   r5   r3   r6   r     s    r   c                       sB   e Zd Zdeddf fddZdejdejdejfdd	Z  ZS )

DeiTOutputr   r    Nc                    s.   t    t|j|j| _t|j| _	d S rK   )
r"   r#   r   rp   r   r'   r   r/   r0   r1   rv   r3   r5   r6   r#     r   zDeiTOutput.__init__r   r   c                 C   s    |  |}| |}|| }|S rK   r   r   r5   r5   r6   rJ     s   

zDeiTOutput.forwardr   r5   r5   r3   r6   r     s    $r   c                       sl   e Zd ZdZdeddf fddZ		ddejd	eej d
e	de
eejejf eej f fddZ  ZS )	DeiTLayerz?This corresponds to the Block class in the timm implementation.r   r    Nc                    sb   t    |j| _d| _t|| _t|| _t|| _	t
j|j|jd| _t
j|j|jd| _d S )Nr   eps)r"   r#   chunk_size_feed_forwardseq_len_dimr   r   r   intermediater   r   r   	LayerNormr'   layer_norm_epslayernorm_beforelayernorm_afterrv   r3   r5   r6   r#   '  s   



zDeiTLayer.__init__Fr   r{   r|   c                 C   s`   | j | |||d}|d }|dd  }|| }| |}| |}| ||}|f| }|S )N)r|   r   r   )r   r   r   r   r   )r2   r   r{   r|   self_attention_outputsr   r   layer_outputr5   r5   r6   rJ   1  s   


zDeiTLayer.forwardr   )rL   rM   rN   rO   r   r#   r%   rQ   r   rP   r   r   rJ   rS   r5   r5   r3   r6   r   $  s    r   c                       sb   e Zd Zdeddf fddZ				ddejd	eej d
ededede	e
ef fddZ  ZS )DeiTEncoderr   r    Nc                    s:   t     | _t fddt jD | _d| _d S )Nc                    s   g | ]}t  qS r5   )r   ).0rE   r   r5   r6   
<listcomp>S  s    z(DeiTEncoder.__init__.<locals>.<listcomp>F)	r"   r#   r   r   
ModuleListrangenum_hidden_layerslayergradient_checkpointingrv   r3   r   r6   r#   P  s   
 
zDeiTEncoder.__init__FTr   r{   r|   output_hidden_statesreturn_dictc                 C   s   |rdnd }|r
dnd }t | jD ]8\}}	|r||f }|d ur$|| nd }
| jr6| jr6| |	j||
|}n|	||
|}|d }|rI||d f }q|rQ||f }|s_tdd |||fD S t|||dS )Nr5   r   r   c                 s   s    | ]	}|d ur|V  qd S rK   r5   )r   vr5   r5   r6   	<genexpr>z  s    z&DeiTEncoder.forward.<locals>.<genexpr>)last_hidden_stater   
attentions)	enumerater   r   training_gradient_checkpointing_func__call__tupler   )r2   r   r{   r|   r   r   all_hidden_statesall_self_attentionsilayer_modulelayer_head_masklayer_outputsr5   r5   r6   rJ   V  s6   

zDeiTEncoder.forward)NFFT)rL   rM   rN   r   r#   r%   rQ   r   rP   r   r   r   rJ   rS   r5   r5   r3   r6   r   O  s&    	
r   c                   @   sH   e Zd ZdZeZdZdZdZdgZ	de
ejejejf ddfd	d
ZdS )DeiTPreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    deitr7   Tr   moduler    Nc                 C   s   t |tjtjfr0tjj|jjt	j
d| jjd|jj|j_|jdur.|jj  dS dS t |tjrE|jj  |jjd dS dS )zInitialize the weightsg        )meanstdNr:   )rY   r   rp   r]   inittrunc_normal_weightdatator%   float32r   initializer_rangedtyperj   zero_r   fill_)r2   r   r5   r5   r6   _init_weights  s   

z!DeiTPreTrainedModel._init_weights)rL   rM   rN   rO   r   config_classbase_model_prefixmain_input_namesupports_gradient_checkpointing_no_split_modulesr   r   rp   r]   r   r   r5   r5   r5   r6   r     s    &r   aF  
    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
    as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.

    Parameters:
        config ([`DeiTConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
aL  
    Args:
        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
            [`DeiTImageProcessor.__call__`] for details.

        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
            Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
z^The bare DeiT Model transformer outputting raw hidden-states without any specific head on top.c                       s   e Zd Zddedededdf fdd	Zdefd
dZdd Ze	e
eeeeded						ddeej deej deej dee dee dee deeef fddZ  ZS )	DeiTModelTFr   add_pooling_layerr   r    Nc                    s\   t  | || _t||d| _t|| _tj|j	|j
d| _|r%t|nd | _|   d S )N)r   r   )r"   r#   r   r   rB   r   encoderr   r   r'   r   	layernorm
DeiTPoolerpooler	post_init)r2   r   r   r   r3   r5   r6   r#     s   
zDeiTModel.__init__c                 C   s   | j jS rK   )rB   r,   )r2   r5   r5   r6   get_input_embeddings  s   zDeiTModel.get_input_embeddingsc                 C   s*   |  D ]\}}| jj| j| qdS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsr   r   r   r   )r2   heads_to_pruner   r   r5   r5   r6   _prune_heads  s   zDeiTModel._prune_headsvision)
checkpointoutput_typer   modalityexpected_outputr7   r8   r{   r|   r   r   c                 C   s  |dur|n| j j}|dur|n| j j}|dur|n| j j}|du r&td| || j j}| jjj	j
j}|j|kr?||}| j||d}| j|||||d}	|	d }
| |
}
| jdurc| |
nd}|sz|duro|
|fn|
f}||	dd  S t|
||	j|	jdS )z
        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
        Nz You have to specify pixel_values)r8   r{   r|   r   r   r   r   )r   pooler_outputr   r   )r   r|   r   use_return_dictra   get_head_maskr   rB   r,   r^   r   r   r   r   r   r   r   r   r   )r2   r7   r8   r{   r|   r   r   expected_dtypeembedding_outputencoder_outputssequence_outputpooled_outputhead_outputsr5   r5   r6   rJ     s<   


zDeiTModel.forward)TFNNNNNN)rL   rM   rN   r   rP   r#   r+   r   r   r   DEIT_INPUTS_DOCSTRINGr   _CHECKPOINT_FOR_DOCr   _CONFIG_FOR_DOC_EXPECTED_OUTPUT_SHAPEr   r%   rQ   rR   r   r   rJ   rS   r5   r5   r3   r6   r     sB     	
r   c                       s*   e Zd Zdef fddZdd Z  ZS )r   r   c                    s*   t    t|j|j| _t | _d S rK   )r"   r#   r   rp   r'   r   Tanh
activationrv   r3   r5   r6   r#   !  s   
zDeiTPooler.__init__c                 C   s(   |d d df }|  |}| |}|S )Nr   )r   r  )r2   r   first_token_tensorr   r5   r5   r6   rJ   &  s   

zDeiTPooler.forward)rL   rM   rN   r   r#   rJ   rS   r5   r5   r3   r6   r      s    r   aW  DeiT Model with a decoder on top for masked image modeling, as proposed in [SimMIM](https://arxiv.org/abs/2111.09886).

    <Tip>

    Note that we provide a script to pre-train this model on custom data in our [examples
    directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining).

    </Tip>
    c                       s   e Zd Zdeddf fddZeeeee	d						dde
ej de
ej d	e
ej d
e
e de
e de
e deeef fddZ  ZS )DeiTForMaskedImageModelingr   r    Nc                    sX   t  | t|ddd| _ttj|j|jd |j	 ddt
|j| _|   d S )NFT)r   r   r!   r   )in_channelsout_channelsrT   )r"   r#   r   r   r   
Sequentialr]   r'   encoder_striderX   PixelShuffledecoderr   rv   r3   r5   r6   r#   <  s   

z#DeiTForMaskedImageModeling.__init__r   r   r7   r8   r{   r|   r   r   c                 C   sH  |dur|n| j j}| j||||||d}|d }|ddddf }|j\}	}
}t|
d  }}|ddd|	|||}| |}d}|dur| j j| j j	 }|d||}|
| j j	d
| j j	dd }tjj||dd	}||  | d
  | j j }|s|f|dd  }|dur|f| S |S t|||j|jdS )aM  
        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).

        Returns:

        Examples:
        ```python
        >>> from transformers import AutoImageProcessor, DeiTForMaskedImageModeling
        >>> import torch
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> image_processor = AutoImageProcessor.from_pretrained("facebook/deit-base-distilled-patch16-224")
        >>> model = DeiTForMaskedImageModeling.from_pretrained("facebook/deit-base-distilled-patch16-224")

        >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
        >>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
        >>> # create random boolean mask of shape (batch_size, num_patches)
        >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()

        >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
        >>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction
        >>> list(reconstructed_pixel_values.shape)
        [1, 3, 224, 224]
        ```N)r8   r{   r|   r   r   r   r   r9   g      ?r!   none)	reductiongh㈵>)lossreconstructionr   r   )r   r   r   r`   rm   rx   reshaper  rV   rW   repeat_interleaver?   r   r   r   l1_losssumrX   r   r   r   )r2   r7   r8   r{   r|   r   r   r   r   rC   sequence_lengthrX   rd   re   reconstructed_pixel_valuesmasked_im_lossr=   rG   reconstruction_lossr   r5   r5   r6   rJ   M  sF   (	
 z"DeiTForMaskedImageModeling.forwardr   )rL   rM   rN   r   r#   r   r   r   r   r  r   r%   rQ   rR   rP   r   r   rJ   rS   r5   r5   r3   r6   r  /  s2    

r  z
    DeiT Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
    the [CLS] token) e.g. for ImageNet.
    c                       s   e Zd Zdeddf fddZeeeee	d						dde
ej de
ej d	e
ej d
e
e de
e de
e deeef fddZ  ZS )DeiTForImageClassificationr   r    Nc                    sR   t  | |j| _t|dd| _|jdkrt|j|jnt | _	| 
  d S NF)r   r   )r"   r#   
num_labelsr   r   r   rp   r'   Identity
classifierr   rv   r3   r5   r6   r#     s
   $z#DeiTForImageClassification.__init__r  r7   r{   labelsr|   r   r   c                 C   s~  |dur|n| j j}| j|||||d}|d }| |dddddf }	d}
|dur||	j}| j jdu r\| jdkrBd| j _n| jdkrX|jt	j
ksS|jt	jkrXd| j _nd| j _| j jdkrzt }| jdkrt||	 | }
n+||	|}
n%| j jdkrt }||	d| j|d}
n| j jdkrt }||	|}
|s|	f|dd  }|
dur|
f| S |S t|
|	|j|jd	S )
aM  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).

        Returns:

        Examples:

        ```python
        >>> from transformers import AutoImageProcessor, DeiTForImageClassification
        >>> import torch
        >>> from PIL import Image
        >>> import requests

        >>> torch.manual_seed(3)  # doctest: +IGNORE_RESULT
        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> # note: we are loading a DeiTForImageClassificationWithTeacher from the hub here,
        >>> # so the head will be randomly initialized, hence the predictions will be random
        >>> image_processor = AutoImageProcessor.from_pretrained("facebook/deit-base-distilled-patch16-224")
        >>> model = DeiTForImageClassification.from_pretrained("facebook/deit-base-distilled-patch16-224")

        >>> inputs = image_processor(images=image, return_tensors="pt")
        >>> outputs = model(**inputs)
        >>> logits = outputs.logits
        >>> # model predicts one of the 1000 ImageNet classes
        >>> predicted_class_idx = logits.argmax(-1).item()
        >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
        Predicted class: magpie
        ```Nr   r   r   
regressionsingle_label_classificationmulti_label_classificationr9   )r  logitsr   r   )r   r   r   r  r   deviceproblem_typer  r   r%   longrm   r
   squeezer	   rw   r   r   r   r   )r2   r7   r{   r   r|   r   r   r   r   r$  r  loss_fctr   r5   r5   r6   rJ     sN   ,

"


z"DeiTForImageClassification.forwardr   )rL   rM   rN   r   r#   r   r   r   r   r  r   r%   rQ   rP   r   r   rJ   rS   r5   r5   r3   r6   r    s2    

r  c                   @   sh   e Zd ZU dZdZejed< dZejed< dZ	ejed< dZ
eeej  ed< dZeeej  ed< dS )+DeiTForImageClassificationWithTeacherOutputa5  
    Output type of [`DeiTForImageClassificationWithTeacher`].

    Args:
        logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
            Prediction scores as the average of the cls_logits and distillation logits.
        cls_logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
            Prediction scores of the classification head (i.e. the linear layer on top of the final hidden state of the
            class token).
        distillation_logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
            Prediction scores of the distillation head (i.e. the linear layer on top of the final hidden state of the
            distillation token).
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
            plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
            the self-attention heads.
    Nr$  
cls_logitsdistillation_logitsr   r   )rL   rM   rN   rO   r$  r%   FloatTensor__annotations__r+  r,  r   r   r   r   r5   r5   r5   r6   r*    s   
 r*  a  
    DeiT Model transformer with image classification heads on top (a linear layer on top of the final hidden state of
    the [CLS] token and a linear layer on top of the final hidden state of the distillation token) e.g. for ImageNet.

    .. warning::

           This model supports inference-only. Fine-tuning with distillation (i.e. with a teacher) is not yet
           supported.
    c                       s   e Zd Zdeddf fddZeeeee	e
ed					ddeej deej d	ee d
ee dee deee	f fddZ  ZS )%DeiTForImageClassificationWithTeacherr   r    Nc                    sv   t  | |j| _t|dd| _|jdkrt|j|jnt | _	|jdkr0t|j|jnt | _
|   d S r  )r"   r#   r  r   r   r   rp   r'   r  cls_classifierdistillation_classifierr   rv   r3   r5   r6   r#   B  s     z.DeiTForImageClassificationWithTeacher.__init__)r   r   r   r   r7   r{   r|   r   r   c                 C   s   |d ur|n| j j}| j|||||d}|d }| |d d dd d f }| |d d dd d f }	||	 d }
|sI|
||	f|dd   }|S t|
||	|j|jdS )Nr   r   r   r!   )r$  r+  r,  r   r   )r   r   r   r0  r1  r*  r   r   )r2   r7   r{   r|   r   r   r   r   r+  r,  r$  r   r5   r5   r6   rJ   S  s,   z-DeiTForImageClassificationWithTeacher.forward)NNNNN)rL   rM   rN   r   r#   r   r   r   _IMAGE_CLASS_CHECKPOINTr*  r  _IMAGE_CLASS_EXPECTED_OUTPUTr   r%   rQ   rP   r   r   rJ   rS   r5   r5   r3   r6   r/  5  s6    
r/  )CrO   collections.abcrZ   r   dataclassesr   typingr   r   r   r   r%   torch.utils.checkpointr   torch.nnr   r	   r
   activationsr   modeling_outputsr   r   r   r   modeling_utilsr   pytorch_utilsr   r   utilsr   r   r   r   r   r   configuration_deitr   
get_loggerrL   loggerr  r  r  r2  r3  "DEIT_PRETRAINED_MODEL_ARCHIVE_LISTModuler   r+   rg   r   r   r   r   r   r   r   DEIT_START_DOCSTRINGr   r   r   r  r  r*  r/  r5   r5   r5   r6   <module>   st    
"%=(+3]	ik	