o
    h                  	   @   s  d Z ddlZddlZddlZddlmZ ddlmZm	Z	m
Z
 ddlZddlZddlmZ ddlmZmZmZ ddlmZ dd	lmZ dd
lmZmZmZ ddlmZmZmZmZmZm Z  ddl!m"Z" e#e$Z%dZ&dZ'g dZ(dZ)dZ*dgZ+eG dd deZ,eG dd deZ-eG dd deZ.eG dd deZ/dd Z0dd Z1dKd ej2d!e3d"e4d#ej2fd$d%Z5G d&d' d'ej6Z7G d(d) d)ej6Z8G d*d+ d+ej6Z9G d,d- d-ej6Z:G d.d/ d/ej6Z;G d0d1 d1ej6Z<G d2d3 d3ej6Z=G d4d5 d5ej6Z>G d6d7 d7ej6Z?G d8d9 d9ej6Z@G d:d; d;ej6ZAG d<d= d=ej6ZBG d>d? d?eZCd@ZDdAZEedBeDG dCdD dDeCZFedEeDG dFdG dGeCZGedHeDG dIdJ dJeCZHdS )Lz" PyTorch Swinv2 Transformer model.    N)	dataclass)OptionalTupleUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FN)PreTrainedModel) find_pruneable_heads_and_indicesmeshgridprune_linear_layer)ModelOutputadd_code_sample_docstringsadd_start_docstrings%add_start_docstrings_to_model_forwardloggingreplace_return_docstrings   )Swinv2Configr   z(microsoft/swinv2-tiny-patch4-window8-256)r   @   i   zEgyptian catc                   @   sb   e Zd ZU dZdZejed< dZe	e
ej  ed< dZe	e
ej  ed< dZe	e
ej  ed< dS )Swinv2EncoderOutputa  
    Swinv2 encoder's outputs, with potential hidden states and attentions.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
        reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, hidden_size, height, width)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
            include the spatial dimensions.
    Nlast_hidden_statehidden_states
