o
    hu                     @   s  d dl Z d dlZd dlZd dlZd dlZd dlZd dlZd dlmZm	Z	m
Z
mZmZ d dlmZmZmZmZ d dlmZmZmZ ddlmZmZmZ ddlmZ ddlmZ dd	lmZm Z m!Z!m"Z"m#Z#m$Z$ d
dl%m&Z&m'Z' e$(e)Z*e" rzd dl+Z+e! rd dl,m-Z- dZ.d2ddZ/dZ0G dd dZ1G dd de1Z2G dd de1Z3de1fddZ4dd Z5ddddd d!d"d#d$d%d&
Z6d'd( Z7d)d* Z8d3d,d-Z9d.d/ Z:G d0d1 d1Z;dS )4    N)AnyDictListOptionalUnion)create_repohf_hub_downloadmetadata_updateupload_folder)RepositoryNotFoundErrorbuild_hf_headersget_session   )custom_object_saveget_class_from_dynamic_moduleget_imports)is_pil_image)AutoProcessor)CONFIG_NAMEcached_fileis_accelerate_availableis_torch_availableis_vision_availablelogging   )handle_agent_inputshandle_agent_outputs)send_to_devicetool_config.jsonc                 K   s   |d ur|S zt | tfddi| W dS  tyD   zt | tfddi| W Y dS  ty9   td|  d tyC   Y Y dS w  tyM   Y dS w )N	repo_typespacemodel`z9` does not seem to be a valid repo identifier on the Hub.)r   TOOL_CONFIG_FILEr   EnvironmentError	Exception)repo_idr   
hub_kwargs r(   M/var/www/html/ai/venv/lib/python3.10/site-packages/transformers/tools/base.pyget_repo_type7   s"   r*   zufrom transformers import launch_gradio_demo
from {module_name} import {class_name}

launch_gradio_demo({class_name})
c                   @   s   e Zd ZU dZdZeed< dZeed< ee ed< ee ed< dd	 Z	d
d Z
dd Zdd Ze			d!dedee dee defddZ				d"dededee deeeef  dedefddZedd  ZdS )#Toola  
    A base class for the functions used by the agent. Subclass this and implement the `__call__` method as well as the
    following class attributes:

    - **description** (`str`) -- A short description of what your tool does, the inputs it expects and the output(s) it
      will return. For instance 'This is a tool that downloads a file from a `url`. It takes the `url` as input, and
      returns the text contained in the file'.
    - **name** (`str`) -- A performative name that will be used for your tool in the prompt to the agent. For instance
      `"text-classifier"` or `"image_generator"`.
    - **inputs** (`List[str]`) -- The list of modalities expected for the inputs (in the same order as in the call).
      Modalitiies should be `"text"`, `"image"` or `"audio"`. This is only used by `launch_gradio_demo` or to make a
      nice space from your tool.
    - **outputs** (`List[str]`) -- The list of modalities returned but the tool (in the same order as the return of the
      call method). Modalitiies should be `"text"`, `"image"` or `"audio"`. This is only used by `launch_gradio_demo`
      or to make a nice space from your tool.

    You can also override the method [`~Tool.setup`] if your tool as an expensive operation to perform before being
    usable (such as loading a model). [`~Tool.setup`] will be called the first time you use your tool, but not at
    instantiation.
    zThis is a tool that ...description nameinputsoutputsc                 O   s
   d| _ d S )NFis_initializedselfargskwargsr(   r(   r)   __init__m   s   
zTool.__init__c                 O   s   t dS )Nz-Write this method in your subclass of `Tool`.)NotImplementedr3   r(   r(   r)   __call__p   s   zTool.__call__c                 C   s
   d| _ dS )z
        Overwrite this method here for any operation that is expensive and needs to be executed before you start using
        your tool. Such as loading a big model.
        TNr1   r4   r(   r(   r)   setups   s   
z
Tool.setupc                 C   s  t j|dd | jdkrtd|  d| dt| |}| jj}|dd }| d| jj }t j	|d	}t j
|rZt|d
dd}t|}W d   n1 sTw   Y  ni }|| j| jd}t|ddd}|tj|dddd  W d   n1 sw   Y  t j	|d}	t|	ddd}|tj|| jjd W d   n1 sw   Y  t j	|d}
g }|D ]	}|t| qtt|}t|
ddd}|d	|d  W d   dS 1 sw   Y  dS )a  
        Saves the relevant code files for your tool so it can be pushed to the Hub. This will copy the code of your
        tool in `output_dir` as well as autogenerate:

        - a config file named `tool_config.json`
        - an `app.py` file so that your tool can be converted to a space
        - a `requirements.txt` containing the names of the module used by your tool (as detected when inspecting its
          code)

        You should only use this method to save tools that are defined in a separate module (not `__main__`).

        Args:
            output_dir (`str`): The folder in which you want to save your tool.
        T)exist_ok__main__z We can't save the code defining z in z{ as it's been defined in __main__. You have to put this code in a separate module so we can include it in the saved folder..r   rutf-8encodingN)
