o
    h                     @   s  d Z ddlZddlmZ ddlmZmZ ddlZddlZddlm	Z	 ddl
mZ ddlmZmZmZ dd	lmZmZmZ dd
lmZ ddlmZ ddlmZ ddlmZ ddlmZ eeZ ddgZ!eG dd deZ"G dd de	j#Z$G dd de	j#Z%G dd de	j#Z&G dd de	j#Z'G dd de	j#Z(G dd  d e	j#Z)G d!d" d"e	j#Z*G d#d$ d$e	j#Z+G d%d& d&e	j#Z,G d'd( d(e	j#Z-G d)d* d*eZ.d+Z/d,Z0G d-d. d.e	j#Z1G d/d0 d0e	j#Z2e1e2d1Z3ed2e/G d3d4 d4e.Z4G d5d6 d6e	j#Z5ed7e/G d8d9 d9e.Z6dS ):zPyTorch TVP Model    N)	dataclass)OptionalTuple)nn   )ACT2FN)add_start_docstrings%add_start_docstrings_to_model_forwardreplace_return_docstrings)BaseModelOutputBaseModelOutputWithPoolingModelOutput)PreTrainedModel)prune_linear_layer)logging   )AutoBackbone   )	TvpConfigzIntel/tvp-basezIntel/tvp-base-ANetc                   @   s^   e Zd ZU dZdZeej ed< dZ	ejed< dZ
eeej  ed< dZeeej  ed< dS )TvpVideoGroundingOutputa  
    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
            Temporal-Distance IoU loss for video grounding.
        logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
            Contains start_time/duration and end_time/duration. It is the time slot of the videos corresponding to the
            input texts.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of
            the model at the output of each layer plus the optional initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.
    Nlosslogitshidden_states
attentions)__name__
__module____qualname____doc__r   r   torchFloatTensor__annotations__r   r   r   r    r!   r!   Z/var/www/html/ai/venv/lib/python3.10/site-packages/transformers/models/tvp/modeling_tvp.pyr   ,   s   
 r   c                       s@   e Zd ZdZ fddZdd Zdd Zdd	 Zd
d Z  Z	S )TvpLossa~  
    This class computes the losses for `TvpForVideoGrounding`. The process happens in two steps: 1) we compute
    hungarian assignment between ground truth boxes and the outputs of the model 2) we supervise each pair of matched
    ground-truth / prediction (supervise class and box).

    Args:
        losses (`List[str]`):
            List of all the losses to be applied.
    c                    sL   t    | j| j| jd| _|D ]}|| jvr td| dq|| _d S )NioudistancedurationzLoss z not supported)super__init__loss_iouloss_distanceloss_durationloss_map
ValueErrorlosses)selfr/   r   	__class__r!   r"   r)   O   s   


zTvpLoss.__init__c           	      C   sH   t ||t || }t ||t || }d|jdd|  }|S )z6
        Measure the intersection over union.
        r   r   min)r   r4   maxclamp)	r0   
start_timeend_timecandidates_start_timecandidates_end_timer'   interunionr%   r!   r!   r"   r*   \   s   zTvpLoss.loss_iouc           	      C   sT   t t ||d}t t ||d}t t ||t || |jdd}|S )z5
        Measure the distance of mid points.
        g       @g?r3   )r   divaddr5   r4   r6   )	r0   r7   r8   r9   r:   r'   mid_candidatesmid_groundtruthdistance_diffr!   r!   r"   r+   f   s   zTvpLoss.loss_distancec           	      C   sB   t ||}t ||}t t t |||}|jdd}|S )z5
        Measure the difference of duration.
        g?r3   )r   subsquarer=   r6   )	r0   r7   r8   r9   r:   r'   duration_candidatesduration_groundtruthduration_diffr!   r!   r"   r,   r   s
   zTvpLoss.loss_durationc              
   C   st   |\}}}t ||}|dddf  |dddf  }}i }	| jD ]}
|	|
| j|
 |||||i q%|	S )am  
        This performs the loss computation.

        Args:
            logits (`torch.FloatTensor`):
                The output logits of head module.
            labels (`List[torch.FloatTensor]`):
                List of tensors ([start, end, duration]), which contains start time, end time of the video corresponding to the text, and also the duration.
        Nr   r   )r   mulfloatr/   updater-   )r0   r   labelsr'   r7   r8   
