o
    h:a                     @   s  d Z ddlZddlZddlmZ ddlmZmZmZmZ ddl	Z
ddlmZ ddlmZmZ ddlmZ dd	lmZ eeZeG d
d dZeG dd dZG dd dZG dd deZG dd deZG dd deZG dd deZG dd deZdS )zJ
Callbacks to use with the Trainer class and customize the training loop.
    N)	dataclass)DictListOptionalUnion)tqdm   )IntervalStrategy
has_length)TrainingArguments)loggingc                   @   s>  e Zd ZU dZdZee ed< dZe	ed< dZ
e	ed< dZe	ed< dZe	ed	< dZe	ed
< dZe	ed< dZe	ed< dZe	ed< dZeed< dZeeeef  ed< dZee ed< dZee ed< dZeed< dZeed< dZeed< dZeed< dZeeeeee	ef f ed< dd ZdefddZ e!defddZ"dS ) TrainerStatea  
    A class containing the [`Trainer`] inner state that will be saved along the model and optimizer when checkpointing
    and passed to the [`TrainerCallback`].

    <Tip>

    In all this class, one step is to be understood as one update step. When using gradient accumulation, one update
    step may require several forward and backward passes: if you use `gradient_accumulation_steps=n`, then one update
    step requires going through *n* batches.

    </Tip>

    Args:
        epoch (`float`, *optional*):
            Only set during training, will represent the epoch the training is at (the decimal part being the
            percentage of the current epoch completed).
        global_step (`int`, *optional*, defaults to 0):
            During training, represents the number of update steps completed.
        max_steps (`int`, *optional*, defaults to 0):
            The number of update steps to do during the current training.
        logging_steps (`int`, *optional*, defaults to 500):
            Log every X updates steps
        eval_steps (`int`, *optional*):
            Run an evaluation every X steps.
        save_steps (`int`, *optional*, defaults to 500):
            Save checkpoint every X updates steps.
        train_batch_size (`int`, *optional*):
            The batch size for the training dataloader. Only needed when
            `auto_find_batch_size` has been used.
        num_input_tokens_seen (`int`, *optional*, defaults to 0):
            The number of tokens seen during training (number of input tokens, not the number of prediction tokens).
        total_flos (`float`, *optional*, defaults to 0):
            The total number of floating operations done by the model since the beginning of training (stored as floats
            to avoid overflow).
        log_history (`List[Dict[str, float]]`, *optional*):
            The list of logs done since the beginning of training.
        best_metric (`float`, *optional*):
            When tracking the best model, the value of the best metric encountered so far.
        best_model_checkpoint (`str`, *optional*):
            When tracking the best model, the value of the name of the checkpoint for the best model encountered so
            far.
        is_local_process_zero (`bool`, *optional*, defaults to `True`):
            Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on
            several machines) main process.
        is_world_process_zero (`bool`, *optional*, defaults to `True`):
            Whether or not this process is the global main process (when training in a distributed fashion on several
            machines, this is only going to be `True` for one process).
        is_hyper_param_search (`bool`, *optional*, defaults to `False`):
            Whether we are in the process of a hyper parameter search using Trainer.hyperparameter_search. This will
            impact the way data will be logged in TensorBoard.
    Nepochr   global_step	max_stepsi  logging_steps
eval_steps
save_stepstrain_batch_sizenum_train_epochsnum_input_tokens_seen
total_floslog_historybest_metricbest_model_checkpointTis_local_process_zerois_world_process_zeroFis_hyper_param_search
trial_nametrial_paramsc                 C   s   | j d u r
g | _ d S d S N)r   self r#   S/var/www/html/ai/venv/lib/python3.10/site-packages/transformers/trainer_callback.py__post_init__k   s   

zTrainerState.__post_init__	json_pathc                 C   sX   t jt| dddd }t|ddd}|| W d   dS 1 s%w   Y  dS )	zDSave the content of this instance in JSON format inside `json_path`.   T)indent	sort_keys
wutf-8encodingN)jsondumpsdataclassesasdictopenwrite)r"   r&   json_stringfr#   r#   r$   save_to_jsono   s   "zTrainerState.save_to_jsonc                 C   sJ   t |ddd}| }W d   n1 sw   Y  | di t|S )z3Create an instance from the content of `json_path`.rr,   r-   Nr#   )r3   readr/   loads)clsr&   r6   textr#   r#   r$   load_from_jsonu   s   
zTrainerState.load_from_json)#__name__
__module____qualname____doc__r   r   float__annotations__r   intr   r   r   r   r   r   r   r   r   r   r   strr   r   r   boolr   r   r   r   r   r%   r7   classmethodr=   r#   r#   r#   r$   r   "   s0   
 4 r   c                   @   sf   e Zd ZU dZdZeed< dZeed< dZeed< dZ	eed< dZ