tool_classr,   r.   wr   )indent	sort_keys
zapp.py)module_name
class_namezrequirements.txt)osmakedirs
__module__
ValueErrorr   	__class__split__name__pathjoinisfileopenjsonloadr,   r.   writedumpsAPP_FILE_TEMPLATEformatextendr   listset)r4   
output_dirmodule_filesrI   last_module	full_nameconfig_fileftool_configapp_filerequirements_fileimportsmoduler(   r(   r)   savez   s@   

"z	Tool.saveNFr&   model_repo_idtokenremotec                    s  |r|du rt  }||vrtd| d|| }g d  fdd| D }t|fi ||d< t|tfd|i|d	d	d
}|du}	|du rZt|tfd|i|d	d	d
}|du ret| dt|dd}
t	
|
}W d   n1 s{w   Y  |	sd|vrt| d|d }n|}|d }t||fd|i|}t|jdkr|d |_|j|d krt|j d |d |_t|jdkr|d |_|j|d krt|j d |d |_|rt|||dS ||fd|i|S )aw  
        Loads a tool defined on the Hub.

        Args:
            repo_id (`str`):
                The name of the repo on the Hub where your tool is defined.
            model_repo_id (`str`, *optional*):
                If your tool uses a model and you want to use a different model than the default, you can pass a second
                repo ID or an endpoint url to this argument.
            token (`str`, *optional*):
                The token to identify you on hf.co. If unset, will use the token generated when running
                `huggingface-cli login` (stored in `~/.huggingface`).
            remote (`bool`, *optional*, defaults to `False`):
                Whether to use your tool by downloading the model or (if it is available) with an inference endpoint.
            kwargs (additional keyword arguments, *optional*):
                Additional keyword arguments that will be split in two: all arguments relevant to the Hub (such as
                `cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the
                others will be passed along to its init.
        N'Could not infer a default endpoint for :, you need to pass one using the `model_repo_id` argument.)	cache_dirforce_downloadresume_downloadproxiesrevisionr   	subfolderlocal_files_onlyc                    s   i | ]\}}| v r||qS r(   r(   ).0kvhub_kwargs_namesr(   r)   
<dictcomp>   s    z!Tool.from_hub.<locals>.<dictcomp>r   rl   F)%_raise_exceptions_for_missing_entries'_raise_exceptions_for_connection_errorszY does not appear to provide a valid configuration in `tool_config.json` or `config.json`.rA   rB   custom_toolzO does not provide a mapping to custom tools in its configuration `config.json`.rD   r   r.   z_ implements a different name in its configuration and class. Using the tool configuration name.r,   zm implements a different description in its configuration and class. Using the tool configuration description.rl   rD   )get_default_endpointsrN   itemsr*   r   r#   r   r$   rU   rV   rW   r   lenr.   loggerwarningrQ   r,   
RemoteTool)clsr&   rk   rl   rm   r6   	endpointsr'   resolved_config_fileis_tool_configreaderconfigr   rD   r(   rz   r)   from_hub   s   










zTool.from_hubUpload toolcommit_messageprivate	create_prreturnc                 C   s   t |||dddd}|j}t|ddgidd t (}| | td| d	d
t	
|  t|||||ddW  d   S 1 sFw   Y  dS )a  
        Upload the tool to the Hub.

        Parameters:
            repo_id (`str`):
                The name of the repository you want to push your tool to. It should contain your organization name when
                pushing to a given organization.
            commit_message (`str`, *optional*, defaults to `"Upload tool"`):
                Message to commit while pushing.
            private (`bool`, *optional*):
                Whether or not the repository created should be private.
            token (`bool` or `str`, *optional*):
                The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated
                when running `huggingface-cli login` (stored in `~/.huggingface`).
            create_pr (`bool`, *optional*, defaults to `False`):
                Whether or not to create a PR with the uploaded files or directly commit.
        Tr    gradio)r&   rl   r   r<   r   	space_sdktagstoolr   z!Uploading the following files to z: ,)r&   r   folder_pathrl   r   r   N)r   r&   r	   tempfileTemporaryDirectoryrj   r   inforS   rK   listdirr
   )r4   r&   r   r   rl   r   repo_urlwork_dirr(   r(   r)   push_to_hub  s"   

"$zTool.push_to_hubc                 C   s    G dd dt }| j|_|| S )z8
        Creates a [`Tool`] from a gradio tool.
        c                       s   e Zd Z fddZ  ZS )z+Tool.from_gradio.<locals>.GradioToolWrapperc                    s   t    |j| _|j| _d S N)superr7   r.   r,   )r4   _gradio_toolrO   r(   r)   r7   Q  s   
z4Tool.from_gradio.<locals>.GradioToolWrapper.__init__)rQ   rM   __qualname__r7   __classcell__r(   r(   r   r)   GradioToolWrapperP  s    r   )r+   runr9   )gradio_toolr   r(   r(   r)   from_gradioJ  s   zTool.from_gradio)NNF)r   NNF)rQ   rM   r   __doc__r,   str__annotations__r.   r   r7   r9   r;   rj   classmethodr   boolr   r   r   staticmethodr   r(   r(   r(   r)   r+   Q   sT   
 6p
,r+   c                   @   s2   e Zd ZdZdddZdd Zdd Zd	d
 ZdS )r   at  
    A [`Tool`] that will make requests to an inference endpoint.

    Args:
        endpoint_url (`str`, *optional*):
            The url of the endpoint to use.
        token (`str`, *optional*):
            The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated when
            running `huggingface-cli login` (stored in `~/.huggingface`).
        tool_class (`type`, *optional*):
            The corresponding `tool_class` if this is a remote version of an existing tool. Will help determine when
            the output should be converted to another type (like images).
    Nc                 C   s   || _ t||d| _|| _d S )Nrl   )endpoint_urlEndpointClientclientrD   )r4   r   rl   rD   r(   r(   r)   r7   i  s   
zRemoteTool.__init__c                 O   sB  |  }t|dkr| jdurbt| jtr| jj}n| jj}t|j	}dd |
 D }|d dkr9|dd }t|t|krSt| j dt| dt| d	t||D ]\}}|||< qXn&t|dkrltd
t|dkrt|d rd| j|d iS d|d iS |
 D ]\}	}
t|
r| j|
||	< qd|iS )aP  
        Prepare the inputs received for the HTTP client sending data to the endpoint. Positional arguments will be
        matched with the signature of the `tool_class` if it was provided at instantation. Images will be encoded into
        bytes.

        You can override this method in your custom class of [`RemoteTool`].
        r   Nc                 S   s*   g | ]\}}|j tjjtjjfvr|qS r(   )kindinspect_ParameterKindVAR_POSITIONALVAR_KEYWORD)rw   rx   pr(   r(   r)   
<listcomp>  s
    z-RemoteTool.prepare_inputs.<locals>.<listcomp>r4   r   z only accepts z arguments but z were given.z4A `RemoteTool` can only accept one positional input.r/   )copyr   rD   
issubclassPipelineToolencoder9   r   	signature
parametersr   rN   zipr   r   encode_image)r4   r5   r6   r/   call_methodr   r   argr.   keyvaluer(   r(   r)   prepare_inputsn  s<   


zRemoteTool.prepare_inputsc                 C   s   |S )z
        You can override this method in your custom class of [`RemoteTool`] to apply some custom post-processing of the
        outputs of the endpoint.
        r(   r4   r0   r(   r(   r)   extract_outputs  s   zRemoteTool.extract_outputsc                 O   s   t |i |\}}| jd uo| jjdgk}| j|i |}t|tr/| jdi |d|i}n| j||d}t|trLt|dkrLt|d trL|d }t	|| jd urW| jjnd }| 
|S )Nimageoutput_image)r   r   r   r(   )r   rD   r0   r   
isinstancedictr   r]   r   r   r   )r4   r5   r6   r   r/   r0   r(   r(   r)   r9     s   
$
zRemoteTool.__call__)NNN)rQ   rM   r   r   r7   r   r   r9   r(   r(   r(   r)   r   Z  s    
+r   c                       sl   e Zd ZdZeZdZeZdZ							d fdd	Z	 fddZ
dd Zd	d
 Zdd Zdd Z  ZS )r   a0	  
    A [`Tool`] tailored towards Transformer models. On top of the class attributes of the base class [`Tool`], you will
    need to specify:

    - **model_class** (`type`) -- The class to use to load the model in this tool.
    - **default_checkpoint** (`str`) -- The default checkpoint that should be used when the user doesn't specify one.
    - **pre_processor_class** (`type`, *optional*, defaults to [`AutoProcessor`]) -- The class to use to load the
      pre-processor
    - **post_processor_class** (`type`, *optional*, defaults to [`AutoProcessor`]) -- The class to use to load the
      post-processor (when different from the pre-processor).

    Args:
        model (`str` or [`PreTrainedModel`], *optional*):
            The name of the checkpoint to use for the model, or the instantiated model. If unset, will default to the
            value of the class attribute `default_checkpoint`.
        pre_processor (`str` or `Any`, *optional*):
            The name of the checkpoint to use for the pre-processor, or the instantiated pre-processor (can be a
            tokenizer, an image processor, a feature extractor or a processor). Will default to the value of `model` if
            unset.
        post_processor (`str` or `Any`, *optional*):
            The name of the checkpoint to use for the post-processor, or the instantiated pre-processor (can be a
            tokenizer, an image processor, a feature extractor or a processor). Will default to the `pre_processor` if
            unset.
        device (`int`, `str` or `torch.device`, *optional*):
            The device on which to execute the model. Will default to any accelerator available (GPU, MPS etc...), the
            CPU otherwise.
        device_map (`str` or `dict`, *optional*):
            If passed along, will be used to instantiate the model.
        model_kwargs (`dict`, *optional*):
            Any keyword argument to send to the model instantiation.
        token (`str`, *optional*):
            The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated when
            running `huggingface-cli login` (stored in `~/.huggingface`).
        hub_kwargs (additional keyword arguments, *optional*):
            Any additional keyword argument to send to the methods that will load the data from the Hub.
    Nc           	         s   t  stdt std|d u r| jd u rtd| j}|d u r$|}|| _|| _|| _|| _|| _	|d u r9i n|| _
|d urE|| j
d< || _|| jd< t   d S )N/Please install torch in order to use this tool.z4Please install accelerate in order to use this tool.zHThis tool does not implement a default checkpoint, you need to pass one.
device_maprl   )r   ImportErrorr   default_checkpointrN   r!   pre_processorpost_processordevicer   model_kwargsr'   r   r7   )	r4   r!   r   r   r   r   r   rl   r'   r   r(   r)   r7     s*   


zPipelineTool.__init__c                    s   t | jtr| jj| jfi | j| _t | jtr)| jj| jfi | j| j| _| j	du r3| j| _	nt | j	trF| j
j| j	fi | j| _	| jdu r`| jdur\t| jj d | _nt | _| jdu rl| j| j t   dS )z^
        Instantiates the `pre_processor`, `model` and `post_processor` if necessary.
        Nr   )r   r   r   pre_processor_classfrom_pretrainedr'   r!   model_classr   r   post_processor_classr   r   r]   hf_device_mapvaluesget_default_devicetor   r;   r:   r   r(   r)   r;     s    