candidatesr9   r:   losses_dictr   r!   r!   r"   forward}   s   

*
zTvpLoss.forward)
r   r   r   r   r)   r*   r+   r,   rM   __classcell__r!   r!   r1   r"   r#   D   s    

r#   c                       $   e Zd Z fddZdd Z  ZS )TvpVisionModelc              	      s@   t    t|j| _tj|jjd |j	dddddd| _
d S )Nr   r   F)kernel_sizestridepaddinggroupsbias)r(   r)   r   from_configbackbone_configbackboner   Conv2dhidden_sizeshidden_sizegrid_encoder_convr0   configr1   r!   r"   r)      s   

zTvpVisionModel.__init__c                 C   s   |j \}}}}}||| |||}| |d d }| |}tjj|ddd}tjj|dd}|j dd  \}	}
}||||	|
|}|ddd	d
d}|S )Nfeature_mapsr   r   )rR   rS   T)inplacer   r      )	shapeviewrY   r]   r   
functional
max_pool2drelupermute)r0   pixel_values
batch_size
num_framesnum_channelsheightwidthgrid_feat_outputsgridnew_channel
new_height	new_widthr!   r!   r"   rM      s   
zTvpVisionModel.forwardr   r   r   r)   rM   rN   r!   r!   r1   r"   rP      s    rP   c                       s0   e Zd ZdZ fddZdd Zdd Z  ZS )TvpVisualInputEmbeddingz;
    Takes input of both image and video (multi-frame)
    c                    sv   t    t|j|j| _t|j|j| _t|j	|j| _
td|j| _tj|j|jd| _t|j| _d S )Nr   eps)r(   r)   r   	Embeddingmax_position_embeddingsr\   position_embeddings max_grid_row_position_embeddingsrow_position_embeddings max_grid_col_position_embeddingscol_position_embeddingstoken_type_embeddings	LayerNormlayer_norm_eps
layer_normDropouthidden_dropout_probdropoutr^   r1   r!   r"   r)      s   
z TvpVisualInputEmbedding.__init__c                 C   s   |j \}}}}tj|tj|jd}| |}dt|j d  |d|f }||j|  }tj|tj|jd}	| |	}
|d||f}||
j|  S )z
        Args:
            grid: (batch_size, height, width, hidden_dim)
        Returns:
            grid + col_position_embeddings.view(*col_shape): (batch_size, *, height, width, hidden_dim)
        dtypedevice)r   r   r   )	rd   r   arangelongr   r}   lenre   r   )r0   rq   rk   rn   ro   
hidden_dimrow_position_idsr}   	row_shapecol_position_idsr   	col_shaper!   r!   r"   add_2d_positional_embeddings   s   

z4TvpVisualInputEmbedding.add_2d_positional_embeddingsc                 C   s   |j \}}}}}|d}| |}||d|}|j dd }|j}	tj|tj|	d}
| |
}|| }| 	|}| 
|}|S )a}  
        Args:
            grid: Array of shape (batch_size, num_frames, height, width, num_channels).
                It contains processed frames extracted from videos, and is generated by Tvp image preprocessor. Note,
                num_frames can be 1

        Returns:
            embeddings: The embedding of grid with size (batch_size, height*width, num_channels)

        r   rQ   Nr   )rd   meanr   re   r   r   zerosr   r   r   r   )r0   rq   rk   rl   rn   ro   rm   visual_tokensvisual_tokens_shaper   token_type_idsr   
embeddingsr!   r!   r"   rM      s   




zTvpVisualInputEmbedding.forward)r   r   r   r   r)   r   rM   rN   r!   r!   r1   r"   rv      s
    
rv   c                       s*   e Zd ZdZ fddZdddZ  ZS )TvpTextInputEmbeddingszGConstruct the embeddings from word, position and token_type embeddings.c                    sl   t    tj|j|j|jd| _t|j|j| _	t|j
|j| _tj|j|jd| _t|j| _d S )N)padding_idxrw   )r(   r)   r   ry   
vocab_sizer\   pad_token_idword_embeddingsrz   r{   type_vocab_sizer   r   r   r   r   r   r   r^   r1   r!   r"   r)      s   
zTvpTextInputEmbeddings.__init__Nc                 C   s   |d ur	|  }n|  d d }|d }|d ur|jn|j}|d u r4tj|tj|d}|d|}|d u rAtj|tj|d}|d u rJ| |}| 	|}| 
|}	|| |	 }
| |
}
| |
}
|
S )NrQ   r   r   r   )sizer   r   r   r   	unsqueezeexpandr   r   r{   r   r   r   )r0   	input_idsr   position_idsinputs_embedsinput_shape
seq_lengthr   r{   r   r   r!   r!   r"   rM      s$   





zTvpTextInputEmbeddings.forward)NNNNr   r   r   r   r)   rM   rN   r!   r!   r1   r"   r      s    r   c                       sV   e Zd Z fddZdd Zdejdedefdd	Z	
	
	
dde	e
 fddZ  ZS )TvpAttentionc                    s   t    |j|j dkrt|dstd|j d|j |j| _t|j|j | _| j| j | _t	
