o
    h8                     @   s   d Z ddlZddlZddlmZ ddlmZmZm	Z	m
Z
mZmZ ddlZddlZddlmZmZ ddlmZ ddlmZmZ dd	lmZ dd
lmZmZ ddlmZmZmZmZm Z m!Z!m"Z"m#Z#m$Z$ ddl%m&Z& e  rqddl'm(Z( e! rzddl)m*Z* e"+e,Z-dZ.dZ/g dZ0dgZ1eG dd deZ2G dd dej3Z4G dd dej3Z5G dd dej3Z6G dd dej3Z7G dd dej3Z8G dd  d ej3Z9G d!d" d"ej3Z:G d#d$ d$ej3Z;G d%d& d&ej3Z<G d'd( d(ej3Z=G d)d* d*ej3Z>G d+d, d,eZ?d-Z@d.ZAed/e@G d0d1 d1e?ZBG d2d3 d3ej3ZCed4e@G d5d6 d6e?ZDd7d8 ZEdWd;eFd<eFfd=d>ZGG d?d@ d@ej3ZHG dAdB dBej3ZIG dCdD dDej3ZJdEedFefdGdHZKdIedFefdJdKZLdLdM ZMdNdO ZNdPdQ ZOG dRdS dSePZQdTee fdUdVZRdS )Xz PyTorch YOLOS model.    N)	dataclass)DictListOptionalSetTupleUnion)Tensornn   )ACT2FN)BaseModelOutputBaseModelOutputWithPooling)PreTrainedModel) find_pruneable_heads_and_indicesprune_linear_layer)	ModelOutputadd_code_sample_docstringsadd_start_docstrings%add_start_docstrings_to_model_forwardis_scipy_availableis_vision_availableloggingreplace_return_docstringsrequires_backends   )YolosConfiglinear_sum_assignment)center_to_corners_formatr   zhustvl/yolos-small)r   iI  i  c                   @   s   e Zd ZU dZdZeej ed< dZ	ee
 ed< dZejed< dZejed< dZeee
  ed< dZeej ed< dZeeej  ed	< dZeeej  ed
< dS )YolosObjectDetectionOutputaG
  
    Output type of [`YolosForObjectDetection`].

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
            Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
            bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
            scale-invariant IoU loss.
        loss_dict (`Dict`, *optional*):
            A dictionary containing the individual losses. Useful for logging.
        logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`):
            Classification logits (including no-object) for all queries.
        pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
            Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
            values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
            possible padding). You can use [`~YolosImageProcessor.post_process`] to retrieve the unnormalized bounding
            boxes.
        auxiliary_outputs (`list[Dict]`, *optional*):
            Optional, only returned when auxilary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
            and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
            `pred_boxes`) for each decoder layer.
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
            Sequence of hidden-states at the output of the last layer of the decoder of the model.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of
            the model at the output of each layer plus the optional initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
            the self-attention heads.
    Nloss	loss_dictlogits
pred_boxesauxiliary_outputslast_hidden_statehidden_states
attentions)__name__
__module____qualname____doc__r!   r   torchFloatTensor__annotations__r"   r   r#   r$   r%   r   r&   r'   r   r(    r0   r0   ^/var/www/html/ai/venv/lib/python3.10/site-packages/transformers/models/yolos/modeling_yolos.pyr    D   s   
 !r    c                       s@   e Zd ZdZdeddf fddZdejdejfdd	Z  Z	S )
YolosEmbeddingszT
    Construct the CLS token, detection tokens, position and patch embeddings.

    configreturnNc                    s   t    ttdd|j| _ttd|j|j| _	t
|| _| jj}ttd||j d |j| _t|j| _t|| _|| _d S Nr   )super__init__r
   	Parameterr-   zeroshidden_size	cls_tokennum_detection_tokensdetection_tokensYolosPatchEmbeddingspatch_embeddingsnum_patchesposition_embeddingsDropouthidden_dropout_probdropout$InterpolateInitialPositionEmbeddingsinterpolationr3   )selfr3   r@   	__class__r0   r1   r7   w   s   



zYolosEmbeddings.__init__pixel_valuesc                 C   s   |j \}}}}| |}| \}}}| j|dd}	| j|dd}
tj|	||
fdd}| | j	||f}|| }| 
|}|S )Nr   dim)shaper?   sizer;   expandr=   r-   catrF   rA   rD   )rG   rJ   
batch_sizenum_channelsheightwidth
embeddingsseq_len_
cls_tokensr=   rA   r0   r0   r1   forward   s   

zYolosEmbeddings.forward
r)   r*   r+   r,   r   r7   r-   r	   rZ   __classcell__r0   r0   rH   r1   r2   q   s    r2   c                       0   e Zd Zd fddZd	dejfddZ  ZS )
rE   r4   Nc                       t    || _d S Nr6   r7   r3   rG   r3   rH   r0   r1   r7         

z-InterpolateInitialPositionEmbeddings.__init__i   i@  c                 C   s  |d d dd d f }|d d d f }|d d | j j d d d f }|d d d| j j d d f }|dd}|j\}}}| j jd | j j | j jd | j j }	}
||||	|
}|\}}|| j j || j j }}tjj	|||fddd}|
ddd}tj|||fdd}|S )Nr   r      bicubicFrO   modealign_cornersrL   )r3   r<   	transposerN   
image_size
patch_sizeviewr
   
functionalinterpolateflattenr-   rQ   )rG   	pos_embedimg_sizecls_pos_embeddet_pos_embedpatch_pos_embedrR   r:   rW   patch_heightpatch_widthrT   rU   new_patch_heigthnew_patch_widthscale_pos_embedr0   r0   r1   rZ      s$     z,InterpolateInitialPositionEmbeddings.forwardr4   Nrc   r)   r*   r+   r7   r-   r	   rZ   r\   r0   r0   rH   r1   rE          rE   c                       r]   )
 InterpolateMidPositionEmbeddingsr4   Nc                    r^   r_   r`   ra   rH   r0   r1   r7      rb   z)InterpolateMidPositionEmbeddings.__init__rc   c                 C   sH  |d d d d dd d f }|d d d f }|d d d d | j j d d d f }|d d d d d| j j d d f }|dd}|j\}}}}	| j jd | j j | j jd | j j }
}||| ||
|}|\}}|| j j || j j }}tjj	|||fddd}|
ddd |||| |}tj|||fdd}|S )	Nr   r   rd   r   re   Frf   rL   )r3   r<   ri   rN   rj   rk   rl   r
   rm   rn   ro   