zPipelineTool.setupc                 C   
   |  |S )zQ
        Uses the `pre_processor` to prepare the inputs for the `model`.
        )r   )r4   
raw_inputsr(   r(   r)   r        
zPipelineTool.encodec                 C   s<   t   | jdi |W  d   S 1 sw   Y  dS )z7
        Sends the inputs through the `model`.
        Nr(   )torchno_gradr!   )r4   r/   r(   r(   r)   forward!  s   
$zPipelineTool.forwardc                 C   r   )zG
        Uses the `post_processor` to decode the model output.
        )r   r   r(   r(   r)   decode(  r   zPipelineTool.decodec                 O   sf   t |i |\}}| js|   | j|i |}t|| j}| |}t|d}| |}t|| j	S )Ncpu)
r   r2   r;   r   r   r   r   r   r   r0   )r4   r5   r6   encoded_inputsr0   decoded_outputsr(   r(   r)   r9   .  s   


zPipelineTool.__call__)NNNNNNN)rQ   rM   r   r   r   r   r   r   r   r7   r;   r   r   r   r9   r   r(   r(   r   r)   r     s&    %%r   rD   c                    sZ   zddl }W n ty   tdw |    fdd}|j|| j| j| j jd  dS )z
    Launches a gradio demo for a tool. The corresponding tool class needs to properly implement the class attributes
    `inputs` and `outputs`.

    Args:
        tool_class (`type`): The class of the tool for which to launch the demo.
    r   Nz<Gradio should be installed in order to launch a gradio demo.c                     s    | i |S r   r(   )r5   r6   r   r(   r)   fnL  s   zlaunch_gradio_demo.<locals>.fn)r   r/   r0   titlearticle)r   r   	Interfacer/   r0   rQ   r,   launch)rD   grr   r(   r   r)   launch_gradio_demo=  s   