|j| j| _t	
|j| j| _t	
|j| j| _t	|j| _t	
|j|j| _t	j|j|jd| _t	|j| _t | _d S )Nr   embedding_sizezThe hidden size z4 is not a multiple of the number of attention heads rw   )r(   r)   r\   num_attention_headshasattrr.   intattention_head_sizeall_head_sizer   Linearquerykeyvaluer   attention_probs_dropout_probattn_dropoutdenser   r   r   r   r   setpruned_headsr^   r1   r!   r"   r)     s    
zTvpAttention.__init__c                    s   t |dkrd S t| j| j}t|| j }|D ]  t fdd| jD   d| < q|d	 
d}tt ||  }t| j|| _t| j|| _t| j|| _t| j|dd| _| jt | | _| j| j | _| j|| _d S )Nr   c                 3   s     | ]}| k r
d ndV  qdS )r   r   Nr!   ).0hheadr!   r"   	<genexpr>5  s    z+TvpAttention.prune_heads.<locals>.<genexpr>rQ   r   dim)r   r   onesr   r   r   r   sumre   
contiguouseqr   r   r   r   r   r   r   r   r<   )r0   headsmaskindexr!   r   r"   prune_heads.  s    
zTvpAttention.prune_headstensorsequence_lengthrk   c                 C   s    | ||| j| jdd S )Nr   r   )re   r   r   	transposer   )r0   r   r   rk   r!   r!   r"   _reshapeE  s   zTvpAttention._reshapeNoutput_attentionsc                 C   s   |j d d \}}| |}| |}| |}	| |||}
| |||}| |	||}t|
|dd}|t	| j
 }|d urG|| }tjj|dd}| |}|d ur\|| }t||}|dd }|||| j}| |}| |}| || }|r||f}|S |f}|S )Nr   rQ   r   r   )rd   r   r   r   r   r   matmulr   mathsqrtr   r   rf   softmaxr   r   reshaper   r   r   r   )r0   r   attention_mask	head_maskr   rk   r   mixed_query_layermixed_key_layermixed_value_layerquery_layer	key_layervalue_layerattention_scoresattention_probsattn_outputoutputsr!   r!   r"   rM   L  s2   





zTvpAttention.forwardNNN)r   r   r   r)   r   r   Tensorr   r   r   boolrM   rN   r!   r!   r1   r"   r     s    
r   c                       2   e Zd Z fddZdejdejfddZ  ZS )TvpIntermediatec                    sD   t    t|j|j| _t|jt	rt
|j | _d S |j| _d S N)r(   r)   r   r   r\   intermediate_sizer   
isinstance
hidden_actstrr   intermediate_act_fnr^   r1   r!   r"   r)   |  s
   
zTvpIntermediate.__init__r   returnc                 C   s   |  |}| |}|S r   )r   r   )r0   r   r!   r!   r"   rM     s   

zTvpIntermediate.forwardr   r   r   r)   r   r   rM   rN   r!   r!   r1   r"   r   {  s    r   c                       s8   e Zd Z fddZdejdejdejfddZ  ZS )TvpOutputLayerc                    sB   t    t|j|j| _tj|j|jd| _	t
|j| _d S )Nrw   )r(   r)   r   r   r   r\   r   r   r   r   r   r   r   r^   r1   r!   r"   r)     s   
zTvpOutputLayer.__init__r   input_tensorr   c                 C   s&   |  |}| |}| || }|S r   )r   r   r   )r0   r   r   r!   r!   r"   rM     s   

