o
    hI                    @   s  d Z ddlZddlZddlmZ ddlmZmZmZ ddl	Z	ddl	m
Z
 ddlmZ ddlmZ dd	lmZmZmZmZ dd
lmZ ddlmZmZmZ ddlmZmZmZmZ ddlm Z  e!e"Z#dZ$dZ%dZ&dZ'g dZ(dd Z)G dd de
j*Z+G dd de
j*Z,G dd de
j*Z-G dd de
j*Z.G dd de
j*Z/G d d! d!e
j*Z0G d"d# d#e
j*Z1G d$d% d%e
j*Z2G d&d' d'e
j*Z3eG d(d) d)eZ4eG d*d+ d+eZ5eG d,d- d-eZ6eG d.d/ d/eZ7G d0d1 d1e
j*Z8G d2d3 d3e
j*Z9G d4d5 d5e
j*Z:G d6d7 d7e
j*Z;G d8d9 d9e
j*Z<d:Z=d;Z>G d<d= d=eZ?G d>d? d?e?Z@ed@e=G dAdB dBe?ZAedCe=G dDdE dEe?ZBedFe=G dGdH dHe?ZCedIe=G dJdK dKe?ZDdLZEedMe=G dNdO dOe?ZFdS )Pz PyTorch REALM model.    N)	dataclass)OptionalTupleUnion)nn)CrossEntropyLoss   )ACT2FN))BaseModelOutputWithPastAndCrossAttentions,BaseModelOutputWithPoolingAndCrossAttentionsMaskedLMOutputModelOutput)PreTrainedModel)apply_chunking_to_forward find_pruneable_heads_and_indicesprune_linear_layer)add_start_docstrings%add_start_docstrings_to_model_forwardloggingreplace_return_docstrings   )RealmConfig(google/realm-cc-news-pretrained-embedder'google/realm-cc-news-pretrained-encoder&google/realm-cc-news-pretrained-scorerr   )r   r   r   z&google/realm-cc-news-pretrained-openqazgoogle/realm-orqa-nq-openqazgoogle/realm-orqa-nq-readerzgoogle/realm-orqa-wq-openqazgoogle/realm-orqa-wq-readerc                 C   sf  zddl }ddl}ddl}W n ty   td  w tj|}t	d|  |j
|}g }g }	|D ] \}
}t	d|
 d|  |j
||
}||
 |	| q6t||	D ]\}
}t| tryd|
vryt	d|
 d	| jj d
 q\|
ds|
drt| tr|
dd}
|
dd}
|
ds|
drt| tr|
dd}
|
drt| trdnd}|
d| d}
|
d| d}
|
d| d}
|
d| d}
|
d| d}
|
dr*t| trdnd}|
d| d}
|
d| d }
|
d!| d"}
|
d#| d$}
|
d%| d}
|
d&| d$}
n"|
d'rLt| tr8dnd}|
d(| d }
|
d)| d"}
|
d*}
td+d, |
D rgt	dd*|
  q\| }|
D ]m}|d-|r{|d.|}n|g}|d d/ks|d d0krt|d1}n4|d d2ks|d d3krt|d4}n z	t||d }W n ty   t	dd*|
  Y qkw t|d5krt|d6 }|| }qk|d7d d8krt|d1}n
|d/kr| |}z|j!|j!ksJ d9|j! d:|j! d;W n t"y! } z| j#|j!|j!f7  _# d}~ww t	d<|
  t$%||_&q\| S )=z'Load tf checkpoints in a pytorch model.r   NzLoading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see https://www.tensorflow.org/install/ for installation instructions.z&Converting TensorFlow checkpoint from zLoading TF weight z with shape readerz	Skipping z as it is not z's parameterbertclszbert/zreader/realm/zcls/zreader/cls/zrealm/ zreader/zreader/module/bert/zreader/module/cls/zreader/dense/zqa_outputs/dense_intermediate/zreader/dense_1/zqa_outputs/dense_output/zreader/layer_normalizationzqa_outputs/layer_normalizationzmodule/module/module/z	embedder/z!module/module/module/module/bert/zmodule/module/module/LayerNorm/zcls/LayerNorm/zmodule/module/module/dense/z
cls/dense/z,module/module/module/module/cls/predictions/zcls/predictions/zmodule/module/module/bert/z%module/module/module/cls/predictions/zmodule/module/zmodule/module/LayerNorm/zmodule/module/dense//c                 s   s    | ]}|d v V  qdS ))adam_vadam_mAdamWeightDecayOptimizerAdamWeightDecayOptimizer_1global_stepN ).0nr%   r%   ^/var/www/html/ai/venv/lib/python3.10/site-packages/transformers/models/realm/modeling_realm.py	<genexpr>|   s
    
z+load_tf_weights_in_realm.<locals>.<genexpr>z[A-Za-z]+_\d+z_(\d+)kernelgammaweightoutput_biasbetabias   r   i_embeddingszPointer shape z and array shape z mismatchedzInitialize PyTorch weight )'renumpy
tensorflowImportErrorloggererrorospathabspathinfotrainlist_variablesload_variableappendzip
isinstanceRealmReader	__class____name__
startswithRealmForOpenQAreplaceRealmKnowledgeAugEncoderRealmEmbeddersplitanyjoin	fullmatchgetattrAttributeErrorlenint	transposeshapeAssertionErrorargstorch
from_numpydata)modelconfigtf_checkpoint_pathr2   nptftf_path	init_varsnamesarraysnamerS   arrayreader_prefixembedder_prefixpointerm_namescope_namesnumer%   r%   r(   load_tf_weights_in_realm:   s   





rk   c                       sh   e Zd ZdZ fddZ					ddeej deej deej d	eej d
e	dej
fddZ  ZS )RealmEmbeddingszGConstruct the embeddings from word, position and token_type embeddings.c                    s   t    tj|j|j|jd| _t|j|j| _	t|j
|j| _tj|j|jd| _t|j| _t|dd| _| jdt|jddd | jd	tj| j tjd
dd d S )N)padding_idxepsposition_embedding_typeabsoluteposition_ids)r   F)
persistenttoken_type_idsdtype)super__init__r   	Embedding
vocab_sizehidden_sizepad_token_idword_embeddingsmax_position_embeddingsposition_embeddingstype_vocab_sizetoken_type_embeddings	LayerNormlayer_norm_epsDropouthidden_dropout_probdropoutrN   rp   register_bufferrV   arangeexpandzerosrr   sizelongselfrZ   rC   r%   r(   ry      s   

zRealmEmbeddings.__init__Nr   	input_idsru   rr   inputs_embedspast_key_values_lengthreturnc                 C   s   |d ur	|  }n|  d d }|d }|d u r&| jd d ||| f }|d u rPt| drE| jd d d |f }||d |}	|	}ntj|tj| jjd}|d u rY| 	|}| 
|}
||
 }| jdkrp| |}||7 }| |}| |}|S )Nrs   r   ru   r   rw   devicerq   )r   rr   hasattrru   r   rV   r   r   r   r~   r   rp   r   r   r   )r   r   ru   rr   r   r   input_shape
