o
    ha                     @   s  d Z ddlZddlZddlmZ ddlmZmZmZ ddl	Z	ddl
Z	ddl	mZ ddlmZmZmZ ddlmZmZ dd	lmZmZmZmZmZmZ dd
lmZmZ ddlmZmZm Z  ddl!m"Z"m#Z#m$Z$m%Z% ddl&m'Z' e%(e)Z*dZ+dZ,g dZ-dd Z.G dd dej/Z0G dd deZ1G dd dej/Z2G dd dej/Z3G dd dej/Z4G dd dej/Z5G d d! d!ej/Z6G d"d# d#ej/Z7G d$d% d%ej/Z8G d&d' d'ej/Z9G d(d) d)ej/Z:G d*d+ d+ej/Z;d,Z<d-Z=e#d.e<G d/d0 d0e1Z>G d1d2 d2ej/Z?e#d3e<G d4d5 d5e1Z@G d6d7 d7ej/ZAe#d8e<G d9d: d:e1ZBe#d;e<G d<d= d=e1ZCe#d>e<G d?d@ d@e1ZDe#dAe<G dBdC dCe1ZEdS )Dz PyTorch ConvBERT model.    N)
attrgetter)OptionalTupleUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FNget_activation)"BaseModelOutputWithCrossAttentionsMaskedLMOutputMultipleChoiceModelOutputQuestionAnsweringModelOutputSequenceClassifierOutputTokenClassifierOutput)PreTrainedModelSequenceSummary)apply_chunking_to_forward find_pruneable_heads_and_indicesprune_linear_layer)add_code_sample_docstringsadd_start_docstrings%add_start_docstrings_to_model_forwardlogging   )ConvBertConfigYituTech/conv-bert-baser   )r   zYituTech/conv-bert-medium-smallzYituTech/conv-bert-smallc                 C   s  zddl }W n ty   td  w tj|}td|  |j	|}i }|D ]\}}td| d|  |j
||}	|	||< q,ddd	d
dddd}
|jdkrYd}nd}t|jD ]"}d| d|
d| d< d| d|
d| d< d| d|
d| d< d| d|
d| d< d| d|
d| d< d| d|
d| d< d| d |
d| d!< d| d"|
d| d#< d| d$|
d| d%< d| d&|
d| d'< d| d(|
d| d)< d| d*|
d| d+< d| d,|
d| d-< d| d.|
d| d/< d| d0|
d| d1< d| d2|
d| d3< d| d4|
d| d5< d| d6| d7|
d| d8< d| d6| d9|
d| d:< d| d;| d7|
d| d<< d| d;| d9|
d| d=< d| d>|
d| d?< d| d@|
d| dA< q`|  D ]c}|d }t|}|| }|
| }t|| }tdB| dC| dD |d7r|dEs|dFs|j}|dGr|ddHd}|dIr|dHdd}|dJr|dK}||_q| S )Lz'Load tf checkpoints in a pytorch model.r   NzLoading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see https://www.tensorflow.org/install/ for installation instructions.z&Converting TensorFlow checkpoint from zLoading TF weight z with shape z"electra/embeddings/word_embeddingsz&electra/embeddings/position_embeddingsz(electra/embeddings/token_type_embeddingsz"electra/embeddings/LayerNorm/gammaz!electra/embeddings/LayerNorm/betaz!electra/embeddings_project/kernelzelectra/embeddings_project/bias)z!embeddings.word_embeddings.weightz%embeddings.position_embeddings.weightz'embeddings.token_type_embeddings.weightzembeddings.LayerNorm.weightzembeddings.LayerNorm.biaszembeddings_project.weightzembeddings_project.biasr   g_densedensezelectra/encoder/layer_z/attention/self/query/kernelzencoder.layer.z.attention.self.query.weightz/attention/self/query/biasz.attention.self.query.biasz/attention/self/key/kernelz.attention.self.key.weightz/attention/self/key/biasz.attention.self.key.biasz/attention/self/value/kernelz.attention.self.value.weightz/attention/self/value/biasz.attention.self.value.biasz./attention/self/conv_attn_key/depthwise_kernelz4.attention.self.key_conv_attn_layer.depthwise.weightz./attention/self/conv_attn_key/pointwise_kernelz4.attention.self.key_conv_attn_layer.pointwise.weightz"/attention/self/conv_attn_key/biasz(.attention.self.key_conv_attn_layer.biasz'/attention/self/conv_attn_kernel/kernelz(.attention.self.conv_kernel_layer.weightz%/attention/self/conv_attn_kernel/biasz&.attention.self.conv_kernel_layer.biasz&/attention/self/conv_attn_point/kernelz%.attention.self.conv_out_layer.weightz$/attention/self/conv_attn_point/biasz#.attention.self.conv_out_layer.biasz/attention/output/dense/kernelz.attention.output.dense.weightz!/attention/output/LayerNorm/gammaz".attention.output.LayerNorm.weightz/attention/output/dense/biasz.attention.output.dense.biasz /attention/output/LayerNorm/betaz .attention.output.LayerNorm.biasz/intermediate/z/kernelz.intermediate.dense.weightz/biasz.intermediate.dense.biasz/output/z.output.dense.weightz.output.dense.biasz/output/LayerNorm/gammaz.output.LayerNorm.weightz/output/LayerNorm/betaz.output.LayerNorm.biaszTF: z, PT:  z/intermediate/g_dense/kernelz/output/g_dense/kernelz/depthwise_kernel   z/pointwise_kernelz/conv_attn_key/bias)
tensorflowImportErrorloggererrorospathabspathinfotrainlist_variablesload_variable
num_groupsrangenum_hidden_layersnamed_parametersr   torch
from_numpyendswithTpermute	unsqueezedata)modelconfigtf_checkpoint_pathtftf_path	init_varstf_datanameshapearrayparam_mappinggroup_dense_namejparam
param_name	retrieverresulttf_namevalue rM   d/var/www/html/ai/venv/lib/python3.10/site-packages/transformers/models/convbert/modeling_convbert.pyload_tf_weights_in_convbert8   s  

	









































rO   c                       sb   e Zd ZdZ fddZ				ddeej deej deej deej d	ejf
d
dZ	  Z
S )ConvBertEmbeddingszGConstruct the embeddings from word, position and token_type embeddings.c                    s   t    tj|j|j|jd| _t|j|j| _	t|j
|j| _tj|j|jd| _t|j| _| jdt|jddd | jdtj| j tjddd d S )	N)padding_idxepsposition_ids)r   r#   F)
persistenttoken_type_ids)dtype)super__init__r   	Embedding
vocab_sizeembedding_sizepad_token_idword_embeddingsmax_position_embeddingsposition_embeddingstype_vocab_sizetoken_type_embeddings	LayerNormlayer_norm_epsDropouthidden_dropout_probdropoutregister_bufferr3   arangeexpandzerosrT   sizelongselfr;   	__class__rM   rN   rY      s   

zConvBertEmbeddings.__init__N	input_idsrV   rT   inputs_embedsreturnc                 C   s   |d ur	|  }n|  d d }|d }|d u r$| jd d d |f }|d u rNt| drC| jd d d |f }||d |}|}ntj|tj| jjd}|d u rW| 	|}| 
|}	| |}
||	 |
 }| |}| |}|S )Nr#   r   rV   r   rW   device)rl   rT   hasattrrV   rj   r3   rk   rm   rv   r^   r`   rb   rc   rg   )ro   rr   rV   rT   rs   input_shape