contiguousr-   rQ   )rG   rp   rq   rr   rs   rt   depthrR   r:   rW   ru   rv   rT   rU   new_patch_heightrx   ry   r0   r0   r1   rZ      s,   &&z(InterpolateMidPositionEmbeddings.forwardrz   r{   r|   r0   r0   rH   r1   r~      r}   r~   c                       s6   e Zd ZdZ fddZdejdejfddZ  ZS )r>   z
    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
    Transformer.
    c                    s   t    |j|j}}|j|j}}t|tjj	r|n||f}t|tjj	r)|n||f}|d |d  |d |d   }|| _|| _|| _|| _
tj||||d| _d S )Nr   r   )kernel_sizestride)r6   r7   rj   rk   rS   r:   
isinstancecollectionsabcIterabler@   r
   Conv2d
projection)rG   r3   rj   rk   rS   r:   r@   rH   r0   r1   r7      s   
 zYolosPatchEmbeddings.__init__rJ   r4   c                 C   s<   |j \}}}}|| jkrtd| |ddd}|S )NzeMake sure that the channel dimension of the pixel values match with the one set in the configuration.rd   r   )rN   rS   
ValueErrorr   ro   ri   )rG   rJ   rR   rS   rT   rU   rV   r0   r0   r1   rZ      s   
zYolosPatchEmbeddings.forward)	r)   r*   r+   r,   r7   r-   r	   rZ   r\   r0   r0   rH   r1   r>      s    r>   c                
       sv   e Zd Zdeddf fddZdejdejfddZ		dd
eej de	de
eejejf eej f fddZ  ZS )YolosSelfAttentionr3   r4   Nc                    s   t    |j|j dkr t|ds td|jf d|j d|j| _t|j|j | _| j| j | _t	j
|j| j|jd| _t	j
|j| j|jd| _t	j
|j| j|jd| _t	|j| _d S )Nr   embedding_sizezThe hidden size z4 is not a multiple of the number of attention heads .)bias)r6   r7   r:   num_attention_headshasattrr   intattention_head_sizeall_head_sizer
   Linearqkv_biasquerykeyvaluerB   attention_probs_dropout_probrD   ra   rH   r0   r1   r7      s   
zYolosSelfAttention.__init__xc                 C   s6   |  d d | j| jf }||}|ddddS )NrK   r   rd   r   r   )rO   r   r   rl   permute)rG   r   new_x_shaper0   r0   r1   transpose_for_scores  s   
z'YolosSelfAttention.transpose_for_scoresF	head_maskoutput_attentionsc                 C   s   |  |}| | |}| | |}| |}t||dd}|t| j	 }t
jj|dd}	| |	}	|d urA|	| }	t|	|}
|
dddd }
|
 d d | jf }|
|}
|rj|
|	f}|S |
f}|S )NrK   rL   r   rd   r   r   )r   r   r   r   r-   matmulri   mathsqrtr   r
   rm   softmaxrD   r   r   rO   r   rl   )rG   r'   r   r   mixed_query_layer	key_layervalue_layerquery_layerattention_scoresattention_probscontext_layernew_context_layer_shapeoutputsr0   r0   r1   rZ     s$   



zYolosSelfAttention.forwardNF)r)   r*   r+   r   r7   r-   r	   r   r   boolr   r   rZ   r\   r0   r0   rH   r1   r      s    r   c                       sF   e Zd ZdZdeddf fddZdejdejdejfd	d
Z  Z	S )YolosSelfOutputz
    The residual connection is defined in YolosLayer instead of here (as is the case with other models), due to the
    layernorm applied before each block.
    r3   r4   Nc                    s.   t    t|j|j| _t|j| _d S r_   )	r6   r7   r
   r   r:   denserB   rC   rD   ra   rH   r0   r1   r7   ?     
zYolosSelfOutput.__init__r'   input_tensorc                 C      |  |}| |}|S r_   r   rD   rG   r'   r   r0   r0   r1   rZ   D     