seq_lengthbuffered_token_type_ids buffered_token_type_ids_expandedr   
embeddingsr   r%   r%   r(   forward   s,   







zRealmEmbeddings.forward)NNNNr   )rD   
__module____qualname____doc__ry   r   rV   
LongTensorFloatTensorrQ   Tensorr   __classcell__r%   r%   r   r(   rl      s*    rl   c                       s   e Zd Zd fdd	ZdejdejfddZ						dd	ejd
eej deej deej deej dee	e	ej   dee
 de	ej fddZ  ZS )RealmSelfAttentionNc                    s   t    |j|j dkrt|dstd|j d|j d|j| _t|j|j | _| j| j | _t	
|j| j| _t	
|j| j| _t	
|j| j| _t	|j| _|p\t|dd| _| jdksh| jd	kry|j| _t	d
|j d | j| _|j| _d S )Nr   embedding_sizezThe hidden size (z6) is not a multiple of the number of attention heads ()rp   rq   relative_keyrelative_key_queryr0   r   )rx   ry   r|   num_attention_headsr   
ValueErrorrQ   attention_head_sizeall_head_sizer   Linearquerykeyvaluer   attention_probs_dropout_probr   rN   rp   r   rz   distance_embedding
is_decoderr   rZ   rp   r   r%   r(   ry      s*   

zRealmSelfAttention.__init__xr   c                 C   s6   |  d d | j| jf }||}|ddddS )Nrs   r   r0   r   r   )r   r   r   viewpermute)r   r   new_x_shaper%   r%   r(   transpose_for_scores  s   
z'RealmSelfAttention.transpose_for_scoresFhidden_statesattention_mask	head_maskencoder_hidden_statesencoder_attention_maskpast_key_valueoutput_attentionsc                 C   s  |  |}|d u}	|	r|d ur|d }
|d }|}nP|	r/| | |}
| | |}|}n;|d urZ| | |}
| | |}tj|d |
gdd}
tj|d |gdd}n| | |}
| | |}| |}|d u}| jrz|
|f}t||
dd}| j	dks| j	dkr	|j
d |
j
d }}|rtj|d tj|jd	dd}ntj|tj|jd	dd}tj|tj|jd	dd}|| }| || j d }|j|jd
}| j	dkrtd||}|| }n| j	dkr	td||}td|
|}|| | }|t| j }|d ur|| }tjj|dd}| |}|d ur0|| }t||}|dddd }| d d | jf }||}|rX||fn|f}| jrd||f }|S )Nr   r   r0   dimrs   r   r   r   rv   zbhld,lrd->bhlrzbhrd,lrd->bhlrr   ) r   r   r   r   rV   catr   matmulrR   rp   rS   tensorr   r   r   r   r   r   torw   einsummathsqrtr   r   
functionalsoftmaxr   r   
contiguousr   r   )r   r   r   r   r   r   r   r   mixed_query_layeris_cross_attention	key_layervalue_layerquery_layer	use_cacheattention_scoresquery_length
key_lengthposition_ids_lposition_ids_rdistancepositional_embeddingrelative_position_scoresrelative_position_scores_queryrelative_position_scores_keyattention_probscontext_layernew_context_layer_shapeoutputsr%   r%   r(   r     sn   









zRealmSelfAttention.forwardNNNNNNF)rD   r   r   ry   rV   r   r   r   r   r   boolr   r   r%   r%   r   r(   r      s4    	r   c                       8   e Zd Z fddZdejdejdejfddZ  ZS )RealmSelfOutputc                    sB   t    t|j|j| _tj|j|jd| _t|j	| _
d S Nrn   )rx   ry   r   r   r|   denser   r   r   r   r   r   r   r%   r(   ry   o     
zRealmSelfOutput.__init__r   input_tensorr   c                 C   &   |  |}| |}| || }|S r   r   r   r   r   r   r   r%   r%   r(   r   u     

zRealmSelfOutput.forwardrD   r   r   ry   rV   r   r   r   r%   r%   r   r(   r   n      $r   c                       s   e Zd Zd fdd	Zdd Z						ddejdeej d	eej d
eej deej dee	e	ej   dee
 de	ej fddZ  ZS )RealmAttentionNc                    s.   t    t||d| _t|| _t | _d S )Nrp   )rx   ry   r   r   r   outputsetpruned_headsr   r   r%   r(   ry   ~  s   

zRealmAttention.__init__c                 C   s   t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S )Nr   r   r   )rP   r   r   r   r   r   r   r   r   r   r   r   r   union)r   headsindexr%   r%   r(   prune_heads  s   zRealmAttention.prune_headsFr   r   r   r   r   r   r   r   c              	   C   s<   |  |||||||}| |d |}	|	f|dd   }
|
S )Nr   r   )r   r   )r   r   r   r   r   r   r   r   self_outputsattention_outputr   r%   r%   r(   r     s   
	zRealmAttention.forwardr   r   )rD   r   r   ry   r   rV   r   r   r   r   r   r   r   r%   r%   r   r(   r   }  s4    	r   c                       2   e Zd Z fddZdejdejfddZ  ZS )RealmIntermediatec                    sD   t    t|j|j| _t|jt	rt