eed< dd	 Zd
d Zdd ZdS )TrainerControlaA  
    A class that handles the [`Trainer`] control flow. This class is used by the [`TrainerCallback`] to activate some
    switches in the training loop.

    Args:
        should_training_stop (`bool`, *optional*, defaults to `False`):
            Whether or not the training should be interrupted.

            If `True`, this variable will not be set back to `False`. The training will just stop.
        should_epoch_stop (`bool`, *optional*, defaults to `False`):
            Whether or not the current epoch should be interrupted.

            If `True`, this variable will be set back to `False` at the beginning of the next epoch.
        should_save (`bool`, *optional*, defaults to `False`):
            Whether or not the model should be saved at this step.

            If `True`, this variable will be set back to `False` at the beginning of the next step.
        should_evaluate (`bool`, *optional*, defaults to `False`):
            Whether or not the model should be evaluated at this step.

            If `True`, this variable will be set back to `False` at the beginning of the next step.
        should_log (`bool`, *optional*, defaults to `False`):
            Whether or not the logs should be reported at this step.

            If `True`, this variable will be set back to `False` at the beginning of the next step.
    Fshould_training_stopshould_epoch_stopshould_saveshould_evaluate
should_logc                 C   
   d| _ dS )z<Internal method that resets the variable for a new training.FN)rI   r!   r#   r#   r$   _new_training      
zTrainerControl._new_trainingc                 C   rN   )z9Internal method that resets the variable for a new epoch.FN)rJ   r!   r#   r#   r$   
_new_epoch   rP   zTrainerControl._new_epochc                 C   s   d| _ d| _d| _dS )z8Internal method that resets the variable for a new step.FN)rK   rL   rM   r!   r#   r#   r$   	_new_step   s   
zTrainerControl._new_stepN)r>   r?   r@   rA   rI   rF   rC   rJ   rK   rL   rM   rO   rQ   rR   r#   r#   r#   r$   rH   }   s   
 rH   c                   @   s.  e Zd ZdZdededefddZdededefddZdededefd	d
Z	dededefddZ
dededefddZdededefddZdededefddZdededefddZdededefddZdededefddZdededefddZdededefddZdededefddZdS ) TrainerCallbacka  
    A class for objects that will inspect the state of the training loop at some events and take some decisions. At
    each of those events the following arguments are available:

    Args:
        args ([`TrainingArguments`]):
            The training arguments used to instantiate the [`Trainer`].
        state ([`TrainerState`]):
            The current state of the [`Trainer`].
        control ([`TrainerControl`]):
            The object that is returned to the [`Trainer`] and can be used to make some decisions.
        model ([`PreTrainedModel`] or `torch.nn.Module`):
            The model being trained.
        tokenizer ([`PreTrainedTokenizer`]):
            The tokenizer used for encoding the data.
        optimizer (`torch.optim.Optimizer`):
            The optimizer used for the training steps.
        lr_scheduler (`torch.optim.lr_scheduler.LambdaLR`):
            The scheduler used for setting the learning rate.
        train_dataloader (`torch.utils.data.DataLoader`, *optional*):
            The current dataloader used for training.
        eval_dataloader (`torch.utils.data.DataLoader`, *optional*):
            The current dataloader used for training.
        metrics (`Dict[str, float]`):
            The metrics computed by the last evaluation phase.

            Those are only accessible in the event `on_evaluate`.
        logs  (`Dict[str, float]`):
            The values to log.

            Those are only accessible in the event `on_log`.

    The `control` object is the only one that can be changed by the callback, in which case the event that changes it
    should return the modified version.

    The argument `args`, `state` and `control` are positionals for all events, all the others are grouped in `kwargs`.
    You can unpack the ones you need in the signature of the event using them. As an example, see the code of the
    simple [`~transformers.PrinterCallback`].

    Example:

    ```python
    class PrinterCallback(TrainerCallback):
        def on_log(self, args, state, control, logs=None, **kwargs):
            _ = logs.pop("total_flos", None)
            if state.is_local_process_zero:
                print(logs)
    ```argsstatecontrolc                 K      dS )zS
        Event called at the end of the initialization of the [`Trainer`].
        Nr#   r"   rT   rU   rV   kwargsr#   r#   r$   on_init_end      zTrainerCallback.on_init_endc                 K   rW   )z<
        Event called at the beginning of training.
        Nr#   rX   r#   r#   r$   on_train_begin   r[   zTrainerCallback.on_train_beginc                 K   rW   )z6
        Event called at the end of training.
        Nr#   rX   r#   r#   r$   on_train_end   r[   zTrainerCallback.on_train_endc                 K   rW   )z<
        Event called at the beginning of an epoch.
        Nr#   rX   r#   r#   r$   on_epoch_begin   r[   zTrainerCallback.on_epoch_beginc                 K   rW   )z6
        Event called at the end of an epoch.
        Nr#   rX   r#   r#   r$   on_epoch_end   r[   zTrainerCallback.on_epoch_endc                 K   rW   )z
        Event called at the beginning of a training step. If using gradient accumulation, one training step might take
        several inputs.
        Nr#   rX   r#   r#   r$   on_step_begin      zTrainerCallback.on_step_beginc                 K   rW   )zU
        Event called at the end of an substep during gradient accumulation.
        Nr#   rX   r#   r#   r$   on_substep_end  r[   zTrainerCallback.on_substep_endc                 K   rW   )z
        Event called at the end of a training step. If using gradient accumulation, one training step might take
        several inputs.
        Nr#   rX   r#   r#   r$   on_step_end  ra   zTrainerCallback.on_step_endc                 K   rW   )z9
        Event called after an evaluation phase.
        Nr#   rX   r#   r#   r$   on_evaluate  r[   zTrainerCallback.on_evaluatec                 K   rW   )z=
        Event called after a successful prediction.
        Nr#   )r"   rT   rU   rV   metricsrY   r#   r#   r$   