attentionsreshaped_hidden_states)__name__
__module____qualname____doc__r   torchFloatTensor__annotations__r   r   r   r   r    r%   r%   `/var/www/html/ai/venv/lib/python3.10/site-packages/transformers/models/swinv2/modeling_swinv2.pyr   B   s   
 r   c                   @   st   e Zd ZU dZdZejed< dZe	ej ed< dZ
e	eej  ed< dZe	eej  ed< dZe	eej  ed< dS )Swinv2ModelOutputaV  
    Swinv2 model's outputs that also contains a pooling of the last hidden states.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed):
            Average pooling of the last layer hidden-state.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
        reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, hidden_size, height, width)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
            include the spatial dimensions.
    Nr   pooler_outputr   r   r   )r   r   r    r!   r   r"   r#   r$   r(   r   r   r   r   r   r%   r%   r%   r&   r'   d   s   
 r'   c                   @   s   e Zd ZU dZdZeej ed< dZ	ejed< dZ
eeej  ed< dZeeej  ed< dZeeej  ed< edd	 ZdS )
Swinv2MaskedImageModelingOutputa  
    Swinv2 masked image model outputs.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided):
            Masked image modeling (MLM) loss.
        reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Reconstructed pixel values.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
        reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, hidden_size, height, width)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
            include the spatial dimensions.
    Nlossreconstructionr   r   r   c                 C   s   t dt | jS )Nzlogits attribute is deprecated and will be removed in version 5 of Transformers. Please use the reconstruction attribute to retrieve the final output instead.)warningswarnFutureWarningr+   selfr%   r%   r&   logits   s
   z&Swinv2MaskedImageModelingOutput.logits)r   r   r    r!   r*   r   r"   r#   r$   r+   r   r   r   r   propertyr1   r%   r%   r%   r&   r)      s   
 r)   c                   @   st   e Zd ZU dZdZeej ed< dZ	ejed< dZ
eeej  ed< dZeeej  ed< dZeeej  ed< dS )Swinv2ImageClassifierOutputa  
    Swinv2 outputs for image classification.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Classification (or regression if config.num_labels==1) loss.
        logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
            Classification (or regression if config.num_labels==1) scores (before SoftMax).
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
        reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, hidden_size, height, width)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
            include the spatial dimensions.
    Nr*   r1   r   r   r   )r   r   r    r!   r*   r   r"   r#   r$   r1   r   r   r   r   r%   r%   r%   r&   r3      s   
 r3   c                 C   sR   | j \}}}}| ||| ||| ||} | dddddd d|||}|S )z2
    Partitions the given input into windows.
    r   r   r
            shapeviewpermute
contiguous)input_featurewindow_size
batch_sizeheightwidthnum_channelswindowsr%   r%   r&   window_partition   s   $rD   c                 C   sN   | j d }| d|| || |||} | dddddd d|||} | S )z?
    Merges windows to produce higher resolution features.
    r7   r   r   r
   r4   r5   r6   r8   )rC   r>   r@   rA   rB   r%   r%   r&   window_reverse   s   
$rE           Finput	drop_probtrainingreturnc                 C   sd   |dks|s| S d| }| j d fd| jd   }|tj|| j| jd }|  | || }|S )aF  
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

    Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
    however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
    layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
    argument.
    rF   r   r   )r   )dtypedevice)r9   ndimr"   randrK   rL   floor_div)rG   rH   rI   	keep_probr9   random_tensoroutputr%   r%   r&   	drop_path   s   
rT   c                       sT   e Zd ZdZddee ddf fddZdejdejfdd	Z	de
fd
dZ  ZS )Swinv2DropPathzXDrop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).NrH   rJ   c                    s   t    || _d S N)super__init__rH   )r0   rH   	__class__r%   r&   rX     s   

zSwinv2DropPath.__init__r   c                 C   s   t || j| jS rV   )rT   rH   rI   r0   r   r%   r%   r&   forward  s   zSwinv2DropPath.forwardc                 C   s   d | jS )Nzp={})formatrH   r/   r%   r%   r&   
extra_repr  s   zSwinv2DropPath.extra_reprrV   )r   r   r    r!   r   floatrX   r"   Tensorr\   strr^   __classcell__r%   r%   rY   r&   rU   
  s
    rU   c                       sN   e Zd ZdZd fdd	Z	ddeej deej de	ej
 fd	d
Z  ZS )Swinv2EmbeddingszW
    Construct the patch and position embeddings. Optionally, also the mask token.
    Fc                    s   t    t|| _| jj}| jj| _|r tt	
dd|jnd | _|jr5tt	
d|d |j| _nd | _t|j| _t|j| _d S )Nr   )rW   rX   Swinv2PatchEmbeddingspatch_embeddingsnum_patches	grid_size
patch_gridr   	Parameterr"   zeros	embed_dim
mask_tokenuse_absolute_embeddingsposition_embeddings	LayerNormnormDropouthidden_dropout_probdropout)r0   configuse_mask_tokenrf   rY   r%   r&   rX     s   


 zSwinv2Embeddings.__init__Npixel_valuesbool_masked_posrJ   c           
      C   s   |  |\}}| |}| \}}}|d ur1| j||d}|d|}	|d|	  ||	  }| jd ur;|| j }| |}||fS )Nr7         ?)	re   rp   sizerl   expand	unsqueezetype_asrn   rs   )
r0   rv   rw   
embeddingsoutput_dimensionsr?   seq_len_mask_tokensmaskr%   r%   r&   r\   .  s   



zSwinv2Embeddings.forward)FrV   )r   r   r    r!   rX   r   r"   r#   
BoolTensorr   r`   r\   rb   r%   r%   rY   r&   rc     s    rc   c                       sN   e Zd ZdZ fddZdd Zdeej de	ej
e	e f fdd	Z  ZS )
rd   z
    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
    Transformer.
    c                    s   t    |j|j}}|j|j}}t|tjj	r|n||f}t|tjj	r)|n||f}|d |d  |d |d   }|| _|| _|| _|| _
|d |d  |d |d  f| _tj||||d| _d S )Nr   r   )kernel_sizestride)rW   rX   
image_size
patch_sizerB   rk   
isinstancecollectionsabcIterablerf   rg   r   Conv2d
projection)r0   rt   r   r   rB   hidden_sizerf   rY   r%   r&   rX   K  s   
 "zSwinv2PatchEmbeddings.__init__c                 C   s   || j d  dkrd| j d || j d   f}tj||}|| j d  dkr>ddd| j d || j d   f}tj||}|S )Nr   r   )r   r   
