o
    h                     @   s  d Z ddlZddlZddlmZ ddlmZmZm	Z	m
Z
mZ ddlZddlZddlmZ ddlmZ ddlmZ dd	lmZmZmZmZ dd
lmZmZmZ ddlmZ ddlmZm Z  ddl!m"Z"m#Z# ddl$m%Z% ddl&m'Z' e#(e)Z*dZ+dZ,g dZ-ddgZ.eG dd de"Z/eG dd de"Z0G dd dej1Z2G dd dej1Z3G dd dej1Z4G d d! d!ej1Z5G d"d# d#ej1Z6G d$d% d%ej1Z7G d&d' d'ej1Z8G d(d) d)ej1Z9G d*d+ d+ej1Z:G d,d- d-ej1Z;G d.d/ d/ej1Z<d0d1 Z=G d2d3 d3ej1Z>G d4d5 d5ej1Z?G d6d7 d7ej1Z@G d8d9 d9ej1ZAG d:d; d;eZBd<ZCd=ZDed>eCG d?d@ d@eBZEG dAdB dBej1ZFG dCdD dDej1ZGG dEdF dFej1ZHedGeCG dHdI dIeBZIG dJdK dKej1ZJG dLdM dMej1ZKedNeCG dOdP dPeBZLdS )Qz PyTorch DPT (Dense Prediction Transformers) model.

This implementation is heavily inspired by OpenMMLab's implementation, found here:
https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/decode_heads/dpt_head.py.

    N)	dataclass)ListOptionalSetTupleUnion)nn)CrossEntropyLoss   )ACT2FN)add_code_sample_docstringsadd_start_docstrings%add_start_docstrings_to_model_forwardreplace_return_docstrings)BaseModelOutputDepthEstimatorOutputSemanticSegmenterOutput)PreTrainedModel) find_pruneable_heads_and_indicesprune_linear_layer)ModelOutputlogging   )AutoBackbone   )	DPTConfigr   zIntel/dpt-large)r   iA  i   zIntel/dpt-hybrid-midasc                   @   s6   e Zd ZU dZdZejed< dZe	e
ej  ed< dS )*BaseModelOutputWithIntermediateActivationsa#  
    Base class for model's outputs that also contains intermediate activations that can be used at later stages. Useful
    in the context of Vision models.:

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        intermediate_activations (`tuple(torch.FloatTensor)`, *optional*):
            Intermediate activations that can be used to compute hidden states of the model at various layers.
    Nlast_hidden_statesintermediate_activations)__name__
__module____qualname____doc__r   torchFloatTensor__annotations__r   r   r    r&   r&   Z/var/www/html/ai/venv/lib/python3.10/site-packages/transformers/models/dpt/modeling_dpt.pyr   A   s   
 r   c                   @   sp   e Zd ZU dZdZejed< dZejed< dZ	e
eej  ed< dZe
eej  ed< dZe
eej  ed< dS )4BaseModelOutputWithPoolingAndIntermediateActivationsa  
    Base class for model's outputs that also contains a pooling of the last hidden states as well as intermediate
    activations that can be used by the model at later stages.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
            Last layer hidden-state of the first token of the sequence (classification token) after further processing
            through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns
            the classification token after processing through a linear layer and a tanh activation function. The linear
            layer weights are trained from the next sentence prediction (classification) objective during pretraining.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
        intermediate_activations (`tuple(torch.FloatTensor)`, *optional*):
            Intermediate activations that can be used to compute hidden states of the model at various layers.
    Nlast_hidden_statepooler_outputhidden_states
attentionsr   )r   r    r!   r"   r)   r#   r$   r%   r*   r+   r   r   r,   r   r&   r&   r&   r'   r(   R   s   
 r(   c                	       sN   e Zd ZdZd fdd	ZdddZ	dd	ejd
ededejfddZ	  Z
S )DPTViTHybridEmbeddingsz
    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
    Transformer.
    Nc           
         sn  t    |j|j}}|j|j}}t|tjj	r|n||f}t|tjj	r)|n||f}|d |d  |d |d   }t
|j| _| jjd }t|jjdkr]tdt|jj ddg| _|d u rt|j}	|	dd  }|	d }nt|tjj	r}|n||f}| jjd }|| _|d | _|| _tj||dd| _ttdd|j| _ttd|d |j| _d S )Nr   r   r
   z1Expected backbone to have 3 output features, got kernel_size)super__init__
image_size
patch_sizenum_channelshidden_size
isinstancecollectionsabcIterabler   from_configbackbone_configbackbonechannelslenout_features
ValueErrorresidual_feature_map_indexbackbone_featmap_shaper   Conv2d
projection	Parameterr#   zeros	cls_tokenposition_embeddings)
selfconfigfeature_sizer4   r5   r6   r7   num_patchesfeature_dimfeat_map_shape	__class__r&   r'   r3   }   s4   
 


 zDPTViTHybridEmbeddings.__init__r   c                 C      |d d d |f }|d|d f }t tt|}|d||ddddd}tjj|||fdd}|ddddd|| d}t	j