seq_lengthbuffered_token_type_ids buffered_token_type_ids_expandedr`   rb   
embeddingsrM   rM   rN   forward   s(   






zConvBertEmbeddings.forward)NNNN)__name__
__module____qualname____doc__rY   r   r3   
LongTensorFloatTensorr}   __classcell__rM   rM   rp   rN   rP      s$    rP   c                   @   s(   e Zd ZdZeZeZdZdZ	dd Z
dS )ConvBertPreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    convbertTc                 C   s   t |tjr |jjjd| jjd |jdur|jj	  dS dS t |tj
rC|jjjd| jjd |jdurA|jj|j 	  dS dS t |tjrX|jj	  |jjd dS dS )zInitialize the weights        meanstdNg      ?)
isinstancer   Linearweightr9   normal_r;   initializer_rangebiaszero_rZ   rQ   rc   fill_)ro   modulerM   rM   rN   _init_weights   s   

z%ConvBertPreTrainedModel._init_weightsN)r~   r   r   r   r   config_classrO   load_tf_weightsbase_model_prefixsupports_gradient_checkpointingr   rM   rM   rM   rN   r      s    r   c                       6   e Zd ZdZ fddZdejdejfddZ  ZS )SeparableConv1DzSThis class implements separable convolution, i.e. a depthwise and a pointwise layerc                    s~   t    tj|||||d dd| _tj||ddd| _tt|d| _	| jj
jjd|jd | jj
jjd|jd d S )Nr"   F)kernel_sizegroupspaddingr   r   )r   r   r   r   )rX   rY   r   Conv1d	depthwise	pointwise	Parameterr3   rk   r   r   r9   r   r   )ro   r;   input_filtersoutput_filtersr   kwargsrp   rM   rN   rY     s   
zSeparableConv1D.__init__hidden_statesrt   c                 C   s"   |  |}| |}|| j7 }|S N)r   r   r   )ro   r   xrM   rM   rN   r}        


zSeparableConv1D.forward	r~   r   r   r   rY   r3   Tensorr}   r   rM   rM   rp   rN   r     s    r   c                       sx   e Zd Z fddZdd Z				ddejdeej d	eej d
eej dee	 de
ejeej f fddZ  ZS )ConvBertSelfAttentionc                    s`  t    |j|j dkrt|dstd|j d|j d|j|j }|dk r1|j| _d| _n|| _|j| _|j| _|j| j dkrHtd|j| j d | _| j| j | _	t
|j| j	| _t
|j| j	| _t
|j| j	| _t||j| j	| j| _t
| j	| j| j | _t
|j| j	| _t
j| jdgt| jd d dgd	| _t
|j| _d S )
Nr   r\   zThe hidden size (z6) is not a multiple of the number of attention heads ()r   z6hidden_size should be divisible by num_attention_headsr"   )r   r   )rX   rY   hidden_sizenum_attention_headsrw   
ValueError
head_ratioconv_kernel_sizeattention_head_sizeall_head_sizer   r   querykeyrL   r   key_conv_attn_layerconv_kernel_layerconv_out_layerUnfoldintunfoldre   attention_probs_dropout_probrg   )ro   r;   new_num_attention_headsrp   rM   rN   rY   '  s<   