zYolosSelfOutput.forwardr[   r0   r0   rH   r1   r   9  s    $r   c                       s~   e Zd Zdeddf fddZdee ddfddZ			dd
ej	de
ej	 dedeeej	ej	f eej	 f fddZ  ZS )YolosAttentionr3   r4   Nc                    s*   t    t|| _t|| _t | _d S r_   )r6   r7   r   	attentionr   outputsetpruned_headsra   rH   r0   r1   r7   M  s   


zYolosAttention.__init__headsc                 C   s   t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S )Nr   r   rL   )lenr   r   r   r   r   r   r   r   r   r   r   r   union)rG   r   indexr0   r0   r1   prune_headsS  s   zYolosAttention.prune_headsFr'   r   r   c                 C   s4   |  |||}| |d |}|f|dd   }|S Nr   r   )r   r   )rG   r'   r   r   self_outputsattention_outputr   r0   r0   r1   rZ   e  s   zYolosAttention.forwardr   )r)   r*   r+   r   r7   r   r   r   r-   r	   r   r   r   r   rZ   r\   r0   r0   rH   r1   r   L  s    r   c                       s<   e Zd Zdeddf fddZdejdejfddZ  ZS )	YolosIntermediater3   r4   Nc                    sD   t    t|j|j| _t|jt	rt
|j | _d S |j| _d S r_   )r6   r7   r
   r   r:   intermediate_sizer   r   
hidden_actstrr   intermediate_act_fnra   rH   r0   r1   r7   u  s
   
zYolosIntermediate.__init__r'   c                 C   r   r_   )r   r   )rG   r'   r0   r0   r1   rZ   }  r   zYolosIntermediate.forward	r)   r*   r+   r   r7   r-   r	   rZ   r\   r0   r0   rH   r1   r   t  s    r   c                       sB   e Zd Zdeddf fddZdejdejdejfdd	Z  ZS )
YolosOutputr3   r4   Nc                    s.   t    t|j|j| _t|j| _	d S r_   )
r6   r7   r
   r   r   r:   r   rB   rC   rD   ra   rH   r0   r1   r7     r   zYolosOutput.__init__r'   r   c                 C   s    |  |}| |}|| }|S r_   r   r   r0   r0   r1   rZ     s   

zYolosOutput.forwardr   r0   r0   rH   r1   r     s    $r   c                       sl   e Zd ZdZdeddf fddZ		ddejd	eej d
e	de
eejejf eej f fddZ  ZS )
YolosLayerz?This corresponds to the Block class in the timm implementation.r3   r4   Nc                    sb   t    |j| _d| _t|| _t|| _t|| _	t
j|j|jd| _t
j|j|jd| _d S )Nr   eps)r6   r7   chunk_size_feed_forwardseq_len_dimr   r   r   intermediater   r   r
   	LayerNormr:   layer_norm_epslayernorm_beforelayernorm_afterra   rH   r0   r1   r7     s   



zYolosLayer.__init__Fr'   r   r   c                 C   s`   | j | |||d}|d }|dd  }|| }| |}| |}| ||}|f| }|S )N)r   r   r   )r   r   r   r   r   )rG   r'   r   r   self_attention_outputsr   r   layer_outputr0   r0   r1   rZ     s   


zYolosLayer.forwardr   )r)   r*   r+   r,   r   r7   r-   r	   r   r   r   r   rZ   r\   r0   r0   rH   r1   r     s    r   c                       sb   e Zd Zdeddf fddZ				ddejd	eej d
ededede	e
ef fddZ  ZS )YolosEncoderr3   r4   Nc                    s   t     | _t fddt jD | _d| _d j	d  j	d   j
d    j } jrAtt jd d| jnd | _ jrNt | _d S d | _d S )Nc                    s   g | ]}t  qS r0   )r   ).0rX   r3   r0   r1   
<listcomp>      z)YolosEncoder.__init__.<locals>.<listcomp>Fr   r   rd   )r6   r7   r3   r
   
ModuleListrangenum_hidden_layerslayergradient_checkpointingrj   rk   r<   use_mid_position_embeddingsr8   r-   r9   r:   mid_position_embeddingsr~   rF   )rG   r3   
seq_lengthrH   r   r1   r7     s$   
 &	zYolosEncoder.__init__FTr'   r   r   output_hidden_statesreturn_dictc                 C   s
  |rdnd }|r
dnd }	| j jr| | j||f}
t| jD ]J\}}|r)||f }|d ur1|| nd }| jrC| jrC| |j	|||}n||||}|d }| j jr_|| j j
d k r_||
|  }|rh|	|d f }	q|rp||f }|s~tdd |||	fD S t|||	dS )Nr0   r   r   c                 s   s    | ]	}|d ur|V  qd S r_   r0   r   vr0   r0   r1   	<genexpr>      z'YolosEncoder.forward.<locals>.<genexpr>)r&   r'   r(   )r3   r   rF   r   	enumerater   r   training_gradient_checkpointing_func__call__r   tupler   )rG   r'   rT   rU   r   r   r   r   all_hidden_statesall_self_attentions$interpolated_mid_position_embeddingsilayer_modulelayer_head_masklayer_outputsr0   r0   r1   rZ     s@   


zYolosEncoder.forward)NFFT)r)   r*   r+   r   r7   r-   r	   r   r   r   r   r   rZ   r\   r0   r0   rH   r1   r     s&    
	r   c                   @   sB   e Zd ZdZeZdZdZdZde	e