||gdd}|S 	Nr   r   r.   r
   r   bilinear)sizemodedimintmathsqrtr@   reshapepermuter   
functionalinterpolater#   catrK   posembgrid_size_heightgrid_size_widthstart_index
posemb_tokposemb_gridold_grid_sizer&   r&   r'   _resize_pos_embed      z(DPTViTHybridEmbeddings._resize_pos_embedFpixel_valuesinterpolate_pos_encodingreturn_dictreturnc              
      s  |j \}}}}|| jkrtd|s7|| jd ks || jd kr7td| d| d| jd  d| jd  d	| | j|| j || j }| |  jd }	 fd	d
| j	D }
| 
|	ddd}| j|dd}tj||fdd}|| }|s||
fS t||
dS )NeMake sure that the channel dimension of the pixel values match with the one set in the configuration.r   r   zInput image size (*z) doesn't match model (z).r.   c                    s   g | ]} j | qS r&   )feature_maps).0indexbackbone_outputr&   r'   
<listcomp>   s    z2DPTViTHybridEmbeddings.forward.<locals>.<listcomp>r   rX   )r   r   )shaper6   rB   r4   rk   rJ   r5   r>   rs   rC   rF   flatten	transposerI   expandr#   rb   r   )rK   rm   rn   ro   
batch_sizer6   heightwidthrJ   featuresoutput_hidden_states
embeddings
cls_tokensr&   rv   r'   forward   s<   


zDPTViTHybridEmbeddings.forwardNr   )FF)r   r    r!   r"   r3   rk   r#   Tensorboolr   __classcell__r&   r&   rQ   r'   r-   v   s    
$r-   c                       s4   e Zd ZdZ fddZd
ddZddd	Z  ZS )DPTViTEmbeddingszB
    Construct the CLS token, position and patch embeddings.

    c                    sh   t    ttdd|j| _t|| _	| j	j
}ttd|d |j| _t|j| _|| _d S )Nr   )r2   r3   r   rG   r#   rH   r7   rI   DPTViTPatchEmbeddingspatch_embeddingsrN   rJ   Dropouthidden_dropout_probdropoutrL   )rK   rL   rN   rQ   r&   r'   r3      s   


zDPTViTEmbeddings.__init__r   c                 C   rS   rT   rZ   rc   r&   r&   r'   rk      rl   z"DPTViTEmbeddings._resize_pos_embedFc                 C   s   |j \}}}}| jj}| | j|| || }| |}	|	 \}}
}| j|dd}t	j
||	fdd}	|	| }	| |	}	|sB|	fS t|	dS )Nr.   r   rX   )r   )ry   rL   r5   rk   rJ   r   rV   rI   r|   r#   rb   r   r   )rK   rm   ro   r}   r6   r~   r   r5   rJ   r   seq_len_r   r&   r&   r'   r      s   


zDPTViTEmbeddings.forwardr   )F)r   r    r!   r"   r3   rk   r   r   r&   r&   rQ   r'   r      s
    

r   c                       s(   e Zd ZdZ fddZdd Z  ZS )r   z$
    Image to Patch Embedding.

    c                    s   t    |j|j}}|j|j}}t|tjj	r|n||f}t|tjj	r)|n||f}|d |d  |d |d   }|| _|| _|| _|| _
tj||||d| _d S )Nr   r   )r1   stride)r2   r3   r4   r5   r6   r7   r8   r9   r:   r;   rN   r   rE   rF   )rK   rL   r4   r5   r6   r7   rN   rQ   r&   r'   r3     s   
 zDPTViTPatchEmbeddings.__init__c                 C   s<   |j \}}}}|| jkrtd| |ddd}|S )Nrq   r   r   )ry   r6   rB   rF   rz   r{   )rK   rm   r}   r6   r~   r   r   r&   r&   r'   r   *  s   
zDPTViTPatchEmbeddings.forwardr   r    r!   r"   r3   r   r   r&   r&   rQ   r'   r     s    r   c                
       sv   e Zd Zdeddf fddZdejdejfddZ		dd
eej de	de
eejejf eej f fddZ  ZS )DPTViTSelfAttentionrL   rp   Nc                    s   t    |j|j dkr t|ds td|jf d|j d|j| _t|j|j | _| j| j | _t	j
|j| j|jd| _t	j
|j| j|jd| _t	j
|j| j|jd| _t	|j| _d S )Nr   embedding_sizezThe hidden size z4 is not a multiple of the number of attention heads .)bias)r2   r3   r7   num_attention_headshasattrrB   r[   attention_head_sizeall_head_sizer   Linearqkv_biasquerykeyvaluer   attention_probs_dropout_probr   rK   rL   rQ   r&   r'   r3   6  s   
zDPTViTSelfAttention.__init__xc                 C   s6   |  d d | j| jf }||}|ddddS )Nr.   r   r   r   r
   )rV   r   r   viewr_   )rK   r   new_x_shaper&   r&   r'   transpose_for_scoresH  s   