|j | _d S |j| _d S r   )rx   ry   r   r   r|   intermediate_sizer   rA   
hidden_actstrr	   intermediate_act_fnr   r   r%   r(   ry     s
   
zRealmIntermediate.__init__r   r   c                 C      |  |}| |}|S r   )r   r  r   r   r%   r%   r(   r        

zRealmIntermediate.forwardr   r%   r%   r   r(   r     s    r   c                       r   )RealmOutputc                    sB   t    t|j|j| _tj|j|jd| _t	|j
| _d S r   )rx   ry   r   r   r   r|   r   r   r   r   r   r   r   r   r%   r(   ry     r   zRealmOutput.__init__r   r   r   c                 C   r   r   r   r   r%   r%   r(   r     r   zRealmOutput.forwardr   r%   r%   r   r(   r    r   r  c                       s   e Zd Z fddZ						ddejdeej deej deej d	eej d
eeeej   dee	 deej fddZ
dd Z  ZS )
RealmLayerc                    sr   t    |j| _d| _t|| _|j| _|j| _| jr-| js&t|  dt|dd| _	t
|| _t|| _d S )Nr   z> should be used as a decoder model if cross attention is addedrq   r   )rx   ry   chunk_size_feed_forwardseq_len_dimr   	attentionr   add_cross_attentionr   crossattentionr   intermediater  r   r   r   r%   r(   ry     s   


zRealmLayer.__init__NFr   r   r   r   r   r   r   r   c              	   C   s  |d ur
|d d nd }| j |||||d}	|	d }
| jr(|	dd }|	d }n|	dd  }d }| jro|d urot| dsDtd|  d|d urN|d	d  nd }| |
||||||}|d }
||dd  }|d }|| }t| j| j| j|
}|f| }| jr||f }|S )
Nr0   )r   r   r   r   rs   r  z'If `encoder_hidden_states` are passed, z` has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`r   )	r
  r   r   r   r  r   feed_forward_chunkr  r	  )r   r   r   r   r   r   r   r   self_attn_past_key_valueself_attention_outputsr   r   present_key_valuecross_attn_present_key_valuecross_attn_past_key_valuecross_attention_outputslayer_outputr%   r%   r(   r     sP   


	

zRealmLayer.forwardc                 C   s   |  |}| ||}|S r   )r  r   )r   r   intermediate_outputr  r%   r%   r(   r    s   
zRealmLayer.feed_forward_chunkr   )rD   r   r   ry   rV   r   r   r   r   r   r   r  r   r%   r%   r   r(   r    s4    	
Ar  c                       s   e Zd Z fddZ									ddejdeej deej d	eej d
eej deeeej   dee	 dee	 dee	 dee	 de
eej ef fddZ  ZS )RealmEncoderc                    s:   t     | _t fddt jD | _d| _d S )Nc                    s   g | ]}t  qS r%   )r  )r&   _rZ   r%   r(   
<listcomp>)  s    z)RealmEncoder.__init__.<locals>.<listcomp>F)	rx   ry   rZ   r   
ModuleListrangenum_hidden_layerslayergradient_checkpointingr   r   r  r(   ry   &  s   
 
zRealmEncoder.__init__NFTr   r   r   r   r   past_key_valuesr   r   output_hidden_statesreturn_dictr   c                 C   s^  |	rdnd }|r
dnd }|r| j jrdnd }| jr%| jr%|r%td d}|r)dnd }t| jD ]^\}}|	r;||f }|d urC|| nd }|d urM|| nd }| jrc| jrc| |j	|||||||}n
||||||||}|d }|rz||d f7 }|r||d f }| j jr||d f }q0|	r||f }|
st
dd	 |||||fD S t|||||d
S )Nr%   zZ`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...Fr   rs   r   r0   c                 s   s    | ]	}|d ur|V  qd S r   r%   )r&   vr%   r%   r(   r)   n  s    z'RealmEncoder.forward.<locals>.<genexpr>)last_hidden_stater   r   
attentionscross_attentions)rZ   r  r  trainingr6   warning_once	enumerater  _gradient_checkpointing_func__call__tupler
   )r   r   r   r   r   r   r   r   r   r!  r"  all_hidden_statesall_self_attentionsall_cross_attentionsnext_decoder_cacheilayer_modulelayer_head_maskr   layer_outputsr%   r%   r(   r   ,  sz   


zRealmEncoder.forward)	NNNNNNFFT)rD   r   r   ry   rV   r   r   r   r   r   r   r
   r   r   r%   r%   r   r(   r  %  sD    		
r  c                       r   )RealmPoolerc                    s*   t    t|j|j| _t | _d S r   )rx   ry   r   r   r|   r   Tanh
activationr   r   r%   r(   ry     s   
zRealmPooler.__init__r   r   c                 C   s(   |d d df }|  |}| |}|S )Nr   )r   r7  )r   r   first_token_tensorpooled_outputr%   r%   r(   r     s   

zRealmPooler.forwardr   r%   r%   r   r(   r5    s    r5  c                   @   sL   e Zd ZU dZdZejed< dZe	e
ej  ed< dZe	e
ej  ed< dS )RealmEmbedderOutputa*  
    Outputs of [`RealmEmbedder`] models.

    Args:
        projected_score (`torch.FloatTensor` of shape `(batch_size, config.retriever_proj_size)`):

            Projected score.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    Nprojected_scorer   r%  )rD   r   r   r   r;  rV   r   __annotations__r   r   r   r%  r%   r%   r%   r(   r:    s
   
 r:  c                   @   s<   e Zd ZU dZdZejed< dZejed< dZ	ejed< dS )RealmScorerOutputa'  
    Outputs of [`RealmScorer`] models.

    Args:
        relevance_score (`torch.FloatTensor` of shape `(batch_size, config.num_candidates)`):
            The relevance score of document candidates (before softmax).
        query_score (`torch.FloatTensor` of shape `(batch_size, config.retriever_proj_size)`):
            Query score derived from the query embedder.
        candidate_score (`torch.FloatTensor` of shape `(batch_size, config.num_candidates, config.retriever_proj_size)`):
            Candidate score derived from the embedder.
    Nrelevance_scorequery_scorecandidate_score)
rD   r   r   r   r>  rV   r   r<  r?  r@  r%   r%   r%   r(   r=    s
   
 r=  c                   @   s   e Zd ZU dZdZejed< dZejed< dZ	ejed< dZ
ejed< dZejed< dZejed< dZejed	< dZejed
< dZejed< dZeeej  ed< dZeeej  ed< dS )RealmReaderOutputa+	  
    Outputs of [`RealmReader`] models.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `start_positions`, `end_positions`, `has_answers` are provided):
            Total loss.
        retriever_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `start_positions`, `end_positions`, `has_answers` are provided):
            Retriever loss.
        reader_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `start_positions`, `end_positions`, `has_answers` are provided):
            Reader loss.
        retriever_correct (`torch.BoolTensor` of shape `(config.searcher_beam_size,)`, *optional*):
            Whether or not an evidence block contains answer.
        reader_correct (`torch.BoolTensor` of shape `(config.reader_beam_size, num_candidates)`, *optional*):
            Whether or not a span candidate contains answer.
        block_idx (`torch.LongTensor` of shape `()`):
            The index of the retrieved evidence block in which the predicted answer is most likely.
        candidate (`torch.LongTensor` of shape `()`):
            The index of the retrieved span candidates in which the predicted answer is most likely.
        start_pos (`torch.IntTensor` of shape `()`):
            Predicted answer starting position in *RealmReader*'s inputs.
        end_pos (`torch.IntTensor` of shape `()`):
            Predicted answer ending position in *RealmReader*'s inputs.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    Nlossretriever_lossreader_lossretriever_correctreader_correct	block_idx	candidate	start_posend_posr   r%  )rD   r   r   r   rB  rV   r   r<  rC  rD  rE  
BoolTensorrF  rG  r   rH  rI  int32rJ  r   r   r   r%  r%   r%   r%   r(   rA    s   
 #rA  c                   @   s,   e Zd ZU dZdZeed< dZej	ed< dS )RealmForOpenQAOutputz

    Outputs of [`RealmForOpenQA`] models.

    Args:
        reader_output (`dict`):
            Reader output.
        predicted_answer_ids (`torch.LongTensor` of shape `(answer_sequence_length)`):
            Predicted answer ids.
    Nreader_outputpredicted_answer_ids)
rD   r   r   r   rN  dictr<  rO  rV   r   r%   r%   r%   r(   rM    s   
 rM  c                       $   e Zd Z fddZdd Z  ZS )RealmPredictionHeadTransformc                    sV   t    t|j|j| _t|jtrt	|j | _
n|j| _
tj|j|jd| _d S r   )rx   ry   r   r   r|   r   rA   r   r  r	   transform_act_fnr   r   r   r   r%   r(   ry     s   
z%RealmPredictionHeadTransform.__init__c                 C   s"   |  |}| |}| |}|S r   )r   rS  r   r  r%   r%   r(   r     s   


z$RealmPredictionHeadTransform.forwardrD   r   r   ry   r   r   r%   r%   r   r(   rR    s    	rR  c                       rQ  )RealmLMPredictionHeadc                    sL   t    t|| _tj|j|jdd| _t	t
|j| _| j| j_d S )NF)r/   )rx   ry   rR  	transformr   r   r|   r{   decoder	ParameterrV   r   r/   r   r   r%   r(   ry     s
   

zRealmLMPredictionHead.__init__c                 C   r  r   )rV  rW  r  r%   r%   r(   r   "  r  zRealmLMPredictionHead.forwardrT  r%   r%   r   r(   rU    s    rU  c                       rQ  )RealmOnlyMLMHeadc                    s   t    t|| _d S r   )rx   ry   rU  predictionsr   r   r%   r(   ry   )  s   
zRealmOnlyMLMHead.__init__c                 C   s   |  |}|S r   )rZ  )r   sequence_outputprediction_scoresr%   r%   r(   r   -  s   
zRealmOnlyMLMHead.forwardrT  r%   r%   r   r(   rY  (  s    rY  c                       rQ  )RealmScorerProjectionc                    s>   t    t|| _t|j|j| _tj	|j|j
d| _	d S r   )rx   ry   rU  rZ  r   r   r|   retriever_proj_sizer   r   r   r   r   r%   r(   ry   3  s   

zRealmScorerProjection.__init__c                 C   r  r   )r   r   r  r%   r%   r(   r   9  r  zRealmScorerProjection.forwardrT  r%   r%   r   r(   r]  2  s    r]  c                       rQ  )RealmReaderProjectionc                    sX   t    || _t|j|jd | _t|jd| _tj	|j|j
d| _t | _d S )Nr0   r   rn   )rx   ry   rZ   r   r   r|   span_hidden_sizedense_intermediatedense_outputr   reader_layer_norm_epslayer_normalizationReLUrelur   r   r%   r(   ry   @  s   
zRealmReaderProjection.__init__c                    s    fdd}t jfdd} |}|jddd\}}||\}}}	t j|d|d	}
t j|d|d	}|
| } |} |} |d}|||	|j	d
7 }|||fS )Nc                    s   j \}fdd t fddtjjD  \}}t|d}t|d}tjd|d}tjd|d}|| }|||fS )aK  
            Generate span candidates.

            Args:
                masks: <bool> [num_retrievals, max_sequence_len]

            Returns:
                starts: <int32> [num_spans] ends: <int32> [num_spans] span_masks: <int32> [num_retrievals, num_spans]
                whether spans locate in evidence block.
            c                    s6   t j|  d  jd}t j| d  jd}||fS )Nr   r   )rV   r   r   )widthcurrent_startscurrent_ends)masksmax_sequence_lenr%   r(   _spans_given_widthV  s   zRRealmReaderProjection.forward.<locals>.span_candidates.<locals>._spans_given_widthc                 3   s    | ]	} |d  V  qdS )r   Nr%   )r&   w)rm  r%   r(   r)   [  s    zIRealmReaderProjection.forward.<locals>.span_candidates.<locals>.<genexpr>r   rs   r   r   )rS   r@   r  rZ   max_span_widthrV   r   index_select)rk  r  startsendsstart_masks	end_masks
span_masksr   )rm  rk  rl  r(   span_candidatesI  s   
"
z6RealmReaderProjection.forward.<locals>.span_candidatesc                 S      d|  | t|j S N      ?typerV   finfominmaskrw   r%   r%   r(   mask_to_scoreh     z4RealmReaderProjection.forward.<locals>.mask_to_scorer0   rs   r   r   ro  rv   )
rV   float32ra  chunkrq  rf  rd  rb  squeezerw   )r   r   
block_maskrx  r  start_projectionend_projectioncandidate_startscandidate_endscandidate_maskcandidate_start_projectionscandidate_end_projectionscandidate_hiddenreader_logitsr%   rw  r(   r   H  s   



zRealmReaderProjection.forwardrT  r%   r%   r   r(   r_  ?  s    r_  aH  
    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
    it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.

    Parameters:
        config ([`RealmConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