functionalpad)r0   rv   r@   rA   
pad_valuesr%   r%   r&   	maybe_padZ  s    zSwinv2PatchEmbeddings.maybe_padrv   rJ   c                 C   sh   |j \}}}}|| jkrtd| |||}| |}|j \}}}}||f}|ddd}||fS )NzeMake sure that the channel dimension of the pixel values match with the one set in the configuration.r4   r   )r9   rB   
ValueErrorr   r   flatten	transpose)r0   rv   r   rB   r@   rA   r}   r~   r%   r%   r&   r\   c  s   

zSwinv2PatchEmbeddings.forward)r   r   r    r!   rX   r   r   r"   r#   r   r`   intr\   rb   r%   r%   rY   r&   rd   D  s
    .	rd   c                	       sh   e Zd ZdZejfdee dedejddf fddZ	d	d
 Z
dejdeeef dejfddZ  ZS )Swinv2PatchMerginga'  
    Patch Merging Layer.

    Args:
        input_resolution (`Tuple[int]`):
            Resolution of input feature.
        dim (`int`):
            Number of input channels.
        norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
            Normalization layer class.
    input_resolutiondim
norm_layerrJ   Nc                    sB   t    || _|| _tjd| d| dd| _|d| | _d S )Nr5   r4   Fbias)rW   rX   r   r   r   Linear	reductionrp   )r0   r   r   r   rY   r%   r&   rX     s
   
zSwinv2PatchMerging.__init__c                 C   sF   |d dkp|d dk}|r!ddd|d d|d f}t j||}|S )Nr4   r   r   )r   r   r   )r0   r=   r@   rA   
should_padr   r%   r%   r&   r     s
   zSwinv2PatchMerging.maybe_padr=   input_dimensionsc                 C   s   |\}}|j \}}}|||||}| |||}|d d dd ddd dd d f }|d d dd ddd dd d f }	|d d dd ddd dd d f }
|d d dd ddd dd d f }t||	|
|gd}||dd| }| |}| |}|S )Nr   r4   r   r7   r5   )r9   r:   r   r"   catr   rp   )r0   r=   r   r@   rA   r?   r   rB   input_feature_0input_feature_1input_feature_2input_feature_3r%   r%   r&   r\     s   $$$$

zSwinv2PatchMerging.forward)r   r   r    r!   r   ro   r   r   ModulerX   r   r"   r`   r\   rb   r%   r%   rY   r&   r   s  s
    **r   c                       sj   e Zd Zddgf fdd	Zdd Z			ddejd	eej d
eej dee	 de
ej f
ddZ  ZS )Swinv2SelfAttentionr   c              
      s  t    || dkrtd| d| d|| _t|| | _| j| j | _t|tj	j
r0|n||f| _|| _ttdt|ddf | _ttjddd	d
tjd	dtjd|dd
| _tj| jd d  | jd tjd}tj| jd d  | jd tjd}tt||gddddd d}|d dkr|d d d d d d df  |d d   < |d d d d d d df  |d d   < n.|d d d d d d df  | jd d   < |d d d d d d df  | jd d   < |d9 }t|tt |d  t!d }| j"d|dd t| jd }	t| jd }
tt|	|
gdd}t#|d}|d d d d d f |d d d d d f  }|ddd }|d d d d df  | jd d 7  < |d d d d df  | jd d 7  < |d d d d df  d| jd  d 9  < |$d}| j"d|dd tj| j| j|j%d
| _&tj| j| jdd
| _'tj| j| j|j%d
| _(t)|j*| _+d S )Nr   zThe hidden size (z6) is not a multiple of the number of attention heads ()
   r   r4   i   Tr   )inplaceFrK   ij)indexing   rx   relative_coords_table)