z(DPTViTSelfAttention.transpose_for_scoresF	head_maskoutput_attentionsc                 C   s   |  |}| | |}| | |}| |}t||dd}|t| j	 }t
jj|dd}	| |	}	|d urA|	| }	t|	|}
|
dddd }
|
 d d | jf }|
|}
|rj|
|	f}|S |
f}|S )Nr.   r/   rX   r   r   r   r
   )r   r   r   r   r#   matmulr{   r\   r]   r   r   r`   softmaxr   r_   
contiguousrV   r   r   )rK   r+   r   r   mixed_query_layer	key_layervalue_layerquery_layerattention_scoresattention_probscontext_layernew_context_layer_shapeoutputsr&   r&   r'   r   M  s$   



zDPTViTSelfAttention.forwardNF)r   r    r!   r   r3   r#   r   r   r   r   r   r   r   r   r&   r&   rQ   r'   r   5  s    r   c                       sF   e Zd ZdZdeddf fddZdejdejdejfd	d
Z  Z	S )DPTViTSelfOutputz
    The residual connection is defined in DPTLayer instead of here (as is the case with other models), due to the
    layernorm applied before each block.
    rL   rp   Nc                    s.   t    t|j|j| _t|j| _d S r   )	r2   r3   r   r   r7   denser   r   r   r   rQ   r&   r'   r3   x     
zDPTViTSelfOutput.__init__r+   input_tensorc                 C      |  |}| |}|S r   r   r   rK   r+   r   r&   r&   r'   r   }     

zDPTViTSelfOutput.forward)
r   r    r!   r"   r   r3   r#   r   r   r   r&   r&   rQ   r'   r   r  s    $r   c                       s~   e Zd Zdeddf fddZdee ddfddZ			dd
ej	de
ej	 dedeeej	ej	f eej	 f fddZ  ZS )DPTViTAttentionrL   rp   Nc                    s*   t    t|| _t|| _t | _d S r   )r2   r3   r   	attentionr   outputsetpruned_headsr   rQ   r&   r'   r3     s   


zDPTViTAttention.__init__headsc                 C   s   t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S )Nr   r   rX   )r@   r   r   r   r   r   r   r   r   r   r   r   r   union)rK   r   ru   r&   r&   r'   prune_heads  s   zDPTViTAttention.prune_headsFr+   r   r   c                 C   s4   |  |||}| |d |}|f|dd   }|S )Nr   r   )r   r   )rK   r+   r   r   self_outputsattention_outputr   r&   r&   r'   r     s   zDPTViTAttention.forwardr   )r   r    r!   r   r3   r   r[   r   r#   r   r   r   r   r   r   r   r&   r&   rQ   r'   r     s    r   c                       s<   e Zd Zdeddf fddZdejdejfddZ  ZS )	DPTViTIntermediaterL   rp   Nc                    sD   t    t|j|j| _t|jt	rt
|j | _d S |j| _d S r   )r2   r3   r   r   r7   intermediate_sizer   r8   
hidden_actstrr   intermediate_act_fnr   rQ   r&   r'   r3     s
   
zDPTViTIntermediate.__init__r+   c                 C   r   r   )r   r   )rK   r+   r&   r&   r'   r     r   zDPTViTIntermediate.forward	r   r    r!   r   r3   r#   r   r   r   r&   r&   rQ   r'   r     s    r   c                       sB   e Zd Zdeddf fddZdejdejdejfdd	Z  ZS )
DPTViTOutputrL   rp   Nc                    s.   t    t|j|j| _t|j| _	d S r   )
r2   r3   r   r   r   r7   r   r   r   r   r   rQ   r&   r'   r3     r   zDPTViTOutput.__init__r+   r   c                 C   s    |  |}| |}|| }|S r   r   r   r&   r&   r'   r     s   

zDPTViTOutput.forwardr   r&   r&   rQ   r'   r     s    $r   c                       sl   e Zd ZdZdeddf fddZ		ddejd	eej d
e	de
eejejf eej f fddZ  ZS )DPTViTLayerz?This corresponds to the Block class in the timm implementation.rL   rp   Nc                    sb   t    |j| _d| _t|| _t|| _t|| _	t
j|j|jd| _t
j|j|jd| _d S )Nr   eps)r2   r3   chunk_size_feed_forwardseq_len_dimr   r   r   intermediater   r   r   	LayerNormr7   layer_norm_epslayernorm_beforelayernorm_afterr   rQ   r&   r'   r3     s   