a5
  
    Args:
        input_ids (`torch.LongTensor` of shape `({0})`):
            Indices of input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)
        token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
            1]`:

            - 0 corresponds to a *sentence A* token,
            - 1 corresponds to a *sentence B* token.

            [What are token type IDs?](../glossary#token-type-ids)
        position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
            config.max_position_embeddings - 1]`.

            [What are position IDs?](../glossary#position-ids)
        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
            Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
            model's internal embedding lookup matrix.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
c                   @   s,   e Zd ZdZeZeZdZdd Z	dd Z
dS )RealmPreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    realmc                 C   s   t |tjr |jjjd| jjd |jdur|jj	  dS dS t |tj
rC|jjjd| jjd |jdurA|jj|j 	  dS dS t |tjrX|jj	  |jjd dS dS )zInitialize the weightsg        )meanstdNr{  )rA   r   r   r,   rX   normal_rZ   initializer_ranger/   zero_rz   rm   r   fill_)r   moduler%   r%   r(   _init_weights  s   

z"RealmPreTrainedModel._init_weightsc                 G   sT   g }|D ]#}|du r| d q|j}t|dkr"|d|d f}| | q|S )z.Flatten inputs' shape to (-1, input_shape[-1])Nr0   rs   )r?   rS   rP   r   )r   inputsflattened_inputsr   r   r%   r%   r(   _flatten_inputs  s   z$RealmPreTrainedModel._flatten_inputsN)rD   r   r   r   r   config_classrk   load_tf_weightsbase_model_prefixr  r  r%   r%   r%   r(   r    s    r  c                       s^   e Zd ZdZd fdd	Zdd Zdd Zd	d
 Z													dddZ  Z	S )RealmBertModelz?
    Same as the original BertModel but remove docstrings.
    Tc                    sD   t  | || _t|| _t|| _|rt|nd | _| 	  d S r   )