je
je
jf ddfdd	ZdS )
YolosPreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    vitrJ   Tmoduler4   Nc                 C   st   t |tjtjfr#|jjjd| jjd |j	dur!|j	j
  dS dS t |tjr8|j	j
  |jjd dS dS )zInitialize the weightsg        )meanstdNg      ?)r   r
   r   r   weightdatanormal_r3   initializer_ranger   zero_r   fill_)rG   r   r0   r0   r1   _init_weights  s   
z"YolosPreTrainedModel._init_weights)r)   r*   r+   r,   r   config_classbase_model_prefixmain_input_namesupports_gradient_checkpointingr   r
   r   r   r   r  r0   r0   r0   r1   r     s    &r   aG  
    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
    as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.

    Parameters:
        config ([`YolosConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
aM  
    Args:
        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
            [`YolosImageProcessor.__call__`] for details.

        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
            Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
z_The bare YOLOS Model transformer outputting raw hidden-states without any specific head on top.c                       s   e Zd Zddedef fddZdefddZd	ee	e
e	 f dd
fddZeeeeeeded	
	
	
	
	
ddeej deej dee dee dee deeef fddZ  ZS )
YolosModelTr3   add_pooling_layerc                    sX   t  | || _t|| _t|| _tj|j	|j
d| _|r#t|nd | _|   d S )Nr   )r6   r7   r3   r2   rV   r   encoderr
   r   r:   r   	layernormYolosPoolerpooler	post_init)rG   r3   r  rH   r0   r1   r7   L  s   

zYolosModel.__init__r4   c                 C   s   | j jS r_   )rV   r?   rG   r0   r0   r1   get_input_embeddingsY  s   zYolosModel.get_input_embeddingsheads_to_pruneNc                 C   s*   |  D ]\}}| jj| j| qdS )z
        Prunes heads of the model.

        Args:
            heads_to_prune (`dict` of {layer_num: list of heads to prune in this layer}):
                See base class `PreTrainedModel`.
        N)itemsr	  r   r   r   )rG   r  r   r   r0   r0   r1   _prune_heads\  s   zYolosModel._prune_headsvision)
checkpointoutput_typer  modalityexpected_outputrJ   r   r   r   r   c              	   C   s   |d ur|n| j j}|d ur|n| j j}|d ur|n| j j}|d u r&td| || j j}| |}| j||j	d |j	d ||||d}|d }| 
|}| jd urX| |nd }	|so|	d urd||	fn|f}
|
|dd   S t||	|j|jdS )Nz You have to specify pixel_valuesr   rK   )rT   rU   r   r   r   r   r   r   )r&   pooler_outputr'   r(   )r3   r   r   use_return_dictr   get_head_maskr   rV   r	  rN   r
  r  r   r'   r(   )rG   rJ   r   r   r   r   embedding_outputencoder_outputssequence_outputpooled_outputhead_outputsr0   r0   r1   rZ   g  s:   
	
zYolosModel.forward)T)NNNNN)r)   r*   r+   r   r   r7   r>   r  r   r   r   r  r   YOLOS_INPUTS_DOCSTRINGr   _CHECKPOINT_FOR_DOCr   _CONFIG_FOR_DOC_EXPECTED_OUTPUT_SHAPEr   r-   r	   r   r   rZ   r\   r0   r0   rH   r1   r  G  s<    	
r  c                       s*   e Zd Zdef fddZdd Z  ZS )r  r3   c                    s*   t    t|j|j| _t | _d S r_   )r6   r7   r
   r   r:   r   Tanh
activationra   rH   r0   r1   r7     s   
zYolosPooler.__init__c                 C   s(   |d d df }|  |}| |}|S )Nr   )r   r%  )rG   r'   first_token_tensorr  r0   r0   r1   rZ     s   

zYolosPooler.forward)r)   r*   r+   r   r7   rZ   r\   r0   r0   rH   r1   r    s    r  zy
    YOLOS Model (consisting of a ViT encoder) with object detection heads on top, for tasks such as COCO detection.
    c                       s   e Zd Zdef fddZejjdd Ze	e
eeed				ddejd	eee  d
ee dee dee deeef fddZ  ZS )YolosForObjectDetectionr3   c                    sX   t  | t|dd| _t|j|j|jd dd| _t|j|jddd| _| 	  d S )NF)r  r   r   )	input_dim
hidden_dim
output_dim
num_layers   )
r6   r7   r  r   YolosMLPPredictionHeadr:   
num_labelsclass_labels_classifierbbox_predictorr  ra   rH   r0   r1   r7     s   z YolosForObjectDetection.__init__c                 C   s$   dd t |d d |d d D S )Nc                 S   s   g | ]	\}}||d qS ))r#   r$   r0   )r   abr0   r0   r1   r     s    z9YolosForObjectDetection._set_aux_loss.<locals>.<listcomp>rK   )zip)rG   outputs_classoutputs_coordr0   r0   r1   _set_aux_loss  s   $z%YolosForObjectDetection._set_aux_loss)r  r  NrJ   labelsr   r   r   r4   c              
      s   |dur|n| j j}| j||||d}|d }|dd| j j dddf }| |}| | }	d\}
}|durt| j j| j j	| j j
d}g d}t|| j j| j j|d}|| j i }||d< |	|d	< | j jr|rr|jn|d
 }| |}| | }| ||}||d< |||d| j jd| j jd< | j jri }t| j jd D ] | fdd D  q| tfdd D }
|s|dur||	f| | }n||	f| }|
dur|
f| S |S t|
||	||j|j|jdS )a  
        labels (`List[Dict]` of len `(batch_size,)`, *optional*):
            Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
            following 2 keys: `'class_labels'` and `'boxes'` (the class labels and bounding boxes of an image in the
            batch respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding
            boxes in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image,
            4)`.

        Returns:

        Examples:

        ```python
        >>> from transformers import AutoImageProcessor, AutoModelForObjectDetection
        >>> import torch
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> image_processor = AutoImageProcessor.from_pretrained("hustvl/yolos-tiny")
        >>> model = AutoModelForObjectDetection.from_pretrained("hustvl/yolos-tiny")

        >>> inputs = image_processor(images=image, return_tensors="pt")
        >>> outputs = model(**inputs)

        >>> # convert outputs (bounding boxes and class logits) to COCO API
        >>> target_sizes = torch.tensor([image.size[::-1]])
        >>> results = image_processor.post_process_object_detection(outputs, threshold=0.9, target_sizes=target_sizes)[
        ...     0
        ... ]

        >>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
        ...     box = [round(i, 2) for i in box.tolist()]
        ...     print(
        ...         f"Detected {model.config.id2label[label.item()]} with confidence "
        ...         f"{round(score.item(), 3)} at location {box}"
        ...     )
        Detected remote with confidence 0.994 at location [46.96, 72.61, 181.02, 119.73]
        Detected remote with confidence 0.975 at location [340.66, 79.19, 372.59, 192.65]
        Detected cat with confidence 0.984 at location [12.27, 54.25, 319.42, 470.99]
        Detected remote with confidence 0.922 at location [41.66, 71.96, 178.7, 120.33]
        Detected cat with confidence 0.914 at location [342.34, 21.48, 638.64, 372.46]
        ```N)r   r   r   r   )NNN)