r   c                   C   sX   t d t stdtjj rtjj rt	dS tj
 r't	dS t	dS )Nu   `get_default_device` is deprecated and will be replaced with `accelerate`'s `PartialState().default_device` in version 4.38 of 🤗 Transformers. r   mpscudar   )r   r   r   r   r   backendsr   is_availableis_builtr   r   r(   r(   r(   r)   r   Y  s   



r   DocumentQuestionAnsweringToolImageCaptioningToolImageQuestionAnsweringToolImageSegmentationToolSpeechToTextToolTextSummarizationToolTextClassificationToolTextQuestionAnsweringToolTextToSpeechToolTranslationTool)
zdocument-question-answeringzimage-captioningzimage-question-answeringzimage-segmentationzspeech-to-textsummarizationztext-classificationztext-question-answeringztext-to-speechtranslationc                  C   sL   t dddd} t| ddd}t|}W d    |S 1 sw   Y  |S )Nz#huggingface-tools/default-endpointszdefault_endpoints.jsondatasetr   r@   rA   rB   )r   rU   rV   rW   )endpoints_filerd   r   r(   r(   r)   r   w  s   
r   c                 C   s   t  }| |v S r   )r   )task_or_repo_idr   r(   r(   r)   supports_remote~  s   r  Fc           
      K   s   | t v r?t |  }td}|j}t||}|r5|du r.t }	| |	vr*td|  d|	|  }t|||dS ||fd|i|S tj	| f|||d|S )a  
    Main function to quickly load a tool, be it on the Hub or in the Transformers library.

    Args:
        task_or_repo_id (`str`):
            The task for which to load the tool or a repo ID of a tool on the Hub. Tasks implemented in Transformers
            are:

            - `"document-question-answering"`
            - `"image-captioning"`
            - `"image-question-answering"`
            - `"image-segmentation"`
            - `"speech-to-text"`
            - `"summarization"`
            - `"text-classification"`
            - `"text-question-answering"`
            - `"text-to-speech"`
            - `"translation"`

        model_repo_id (`str`, *optional*):
            Use this argument to use a different model than the default one for the tool you selected.
        remote (`bool`, *optional*, defaults to `False`):
            Whether to use your tool by downloading the model or (if it is available) with an inference endpoint.
        token (`str`, *optional*):
            The token to identify you on hf.co. If unset, will use the token generated when running `huggingface-cli
            login` (stored in `~/.huggingface`).
        kwargs (additional keyword arguments, *optional*):
            Additional keyword arguments that will be split in two: all arguments relevant to the Hub (such as
            `cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the others
            will be passed along to its init.
    transformersNrn   ro   r   rl   )rk   rl   rm   )
TASK_MAPPING	importlibimport_moduletoolsgetattrr   rN   r   r+   r   )
r  rk   rm   rl   r6   tool_class_namemain_moduletools_modulerD   r   r(   r(   r)   	load_tool  s     


r  c                    s    fdd}|S )z<
    A decorator that adds a description to a function.
    c                    s    | _ | j| _| S r   )r,   rQ   r.   )funcr,   r(   r)   inner  s   zadd_description.<locals>.innerr(   )r,   r  r(   r  r)   add_description  s   r  c                   @   s   e Zd Zddedee fddZedd Zedd	 Z				
ddee	ee
ee eee  f  dee
 dee dedef
ddZdS )r   Nr   rl   c                 C   s"   i t |dddi| _|| _d S )Nr   zContent-Typezapplication/json)r   headersr   )r4   r   rl   r(   r(   r)   r7     s   
zEndpointClient.__init__c                 C   s.   t  }| j|dd t| }|dS )NPNG)r[   rA   )ioBytesIOrj   base64	b64encodegetvaluer   )r   _bytesb64r(   r(   r)   r     s   
zEndpointClient.encode_imagec                 C   s8   t  stdddlm} t| }t|}||S )NzbThis tool returned an image but Pillow is not installed. Please install it (`pip install Pillow`).r   )Image)	r   r   PILr  r  	b64decoder  r  rU   )	raw_imager  r  r  r(   r(   r)   decode_image  s   


zEndpointClient.decode_imageFr/   paramsdatar   r   c                 C   sL   i }|r||d< |r||d< t  j| j| j||d}|r"| |jS | S )Nr/   r   )r  rV   r#  )r   postr   r  r!  contentrV   )r4   r/   r"  r#  r   payloadresponser(   r(   r)   r9     s   zEndpointClient.__call__r   )NNNF)rQ   rM   r   r   r   r7   r   r   r!  r   r   r   bytesr   r   r9   r(   r(   r(   r)   r     s*    

r   r   )NFN)<r  r  r   r  rV   rK   r   typingr   r   r   r   r   huggingface_hubr   r   r	   r
   huggingface_hub.utilsr   r   r   dynamic_module_utilsr   r   r   image_utilsr   models.autor   utilsr   r   r   r   r   r   agent_typesr   r   
get_loggerrQ   r   r   accelerate.utilsr   r#   r*   rZ   r+   r   r   r   r   r  r   r  r  r  r   r(   r(   r(   r)   <module>   s^    

  W 
6