zDPTViTLayer.__init__Fr+   r   r   c                 C   s`   | j | |||d}|d }|dd  }|| }| |}| |}| ||}|f| }|S )N)r   r   r   )r   r   r   r   r   )rK   r+   r   r   self_attention_outputsr   r   layer_outputr&   r&   r'   r     s   


zDPTViTLayer.forwardr   )r   r    r!   r"   r   r3   r#   r   r   r   r   r   r   r   r&   r&   rQ   r'   r     s    r   c                       sb   e Zd Zdeddf fddZ				ddejd	eej d
ededede	e
ef fddZ  ZS )DPTViTEncoderrL   rp   Nc                    s:   t     | _t fddt jD | _d| _d S )Nc                    s   g | ]}t  qS r&   )r   )rt   r   rL   r&   r'   rx     s    z*DPTViTEncoder.__init__.<locals>.<listcomp>F)	r2   r3   rL   r   
ModuleListrangenum_hidden_layerslayergradient_checkpointingr   rQ   r   r'   r3     s   
 
zDPTViTEncoder.__init__FTr+   r   r   r   ro   c                 C   s   |rdnd }|r
dnd }t | jD ]8\}}	|r||f }|d ur$|| nd }
| jr6| jr6| |	j||
|}n|	||
|}|d }|rI||d f }q|rQ||f }|s_tdd |||fD S t|||dS )Nr&   r   r   c                 s   s    | ]	}|d ur|V  qd S r   r&   )rt   vr&   r&   r'   	<genexpr>%  s    z(DPTViTEncoder.forward.<locals>.<genexpr>)r)   r+   r,   )	enumerater   r   training_gradient_checkpointing_func__call__tupler   )rK   r+   r   r   r   ro   all_hidden_statesall_self_attentionsilayer_modulelayer_head_masklayer_outputsr&   r&   r'   r     s6   

zDPTViTEncoder.forward)NFFT)r   r    r!   r   r3   r#   r   r   r   r   r   r   r   r   r&   r&   rQ   r'   r     s&    	
r   c                       sP   e Zd ZdZ fddZdd Zdd Zdd	eej	 d
eej	 fddZ
  ZS )DPTReassembleStagea@  
    This class reassembles the hidden states of the backbone into image-like feature representations at various
    resolutions.

    This happens in 3 stages:
    1. Map the N + 1 tokens to a set of N tokens, by taking into account the readout ([CLS]) token according to
       `config.readout_type`.
    2. Project the channel dimension of the hidden states according to `config.neck_hidden_sizes`.
    3. Resizing the spatial dimensions (height, width).

    Args:
        config (`[DPTConfig]`):
            Model configuration class defining the model architecture.
    c                    sB   t    || _t | _|jr| | n| | |j	| _	d S r   )
r2   r3   rL   r   r   layers	is_hybrid_init_reassemble_dpt_hybrid_init_reassemble_dptneck_ignore_stagesr   rQ   r&   r'   r3   =  s   


zDPTReassembleStage.__init__c              	   C   s   t tt|j|jD ]#\}}|dkr| jt  q|dkr.| jt	||j| |d q|j
dkr=td|j
 dt | _t|}tt|jD ])}|dkr_| jtt  qM|dkrv| jttd| |t|j  qMdS )a   "
        For DPT-Hybrid the first 2 reassemble layers are set to `nn.Identity()`, please check the official
        implementation: https://github.com/isl-org/DPT/blob/f43ef9e08d70a752195028a51be5e1aff227b913/dpt/vit.py#L438
        for more details.
        r   r?   factorprojectzReadout type z! is not supported for DPT-Hybrid.r   N)zipr   r@   neck_hidden_sizesreassemble_factorsr   appendr   IdentityDPTReassembleLayerreadout_typerB   r   readout_projects_get_backbone_hidden_size
Sequentialr   r   r   )rK   rL   r   r   r7   r&   r&   r'   r   I  s&   

z.DPTReassembleStage._init_reassemble_dpt_hybridc              	   C   s   t tt|j|jD ]\}}| jt||j| |d q|jdkrIt	
 | _t|}tt|jD ]}| jt	t	d| |t|j  q3d S d S )Nr   r   r   )r   r   r@   r   r   r   r   r   r   r   r   r   r   r  r   r   r   )rK   rL   r   r   r7   r   r&   r&   r'   r   c  s   

z'DPTReassembleStage._init_reassemble_dptNr+   rp   c                 C   sN  g }t |D ]\}}|| jvr|dddf |ddddf }}|j\}}	}
|dur9|dur9|||||
}ntt|	}|||||
}|dddd }|j}| j	j
dkr|dd}|d|}| j| t||fd}|ddd|}n| j	j
d	kr|d|d }||}| j| |}|| q|S )
z
        Args:
            hidden_states (`List[torch.FloatTensor]`, each of shape `(batch_size, sequence_length + 1, hidden_size)`):
                List of hidden states from the backbone.
        Nr   r   r
   r   r   )r   r   r   r.   add)r   r   ry   r^   r[   r\   r]   r_   r   rL   r   rz   	unsqueeze	expand_asr   r#   rb   r   r   )rK   r+   patch_heightpatch_widthoutr   hidden_staterI   r}   sequence_lengthr6   rV   feature_shapereadoutr&   r&   r'   r   o  s,   
&
zDPTReassembleStage.forwardNN)r   r    r!   r"   r3   r   r   r   r#   r   r   r   r&   r&   rQ   r'   r   -  s    (r   c                 C   s"   | j d ur| jdu r| j jS | jS r   )r=   r   r7   r   r&   r&   r'   r     s   r   c                       $   e Zd Z fddZdd Z  ZS )r   c                    s   t    t|}tj||dd| _|dkr#tj||||dd| _d S |dkr.t | _d S |dk rCtj||dt	d| dd| _d S d S )Nr   )in_channelsout_channelsr1   r   r1   r   paddingr
   )
r2   r3   r   r   rE   rF   ConvTranspose2dresizer   r[   )rK   rL   r?   r   r7   rQ   r&   r'   r3     s   
"zDPTReassembleLayer.__init__c                 C   r   r   )rF   r  )rK   r  r&   r&   r'   r     s   

zDPTReassembleLayer.forwardr   r    r!   r3   r   r   r&   r&   rQ   r'   r     s    r   c                       r  )DPTFeatureFusionStagec                    s<   t    t | _tt|jD ]
}| jt	| qd S r   )
r2   r3   r   r   r   r   r@   r   r   DPTFeatureFusionLayer)rK   rL   r   rQ   r&   r'   r3     s
   

zDPTFeatureFusionStage.__init__c                 C   sl   |d d d }g }| j d |d }|| t|dd  | j dd  D ]\}}|||}|| q%|S )Nr.   r   r   )r   r   r   )rK   r+   fused_hidden_statesfused_hidden_stater  r   r&   r&   r'   r     s   
$
zDPTFeatureFusionStage.forwardr  r&   r&   rQ   r'   r    s    r  c                       s6   e Zd ZdZ fddZdejdejfddZ  ZS )DPTPreActResidualLayerz
    ResidualConvUnit, pre-activate residual unit.

    Args:
        config (`[DPTConfig]`):
            Model configuration class defining the model architecture.
    c                    s   t    |j| _|jd ur|jn| j }t | _tj|j	|j	ddd|d| _
t | _tj|j	|j	ddd|d| _| jrNt|j	| _t|j	| _d S d S )Nr
   r   )r1   r   r  r   )r2   r3   !use_batch_norm_in_fusion_residualuse_batch_normuse_bias_in_fusion_residualr   ReLUactivation1rE   fusion_hidden_sizeconvolution1activation2convolution2BatchNorm2dbatch_norm1batch_norm2)rK   rL   r  rQ   r&   r'   r3     s8   



		zDPTPreActResidualLayer.__init__r  rp   c                 C   sT   |}|  |}| |}| jr| |}| |}| |}| jr&| |}|| S r   )r  r   r  r$  r!  r"  r%  rK   r  residualr&   r&   r'   r     s   





zDPTPreActResidualLayer.forward)	r   r    r!   r"   r3   r#   r   r   r   r&   r&   rQ   r'   r    s    "r  c                       s,   e Zd ZdZd fdd	Zd	ddZ  ZS )
r  a3  Feature fusion layer, merges feature maps from different stages.

    Args:
        config (`[DPTConfig]`):
            Model configuration class defining the model architecture.
        align_corners (`bool`, *optional*, defaults to `True`):
            The align_corner setting for bilinear upsample.
    Tc                    s@   t    || _tj|j|jddd| _t|| _t|| _	d S )Nr   T)r1   r   )
r2   r3   align_cornersr   rE   r  rF   r  residual_layer1residual_layer2)rK   rL   r(  rQ   r&   r'   r3     s
   

zDPTFeatureFusionLayer.__init__Nc                 C   st   |d ur#|j |j krtjj||j d |j d fddd}|| | }| |}tjj|dd| jd}| |}|S )Nr   r
   rU   FrV   rW   r(  scale_factorrW   r(  )ry   r   r`   ra   r)  r*  r(  rF   r&  r&   r&   r'   r     s   


zDPTFeatureFusionLayer.forwardTr   r   r&   r&   rQ   r'   r    s    	
r  c                   @   s(   e Zd ZdZeZdZdZdZdd Z	dS )DPTPreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    dptrm   Tc                 C   sx   t |tjtjtjfr%|jjjd| jj	d |j
dur#|j
j  dS dS t |tjr:|j
j  |jjd dS dS )zInitialize the weightsg        )meanstdNg      ?)r8   r   r   rE   r  weightdatanormal_rL   initializer_ranger   zero_r   fill_)rK   moduler&   r&   r'   _init_weights6  s   
z DPTPreTrainedModel._init_weightsN)
r   r    r!   r"   r   config_classbase_model_prefixmain_input_namesupports_gradient_checkpointingr:  r&   r&   r&   r'   r/  +  s    r/  aE  
    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
    as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.

    Parameters:
        config ([`ViTConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
aP  
    Args:
        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`DPTImageProcessor.__call__`]
            for details.

        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
            Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
z]The bare DPT Model transformer outputting raw hidden-states without any specific head on top.c                       s   e Zd Zd fdd	Zdd Zdd Zeeee	e
eded		
	
	
	
ddejdeej dee dee dee deee
f fddZ  ZS )DPTModelTc                    sj   t  | || _|jrt|| _nt|| _t|| _t	j
|j|jd| _|r,t|nd | _|   d S )Nr   )r2   r3   rL   r   r-   r   r   r   encoderr   r   r7   r   	layernormDPTViTPoolerpooler	post_init)rK   rL   add_pooling_layerrQ   r&   r'   r3   j  s   

zDPTModel.__init__c                 C   s   | j jr| jS | jjS r   )rL   r   r   r   rK   r&   r&   r'   get_input_embeddings{  s   zDPTModel.get_input_embeddingsc                 C   s*   |  D ]\}}| jj| j| qdS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsr@  r   r   r   )rK   heads_to_pruner   r   r&   r&   r'   _prune_heads  s   zDPTModel._prune_headsvision)
checkpointoutput_typer;  modalityexpected_outputNrm   r   r   r   ro   rp   c                 C   s   |d ur|n| j j}|d ur|n| j j}|d ur|n| j j}| || j j}| j||d}|s3|d n|j}| j|||||d}|d }	| 	|	}	| j
d urS| 
|	nd }
|sp|
d ur_|	|
fn|	f}||dd   |dd   S t|	|
|j|j|jdS )N)ro   r   r   r   r   ro   r   )r)   r*   r+   r,   r   )rL   r   r   use_return_dictget_head_maskr   r   r   r@  rA  rC  r(   r+   r,   r   )rK   rm   r   r   r   ro   embedding_outputembedding_last_hidden_statesencoder_outputssequence_outputpooled_outputhead_outputsr&   r&   r'   r     s6   
zDPTModel.forwardr.  )NNNN)r   r    r!   r3   rG  rJ  r   DPT_INPUTS_DOCSTRINGr   _CHECKPOINT_FOR_DOCr(   _CONFIG_FOR_DOC_EXPECTED_OUTPUT_SHAPEr#   r$   r   r   r   r   r   r   r&   r&   rQ   r'   r?  e  s:    

r?  c                       s*   e Zd Zdef fddZdd Z  ZS )rB  rL   c                    s*   t    t|j|j| _t | _d S r   )r2   r3   r   r   r7   r   Tanh
activationr   rQ   r&   r'   r3     s   
zDPTViTPooler.__init__c                 C   s(   |d d df }|  |}| |}|S )Nr   )r   r^  )rK   r+   first_token_tensorrW  r&   r&   r'   r     s   

zDPTViTPooler.forward)r   r    r!   r   r3   r   r   r&   r&   rQ   r'   rB    s    rB  c                       s@   e Zd ZdZ fddZd	deej deej fddZ  Z	S )
DPTNecka;  
    DPTNeck. A neck is a module that is normally used between the backbone and the head. It takes a list of tensors as
    input and produces another list of tensors as output. For DPT, it includes 2 stages:

    * DPTReassembleStage
    * DPTFeatureFusionStage.

    Args:
        config (dict): config dict.
    c              
      sz   t    || _|jd ur|jjdv rd | _nt|| _t | _	|j
D ]}| j	tj||jdddd q$t|| _d S )N)swinv2r
   r   Fr1   r  r   )r2   r3   rL   r=   
model_typereassemble_stager   r   r   convsr   r   rE   r  r  fusion_stage)rK   rL   channelrQ   r&   r'   r3     s   



 zDPTNeck.__init__Nr+   rp   c                    sn   t |ttfstdt|t jjkrtd jdur% |||} fddt|D } 	|}|S )z
        Args:
            hidden_states (`List[torch.FloatTensor]`, each of shape `(batch_size, sequence_length, hidden_size)` or `(batch_size, hidden_size, height, width)`):
                List of hidden states from the backbone.
        z2hidden_states should be a tuple or list of tensorszOThe number of hidden states should be equal to the number of neck hidden sizes.Nc                    s   g | ]\}} j | |qS r&   )re  )rt   r   featurerF  r&   r'   rx      s    z#DPTNeck.forward.<locals>.<listcomp>)
r8   r   listrB   r@   rL   r   rd  r   rf  )rK   r+   r  r  r   r   r&   rF  r'   r     s   

zDPTNeck.forwardr  
r   r    r!   r"   r3   r   r#   r   r   r   r&   r&   rQ   r'   r`    s    (r`  c                       s:   e Zd ZdZ fddZdeej dejfddZ  Z	S )DPTDepthEstimationHeada  
    Output head head consisting of 3 convolutional layers. It progressively halves the feature dimension and upsamples
    the predictions to the input resolution after the first convolutional layer (details can be found in the paper's
    supplementary material).
    c                    s   t    || _d | _|jrtjdddddd| _|j}ttj||d ddddtj	ddd	d
tj|d dddddt
 tjddddddt
 | _d S )N   )r
   r
   )r   r   r  r   r
   r   rU   Tr,      r   )r2   r3   rL   rF   add_projectionr   rE   r  r  Upsampler  headrK   rL   r   rQ   r&   r'   r3     s   

zDPTDepthEstimationHead.__init__r+   rp   c                 C   sF   || j j }| jd ur| |}t |}| |}|jdd}|S )Nr   rX   )rL   head_in_indexrF   r   r  rp  squeeze)rK   r+   predicted_depthr&   r&   r'   r   "  s   


zDPTDepthEstimationHead.forwardrj  r&   r&   rQ   r'   rk    s    "rk  zu
    DPT Model with a depth estimation head on top (consisting of 3 convolutional layers) e.g. for KITTI, NYUv2.
    c                       s   e Zd Z fddZeeeeed					dde	j
dee	j
 dee	j dee d	ee d
ee deee	j ef fddZ  ZS )DPTForDepthEstimationc                    sd   t  | d | _|jd ur|jdu rt|j| _nt|dd| _t	|| _
t|| _|   d S NF)rE  )r2   r3   r>   r=   r   r   r<   r?  r0  r`  neckrk  rp  rD  r   rQ   r&   r'   r3   8  s   

zDPTForDepthEstimation.__init__rM  r;  Nrm   r   labelsr   r   ro   rp   c                    s  |dur|n j j}|dur|n j j}|dur|n j j} jdur0 jj|||d}|j}nF j|||d|d}|r?|jn|d } j j	sW fddt
|dd D }n|r\|jnt|d }	|	 fd	d
t
|dd D  |	}d\}
} j jdur j j	du r|j\}}}} j jj}|| }
|| } ||
|} |}d}|durtd|s|r|f|dd  }n	|f|dd  }|dur|f| S |S t|||r|jnd|jdS )a  
        labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
            Ground truth depth estimation maps for computing the loss.

        Returns:

        Examples:
        ```python
        >>> from transformers import AutoImageProcessor, DPTForDepthEstimation
        >>> import torch
        >>> import numpy as np
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> image_processor = AutoImageProcessor.from_pretrained("Intel/dpt-large")
        >>> model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large")

        >>> # prepare image for the model
        >>> inputs = image_processor(images=image, return_tensors="pt")

        >>> with torch.no_grad():
        ...     outputs = model(**inputs)
        ...     predicted_depth = outputs.predicted_depth

        >>> # interpolate to original size
        >>> prediction = torch.nn.functional.interpolate(
        ...     predicted_depth.unsqueeze(1),
        ...     size=image.size[::-1],
        ...     mode="bicubic",
        ...     align_corners=False,
        ... )

        >>> # visualize the prediction
        >>> output = prediction.squeeze().cpu().numpy()
        >>> formatted = (output * 255 / np.max(output)).astype("uint8")
        >>> depth = Image.fromarray(formatted)
        ```N)r   r   TrP  r   c                        g | ]\}}| j jv r|qS r&   rL   backbone_out_indicesrt   idxrh  rF  r&   r'   rx         z1DPTForDepthEstimation.forward.<locals>.<listcomp>r.   c                 3   ,    | ]\}}| j jd d v r|V  qdS r   Nr{  r}  rF  r&   r'   r     s    z0DPTForDepthEstimation.forward.<locals>.<genexpr>r  FzTraining is not implemented yetr   )lossrt  r+   r,   )rL   rQ  r   r   r>   forward_with_filtered_kwargsrs   r0  r+   r   r   r   ri  extendr=   ry   r5   rw  rp  NotImplementedErrorr   r,   )rK   rm   r   ry  r   r   ro   r   r+   backbone_hidden_statesr  r  r   r~   r   r5   rt  r  r   r&   rF  r'   r   J  s`   3