persistentr7   relative_position_index),rW   rX   r   num_attention_headsr   attention_head_sizeall_head_sizer   r   r   r   r>   pretrained_window_sizer   ri   r"   logoneslogit_scale
Sequentialr   ReLUcontinuous_position_bias_mlparangefloat32stackr   r;   r<   r{   signlog2absmathregister_bufferr   sumqkv_biasquerykeyvaluerq   attention_probs_dropout_probrs   )r0   rt   r   	num_headsr>   r   relative_coords_hrelative_coords_wr   coords_hcoords_wcoordscoords_flattenrelative_coordsr   rY   r%   r&   rX     s\   
"&$$
,...&,((,
zSwinv2SelfAttention.__init__c                 C   s6   |  d d | j| jf }||}|ddddS )Nr7   r   r4   r   r
   )ry   r   r   r:   r;   )r0   xnew_x_shaper%   r%   r&   transpose_for_scores  s   
z(Swinv2SelfAttention.transpose_for_scoresNFr   attention_mask	head_maskoutput_attentionsrJ   c                 C   s  |j \}}}| |}| | |}	| | |}
| |}tjj|ddtjj|	dddd }t	j
| jtdd }|| }| | jd| j}|| jd | jd | jd  | jd | jd  d}|ddd }d	t	| }||d }|d ur|j d }||| || j|||dd }||dd }|d| j||}tjj|dd}| |}|d ur|| }t	||
}|dddd
 }| d d | jf }||}|r||f}|S |f}|S )Nr7   r   g      Y@)maxr   r   r4      r
   )r9   r   r   r   r   r   r   	normalizer   r"   clampr   r   r   expr   r   r:   r   r   r>   r;   r<   sigmoidr{   softmaxrs   matmulry   r   )r0   r   r   r   r   r?   r   rB   mixed_query_layer	key_layervalue_layerquery_layerattention_scoresr   relative_position_bias_tablerelative_position_bias
mask_shapeattention_probscontext_layernew_context_layer_shapeoutputsr%   r%   r&   r\     sT   

&


zSwinv2SelfAttention.forwardNNF)r   r   r    rX   r   r"   r`   r   r#   boolr   r\   rb   r%   r%   rY   r&   r     s"    ;r   c                       s8   e Zd Z fddZdejdejdejfddZ  ZS )Swinv2SelfOutputc                    s*   t    t||| _t|j| _d S rV   )rW   rX   r   r   denserq   r   rs   r0   rt   r   rY   r%   r&   rX   *  s   
zSwinv2SelfOutput.__init__r   input_tensorrJ   c                 C      |  |}| |}|S rV   r   rs   )r0   r   r   r%   r%   r&   r\   /  s   

zSwinv2SelfOutput.forwardr   r   r    rX   r"   r`   r\   rb   r%   r%   rY   r&   r   )  s    $r   c                       sd   e Zd Zd fdd	Zdd Z			ddejd	eej d
eej dee	 de
ej f
ddZ  ZS )Swinv2Attentionr   c                    sL   t    t||||t|tjjr|n||fd| _t||| _	t
 | _d S )Nrt   r   r   r>   r   )rW   rX   r   r   r   r   r   r0   r   rS   setpruned_heads)r0   rt   r   r   r>   r   rY   r%   r&   rX   7  s   
	zSwinv2Attention.__init__c                 C   s   t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S )Nr   r   r   )lenr   r0   r   r   r   r   r   r   r   rS   r   r   union)r0   headsindexr%   r%   r&   prune_headsE  s   zSwinv2Attention.prune_headsNFr   r   r   r   rJ   c                 C   s6   |  ||||}| |d |}|f|dd   }|S )Nr   r   )r0   rS   )r0   r   r   r   r   self_outputsattention_outputr   r%   r%   r&   r\   W  s   zSwinv2Attention.forwardr   r   )r   r   r    rX   r   r"   r`   r   r#   r   r   r\   rb   r%   r%   rY   r&   r   6  s"    r   c                       2   e Zd Z fddZdejdejfddZ  ZS )Swinv2Intermediatec                    sJ   t    t|t|j| | _t|jt	rt
|j | _d S |j| _d S rV   )rW   rX   r   r   r   	mlp_ratior   r   
hidden_actra   r   intermediate_act_fnr   rY   r%   r&   rX   f  s
   