zTvpOutputLayer.forwardr   r!   r!   r1   r"   r     s    $r   c                       s6   e Zd Z fddZ			ddee fddZ  ZS )TvpEncodeLayerc                    s,   t    t|| _t|| _t|| _d S r   )r(   r)   r   	attentionr   intermediater   outputr^   r1   r!   r"   r)     s   


zTvpEncodeLayer.__init__Nr   c           
      C   sJ   | j ||||d}|d }|dd  }| |}| ||}	|	f| }|S )N)r   r   r   )r   r   r   )
r0   r   r   r   r   self_attention_outputsattention_outputr   intermediate_outputlayer_outputr!   r!   r"   rM     s   

zTvpEncodeLayer.forwardr   )r   r   r   r)   r   r   rM   rN   r!   r!   r1   r"   r     s    	r   c                
       sT   e Zd Z fddZ					d
deej dee dee dee fdd	Z  Z	S )
TvpEncoderc                    s:   t     | _t fddt jD | _d| _d S )Nc                    s   g | ]}t  qS r!   )r   )r   _r_   r!   r"   
<listcomp>  s    z'TvpEncoder.__init__.<locals>.<listcomp>F)	r(   r)   r_   r   
ModuleListrangenum_hidden_layerslayergradient_checkpointingr^   r1   r   r"   r)     s   
 
zTvpEncoder.__init__Nr   r   output_hidden_statesreturn_dictc                 C   s  |d ur|n| j j}|d ur|n| j j}|d ur|n| j j}d}d}t| jD ]:\}	}
|r2||f }| jrK| jrK| |
j	|||d urF||	 nd |}n	|
||||	 |}|d }|ra||d f }q'|ri||f }|s~|f}|ru||f }|r|||f }|S t
||r|nd |r|dS d dS )Nr!   r   r   )last_hidden_stater   r   )r_   r   r   r   	enumerater   r   training_gradient_checkpointing_func__call__r   )r0   r   r   r   r   r   r   all_hidden_statesall_attentionsilayer_modulelayer_outputsr   r!   r!   r"   rM     sL   	




zTvpEncoder.forward)NNNNN)
r   r   r   r)   r   r   r   r   rM   rN   r!   r!   r1   r"   r     s     	r   c                       r   )	TvpPoolerc                    s*   t    t|j|j| _t | _d S r   )r(   r)   r   r   r\   r   Tanh
activationr^   r1   r!   r"   r)     s   
zTvpPooler.__init__r   r   c                 C   s(   |d d df }|  |}| |}|S )Nr   )r   r  )r0   r   first_token_tensorpooled_outputr!   r!   r"   rM     s   