class_cost	bbox_cost	giou_cost)r7  boxescardinality)matchernum_classeseos_coeflossesr#   r$   r,  r%   r   )loss_ce	loss_bbox	loss_giouc                        i | ]\}}|d    |qS rX   r0   r   kr   r   r0   r1   
<dictcomp>=       z3YolosForObjectDetection.forward.<locals>.<dictcomp>c                 3   s(    | ]}|v r | |  V  qd S r_   r0   )r   rG  )r"   weight_dictr0   r1   r   ?  s   & z2YolosForObjectDetection.forward.<locals>.<genexpr>)r!   r"   r#   r$   r%   r&   r'   r(   )r3   r  r   r<   r/  r0  sigmoidYolosHungarianMatcherr8  r9  r:  	YolosLossr.  eos_coefficienttodeviceauxiliary_lossintermediate_hidden_statesr6  bbox_loss_coefficientgiou_loss_coefficientr   decoder_layersupdater  sumkeysr    r&   r'   r(   )rG   rJ   r7  r   r   r   r   r  r#   r$   r!   r%   r=  r@  	criterionoutputs_lossr   r4  r5  aux_weight_dictr   r0   )r   r"   rK  r1   rZ     sr   7 




zYolosForObjectDetection.forward)NNNN)r)   r*   r+   r   r7   r-   jitunusedr6  r   r   r   r    r"  r.   r   r   r   r   r   r   rZ   r\   r0   r0   rH   r1   r'    s.    



r'  c                 C   sX   |   } | d} d| | d }| d|d }d|d |d   }| | S )a  
    Compute the DICE loss, similar to generalized IOU for masks

    Args:
        inputs: A float tensor of arbitrary shape.
                The predictions for each example.
        targets: A float tensor with the same shape as inputs. Stores the binary
                 classification label for each element in inputs (0 for the negative class and 1 for the positive
                 class).
    r   rd   rK   )rL  ro   rX  )inputstargets	num_boxes	numeratordenominatorr!   r0   r0   r1   	dice_lossU  s   
rd        ?rd   alphagammac           
      C   s|   |   }tjj| |dd}|| d| d|   }|d| |  }|dkr5|| d| d|   }	|	| }|d | S )a  
    Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.

    Args:
        inputs (`torch.FloatTensor` of arbitrary shape):
            The predictions for each example.
        targets (`torch.FloatTensor` with the same shape as `inputs`)
            A tensor storing the binary classification label for each element in the `inputs` (0 for the negative class
            and 1 for the positive class).
        alpha (`float`, *optional*, defaults to `0.25`):
            Optional weighting factor in the range (0,1) to balance positive vs. negative examples.
        gamma (`int`, *optional*, defaults to `2`):
            Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples.

    Returns:
        Loss tensor
    none	reductionr   r   )rL  r
   rm    binary_cross_entropy_with_logitsr   rX  )