zConvBertSelfAttention.__init__c                 C   s6   |  d d | j| jf }|j| }|ddddS )Nr#   r   r"   r   r
   )rl   r   r   viewr7   )ro   r   new_x_shaperM   rM   rN   transpose_for_scoresN  s   
z*ConvBertSelfAttention.transpose_for_scoresNFr   attention_mask	head_maskencoder_hidden_statesoutput_attentionsrt   c                 C   sV  |  |}|d}|d ur| |}| |}	n
| |}| |}	| |dd}
|
dd}
| |}| |}| |	}t|
|}| 	|}t
|d| jdg}tj|dd}| |}t
||d| jg}|dd d}tjj|| jdgd| jd d dgdd}|dd
|d| j| j}t
|d| j| jg}t||}t
|d| jg}t||dd}|t| j }|d ur|| }tjj|dd}| |}|d ur|| }t||}|dddd }t
||d| j| jg}t||gd}| d d | j| j d f }|j| }|r&||f}|S |f}|S )	Nr   r   r"   r#   dim)r   dilationr   strider
   )r   rl   r   rL   r   	transposer   r3   multiplyr   reshaper   softmaxr   r   
contiguousr8   r   
functionalr   r   matmulmathsqrtrg   r7   r   catr   )ro   r   r   r   r   r   mixed_query_layer
batch_sizemixed_key_layermixed_value_layermixed_key_conv_attn_layerquery_layer	key_layervalue_layerconv_attn_layerr   r   attention_scoresattention_probscontext_layerconv_outnew_context_layer_shapeoutputsrM   rM   rN   r}   S  sh   