zTvpPooler.forwardr   r!   r!   r1   r"   r
    s    r
  c                   @   s$   e Zd ZdZeZdZdZdd ZdS )TvpPreTrainedModelzAn abstract class to handle weights initialization and
    a simple interface for downloading and loading pretrained models.
    modelTc                 C   s   t |tjtjfr|jjjd| jjd nt |tj	r(|j
j  |jjd t |tjr9|j
dur9|j
j  t |tjrXtjj|jddd |j
durZtj|j
d dS dS dS )	zInitialize the weights        )r   stdg      ?Nfan_outrh   )modenonlinearityr   )r   r   r   ry   weightdatanormal_r_   initializer_ranger   rV   zero_fill_rZ   initkaiming_normal_	constant_)r0   moduler!   r!   r"   _init_weights  s   
z TvpPreTrainedModel._init_weightsN)	r   r   r   r   r   config_classbase_model_prefixsupports_gradient_checkpointingr   r!   r!   r!   r"   r    s    r  aE  
    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
    as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.

    Parameters:
        config ([`TvpConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
a+  
    Args:
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See
            [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
            IDs?](../glossary#input-ids)

        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
            Pixel values. Pixel values can be obtained using [`TvpImageProcessor`]. See [`TvpImageProcessor.__call__`]
            for details.

        attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.
            [What are attention masks?](../glossary#attention-mask)

        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
            Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.

        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.

        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
c                       (   e Zd ZdZ fddZdd Z  ZS )TvpFrameDownPadPrompterz>
    Pad frames extracted from videos only at the bottom.
    c              	      sb   |j dvr	tdt   |j| _|j| _|j| _|j | _ tt	
d|jd|j|jg| _d S )Nr>   replaceremove9`visual_prompter_apply` must be in (add, replace, remove)r   r   )visual_prompter_applyr.   r(   r)   visual_prompt_size	frame_nummax_img_sizer   	Parameterr   randnpad_downr^   r1   r!   r"   r)   Q  s   


z TvpFrameDownPadPrompter.__init__c                 C   s   | j dkr&tj| j| jg|j|jd}d|| j| j | jd d f< ||9 }| j dkrctj|jd |jd d| j| jg|jd}| j| j }| j	|d d d d d d || jd d f< ||
|j7 }|S )	Nr>   r   r  r(  r   r   r   r   )r*  r   r   r-  r   r   r+  r   rd   r0  to)r0   rj   visual_prompt_maskpromptstart_pointr!   r!   r"   rM   _  s   

*zTvpFrameDownPadPrompter.forwardr   r!   r!   r1   r"   r%  L  s    r%  c                       r$  )TvpFramePadPrompterz?
    Pad frames extracted from videos in the surroundings.
    c              
      s   |j dvr	tdt   |j| _|j| _|j | _ |j|jd  | _t	t
d|jd|j|jg| _t	t
d|jd|j|jg| _t	t
d|jd|j|jd  |jg| _t	t
d|jd|j|jd  |jg| _d S )Nr&  r)  r   r   r   )r*  r.   r(   r)   rl   r-  r+  	base_sizer   r.  r   r/  pad_upr0  pad_left	pad_rightr^   r1   r!   r"   r)   v  sB   


zTvpFramePadPrompter.__init__c                 C   s   | j dvrtd| j  | j dv r$tj| j| jg|j|jd}||9 }| j dv rctjd| jd| j	| j	|jd}tj
| j|| jgd	d
}tj
| j|| jgdd
}t
|d|g }|||j7 }|S )N)r>   r(  r'  z$Invalid visual_prompter_apply value )r'  r(  r   )r'  r>   r   r   r1  rc   r   r   )r*  r.   r   r   r-  r   r   r   rl   r7  catr9  r:  r8  r0  r   r2  )r0   rj   r3  baser4  r!   r!   r"   rM     s   


zTvpFramePadPrompter.forwardr   r!   r!   r1   r"   r6  q  s    'r6  )framedownpadframepadzmThe bare Tvp Model transformer outputting BaseModelOutputWithPooling object without any specific head on top.c                       s   e Zd Z fddZdd Zdd Zdd Zeee	e
ed		
	
	
	
	
	
	
ddeej deej deej deej dee dee dee fddZ  ZS )TvpModelc                    s   t  | || _t|| _t|| _t|| _t	|| _
t|| _ttdd|jg| _t|j| _|jtvr?tdt|j || _|   d S )Nr   
   z:`visual_prompter_type` must be in (framedownpad, framepad))r(   r)   r_   rP   vision_modelr   r   rv   visual_embeddingsr   encoderr
  poolerr   r.  r   r/  r\   text_promptr   r   r   visual_prompter_typeTVP_PROMPTER_CLASSES_MAPPINGr.   visual_prompter	post_initr^   r1   r!   r"   r)     s   





zTvpModel.__init__c                 C   s   | j jS r   r   r   )r0   r!   r!   r"   get_input_embeddings  s   zTvpModel.get_input_embeddingsc                 C   s   || j _d S r   rJ  )r0   r   r!   r!   r"   set_input_embeddings  s   zTvpModel.set_input_embeddingsc                 C   s*   |  D ]\}}| jj| j| qdS )zPrunes heads of the model.
        heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base class PreTrainedModel
        N)itemsrC  r   r   r   )r0   heads_to_pruner   r   r!   r!   r"   _prune_heads  s   zTvpModel._prune_headsoutput_typer!  Nr   rj   r   r   r   r   r   c                 C   sJ  |dur|n| j j}| | |}| j|d}| |}	|durQ||	jdd }
t	|jd dj
|j|jd}tj|||
gdd}| || 
|j}| j|jd dd}tj|||	gd	d}| j||| || j j|||d
}|r||jn|d }| |}| |}| |}|s||f|d	d  S t|||j|jdS )a(  
        Returns:

        Examples:
        ```python
        >>> import torch
        >>> from transformers import AutoConfig, AutoTokenizer, TvpModel

        >>> model = TvpModel.from_pretrained("Jiqing/tiny-random-tvp")

        >>> tokenizer = AutoTokenizer.from_pretrained("Jiqing/tiny-random-tvp")

        >>> pixel_values = torch.rand(1, 1, 3, 448, 448)
        >>> text_inputs = tokenizer("This is an example input", return_tensors="pt")
        >>> output = model(text_inputs.input_ids, pixel_values, text_inputs.attention_mask)
        ```N)r   r   r   r@  )r   r   rQ   r   r   )r   r   r   r   r   )r   pooler_outputr   r   )r_   r   rA  rH  r   rB  new_onesrd   r   r   r2  r   r   r;  get_extended_attention_maskr   rE  r   rC  get_head_maskr   r   rD  r   r   r   r   )r0   r   rj   r   r   r   r   r   text_embedding_outputvisual_embedding_outputvisual_attention_maskpt_maskrE  embedding_outputencoder_outputsr   r  r!   r!   r"   rM     sB   



zTvpModel.forward)NNNNNNN)r   r   r   r)   rK  rL  rO  r	   TVP_INPUTS_DOCSTRINGr
   r   r   r   r   
LongTensorr   r   rM   rN   r!   r!   r1   r"   r?    s:    
r?  c                       rO   )TvpVideoGroundingHeadc                    sL   t    t|j|jd | _t|jd d| _t | _t	 | _
d S )Nr   )r(   r)   r   r   r\   layer_0layer_1ReLUactivation_0Sigmoidactivation_1r^   r1   r!   r"   r)      s
   

zTvpVideoGroundingHead.__init__c                 C   s$   |  | |}| | |}|S r   )rb  r_  rd  r`  )r0   rR  r   r!   r!   r"   rM   '  s   zTvpVideoGroundingHead.forwardru   r!   r!   r1   r"   r^    s    r^  zb
    Tvp Model with a video grounding head on top computing IoU, distance, and duration loss.
    c                       s   e Zd Z fddZeeeeed								dde	e
j de	e
j de	e
j dee
j d	e	e
j d
e	e de	e de	e fddZ  ZS )TvpForVideoGroundingc                    s2   t  | || _t|| _t|| _|   d S r   )r(   r)   r_   r?  r  r^  video_grounding_headrI  r^   r1   r!   r"   r)   4  s
   