r_  r`  ra  rf  rg  probce_lossp_tr!   alpha_tr0   r0   r1   sigmoid_focal_lossi  s   rp  c                       sh   e Zd ZdZ fddZdd Ze dd Zdd	 Z	d
d Z
dd Zdd Zdd Zdd Z  ZS )rN  a  
    This class computes the losses for YolosForObjectDetection/YolosForSegmentation. The process happens in two steps: 1)
    we compute hungarian assignment between ground truth boxes and the outputs of the model 2) we supervise each pair
    of matched ground-truth / prediction (supervise class and box).

    A note on the `num_classes` argument (copied from original repo in detr.py): "the naming of the `num_classes`
    parameter of the criterion is somewhat misleading. It indeed corresponds to `max_obj_id` + 1, where `max_obj_id` is
    the maximum id for a class in your dataset. For example, COCO has a `max_obj_id` of 90, so we pass `num_classes` to
    be 91. As another example, for a dataset that has a single class with `id` 1, you should pass `num_classes` to be 2
    (`max_obj_id` + 1). For more details on this, check the following discussion
    https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223"


    Args:
        matcher (`YolosHungarianMatcher`):
            Module able to compute a matching between targets and proposals.
        num_classes (`int`):
            Number of object categories, omitting the special no-object category.
        eos_coef (`float`):
            Relative classification weight applied to the no-object category.
        losses (`List[str]`):
            List of all the losses to be applied. See `get_loss` for a list of all available losses.
    c                    sL   t    || _|| _|| _|| _t| jd }| j|d< | d| d S )Nr   rK   empty_weight)	r6   r7   r=  r>  r?  r@  r-   onesregister_buffer)rG   r=  r>  r?  r@  rq  rH   r0   r1   r7     s   

zYolosLoss.__init__c                 C   s   d|vrt d|d }| |}tdd t||D }tj|jdd | jtj|j	d}|||< t
j|dd|| j}	d	|	i}
|
S )
z
        Classification loss (NLL) targets dicts must contain the key "class_labels" containing a tensor of dim
        [nb_target_boxes]
        r#   z#No logits were found in the outputsc                 S       g | ]\}\}}|d  | qS class_labelsr0   )r   trX   Jr0   r0   r1   r     rJ  z)YolosLoss.loss_labels.<locals>.<listcomp>Nrd   dtyperQ  r   rA  )KeyError_get_source_permutation_idxr-   rQ   r3  fullrN   r>  int64rQ  r
   rm   cross_entropyri   rq  )rG   r   r`  indicesra  source_logitsidxtarget_classes_otarget_classesrA  r@  r0   r0   r1   loss_labels  s   
zYolosLoss.loss_labelsc                 C   sf   |d }|j }tjdd |D |d}|d|jd d kd}tj|	 |	 }	d|	i}
|
S )z
        Compute the cardinality error, i.e. the absolute error in the number of predicted non-empty boxes.

        This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients.
        r#   c                 S      g | ]}t |d  qS ru  r   r   r0   r0   r1   r         z.YolosLoss.loss_cardinality.<locals>.<listcomp>)rQ  rK   r   cardinality_error)
rQ  r-   	as_tensorargmaxrN   rX  r
   rm   l1_lossfloat)rG   r   r`  r  ra  r#   rQ  target_lengths	card_predcard_errr@  r0   r0   r1   loss_cardinality  s   zYolosLoss.loss_cardinalityc                 C   s   d|vrt d| |}|d | }tjdd t||D dd}tjj||dd}i }	| | |	d	< d
t	t
t|t| }
|
 | |	d< |	S )a<  
        Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss.

        Targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]. The target boxes
        are expected in format (center_x, center_y, w, h), normalized by the image size.
        r$   z#No predicted boxes found in outputsc                 S   rt  r;  r0   )r   rw  rX   r   r0   r0   r1   r     rJ  z(YolosLoss.loss_boxes.<locals>.<listcomp>r   rL   rh  ri  rB  r   rC  )r{  r|  r-   rQ   r3  r
   rm   r  rX  diaggeneralized_box_iour   )rG   r   r`  r  ra  r  source_boxestarget_boxesrB  r@  rC  r0   r0   r1   
loss_boxes  s   
zYolosLoss.loss_boxesc                 C   s   d|vrt d| |}| |}|d }|| }dd |D }t| \}	}
|	|}	|	| }	tjj|dddf |	j	dd ddd	}|ddd
f 
d}|	
d}	|	|j	}	t||	|t||	|d}|S )z
        Compute the losses related to the masks: the focal loss and the dice loss.

        Targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w].
        
pred_masksz#No predicted masks found in outputsc                 S      g | ]}|d  qS )masksr0   r   rw  r0   r0   r1   r     r   z(YolosLoss.loss_masks.<locals>.<listcomp>Nr   bilinearFrf   r   r   )	loss_mask	loss_dice)r{  r|  _get_target_permutation_idxnested_tensor_from_tensor_list	decomposerP  r
   rm   rn   rN   ro   rl   rp  rd  )rG   r   r`  r  ra  
source_idx
target_idxsource_masksr  target_masksvalidr@  r0   r0   r1   
loss_masks  s(   





zYolosLoss.loss_masksc                 C   4   t dd t|D }t dd |D }||fS )Nc                 S   s    g | ]\}\}}t ||qS r0   r-   	full_like)r   r   sourcerX   r0   r0   r1   r     rJ  z9YolosLoss._get_source_permutation_idx.<locals>.<listcomp>c                 S   s   g | ]\}}|qS r0   r0   )r   r  rX   r0   r0   r1   r     r   r-   rQ   r   )rG   r  	batch_idxr  r0   r0   r1   r|       z%YolosLoss._get_source_permutation_idxc                 C   r  )Nc                 S   s    g | ]\}\}}t ||qS r0   r  )r   r   rX   targetr0   r0   r1   r     rJ  z9YolosLoss._get_target_permutation_idx.<locals>.<listcomp>c                 S   s   g | ]\}}|qS r0   r0   )r   rX   r  r0   r0   r1   r     r   r  )rG   r  r  r  r0   r0   r1   r    r  z%YolosLoss._get_target_permutation_idxc                 C   s@   | j | j| j| jd}||vrtd| d|| ||||S )N)r7  r<  r;  r  zLoss z not supported)r  r  r  r  r   )rG   r!   r   r`  r  ra  loss_mapr0   r0   r1   get_loss  s   zYolosLoss.get_lossc           
   
      s  dd |  D }| ||}tdd |D }tj|gtjtt| j	d}tj