on_predict  r[   zTrainerCallback.on_predictc                 K   rW   )z7
        Event called after a checkpoint save.
        Nr#   rX   r#   r#   r$   on_save   r[   zTrainerCallback.on_savec                 K   rW   )z;
        Event called after logging the last logs.
        Nr#   rX   r#   r#   r$   on_log&  r[   zTrainerCallback.on_logc                 K   rW   )z7
        Event called after a prediction step.
        Nr#   rX   r#   r#   r$   on_prediction_step,  r[   z"TrainerCallback.on_prediction_stepN)r>   r?   r@   rA   r   r   rH   rZ   r\   r]   r^   r_   r`   rb   rc   rd   rf   rg   rh   ri   r#   r#   r#   r$   rS      s    1rS   c                   @   sb  e Zd ZdZdd Zdd Zdd Zdd	 Zed
d Z	de
dedefddZde
dedefddZde
dedefddZde
dedefddZde
dedefddZde
dedefddZde
dedefddZde
dedefddZde
dedefdd Zde
dedefd!d"Zde
dedefd#d$Zde
dedefd%d&Zde
dedefd'd(Zd)d* Zd+S ),CallbackHandlerz>Internal class that just calls the list of callbacks in order.c                 C   sj   g | _ |D ]}| | q|| _|| _|| _|| _d | _d | _tdd | j D s3t	
d| j  d S d S )Nc                 s   s    | ]}t |tV  qd S r    )
isinstanceDefaultFlowCallback.0cbr#   r#   r$   	<genexpr>A  s    z+CallbackHandler.__init__.<locals>.<genexpr>zThe Trainer will not work properly if you don't have a `DefaultFlowCallback` in its callbacks. You
should add one before training with `trainer.add_callback(DefaultFlowCallback). The current list ofcallbacks is
:)	callbacksadd_callbackmodel	tokenizer	optimizerlr_schedulertrain_dataloadereval_dataloaderanyloggerwarningcallback_list)r"   rq   rs   rt   ru   rv   ro   r#   r#   r$   __init__6  s    zCallbackHandler.__init__c                 C   sh   t |tr| n|}t |tr|n|j}|dd | jD v r,td| dd | j  | j| d S )Nc                 S   s   g | ]}|j qS r#   )	__class__)rn   cr#   r#   r$   
<listcomp>L  s    z0CallbackHandler.add_callback.<locals>.<listcomp>zYou are adding a zH to the callbacks of this Trainer, but there is already one. The currentzlist of callbacks is
:)rk   typer~   rq   rz   r{   r|   append)r"   callbackro   cb_classr#   r#   r$   rr   I  s   
zCallbackHandler.add_callbackc                 C   sd   t |tr| jD ]}t ||r| j| |  S qd S | jD ]}||kr/| j| |  S qd S r    rk   r   rq   remover"   r   ro   r#   r#   r$   pop_callbackT  s   



zCallbackHandler.pop_callbackc                 C   sF   t |tr| jD ]}t ||r| j|  d S qd S | j| d S r    r   r   r#   r#   r$   remove_callback`  s   