zDPTForDepthEstimation.forward)NNNNN)r   r    r!   r3   r   rY  r   r   r[  r#   r$   r   
LongTensorr   r   r   r   r   r   r&   r&   rQ   r'   ru  1  s0    
ru  c                       s6   e Zd Z fddZdeej dejfddZ  ZS )DPTSemanticSegmentationHeadc                    sl   t    || _|j}ttj||ddddt|t t	|j
tj||jddtjdddd	| _d S )
Nr
   r   Frb  r0   r   rU   Tr,  )r2   r3   rL   r  r   r  rE   r#  r  r   semantic_classifier_dropout
num_labelsro  rp  rq  rQ   r&   r'   r3     s   


z$DPTSemanticSegmentationHead.__init__r+   rp   c                 C   s   || j j }| |}|S r   )rL   rr  rp  rK   r+   logitsr&   r&   r'   r     s   
z#DPTSemanticSegmentationHead.forward)	r   r    r!   r3   r   r#   r   r   r   r&   r&   rQ   r'   r    s    "r  c                       r  )DPTAuxiliaryHeadc                    sX   t    |j}ttj||ddddt|t tddtj||j	dd| _
d S )Nr
   r   Frb  g?r0   )r2   r3   r  r   r  rE   r#  r  r   r  rp  rq  rQ   r&   r'   r3     s   