zConvBertSelfAttention.forwardNNNF)r~   r   r   rY   r   r3   r   r   r   boolr   r}   r   rM   rM   rp   rN   r   &  s(    'r   c                       8   e Zd Z fddZdejdejdejfddZ  ZS )ConvBertSelfOutputc                    sB   t    t|j|j| _tj|j|jd| _t|j	| _
d S NrR   )rX   rY   r   r   r   r    rc   rd   re   rf   rg   rn   rp   rM   rN   rY     s   
zConvBertSelfOutput.__init__r   input_tensorrt   c                 C   &   |  |}| |}| || }|S r   r    rg   rc   ro   r   r   rM   rM   rN   r}        

zConvBertSelfOutput.forwardr~   r   r   rY   r3   r   r}   r   rM   rM   rp   rN   r     s    $r   c                       sx   e Zd Z fddZdd Z				ddejdeej d	eej d
eej dee	 de
ejeej f fddZ  ZS )ConvBertAttentionc                    s*   t    t|| _t|| _t | _d S r   )rX   rY   r   ro   r   outputsetpruned_headsrn   rp   rM   rN   rY     s   


zConvBertAttention.__init__c                 C   s   t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S )Nr   r   r   )lenr   ro   r   r   r   r   r   r   rL   r   r    r   union)ro   headsindexrM   rM   rN   prune_heads  s   zConvBertAttention.prune_headsNFr   r   r   r   r   rt   c           	      C   s8   |  |||||}| |d |}|f|dd   }|S )Nr   r   )ro   r   )	ro   r   r   r   r   r   self_outputsattention_outputr   rM   rM   rN   r}     s   zConvBertAttention.forwardr   )r~   r   r   rY   r   r3   r   r   r   r   r   r}   r   rM   rM   rp   rN   r     s(    r   c                       2   e Zd Z fddZdejdejfddZ  ZS )GroupedLinearLayerc                    sj   t    || _|| _|| _| j| j | _| j| j | _tt	
| j| j| j| _tt	
|| _d S r   )rX   rY   
input_sizeoutput_sizer/   group_in_dimgroup_out_dimr   r   r3   emptyr   r   )ro   r   r   r/   rp   rM   rN   rY     s   
zGroupedLinearLayer.__init__r   rt   c                 C   sr   t | d }t|d| j| jg}|ddd}t|| j}|ddd}t||d| j	g}|| j
 }|S )Nr   r#   r   r"   )listrl   r3   r   r/   r   r7   r   r   r   r   )ro   r   r   r   rM   rM   rN   r}     s   
zGroupedLinearLayer.forwardr   rM   rM   rp   rN   r     s    
r   c                       r   )ConvBertIntermediatec                    sf   t    |jdkrt|j|j| _nt|j|j|jd| _t	|j
tr-t|j
 | _d S |j
| _d S )Nr   r   r   r/   )rX   rY   r/   r   r   r   intermediate_sizer    r   r   
hidden_actstrr   intermediate_act_fnrn   rp   rM   rN   rY     s   

zConvBertIntermediate.__init__r   rt   c                 C   s   |  |}| |}|S r   )r    r  ro   r   rM   rM   rN   r}     s   

zConvBertIntermediate.forwardr   rM   rM   rp   rN   r    s    r  c                       r   )ConvBertOutputc                    sd   t    |jdkrt|j|j| _nt|j|j|jd| _tj	|j|j
d| _	t|j| _d S )Nr   r  rR   )rX   rY   r/   r   r   r  r   r    r   rc   rd   re   rf   rg   rn   rp   rM   rN   rY     s   

zConvBertOutput.__init__r   r   rt   c                 C   r   r   r   r   rM   rM   rN   r}     r   zConvBertOutput.forwardr   rM   rM   rp   rN   r    s    $r  c                       s   e Zd Z fddZ					ddejdeej deej deej d	eej d
ee de	ejeej f fddZ
dd Z  ZS )ConvBertLayerc                    sn   t    |j| _d| _t|| _|j| _|j| _| jr+| js&t|  dt|| _	t
|| _t|| _d S )Nr   z> should be used as a decoder model if cross attention is added)rX   rY   chunk_size_feed_forwardseq_len_dimr   	attention
is_decoderadd_cross_attention	TypeErrorcrossattentionr  intermediater  r   rn   rp   rM   rN   rY     s   