zCallbackHandler.remove_callbackc                 C   s   d dd | jD S )Nr*   c                 s   s    | ]}|j jV  qd S r    )r~   r>   rm   r#   r#   r$   rp   k  s    z0CallbackHandler.callback_list.<locals>.<genexpr>)joinrq   r!   r#   r#   r$   r|   i  s   zCallbackHandler.callback_listrT   rU   rV   c                 C      |  d|||S )NrZ   
call_eventr"   rT   rU   rV   r#   r#   r$   rZ   m     zCallbackHandler.on_init_endc                 C      d|_ | d|||S )NFr\   )rI   r   r   r#   r#   r$   r\   p     zCallbackHandler.on_train_beginc                 C   r   )Nr]   r   r   r#   r#   r$   r]   t  r   zCallbackHandler.on_train_endc                 C   r   )NFr^   )rJ   r   r   r#   r#   r$   r^   w  r   zCallbackHandler.on_epoch_beginc                 C   r   )Nr_   r   r   r#   r#   r$   r_   {  r   zCallbackHandler.on_epoch_endc                 C   s"   d|_ d|_d|_| d|||S )NFr`   )rM   rL   rK   r   r   r#   r#   r$   r`   ~  s   zCallbackHandler.on_step_beginc                 C   r   )Nrb   r   r   r#   r#   r$   rb     r   zCallbackHandler.on_substep_endc                 C   r   )Nrc   r   r   r#   r#   r$   rc     r   zCallbackHandler.on_step_endc                 C      d|_ | jd||||dS )NFrd   re   )rL   r   r"   rT   rU   rV   re   r#   r#   r$   rd        zCallbackHandler.on_evaluatec                 C   s   | j d||||dS )Nrf   r   r   r   r#   r#   r$   rf     s   zCallbackHandler.on_predictc                 C   r   )NFrg   )rK   r   r   r#   r#   r$   rg     r   zCallbackHandler.on_savec                 C   r   )NFrh   )logs)rM   r   )r"   rT   rU   rV   r   r#   r#   r$   rh     r   zCallbackHandler.on_logc                 C   r   )Nri   r   r   r#   r#   r$   ri     r   z"CallbackHandler.on_prediction_stepc              
   K   sP   | j D ]"}t|||||f| j| j| j| j| j| jd|}|d ur%|}q|S )N)rs   rt   ru   rv   rw   rx   )rq   getattrrs   rt   ru   rv   rw   rx   )r"   eventrT   rU   rV   rY   r   resultr#   r#   r$   r     s&   