zSwinv2Intermediate.__init__r   rJ   c                 C   r   rV   )r   r  r[   r%   r%   r&   r\   n     

zSwinv2Intermediate.forwardr   r%   r%   rY   r&   r  e  s    r  c                       r  )Swinv2Outputc                    s4   t    tt|j| || _t|j| _	d S rV   )
rW   rX   r   r   r   r  r   rq   rr   rs   r   rY   r%   r&   rX   v  s   
zSwinv2Output.__init__r   rJ   c                 C   r   rV   r   r[   r%   r%   r&   r\   {  r  zSwinv2Output.forwardr   r%   r%   rY   r&   r	  u  s    r	  c                       s   e Zd Zd fdd	Zdd Zdd Zdd	 Z	
		ddejde	e
e
f deej dee dee de	ejejf fddZ  ZS )Swinv2Layerr   c                    s   t    |j| _|| _|j| _|| _| | t|||| jt|t	j
jr'|n||fd| _tj||jd| _|jdkrAt|jnt | _t||| _t||| _tj||jd| _d S )Nr   epsrF   )rW   rX   chunk_size_feed_forward
shift_sizer>   r   set_shift_and_window_sizer   r   r   r   r   	attentionr   ro   layer_norm_epslayernorm_beforedrop_path_raterU   IdentityrT   r  intermediater	  rS   layernorm_after)r0   rt   r   r   r   r  r   rY   r%   r&   rX     s(   

	zSwinv2Layer.__init__c                 C   s   t | jtjjr| jn| j| jf}t | jtjjr| jn| j| jf}t|d r/|d  n|d }||d kr;|n|d | _|t | jtjjrL| jn| j| jfkrXd| _d S |d | _d S Nr   )	r   r>   r   r   r   r  r"   	is_tensoritem)r0   r   target_window_sizetarget_shift_size
window_dimr%   r%   r&   r    s&   

"
z%Swinv2Layer.set_shift_and_window_sizec              	   C   s  | j dkrtjd||df|d}td| j t| j | j  t| j  d f}td| j t| j | j  t| j  d f}d}|D ]}|D ]}	||d d ||	d d f< |d7 }qDq@t|| j}
|
d| j| j }
|
d|
d }||dkt	d|dkt	d}|S d }|S )Nr   r   r   r7   r4   g      YrF   )
r  r"   rj   slicer>   rD   r:   r{   masked_fillr_   )r0   r@   rA   rK   img_maskheight_sliceswidth_slicescountheight_slicewidth_slicemask_windows	attn_maskr%   r%   r&   get_attn_mask  s.   

$zSwinv2Layer.get_attn_maskc                 C   sR   | j || j   | j  }| j || j   | j  }ddd|d|f}tj||}||fS r  )r>   r   r   r   )r0   r   r@   rA   	pad_right
pad_bottomr   r%   r%   r&   r     s
   zSwinv2Layer.maybe_padNFr   r   r   r   always_partitionrJ   c                 C   s  |s|  | n	 |\}}| \}}	}
|}|||||
}| |||\}}|j\}	}}}	| jdkrBtj|| j | j fdd}n|}t|| j	}|d| j	| j	 |
}| j
|||jd}|d urh||j}| j||||d}|d }|d| j	| j	|
}t|| j	||}| jdkrtj|| j| jfdd}n|}|d dkp|d dk}|r|d d d |d |d d f  }|||| |
}| |}|| | }| |}| |}|| | | }|r||d	 f}|S |f}|S )
Nr   )r   r4   )shiftsdimsr7   r   )r   r
   r6   r   )r  ry   r:   r   r9   r  r"   rollrD   r>   r'  rK   torL   r  rE   r<   r  rT   r  rS   r  )r0   r   r   r   r   r*  r@   rA   r?   r   channelsshortcutr   
height_pad	width_padshifted_hidden_stateshidden_states_windowsr&  attention_outputsr  attention_windowsshifted_windows
was_paddedlayer_outputlayer_outputsr%   r%   r&   r\     sN   

$


zSwinv2Layer.forward)r   r   NFF)r   r   r    rX   r  r'  r   r"   r`   r   r   r   r#   r   r\   rb   r%   r%   rY   r&   r
    s*    
r
  c                       sh   e Zd Z	d fdd	Z			ddejdeeef deej	 d	ee
 d