|dd }i }| jD ]}|| ||||| q7d|v rt|d D ].\ }| ||}| jD ] }|d	krdq]| |||||}	 fd
d|	  D }	||	 q]qP|S )a  
        This performs the loss computation.

        Args:
             outputs (`dict`, *optional*):
                Dictionary of tensors, see the output specification of the model for the format.
             targets (`List[dict]`, *optional*):
                List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the
                losses applied, see each loss' doc.
        c                 S   s   i | ]\}}|d kr||qS )r%   r0   rF  r0   r0   r1   rI  -  s    z%YolosLoss.forward.<locals>.<dictcomp>c                 s   s    | ]	}t |d  V  qdS )rv  Nr  r  r0   r0   r1   r   3  r   z$YolosLoss.forward.<locals>.<genexpr>ry  r   minr%   r  c                    rD  rE  r0   rF  rH  r0   r1   rI  I  rJ  )r  r=  rX  r-   r  r  nextitervaluesrQ  clampitemr@  rW  r  r   )
rG   r   r`  outputs_without_auxr  ra  r@  r!   r%   l_dictr0   rH  r1   rZ   "  s&   "

zYolosLoss.forward)r)   r*   r+   r,   r7   r  r-   no_gradr  r  r  r|  r  r  rZ   r\   r0   r0   rH   r1   rN    s    
!rN  c                       s(   e Zd ZdZ fddZdd Z  ZS )r-  a  
    Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
    height and width of a bounding box w.r.t. an image.

    Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py

    c                    sJ   t    || _|g|d  }tdd t|g| ||g D | _d S )Nr   c                 s   s     | ]\}}t ||V  qd S r_   )r
   r   )r   nrG  r0   r0   r1   r   ]  s    z2YolosMLPPredictionHead.__init__.<locals>.<genexpr>)r6   r7   r+  r
   r   r3  layers)rG   r(  r)  r*  r+  hrH   r0   r1   r7   Y  s   
,zYolosMLPPredictionHead.__init__c                 C   s>   t | jD ]\}}|| jd k rtj||n||}q|S r5   )r   r  r+  r
   rm   relu)rG   r   r   r   r0   r0   r1   rZ   _  s   (zYolosMLPPredictionHead.forward)r)   r*   r+   r,   r7   rZ   r\   r0   r0   rH   r1   r-  P  s    r-  c                       s@   e Zd ZdZd
dededef fddZe dd	 Z  Z	S )rM  a  
    This class computes an assignment between the targets and the predictions of the network.

    For efficiency reasons, the targets don't include the no_object. Because of this, in general, there are more
    predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, while the others are
    un-matched (and thus treated as non-objects).

    Args:
        class_cost:
            The relative weight of the classification error in the matching cost.
        bbox_cost:
            The relative weight of the L1 error of the bounding box coordinates in the matching cost.
        giou_cost:
            The relative weight of the giou loss of the bounding box in the matching cost.
    r   r8  r9  r:  c                    sT   t    t| dg || _|| _|| _|dkr$|dkr&|dkr(tdd S d S d S )Nscipyr   z#All costs of the Matcher can't be 0)r6   r7   r   r8  r9  r:  r   )rG   r8  r9  r:  rH   r0   r1   r7   w  s   
zYolosHungarianMatcher.__init__c                 C   s   |d j dd \}}|d ddd}|d dd}tdd	 |D }td
d	 |D }|dd|f  }	tj||dd}
tt|t| }| j|
 | j	|	  | j
|  }|||d }dd	 |D }dd	 t||dD }dd	 |D S )a  
        Args:
            outputs (`dict`):
                A dictionary that contains at least these entries:
                * "logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
                * "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates.
            targets (`List[dict]`):
                A list of targets (len(targets) = batch_size), where each target is a dict containing:
                * "class_labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of
                  ground-truth
                 objects in the target) containing the class labels
                * "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates.

        Returns:
            `List[Tuple]`: A list of size `batch_size`, containing tuples of (index_i, index_j) where:
            - index_i is the indices of the selected predictions (in order)
            - index_j is the indices of the corresponding selected targets (in order)
            For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
        r#   Nrd   r   r   rK   r$   c                 S   r  ru  r0   r   r0   r0   r1   r     r   z1YolosHungarianMatcher.forward.<locals>.<listcomp>c                 S   r  r  r0   r   r0   r0   r1   r     r   )pc                 S   r  r  r  r   r0   r0   r1   r     r  c                 S   s   g | ]
\}}t || qS r0   r   )r   r   cr0   r0   r1   r     s    c                 S   s0   g | ]\}}t j|t jd t j|t jd fqS ))rz  )r-   r  r~  )r   r   jr0   r0   r1   r     s   0 )rN   ro   r   r-   rQ   cdistr  r   r9  r8  r:  rl   cpur   split)rG   r   r`  rR   num_queriesout_probout_bbox
target_idstarget_bboxr8  r9  r:  cost_matrixsizesr  r0   r0   r1   rZ     s   zYolosHungarianMatcher.forward)r   r   r   )
r)   r*   r+   r,   r  r7   r-   r  rZ   r\   r0   r0   rH   r1   rM  f  s
    