zCallbackHandler.call_eventN)r>   r?   r@   rA   r}   rr   r   r   propertyr|   r   r   rH   rZ   r\   r]   r^   r_   r`   rb   rc   rd   rf   rg   rh   ri   r   r#   r#   r#   r$   rj   3  s,    	
rj   c                   @   s<   e Zd ZdZdededefddZdededefddZd	S )
rl   zx
    A [`TrainerCallback`] that handles the default flow of the training loop for logs, evaluation and checkpoints.
    rT   rU   rV   c                 K   s   |j dkr|jrd|_|jtjkr|j |j dkrd|_|jtjkr3|j |j dkr3|j	|j kr3d|_
|jtjkrI|jdkrI|j |j dkrId|_|j |jkrRd|_|S )Nr   Tr   )r   logging_first_steprM   logging_strategyr	   STEPSr   evaluation_strategyr   
eval_delayrL   save_strategyr   rK   r   rI   rX   r#   r#   r$   rc     s   
zDefaultFlowCallback.on_step_endc                 K   sF   |j tjkr	d|_|jtjkr|j|jkrd|_|jtjkr!d|_	|S )NT)
r   r	   EPOCHrM   r   r   r   rL   r   rK   rX   r#   r#   r$   r_     s   z DefaultFlowCallback.on_epoch_endN)	r>   r?   r@   rA   r   r   rH   rc   r_   r#   r#   r#   r$   rl     s    rl   c                   @   sT   e Zd ZdZdd Zdd Zdd Zdd	d
Zdd Zdd Z	dddZ
dd ZdS )ProgressCallbackzU
    A [`TrainerCallback`] that displays the progress of training or evaluation.
    c                 C   s   d | _ d | _d S r    )training_barprediction_barr!   r#   r#   r$   r}     s   
zProgressCallback.__init__c                 K   s    |j rt|jdd| _d| _d S )NT)totaldynamic_ncolsr   )r   r   r   r   current_steprX   r#   r#   r$   r\     s   
zProgressCallback.on_train_beginc                 K   s*   |j r| j|j| j  |j| _d S d S r    )r   r   updater   r   rX   r#   r#   r$   rc     s   zProgressCallback.on_step_endNc                 K   sJ   |j r!t|r#| jd u rtt|| jd u dd| _| jd d S d S d S )NT)r   leaver   r   )r   r
   r   r   lenr   r   )r"   rT   rU   rV   rx   rY   r#   r#   r$   ri     s   
z#ProgressCallback.on_prediction_stepc                 K   (   |j r| jd ur| j  d | _d S d S r    r   r   closerX   r#   r#   r$   rd     
   


zProgressCallback.on_evaluatec                 K   r   r    r   rX   r#   r#   r$   rf     r   zProgressCallback.on_predictc                 K   s8   |j r| jd ur|dd }| jt| d S d S d S Nr   )r   r   popr4   rE   r"   rT   rU   rV   r   rY   _r#   r#   r$   rh   	  s   zProgressCallback.on_logc                 K   s   |j r| j  d | _d S d S r    )r   r   r   rX   r#   r#   r$   r]     s   

zProgressCallback.on_train_endr    )r>   r?   r@   rA   r}   r\   rc   ri   rd   rf   rh   r]   r#   r#   r#   r$   r     s    

r   c                   @   s   e Zd ZdZdddZdS )PrinterCallbackz?
    A bare [`TrainerCallback`] that just prints the logs.
    Nc                 K   s"   | dd }|jrt| d S d S r   )r   r   printr   r#   r#   r$   rh     s   zPrinterCallback.on_logr    )r>   r?   r@   rA   rh   r#   r#   r#   r$   r     s    r   c                   @   s@   e Zd ZdZddedee fddZdd	 Zd
d Z	dd Z
dS )EarlyStoppingCallbacka1  
    A [`TrainerCallback`] that handles early stopping.

    Args:
        early_stopping_patience (`int`):
            Use with `metric_for_best_model` to stop training when the specified metric worsens for
            `early_stopping_patience` evaluation calls.
        early_stopping_threshold(`float`, *optional*):
            Use with TrainingArguments `metric_for_best_model` and `early_stopping_patience` to denote how much the
            specified metric must improve to satisfy early stopping conditions. `

    This callback depends on [`TrainingArguments`] argument *load_best_model_at_end* functionality to set best_metric
    in [`TrainerState`]. Note that if the [`TrainingArguments`] argument *save_steps* differs from *eval_steps*, the
    early stopping will not occur until the next save step.
    r           early_stopping_patienceearly_stopping_thresholdc                 C   s   || _ || _d| _d S )Nr   )r   r   early_stopping_patience_counter)r"   r   r   r#   r#   r$   r}   0  s   
zEarlyStoppingCallback.__init__c                 C   sX   |j rtjntj}|jd u s|||jr#t||j | jkr#d| _d S |  jd7  _d S )Nr   r   )greater_is_betternpgreaterlessr   absr   r   )r"   rT   rU   rV   metric_valueoperatorr#   r#   r$   check_metric_value6  s   


z(EarlyStoppingCallback.check_metric_valuec                 K   s8   |j sJ d|jd usJ d|jtjksJ dd S )Nz<EarlyStoppingCallback requires load_best_model_at_end = Truez?EarlyStoppingCallback requires metric_for_best_model is definedzAEarlyStoppingCallback requires IntervalStrategy of steps or epoch)load_best_model_at_endmetric_for_best_modelr   r	   NOrX   r#   r#   r$   r\   A  s   z$EarlyStoppingCallback.on_train_beginc                 K   sl   |j }|dsd| }||}|d u r!td| d d S | |||| | j| jkr4d|_d S d S )Neval_z@early stopping required metric_for_best_model, but did not find z so early stopping is disabledT)	r   
startswithgetrz   r{   r   r   r   rI   )r"   rT   rU   rV   re   rY   metric_to_checkr   r#   r#   r$   rd   J  s   




z!EarlyStoppingCallback.on_evaluateN)r   r   )r>   r?   r@   rA   rD   r   rB   r}   r   r\   rd   r#   r#   r#   r$   r     s    	r   )rA   r1   r/   r   typingr   r   r   r   numpyr   	tqdm.autor   trainer_utilsr	   r
   training_argsr   utilsr   
get_loggerr>   rz   r   rH   rS   rj   rl   r   r   r   r#   r#   r#   r$   <module>   s,   
Z1 }22