ee
 deej fddZ  ZS )Swinv2Stager   c	           	         s   t     | _| _t jtjjr jn j jft	
 fddt|D | _|d ur>|t	jd| _nd | _d| _d S )Nc              
      sH   g | ] }t  |d  dkrddgnd d  d d  gdqS )r4   r   r   )rt   r   r   r   r  r   )r
  ).0irt   r   r   r   r   r>   r%   r&   
<listcomp>   s    	*z(Swinv2Stage.__init__.<locals>.<listcomp>)r   r   F)rW   rX   rt   r   r   r>   r   r   r   r   
ModuleListrangeblocksro   
downsamplepointing)	r0   rt   r   r   depthr   rT   rD  r   rY   r?  r&   rX     s    

	
zSwinv2Stage.__init__NFr   r   r   r   r*  rJ   c                 C   s   |\}}t | jD ]\}}	|d ur|| nd }
|	|||
||}|d }q	|}| jd urE|d d |d d }}||||f}| ||}n||||f}|||f}|rZ||dd  7 }|S )Nr   r   r4   )	enumeraterC  rD  )r0   r   r   r   r   r*  r@   rA   r>  layer_modulelayer_head_maskr:  !hidden_states_before_downsamplingheight_downsampledwidth_downsampledr~   stage_outputsr%   r%   r&   r\   6  s"   



zSwinv2Stage.forwardr  r;  )r   r   r    rX   r"   r`   r   r   r   r#   r   r\   rb   r%   r%   rY   r&   r<    s&    &
r<  c                       s   e Zd Zd fdd	Z						ddejdeeef d	eej	 d
ee
 dee
 dee
 dee
 dee
 deeef fddZ  ZS )Swinv2Encoderr   r   r   r   c                    s   t    t j_ _jjd ur jdd td j	t
 jD t fddtjD _d_d S )Nc                 S   s   g | ]}|  qS r%   )r  )r=  r   r%   r%   r&   r@  ^  s    z*Swinv2Encoder.__init__.<locals>.<listcomp>r   c                    s   g | ]H}t  t jd |  d d |  d d |  f j|  j| t jd| t jd|d   |jd k rCtnd| dqS )r4   r   r   N)rt   r   r   rF  r   rT   rD  r   )r<  r   rk   depthsr   r   
num_layersr   )r=  i_layerrt   dprrg   pretrained_window_sizesr0   r%   r&   r@  `  s    *F)rW   rX   r   rP  rQ  rt   rU  r"   linspacer  r   r   rA  rB  layersgradient_checkpointing)r0   rt   rg   rU  rY   rS  r&   rX   X  s   
 
zSwinv2Encoder.__init__NFTr   r   r   r   output_hidden_states(output_hidden_states_before_downsamplingr*  return_dictrJ   c	                 C   s  |rdnd }	|r
dnd }
|rdnd }|r7|j \}}}|j|g||R  }|dddd}|	|f7 }	|
|f7 }
t| jD ]\}}|d urH|| nd }| jr[| jr[| |j||||}n||||||}|d }|d }|d }|d |d f}|r|r|j \}}}|j|g|d |d f|R  }|dddd}|	|f7 }	|
|f7 }
n'|r|s|j \}}}|j|g||R  }|dddd}|	|f7 }	|
|f7 }
|r||dd  7 }q<|st	dd	 ||	|fD S t
||	||
d
S )Nr%   r   r
   r   r4   r   r7   c                 s   s    | ]	}|d ur|V  qd S rV   r%   )r=  vr%   r%   r&   	<genexpr>  s    z(Swinv2Encoder.forward.<locals>.<genexpr>)r   r   r   r   )r9   r:   r;   rG  rW  rX  rI   _gradient_checkpointing_func__call__tupler   )r0   r   r   r   r   rY  rZ  r*  r[  all_hidden_statesall_reshaped_hidden_statesall_self_attentionsr?   r   r   reshaped_hidden_stater>  rH  rI  r:  rJ  r~   r%   r%   r&   r\   r  sf   





zSwinv2Encoder.forward)rO  )NFFFFT)r   r   r    rX   r"   r`   r   r   r   r#   r   r   r   r\   rb   r%   r%   rY   r&   rN  W  s6    
	