rM  rw  r4   c                 C   sD   |   r| jtjtjfv r| S |  S | jtjtjfv r| S |  S r_   )	is_floating_pointrz  r-   float32float64r  int32r~  r   )rw  r0   r0   r1   _upcast  s   r  r;  c                 C   sH   t | } | dddf | dddf  | dddf | dddf   S )a  
    Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates.

    Args:
        boxes (`torch.FloatTensor` of shape `(number_of_boxes, 4)`):
            Boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with `0 <= x1
            < x2` and `0 <= y1 < y2`.

    Returns:
        `torch.FloatTensor`: a tensor containing the area for each box.
    Nrd   r   r   r   )r  r  r0   r0   r1   box_area  s   @r  c           
      C   s   t | }t |}t| d d d d df |d d d df }t| d d d dd f |d d dd f }|| jdd}|d d d d df |d d d d df  }|d d d f | | }|| }	|	|fS )Nrd   r   r  r   )r  r-   maxr  r  )
boxes1boxes2area1area2left_topright_bottomwidth_heightinterr   iour0   r0   r1   box_iou  s   ..,r  c                 C   s*  | ddddf | ddddf k  std|  |ddddf |ddddf k  s:td| t| |\}}t| dddddf |ddddf }t| dddddf |ddddf }|| jdd}|dddddf |dddddf  }||| |  S )z
    Generalized IoU from https://giou.stanford.edu/. The boxes should be in [x0, y0, x1, y1] (corner) format.

    Returns:
        `torch.FloatTensor`: a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2)
    Nrd   z<boxes1 must be in [x0, y0, x1, y1] (corner) format, but got z<boxes2 must be in [x0, y0, x1, y1] (corner) format, but got r   r  r   )allr   r  r-   r  r  r  )r  r  r  r   top_leftbottom_rightr  arear0   r0   r1   r    s   ,	,..,r  c                 C   sB   | d }| dd  D ]}t |D ]\}}t|| |||< qq
|S r   )r   r  )the_listmaxessublistr   r  r0   r0   r1   _max_by_axis  s   r  c                   @   s6   e Zd Zdee fddZdd Zdd Zdd	 Zd
S )NestedTensormaskc                 C   s   || _ || _d S r_   tensorsr  )rG   r  r  r0   r0   r1   r7     s   
zNestedTensor.__init__c                 C   s4   | j |}| j}|d ur||}nd }t||S r_   )r  rP  r  r  )rG   rQ  cast_tensorr  	cast_maskr0   r0   r1   rP    s   
zNestedTensor.toc                 C   s   | j | jfS r_   r  r  r0   r0   r1   r    s   zNestedTensor.decomposec                 C   s
   t | jS r_   )r   r  r  r0   r0   r1   __repr__  s   
zNestedTensor.__repr__N)	r)   r*   r+   r   r	   r7   rP  r  r  r0   r0   r0   r1   r    s
    	r  tensor_listc                 C   s   | d j dkrqtdd | D }t| g| }|\}}}}| d j}| d j}tj|||d}	tj|||ftj|d}
t	| |	|
D ].\}}}|d |j
d d |j
d d |j
d f | d|d |j
d d |j
d f< qAntd	t|	|
S )
Nr   r   c                 S   s   g | ]}t |jqS r0   )listrN   )r   imgr0   r0   r1   r     s    z2nested_tensor_from_tensor_list.<locals>.<listcomp>ry  r   rd   Fz(Only 3-dimensional tensors are supported)ndimr  r   rz  rQ  r-   r9   rr  r   r3  rN   copy_r   r  )r  max_sizebatch_shaperR   rS   rT   rU   rz  rQ  tensorr  r  pad_imgmr0   r0   r1   r    s   

2"
r  )re  rd   )Sr,   collections.abcr   r   dataclassesr   typingr   r   r   r   r   r   r-   torch.utils.checkpointr	   r
   activationsr   modeling_outputsr   r   modeling_utilsr   pytorch_utilsr   r   utilsr   r   r   r   r   r   r   r   r   configuration_yolosr   scipy.optimizer   transformers.image_transformsr   
get_loggerr)   loggerr"  r!  r#  #YOLOS_PRETRAINED_MODEL_ARCHIVE_LISTr    Moduler2   rE   r~   r>   r   r   r   r   r   r   r   r   YOLOS_START_DOCSTRINGr   r  r  r'  rd  r  rp  rN  r-  rM  r  r  r  r  r  objectr  r  r0   r0   r0   r1   <module>   s    ,
,+!"=(*NW   HO	