zConvBertLayer.__init__NFr   r   r   r   encoder_attention_maskr   rt   c                 C   s   | j ||||d}|d }|dd  }	| jr<|d ur<t| ds'td|  d| |||||}
|
d }|	|
dd   }	t| j| j| j|}|f|	 }	|	S )N)r   r   r   r  z'If `encoder_hidden_states` are passed, z` has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`)	r  r  rw   AttributeErrorr  r   feed_forward_chunkr
  r  )ro   r   r   r   r   r  r   self_attention_outputsr   r   cross_attention_outputslayer_outputrM   rM   rN   r}   -  s6   	


zConvBertLayer.forwardc                 C   s   |  |}| ||}|S r   )r  r   )ro   r   intermediate_outputr  rM   rM   rN   r  U  s   
z ConvBertLayer.feed_forward_chunk)NNNNF)r~   r   r   rY   r3   r   r   r   r   r   r}   r  r   rM   rM   rp   rN   r	    s.    
(r	  c                       s   e Zd Z fddZ							ddejdeej deej d	eej d
eej dee dee dee de	e
ef fddZ  ZS )ConvBertEncoderc                    s:   t     | _t fddt jD | _d| _d S )Nc                    s   g | ]}t  qS rM   )r	  ).0_r;   rM   rN   
<listcomp>_  s    z,ConvBertEncoder.__init__.<locals>.<listcomp>F)	rX   rY   r;   r   
ModuleListr0   r1   layergradient_checkpointingrn   rp   r  rN   rY   \  s   
 
zConvBertEncoder.__init__NFTr   r   r   r   r  r   output_hidden_statesreturn_dictrt   c	              
   C   s  |rdnd }	|r
dnd }
|r| j jrdnd }t| jD ]I\}}|r&|	|f }	|d ur.|| nd }| jrC| jrC| |j||||||}n	|||||||}|d }|rd|
|d f }
| j jrd||d f }q|rl|	|f }	|s{tdd ||	|
|fD S t	||	|
|dS )NrM   r   r   r"   c                 s   s    | ]	}|d ur|V  qd S r   rM   )r  vrM   rM   rN   	<genexpr>  s    z*ConvBertEncoder.forward.<locals>.<genexpr>)last_hidden_stater   
attentionscross_attentions)
r;   r  	enumerater  r   training_gradient_checkpointing_func__call__tupler   )ro   r   r   r   r   r  r   r!  r"  all_hidden_statesall_self_attentionsall_cross_attentionsilayer_modulelayer_head_masklayer_outputsrM   rM   rN   r}   b  sV   