rN  c                   @   s(   e Zd ZdZeZdZdZdZdd Z	dS )Swinv2PreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    swinv2rv   Tc                 C   st   t |tjtjfr#|jjjd| jjd |j	dur!|j	j
  dS dS t |tjr8|j	j
  |jjd dS dS )zInitialize the weightsrF   )meanstdNrx   )r   r   r   r   weightdatanormal_rt   initializer_ranger   zero_ro   fill_)r0   moduler%   r%   r&   _init_weights  s   
z#Swinv2PreTrainedModel._init_weightsN)
r   r   r    r!   r   config_classbase_model_prefixmain_input_namesupports_gradient_checkpointingrp  r%   r%   r%   r&   re    s    re  aI  
    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
    it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.

    Parameters:
        config ([`Swinv2Config`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
aJ  
    Args:
        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`]
            for details.
        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
            Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
z`The bare Swinv2 Model transformer outputting raw hidden-states without any specific head on top.c                       s   e Zd Zd fdd	Zdd Zdd Zeeee	e
ed	ed
						ddeej deej deej dee dee dee deee
f fddZ  ZS )Swinv2ModelTFc                    s   t  | || _t|j| _t|jd| jd   | _t	||d| _
t|| j
j| _tj| j|jd| _|r<tdnd | _|   d S )Nr4   r   )ru   r  )rW   rX   rt   r   rP  rQ  r   rk   num_featuresrc   r}   rN  rh   encoderr   ro   r  	layernormAdaptiveAvgPool1dpooler	post_init)r0   rt   add_pooling_layerru   rY   r%   r&   rX     s   zSwinv2Model.__init__c                 C   s   | j jS rV   )r}   re   r/   r%   r%   r&   get_input_embeddings
  s   z Swinv2Model.get_input_embeddingsc                 C   s*   |  D ]\}}| jj| j| qdS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsrw  layerr  r   )r0   heads_to_pruner  r   r%   r%   r&   _prune_heads  s   zSwinv2Model._prune_headsvision)
checkpointoutput_typerq  modalityexpected_outputNrv   rw   r   r   rY  r[  rJ   c                 C   s   |dur|n| j j}|dur|n| j j}|dur|n| j j}|du r&td| |t| j j}| j||d\}}| j	||||||d}	|	d }
| 
|
}
d}| jdurc| |
dd}t|d}|sq|
|f|	dd  }|S t|
||	j|	j|	jdS )	z
        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
        Nz You have to specify pixel_values)rw   r   r   rY  r[  r   r   r4   )r   r(   r   r   r   )rt   r   rY  use_return_dictr   get_head_maskr   rP  r}   rw  rx  rz  r   r"   r   r'   r   r   r   )r0   rv   rw   r   r   rY  r[  embedding_outputr   encoder_outputssequence_outputpooled_outputrS   r%   r%   r&   r\     s@   	

zSwinv2Model.forward)TFNNNNNN)r   r   r    rX   r}  r  r   SWINV2_INPUTS_DOCSTRINGr   _CHECKPOINT_FOR_DOCr'   _CONFIG_FOR_DOC_EXPECTED_OUTPUT_SHAPEr   r"   r#   r   r   r   r   r\   rb   r%   r%   rY   r&   ru    sB    	
ru  aY  Swinv2 Model with a decoder on top for masked image modeling, as proposed in
[SimMIM](https://arxiv.org/abs/2111.09886).

    <Tip>

    Note that we provide a script to pre-train this model on custom data in our [examples
    directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining).

    </Tip>
    c                       s   e Zd Z fddZeeeeed						dde	e
j de	e
j de	e
j de	e d	e	e d
e	e deeef fddZ  ZS )Swinv2ForMaskedImageModelingc                    sn   t  | t|ddd| _t|jd|jd   }ttj	||j
d |j ddt|j
| _|   d S )NFT)r|  ru   r4   r   )in_channelsout_channelsr   )rW   rX   ru  rf  r   rk   rQ  r   r   r   encoder_striderB   PixelShuffledecoderr{  )r0   rt   rv  rY   r%   r&   rX   j  s   
z%Swinv2ForMaskedImageModeling.__init__)r  rq  Nrv   rw   r   r   rY  r[  rJ   c                 C   s<  |dur|n| j j}| j||||||d}|d }|dd}|j\}	}
}t|d  }}||	|
||}| |}d}|dur|| j j	| j j
 }|d||}|| j j