rx   ry   rZ   rl   r   r  encoderr5  pooler	post_init)r   rZ   add_pooling_layerr   r%   r(   ry     s   

zRealmBertModel.__init__c                 C   s   | j jS r   r   r~   rw  r%   r%   r(   get_input_embeddings  s   z#RealmBertModel.get_input_embeddingsc                 C   s   || j _d S r   r  r   r   r%   r%   r(   set_input_embeddings  s   z#RealmBertModel.set_input_embeddingsc                 C   s*   |  D ]\}}| jj| j| qdS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsr  r  r
  r   )r   heads_to_pruner  r   r%   r%   r(   _prune_heads  s   zRealmBertModel._prune_headsNc                 C   sP  |d ur|n| j j}|d ur|n| j j}|d ur|n| j j}| j jr-|
d ur(|
n| j j}
nd}
|d ur;|d ur;td|d urJ| || | }n|d urW| d d }ntd|\}}|d urf|j	n|j	}|	d urv|	d d j
d nd}|d u rtj||| f|d}|d u rt| jdr| jjd d d |f }|||}|}n	tj|tj|d	}| ||}| j jr|d ur| \}}}||f}|d u rtj||d}| |}nd }| || j j}| j|||||d
}| j||||||	|
|||d
}|d }| jd ur| |nd }|s||f|dd   S t|||j|j|j|jdS )NFzDYou cannot specify both input_ids and inputs_embeds at the same timers   z5You have to specify either input_ids or inputs_embedsr   r0   rg  ru   r   )r   rr   ru   r   r   )	r   r   r   r   r   r   r   r!  r"  r   )r$  pooler_outputr   r   r%  r&  )rZ   r   r!  use_return_dictr   r   r   %warn_if_padding_and_no_attention_maskr   r   rS   rV   onesr   r   ru   r   r   r   get_extended_attention_maskinvert_attention_maskget_head_maskr  r  r  r   r   r   r%  r&  )r   r   r   ru   rr   r   r   r   r   r   r   r   r!  r"  r   
batch_sizer   r   r   r   r   extended_attention_maskencoder_batch_sizeencoder_sequence_lengthr  encoder_hidden_shapeencoder_extended_attention_maskembedding_outputencoder_outputsr[  r9  r%   r%   r(   r     s   
zRealmBertModel.forward)TNNNNNNNNNNNNN)
rD   r   r   r   ry   r  r  r  r   r   r%   r%   r   r(   r    s(    
r  z`The embedder of REALM outputting projected score that will be used to calculate relevance score.c                       s   e Zd ZdgZ fddZdd Zdd Zee	de
eed		
	
	
	
	
	
	
	
	
ddeej deej deej deej deej deej dee dee dee deeef fddZ  ZS )rI   zcls.predictions.decoder.biasc                    0   t  | t| j| _t| j| _|   d S r   )rx   ry   r  rZ   r  r]  r   r  r   r   r%   r(   ry   }  s   zRealmEmbedder.__init__c                 C   
   | j jjS r   r  r   r~   rw  r%   r%   r(   r       
z"RealmEmbedder.get_input_embeddingsc                 C      || j j_d S r   r  r  r%   r%   r(   r       z"RealmEmbedder.set_input_embeddingsbatch_size, sequence_lengthoutput_typer  Nr   r   ru   rr   r   r   r   r!  r"  r   c
                 C   sj   |	dur|	n| j j}	| j|||||||||	d	}
|
d }| |}|	s,|f|
dd  S t||
j|
jdS )a  
        Returns:

        Example:

        ```python
        >>> from transformers import AutoTokenizer, RealmEmbedder
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("google/realm-cc-news-pretrained-embedder")
        >>> model = RealmEmbedder.from_pretrained("google/realm-cc-news-pretrained-embedder")

        >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
        >>> outputs = model(**inputs)

        >>> projected_score = outputs.projected_score
        ```
        Nr   ru   rr   r   r   r   r!  r"  r   r0      )r;  r   r%  )rZ   r  r  r   r:  r   r%  )r   r   r   ru   rr   r   r   r   r!  r"  realm_outputsr  r;  r%   r%   r(   r     s*   !
zRealmEmbedder.forward)	NNNNNNNNN)rD   r   r   _tied_weights_keysry   r  r  r   REALM_INPUTS_DOCSTRINGformatr   r:  _CONFIG_FOR_DOCr   rV   r   r   r   r   r   r   r   r%   r%   r   r(   rI   v  sJ    
	

rI   zoThe scorer of REALM outputting relevance scores representing the score of document candidates (before softmax).c                !       s   e Zd ZdZd fdd	Zeedee	e
d													ddeej deej d	eej d
eej deej deej deej deej deej deej dee dee dee deee	f fddZ  ZS )RealmScorerz
    Args:
        query_embedder ([`RealmEmbedder`]):
            Embedder for input sequences. If not specified, it will use the same embedder as candidate sequences.
    Nc                    s8   t  | t| j| _|d ur|n| j| _|   d S r   )rx   ry   rI   rZ   embedderquery_embedderr  )r   rZ   r  r   r%   r(   ry     s   zRealmScorer.__init__r  r  r   r   ru   rr   candidate_input_idscandidate_attention_maskcandidate_token_type_idscandidate_inputs_embedsr   r   r   r!  r"  r   c                 C   s   |dur|n| j j}|du r|
du rtd|du r"|du r"td| j|||||	|
|||d	}| |||\}}}| j|||||	||||d	}|d }|d }|d| j j| j j}t	
d||}|si|||fS t|||dS )	a
  
        candidate_input_ids (`torch.LongTensor` of shape `(batch_size, num_candidates, sequence_length)`):
            Indices of candidate input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        candidate_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_candidates, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)
        candidate_token_type_ids (`torch.LongTensor` of shape `(batch_size, num_candidates, sequence_length)`, *optional*):
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
            1]`:

            - 0 corresponds to a *sentence A* token,
            - 1 corresponds to a *sentence B* token.

            [What are token type IDs?](../glossary#token-type-ids)
        candidate_inputs_embeds (`torch.FloatTensor` of shape `(batch_size * num_candidates, sequence_length, hidden_size)`, *optional*):
            Optionally, instead of passing `candidate_input_ids` you can choose to directly pass an embedded
            representation. This is useful if you want more control over how to convert *candidate_input_ids* indices
            into associated vectors than the model's internal embedding lookup matrix.

        Returns:

        Example:

        ```python
        >>> import torch
        >>> from transformers import AutoTokenizer, RealmScorer

        >>> tokenizer = AutoTokenizer.from_pretrained("google/realm-cc-news-pretrained-scorer")
        >>> model = RealmScorer.from_pretrained("google/realm-cc-news-pretrained-scorer", num_candidates=2)

        >>> # batch_size = 2, num_candidates = 2
        >>> input_texts = ["How are you?", "What is the item in the picture?"]
        >>> candidates_texts = [["Hello world!", "Nice to meet you!"], ["A cute cat.", "An adorable dog."]]

        >>> inputs = tokenizer(input_texts, return_tensors="pt")
        >>> candidates_inputs = tokenizer.batch_encode_candidates(candidates_texts, max_length=10, return_tensors="pt")

        >>> outputs = model(
        ...     **inputs,
        ...     candidate_input_ids=candidates_inputs.input_ids,
        ...     candidate_attention_mask=candidates_inputs.attention_mask,
        ...     candidate_token_type_ids=candidates_inputs.token_type_ids,
        ... )
        >>> relevance_score = outputs.relevance_score
        ```Nz5You have to specify either input_ids or input_embeds.zJYou have to specify either candidate_input_ids or candidate_inputs_embeds.r  r   rs   z
bd,bnd->bn)r>  r?  r@  )rZ   r  r   r  r  r  r   num_candidatesr^  rV   r   r=  )r   r   r   ru   rr   r  r  r  r  r   r   r   r!  r"  query_outputsflattened_input_idsflattened_attention_maskflattened_token_type_idscandidate_outputsr?  r@  r>  r%   r%   r(   r     sN   I