zConvBertEncoder.forward)NNNNFFT)r~   r   r   rY   r3   r   r   r   r   r   r   r   r}   r   rM   rM   rp   rN   r  [  s8    		

r  c                       r   )ConvBertPredictionHeadTransformc                    sV   t    t|j|j| _t|jtrt	|j | _
n|j| _
tj|j|jd| _d S r   )rX   rY   r   r   r   r    r   r  r  r   transform_act_fnrc   rd   rn   rp   rM   rN   rY     s   
z(ConvBertPredictionHeadTransform.__init__r   rt   c                 C   s"   |  |}| |}| |}|S r   )r    r5  rc   r  rM   rM   rN   r}     r   z'ConvBertPredictionHeadTransform.forwardr   rM   rM   rp   rN   r4    s    	r4  aK  
    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
    it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.

    Parameters:
        config ([`ConvBertConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
a8
  
    Args:
        input_ids (`torch.LongTensor` of shape `({0})`):
            Indices of input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:


            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)
        token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
            1]`:


            - 0 corresponds to a *sentence A* token,
            - 1 corresponds to a *sentence B* token.

            [What are token type IDs?](../glossary#token-type-ids)
        position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
            config.max_position_embeddings - 1]`.

            [What are position IDs?](../glossary#position-ids)
        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
            Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:


            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
            model's internal embedding lookup matrix.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
zbThe bare ConvBERT Model transformer outputting raw hidden-states without any specific head on top.c                       s   e Zd Z fddZdd Zdd Zdd Zee	d	e
eeed
									ddeej deej deej deej deej deej dee dee dee deeef fddZ  ZS )ConvBertModelc                    sP   t  | t|| _|j|jkrt|j|j| _t	|| _
|| _|   d S r   )rX   rY   rP   r|   r\   r   r   r   embeddings_projectr  encoderr;   	post_initrn   rp   rM   rN   rY     s   

zConvBertModel.__init__c                 C   s   | j jS r   r|   r^   ro   rM   rM   rN   get_input_embeddings  s   z"ConvBertModel.get_input_embeddingsc                 C   s   || j _d S r   r:  )ro   rL   rM   rM   rN   set_input_embeddings  s   z"ConvBertModel.set_input_embeddingsc                 C   s*   |  D ]\}}| jj| j| qdS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsr8  r  r  r   )ro   heads_to_pruner  r   rM   rM   rN   _prune_heads  s   zConvBertModel._prune_headsbatch_size, sequence_length
checkpointoutput_typer   Nrr   r   rV   rT   r   rs   r   r!  r"  rt   c
                 C   sr  |d ur|n| j j}|d ur|n| j j}|	d ur|	n| j j}	|d ur*|d ur*td|d ur9| || | }
n|d urF| d d }
ntd|
\}}|d urU|jn|j}|d u rctj	|
|d}|d u rt
| jdr| jjd d d |f }|||}|}n	tj|
tj|d}| ||
}| || j j}| j||||d}t
| dr| |}| j||||||	d	}|S )
NzDYou cannot specify both input_ids and inputs_embeds at the same timer#   z5You have to specify either input_ids or inputs_embeds)rv   rV   ru   )rr   rT   rV   rs   r7  )r   r   r   r!  r"  )r;   r   r!  use_return_dictr   %warn_if_padding_and_no_attention_maskrl   rv   r3   onesrw   r|   rV   rj   rk   rm   get_extended_attention_maskget_head_maskr1   r7  r8  )ro   rr   r   rV   rT   r   rs   r   r!  r"  rx   r   ry   rv   rz   r{   extended_attention_maskr   rM   rM   rN   r}     sL   


	zConvBertModel.forward)	NNNNNNNNN)r~   r   r   rY   r<  r=  r@  r   CONVBERT_INPUTS_DOCSTRINGformatr   _CHECKPOINT_FOR_DOCr   _CONFIG_FOR_DOCr   r3   r   r   r   r   r   r}   r   rM   rM   rp   rN   r6    sR    	

r6  c                       r   )ConvBertGeneratorPredictionszAPrediction module for the generator, made up of two dense layers.c                    s4   t    tj|j|jd| _t|j|j| _d S r   )	rX   rY   r   rc   r\   rd   r   r   r    rn   rp   rM   rN   rY   X  s   
z%ConvBertGeneratorPredictions.__init__generator_hidden_statesrt   c                 C   s$   |  |}td|}| |}|S )Ngelu)r    r   rc   )ro   rP  r   rM   rM   rN   r}   ^  s   

z$ConvBertGeneratorPredictions.forward)	r~   r   r   r   rY   r3   r   r}   r   rM   rM   rp   rN   rO  U  s    rO  z6ConvBERT Model with a `language modeling` head on top.c                       s   e Zd ZdgZ fddZdd Zdd Zee	de
eeed		
	
	
	
	
	
	
	
	
	
ddeej deej deej deej deej deej deej dee dee dee deeef fddZ  ZS )ConvBertForMaskedLMzgenerator.lm_head.weightc                    s>   t  | t|| _t|| _t|j|j	| _
|   d S r   )rX   rY   r6  r   rO  generator_predictionsr   r   r\   r[   generator_lm_headr9  rn   rp   rM   rN   rY   j  s
   

zConvBertForMaskedLM.__init__c                 C   s   | j S r   rT  r;  rM   rM   rN   get_output_embeddingst  s   z)ConvBertForMaskedLM.get_output_embeddingsc                 C   s
   || _ d S r   rU  )ro   r^   rM   rM   rN   set_output_embeddingsw  s   
z)ConvBertForMaskedLM.set_output_embeddingsrA  rB  Nrr   r   rV   rT   r   rs   labelsr   r!  r"  rt   c                 C   s   |
dur|
n| j j}
| ||||||||	|
	}|d }| |}| |}d}|dur=t }||d| j j|d}|
sS|f|dd  }|durQ|f| S |S t	|||j
|jdS )a  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
            config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
            loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
        Nr   r#   r   losslogitsr   r&  )r;   rE  r   rS  rT  r   r   r   r[   r   r   r&  )ro   rr   r   rV   rT   r   rs   rX  r   r!  r"  rP  generator_sequence_outputprediction_scoresrZ  loss_fctr   rM   rM   rN   r}   z  s8   

zConvBertForMaskedLM.forward
NNNNNNNNNN)r~   r   r   _tied_weights_keysrY   rV  rW  r   rK  rL  r   rM  r   rN  r   r3   r   r   r   r   r   r}   r   rM   rM   rp   rN   rR  f  sX    
	

rR  c                       r   )ConvBertClassificationHeadz-Head for sentence-level classification tasks.c                    sZ   t    t|j|j| _|jd ur|jn|j}t|| _	t|j|j
| _|| _d S r   )rX   rY   r   r   r   r    classifier_dropoutrf   re   rg   
num_labelsout_projr;   ro   r;   rb  rp   rM   rN   rY     s   

z#ConvBertClassificationHead.__init__r   rt   c                 K   sR   |d d dd d f }|  |}| |}t| jj |}|  |}| |}|S )Nr   )rg   r    r   r;   r  rd  )ro   r   r   r   rM   rM   rN   r}     s   



z"ConvBertClassificationHead.forwardr   rM   rM   rp   rN   ra    s    ra  z
    ConvBERT Model transformer with a sequence classification/regression head on top (a linear layer on top of the
    pooled output) e.g. for GLUE tasks.
    c                          e Zd Z fddZeedeee	e
d										ddeej deej deej d	eej d
eej deej deej dee dee dee deee	f fddZ  ZS )!ConvBertForSequenceClassificationc                    s:   t  | |j| _|| _t|| _t|| _|   d S r   )	rX   rY   rc  r;   r6  r   ra  
classifierr9  rn   rp   rM   rN   rY     s   

z*ConvBertForSequenceClassification.__init__rA  rB  Nrr   r   rV   rT   r   rs   rX  r   r!  r"  rt   c                 C   sh  |
dur|
n| j j}
| j||||||||	|
d	}|d }| |}d}|dur| j jdu rQ| jdkr7d| j _n| jdkrM|jtjksH|jtj	krMd| j _nd| j _| j jdkrot
 }| jdkri|| | }n+|||}n%| j jdkrt }||d| j|d}n| j jdkrt }|||}|
s|f|dd  }|dur|f| S |S t|||j|jd	S )
a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        Nr   rV   rT   r   rs   r   r!  r"  r   r   
regressionsingle_label_classificationmulti_label_classificationr#   rY  )r;   rE  r   rh  problem_typerc  rW   r3   rm   r   r	   squeezer   r   r   r   r   r&  ro   rr   r   rV   rT   r   rs   rX  r   r!  r"  r   sequence_outputr[  rZ  r^  r   rM   rM   rN   r}     sT   


"


z)ConvBertForSequenceClassification.forwardr_  )r~   r   r   rY   r   rK  rL  r   rM  r   rN  r   r3   r   r   r   r   r   r}   r   rM   rM   rp   rN   rg    sR    
	

rg  z
    ConvBERT Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
    softmax) e.g. for RocStories/SWAG tasks.
    c                       rf  )ConvBertForMultipleChoicec                    s<   t  | t|| _t|| _t|jd| _	| 
  d S )Nr   )rX   rY   r6  r   r   sequence_summaryr   r   r   rh  r9  rn   rp   rM   rN   rY   6  s
   

z"ConvBertForMultipleChoice.__init__z(batch_size, num_choices, sequence_lengthrB  Nrr   r   rV   rT   r   rs   rX  r   r!  r"  rt   c                 C   sn  |
dur|
n| j j}
|dur|jd n|jd }|dur%|d|dnd}|dur4|d|dnd}|durC|d|dnd}|durR|d|dnd}|dure|d|d|dnd}| j||||||||	|
d	}|d }| |}| |}|d|}d}|durt }|||}|
s|f|dd  }|dur|f| S |S t	|||j
|jdS )aJ  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
            num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
            `input_ids` above)
        Nr   r#   r   ri  r   rY  )r;   rE  rB   r   rl   r   rr  rh  r   r   r   r&  )ro   rr   r   rV   rT   r   rs   rX  r   r!  r"  num_choicesr   rp  pooled_outputr[  reshaped_logitsrZ  r^  r   rM   rM   rN   r}   @  sL   


z!ConvBertForMultipleChoice.forwardr_  )r~   r   r   rY   r   rK  rL  r   rM  r   rN  r   r3   r   r   r   r   r   r}   r   rM   rM   rp   rN   rq  .  sV    
	

rq  z
    ConvBERT Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
    Named-Entity-Recognition (NER) tasks.
    c                       rf  )ConvBertForTokenClassificationc                    s^   t  | |j| _t|| _|jd ur|jn|j}t|| _	t
|j|j| _|   d S r   )rX   rY   rc  r6  r   rb  rf   r   re   rg   r   r   rh  r9  re  rp   rM   rN   rY     s   
z'ConvBertForTokenClassification.__init__rA  rB  Nrr   r   rV   rT   r   rs   rX  r   r!  r"  rt   c                 C   s   |
dur|
n| j j}
| j||||||||	|
d	}|d }| |}| |}d}|dur<t }||d| j|d}|
sR|f|dd  }|durP|f| S |S t|||j	|j
dS )z
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
        Nri  r   r#   r   rY  )r;   rE  r   rg   rh  r   r   rc  r   r   r&  ro  rM   rM   rN   r}     s8   

z&ConvBertForTokenClassification.forwardr_  )r~   r   r   rY   r   rK  rL  r   rM  r   rN  r   r3   r   r   r   r   r   r}   r   rM   rM   rp   rN   rv    sR    	

rv  z
    ConvBERT Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
    layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
    c                       s   e Zd Z fddZeedeee	e
d											ddeej deej deej d	eej d
eej deej deej deej dee dee dee deee	f fddZ  ZS )ConvBertForQuestionAnsweringc                    s<   t  | |j| _t|| _t|j|j| _| 	  d S r   )
rX   rY   rc  r6  r   r   r   r   
qa_outputsr9  rn   rp   rM   rN   rY     s
   
z%ConvBertForQuestionAnswering.__init__rA  rB  Nrr   r   rV   rT   r   rs   start_positionsend_positionsr   r!  r"  rt   c                 C   sH  |dur|n| j j}| j|||||||	|
|d	}|d }| |}|jddd\}}|d }|d }d}|dur|durt| dkrO|d}t| dkr\|d}|d}|	d|}|	d|}t
|d}|||}|||}|| d }|s||f|dd  }|dur|f| S |S t||||j|jd	S )
a  
        start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
            are not taken into account for computing the loss.
        end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
            are not taken into account for computing the loss.
        Nri  r   r   r#   r   )ignore_indexr"   )rZ  start_logits
end_logitsr   r&  )r;   rE  r   rx  splitrn  r   r   rl   clampr   r   r   r&  )ro   rr   r   rV   rT   r   rs   ry  rz  r   r!  r"  r   rp  r[  r|  r}  
total_lossignored_indexr^  
start_lossend_lossr   rM   rM   rN   r}     sP   






z$ConvBertForQuestionAnswering.forward)NNNNNNNNNNN)r~   r   r   rY   r   rK  rL  r   rM  r   rN  r   r3   r   r   r   r   r   r}   r   rM   rM   rp   rN   rw    sX    
	

rw  )Fr   r   r(   operatorr   typingr   r   r   r3   torch.utils.checkpointr   torch.nnr   r   r	   activationsr   r   modeling_outputsr   r   r   r   r   r   modeling_utilsr   r   pytorch_utilsr   r   r   utilsr   r   r   r   configuration_convbertr   
get_loggerr~   r&   rM  rN  &CONVBERT_PRETRAINED_MODEL_ARCHIVE_LISTrO   ModulerP   r   r   r   r   r   r   r  r  r	  r  r4  CONVBERT_START_DOCSTRINGrK  r6  rO  rR  ra  rg  rq  rv  rw  rM   rM   rM   rN   <module>   s    
|< -=E5`PXVJ