d| j j
dd }tjj||dd	}||  | d
  | j j }|s|f|dd  }|dur|f| S |S t|||j|j|jdS )aQ  
        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).

        Returns:

        Examples:
        ```python
        >>> from transformers import AutoImageProcessor, Swinv2ForMaskedImageModeling
        >>> import torch
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256")
        >>> model = Swinv2ForMaskedImageModeling.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256")

        >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
        >>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
        >>> # create random boolean mask of shape (batch_size, num_patches)
        >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()

        >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
        >>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction
        >>> list(reconstructed_pixel_values.shape)
        [1, 3, 256, 256]
        ```N)rw   r   r   rY  r[  r   r   r4   g      ?r7   none)r   gh㈵>)r*   r+   r   r   r   )rt   r  rf  r   r9   r   floorreshaper  r   r   repeat_interleaver{   r<   r   r   l1_lossr   rB   r)   r   r   r   )r0   rv   rw   r   r   rY  r[  r   r  r?   rB   sequence_lengthr@   rA   reconstructed_pixel_valuesmasked_im_lossry   r   reconstruction_lossrS   r%   r%   r&   r\   z  sH   (	
 z$Swinv2ForMaskedImageModeling.forwardr  )r   r   r    rX   r   r  r   r)   r  r   r"   r#   r   r   r   r   r\   rb   r%   r%   rY   r&   r  [  s2    

r  z
    Swinv2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state
    of the [CLS] token) e.g. for ImageNet.
    c                       s   e Zd Z fddZeeeeee	e
d						ddeej deej deej dee d	ee d
ee deeef fddZ  ZS )Swinv2ForImageClassificationc                    sP   t  | |j| _t|| _|jdkrt| jj|jnt | _	| 
  d S r  )rW   rX   
num_labelsru  rf  r   r   rv  r  
classifierr{  )r0   rt   rY   r%   r&   rX     s   
"z%Swinv2ForImageClassification.__init__)r  r  rq  r  Nrv   r   labelsr   rY  r[  rJ   c                 C   sd  |dur|n| j j}| j|||||d}|d }| |}	d}
|dur| j jdu rM| jdkr3d| j _n| jdkrI|jtjksD|jtj	krId| j _nd| j _| j jdkrkt
 }| jdkre||	 | }
n+||	|}
n%| j jdkrt }||	d| j|d}
n| j jdkrt }||	|}
|s|	f|dd  }|
dur|
f| S |S t|
|	|j|j|jd	S )
a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        Nr  r   
regressionsingle_label_classificationmulti_label_classificationr7   r4   )r*   r1   r   r   r   )rt   r  rf  r  problem_typer  rK   r"   longr   r	   squeezer   r:   r   r3   r   r   r   )r0   rv   r   r  r   rY  r[  r   r  r1   r*   loss_fctrS   r%   r%   r&   r\     sN   


"


z$Swinv2ForImageClassification.forwardr  )r   r   r    rX   r   r  r   _IMAGE_CLASS_CHECKPOINTr3   r  _IMAGE_CLASS_EXPECTED_OUTPUTr   r"   r#   
LongTensorr   r   r   r\   rb   r%   r%   rY   r&   r    s<    	
r  )rF   F)Ir!   collections.abcr   r   r,   dataclassesr   typingr   r   r   r"   torch.utils.checkpointr   torch.nnr   r   r	   activationsr   modeling_utilsr   pytorch_utilsr   r   r   utilsr   r   r   r   r   r   configuration_swinv2r   
get_loggerr   loggerr  r  r  r  r  $SWINV2_PRETRAINED_MODEL_ARCHIVE_LISTr   r'   r)   r3   rD   rE   r`   r_   r   rT   r   rU   rc   rd   r   r   r   r   r  r	  r
  r<  rN  re  SWINV2_START_DOCSTRINGr  ru  r  r  r%   r%   r%   r&   <module>   s    
	 #,$ +/6 / Dea
h