zRealmScorer.forwardr   r  )rD   r   r   r   ry   r   r  r  r   r=  r  r   rV   r   r   r   r   r   r   r   r%   r%   r   r(   r    s^    	
	

r  zrThe knowledge-augmented encoder of REALM outputting masked language model logits and marginal log-likelihood loss.c                       s   e Zd ZdgZ fddZdd Zdd Zdd	 Zd
d Ze	e
deeed												ddeej deej deej deej deej deej deej deej deej dee dee dee deeef fddZ  ZS )rH   zcls.predictions.decoderc                    r  r   )rx   ry   r  rZ   r  rY  r   r  r   r   r%   r(   ry   c  s   z!RealmKnowledgeAugEncoder.__init__c                 C   r  r   r  rw  r%   r%   r(   r  i  r  z-RealmKnowledgeAugEncoder.get_input_embeddingsc                 C   r  r   r  r  r%   r%   r(   r  l  r  z-RealmKnowledgeAugEncoder.set_input_embeddingsc                 C   r  r   r   rZ  rW  rw  r%   r%   r(   get_output_embeddingso  r  z.RealmKnowledgeAugEncoder.get_output_embeddingsc                 C   r  r   r  )r   new_embeddingsr%   r%   r(   set_output_embeddingsr  r  z.RealmKnowledgeAugEncoder.set_output_embeddingsz+batch_size, num_candidates, sequence_lengthr  Nr   r   ru   rr   r   r   r>  labelsmlm_maskr   r!  r"  r   c                 C   st  |dur|n| j j}| |||\}}}| j|||||||
||d	}|d }| |}|}d}|dur|du r;td| \}}|	du rNtj|tj	d}	n|	
tj	}	tdd}|d| j j}|d	| j jd}||||| j j| }|dd}|| }|d	}tt||	 t|	  }|s|f|d
d  }|dur|f| S |S t|||j|jdS )a  
        relevance_score (`torch.FloatTensor` of shape `(batch_size, num_candidates)`, *optional*):
            Relevance score derived from RealmScorer, must be specified if you want to compute the masked language
            modeling loss.

        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
            config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
            loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`

        mlm_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid calculating joint loss on certain positions. If not specified, the loss will not be masked.
            Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

        Returns:

        Example:

        ```python
        >>> import torch
        >>> from transformers import AutoTokenizer, RealmKnowledgeAugEncoder

        >>> tokenizer = AutoTokenizer.from_pretrained("google/realm-cc-news-pretrained-encoder")
        >>> model = RealmKnowledgeAugEncoder.from_pretrained(
        ...     "google/realm-cc-news-pretrained-encoder", num_candidates=2
        ... )

        >>> # batch_size = 2, num_candidates = 2
        >>> text = [["Hello world!", "Nice to meet you!"], ["The cute cat.", "The adorable dog."]]

        >>> inputs = tokenizer.batch_encode_candidates(text, max_length=10, return_tensors="pt")
        >>> outputs = model(**inputs)
        >>> logits = outputs.logits
        ```Nr  r   zZYou have to specify `relevance_score` when `labels` is specified in order to compute loss.rv   none)	reductionrs   r   r0   r  )rB  logitsr   r%  )rZ   r  r  r  r   r   r   rV   	ones_liker  r}  r   r   r{   tiler  log_softmax	unsqueeze	logsumexpnansumsumr   r   r%  )r   r   r   ru   rr   r   r   r>  r  r  r   r!  r"  r  r  r  joint_outputsjoint_outputr\  r@  masked_lm_lossr  r   loss_fct
mlm_logitsmlm_targetsmasked_lm_log_probcandidate_log_probjoint_gold_log_probmarginal_gold_log_probsr   r%   r%   r(   r   u  s^   9





 z RealmKnowledgeAugEncoder.forward)NNNNNNNNNNNN)rD   r   r   r  ry   r  r  r  r  r   r  r  r   r   r  r   rV   r   r   r   r   r   r   r   r%   r%   r   r(   rH   [  sd    
	

rH   zThe reader of REALM.c                #       s   e Zd Z fddZeedeee	d														dde
ej de
ej de
ej d	e
ej d
e
ej de
ej de
ej de
ej de
ej de
ej de
ej de
e de
e de
e deeef fddZ  ZS )rB   c                    s>   t  | |j| _t|| _t|| _t|| _| 	  d S r   )
rx   ry   
num_labelsr  r  rY  r   r_  
qa_outputsr  r   r   r%   r(   ry     s   


zRealmReader.__init__z!reader_beam_size, sequence_lengthr  Nr   r   ru   rr   r   r   r>  r  start_positionsend_positionshas_answersr   r!  r"  r   c           $      C   sF  |dur|n| j j}|du rtd|du rtd|d| j jk r'td| j|||||||||d	}|d }| ||d| j j \}}}t	|d| j j d}||7 }t
tj|dd	j}t
tj|dd	j}tj|d|d
}tj|d|d
}d}d}d}d}d}|	dur|
dur|durdd }dd }|d} |	d| }	|
d| }
|}t|}!||||	d| j j |
d| j j d}t|}"|||}||d|d}||!tj9 }||"tj9 }||  }|s||||f|dd  }#|dur|||||f|# S |#S t||||||||||j|jdS )ar  
        relevance_score (`torch.FloatTensor` of shape `(searcher_beam_size,)`, *optional*):
            Relevance score, which must be specified if you want to compute the logits and marginal log loss.
        block_mask (`torch.BoolTensor` of shape `(searcher_beam_size, sequence_length)`, *optional*):
            The mask of the evidence block, which must be specified if you want to compute the logits and marginal log
            loss.
        start_positions (`torch.LongTensor` of shape `(searcher_beam_size,)`, *optional*):
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
            are not taken into account for computing the loss.
        end_positions (`torch.LongTensor` of shape `(searcher_beam_size,)`, *optional*):
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
            are not taken into account for computing the loss.
        has_answers (`torch.BoolTensor` of shape `(searcher_beam_size,)`, *optional*):
            Whether or not the evidence block has answer(s).

        Returns:
        NzCYou have to specify `relevance_score` to calculate logits and loss.zOYou have to specify `block_mask` to separate question block and evidence block.r   zQThe input sequence length must be greater than or equal to config.max_span_width.r  r   rs   r   ro  c                 S   s\   t t t | ddt |d}t t t |ddt |d}t t ||dS )zCompute correct span.r   rs   r   )rV   eqr  rK   logical_and)r  r  gold_starts	gold_endsis_gold_startis_gold_endr%   r%   r(   compute_correct_candidatesV  s   z7RealmReader.forward.<locals>.compute_correct_candidatesc                 S   s@   t jfdd}t j| ||| jd dd}t j| dd}|| S )z3Loss based on the negative marginal log-likelihood.c                 S   ry  rz  r|  r  r%   r%   r(   r  f  r  zERealmReader.forward.<locals>.marginal_log_loss.<locals>.mask_to_scorerv   rs   r   )rV   r  r  rw   )r  
is_correctr  log_numeratorlog_denominatorr%   r%   r(   marginal_log_lossc  s   z.RealmReader.forward.<locals>.marginal_log_loss)r  r  r  r  r0   )rB  rC  rD  rE  rF  rG  rH  rI  rJ  r   r%  )rZ   r  r   r   rp  r  r  reader_beam_sizerV   r  argmaxmaxvaluesrq  clamprK   r   r}  r  r  rA  r   r%  )$r   r   r   ru   rr   r   r   r>  r  r  r  r  r   r!  r"  r   r[  r  r  r  retriever_logitspredicted_block_indexpredicted_candidatepredicted_startpredicted_end
total_lossrC  rD  rE  rF  r  r  ignored_indexany_retriever_correctany_reader_correctr   r%   r%   r(   r     s   &





zRealmReader.forward)NNNNNNNNNNNNNN)rD   r   r   ry   r   r  r  r   rA  r  r   rV   r   r   rK  r   r   r   r   r   r%   r%   r   r(   rB     sb    

	

rB   ay  
    Args:
        input_ids (`torch.LongTensor` of shape `({0})`):
            Indices of input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)
        token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
            1]`:

            - 0 corresponds to a *sentence A* token,
            - 1 corresponds to a *sentence B* token (should not be used in this model by design).

            [What are token type IDs?](../glossary#token-type-ids)
        answer_ids (`list` of shape `(num_answers, answer_length)`, *optional*):
            Answer ids for computing the marginal log-likelihood loss. Indices should be in `[-1, 0, ...,
            config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-1` are ignored (masked), the
            loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
z?`RealmForOpenQA` for end-to-end open domain question answering.c                       s   e Zd Zd fdd	Zedd Zdd Zee	de
eed					dd
eej deej deej deej dee deeef fddZ  ZS )rF   Nc              	      s`   t  | t|| _t|| _| dtdj	|j
|jftjtdd || _|   d S )N	block_embr%   cpu)r   rw   r   )rx   ry   rI   r  rB   r   r   rV   r   	new_emptynum_block_recordsr^  r  r   	retrieverr  )r   rZ   r  r   r%   r(   ry     s   



zRealmForOpenQA.__init__c                 C   s   | j r| jjS | jjS r   )r'  rZ   searcher_beam_sizer  rw  r%   r%   r(   r    s   z!RealmForOpenQA.searcher_beam_sizec                 C   s   | j || _ dS )zSend `self.block_emb` to a specific device.

        Args:
            device (`str` or `torch.device`):
                The device to which `self.block_emb` will be sent.
        N)r  r   )r   r   r%   r%   r(   block_embedding_to  s   z!RealmForOpenQA.block_embedding_toz1, sequence_lengthr  r   r   ru   
answer_idsr"  r   c                 C   s  |dur|n| j j}|dur|jd dkrtd| j|||dd}|d }td| j|| jj	}tj
|| jdd	\}	}
|
 }
tj| jd|
d
}| j|
 ||| j jd\}}}}|| jj	}|jtjj| jj	d}| |jtj |durtj|tj| jj	d}tj|tj| jj	d}tj|tj| jj	d}td| || jj	}| j|jd| j j |jd| j j |jd| j j |||||dd	}|j|j }||j|jd  }|s||fS t ||dS )a  
        Returns:

        Example:

        ```python
        >>> import torch
        >>> from transformers import RealmForOpenQA, RealmRetriever, AutoTokenizer

        >>> retriever = RealmRetriever.from_pretrained("google/realm-orqa-nq-openqa")
        >>> tokenizer = AutoTokenizer.from_pretrained("google/realm-orqa-nq-openqa")
        >>> model = RealmForOpenQA.from_pretrained("google/realm-orqa-nq-openqa", retriever=retriever)

        >>> question = "Who is the pioneer in modern computer science?"
        >>> question_ids = tokenizer([question], return_tensors="pt")
        >>> answer_ids = tokenizer(
        ...     ["alan mathison turing"],
        ...     add_special_tokens=False,
        ...     return_token_type_ids=False,
        ...     return_attention_mask=False,
        ... ).input_ids

        >>> reader_output, predicted_answer_ids = model(**question_ids, answer_ids=answer_ids, return_dict=False)
        >>> predicted_answer = tokenizer.decode(predicted_answer_ids)
        >>> loss = reader_output.loss
        ```Nr   r   z'The batch_size of the inputs must be 1.T)r   ru   r   r"  z	BD,QD->QBrs   )kr   ro  )
max_lengthrg  r   zD,BD->B)	r   r   ru   r>  r  r  r  r  r"  )rN  rO  )!rZ   r  rS   r   r  rV   r   r  r   r   topkr  r  rq  r  r  reader_seq_lenr   special_tokens_maskr}  r   logical_not_logical_and_ru   r   r   r   r  r   rG  rI  rJ  rM  )r   r   r   ru   r  r"  question_outputsquestion_projectionbatch_scoresr  retrieved_block_idsretrieved_block_embr  rI  rJ  concat_inputsr  retrieved_logitsrN  predicted_blockrO  r%   r%   r(   r     sV   %zRealmForOpenQA.forwardr   )NNNN)rD   r   r   ry   propertyr  r  r   REALM_FOR_OPEN_QA_DOCSTRINGr  r   rM  r  r   rV   r   r   r   r   r   r   r   r%   r%   r   r(   rF     s0    



rF   )Gr   r   r8   dataclassesr   typingr   r   r   rV   r   torch.nnr   activationsr	   modeling_outputsr
   r   r   r   modeling_utilsr   pytorch_utilsr   r   r   utilsr   r   r   r   configuration_realmr   
get_loggerrD   r6   _EMBEDDER_CHECKPOINT_FOR_DOC_ENCODER_CHECKPOINT_FOR_DOC_SCORER_CHECKPOINT_FOR_DOCr  #REALM_PRETRAINED_MODEL_ARCHIVE_LISTrk   Modulerl   r   r   r   r   r  r  r  r5  r:  r=  rA  rM  rR  rU  rY  r]  r_  REALM_START_DOCSTRINGr  r  r  rI   r  rH   rB   r*  rF   r%   r%   r%   r(   <module>   s   
lA 2W^1
C2( N   (!