zTvpForVideoGrounding.__init__rP  Nr   rj   r   rJ   r   r   r   r   c	              	   C   s   |dur|n| j j}| j|||||||d}	|	d }
| |
}d}|durJtg d}|| j |||}|d | j j|d   | j j|d   }|s`|f|	dd  }	|dur^|f|	 }	|	S t	|||	j
|	jd	S )
a  
        labels (`torch.FloatTensor` of shape `(batch_size, 3)`, *optional*):
            The labels contains duration, start time, and end time of the video corresponding to the text.
        Returns:

        Examples:
        ```python
        >>> import torch
        >>> from transformers import AutoConfig, AutoTokenizer, TvpForVideoGrounding

        >>> model = TvpForVideoGrounding.from_pretrained("Jiqing/tiny-random-tvp")

        >>> tokenizer = AutoTokenizer.from_pretrained("Jiqing/tiny-random-tvp")

        >>> pixel_values = torch.rand(1, 1, 3, 448, 448)
        >>> text_inputs = tokenizer("This is an example input", return_tensors="pt")
        >>> output = model(text_inputs.input_ids, pixel_values, text_inputs.attention_mask)
        ```N)r   r   r   r   r   r$   r%   r&   r'   r   )r   r   r   r   )r_   r   r  rf  r#   r2  r   distance_loss_weightduration_loss_weightr   r   r   )r0   r   rj   r   rJ   r   r   r   r   r   rR  r   r   	criterion	loss_dictr!   r!   r"   rM   <  sD   	


zTvpForVideoGrounding.forward)NNNNNNNN)r   r   r   r)   r	   r\  r
   r   r   r   r   r]  r   r   r   r   rM   rN   r!   r!   r1   r"   re  -  s:    
	re  )7r   r   dataclassesr   typingr   r   r   torch.utils.checkpointr   activationsr   
file_utilsr   r	   r
   modeling_outputsr   r   r   modeling_utilsr   pytorch_utilsr   utilsr   autor   configuration_tvpr   
get_loggerr   logger!TVP_PRETRAINED_MODEL_ARCHIVE_LISTr   Moduler#   rP   rv   r   r   r   r   r   r   r
  r  TVP_START_DOCSTRINGr\  r%  r6  rG  r?  r^  re  r!   r!   r!   r"   <module>   sd   
PB$c?#%>g