zDPTAuxiliaryHead.__init__c                 C   s   |  |}|S r   )rp  r  r&   r&   r'   r     s   
zDPTAuxiliaryHead.forwardr  r&   r&   rQ   r'   r    s    r  zY
    DPT Model with a semantic segmentation head on top e.g. for ADE20k, CityScapes.
    c                       s   e Zd Z fddZeeeeed						dde	e
j de	e
j de	e
j de	e d	e	e d
e	e deee
j ef fddZ  ZS )DPTForSemanticSegmentationc                    sN   t  | t|dd| _t|| _t|| _|jrt	|nd | _
|   d S rv  )r2   r3   r?  r0  r`  rw  r  rp  use_auxiliary_headr  auxiliary_headrD  r   rQ   r&   r'   r3     s   

z#DPTForSemanticSegmentation.__init__rx  Nrm   r   ry  r   r   ro   rp   c                    s  |dur|n j j}|dur|n j j} j|||d|d}|r#|jn|d } j js; fddt|dd D }n|r@|jnt|d }	|		 fdd	t|dd D  |	} j
|d
} |}
d} jdurs |d }d}|dur j jdkrtdtjj|
|jdd ddd}|durtjj||jdd ddd}t j jd}|||}|||}| j j|  }|s|r|
f|dd  }n	|
f|dd  }|dur|f| S |S t||
|r|jnd|jdS )a  
        labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
            Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).

        Returns:

        Examples:
        ```python
        >>> from transformers import AutoImageProcessor, DPTForSemanticSegmentation
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> image_processor = AutoImageProcessor.from_pretrained("Intel/dpt-large-ade")
        >>> model = DPTForSemanticSegmentation.from_pretrained("Intel/dpt-large-ade")

        >>> inputs = image_processor(images=image, return_tensors="pt")

        >>> outputs = model(**inputs)
        >>> logits = outputs.logits
        ```NTrP  r   c                    rz  r&   r{  r}  rF  r&   r'   rx   6  r  z6DPTForSemanticSegmentation.forward.<locals>.<listcomp>r.   c                 3   r  r  r{  r}  rF  r&   r'   r   ;  s    "z5DPTForSemanticSegmentation.forward.<locals>.<genexpr>)r+   z/The number of labels should be greater than oner/   rU   Fr+  )ignore_indexr   )r  r  r+   r,   )rL   rQ  r   r0  r+   r   r   r   ri  r  rw  rp  r  r  rB   r   r`   ra   ry   r	   semantic_loss_ignore_indexauxiliary_loss_weightr   r,   )rK   rm   r   ry  r   r   ro   r   r+   r  r  auxiliary_logitsr  upsampled_logitsupsampled_auxiliary_logitsloss_fct	main_lossauxiliary_lossr   r&   rF  r'   r     sf   #




z"DPTForSemanticSegmentation.forward)NNNNNN)r   r    r!   r3   r   rY  r   r   r[  r   r#   r$   r  r   r   r   r   r   r   r&   r&   rQ   r'   r    s2    
r  )Mr"   collections.abcr9   r\   dataclassesr   typingr   r   r   r   r   r#   torch.utils.checkpointr   torch.nnr	   activationsr   
file_utilsr   r   r   r   modeling_outputsr   r   r   modeling_utilsr   pytorch_utilsr   r   utilsr   r   autor   configuration_dptr   
get_loggerr   loggerr[  rZ  r\  !DPT_PRETRAINED_MODEL_ARCHIVE_LISTr   r(   Moduler-   r   r   r   r   r   r   r   r   r   r   r   r   r  r  r  r/  DPT_START_DOCSTRINGrY  r?  rB  r`  rk  ru  r  r  r  r&   r&   r&   r'   <module>   s   
#e: =*+3h=%[5) 	