o
    óÜÓhxu  ã                   @   s  d dl Z d dlZd dlZd dlmZ d dlmZ ddlmZ e 	e
¡ZG dd„ dejƒZG dd	„ d	ejƒZG d
d„ dejƒZG dd„ dejƒZG dd„ dejƒZG dd„ dejƒZd$dd„Zd$dd„Zd$dd„ZG dd„ deƒZG dd„ deƒZG dd„ deƒZd%d d!„ZG d"d#„ d#eƒZdS )&é    N)Únn)ÚFunctioné   )Úloggingc                       s>   e Zd ZdZ									d‡ fdd„	Zdd	d
„Z‡  ZS )ÚQuantEmbeddingaÞ  
    Quantized version of `torch.nn.Embedding`. Adds quantization-specific arguments on top of `torch.nn.Embedding`.

    Args:
        weight_bit (`int`, *optional*, defaults to `8`):
            Bitwidth for the quantized weight.
        momentum (`float`, *optional*, defaults to `0.95`):
            Momentum for updating the activation quantization range.
        quant_mode (`bool`, *optional*, defaults to `False`):
            Whether or not the layer is quantized.
    Nç       @Fé   çffffffî?c                    s”   t ƒ  ¡  || _|| _|| _|| _|| _|| _|| _t	 
t ||g¡¡| _|  dt d¡¡ |  dt | j¡¡ |	| _|
| _|| _d| _tj| _d S )NÚweight_scaling_factoré   Úweight_integerF)ÚsuperÚ__init__Únum_ÚdimÚpadding_idxÚmax_normÚ	norm_typeÚscale_grad_by_freqÚsparser   Ú	ParameterÚtorchÚzerosÚweightÚregister_bufferÚ
zeros_likeÚ
weight_bitÚmomentumÚ
quant_modeÚpercentile_modeÚSymmetricQuantFunctionÚapplyÚweight_function)ÚselfÚnum_embeddingsÚembedding_dimr   r   r   r   r   Ú_weightr   r   r   ©Ú	__class__© ú]/var/www/html/ai/venv/lib/python3.10/site-packages/transformers/models/ibert/quant_modules.pyr   ,   s    
zQuantEmbedding.__init__c           	   	   C   sº   | j stj || j| j| j| j| j| j	¡d fS | j}|j
 ¡ }| ¡  d¡}| ¡  d¡}t| j||dƒ| _|  | j| j| j| j¡| _tj || j| j| j| j| j| j	¡}|| j | jfS )Nr   F)r   r   Ú
functionalÚ	embeddingr   r   r   r   r   r   ÚdataÚdetachÚminÚexpandÚmaxÚ$symmetric_linear_quantization_paramsr   r
   r"   r   r   )	r#   ÚxÚ	positionsÚincremental_stateÚwÚw_transformÚw_minÚw_maxÚemb_intr)   r)   r*   ÚforwardM   s<   ù	ö
ÿù	zQuantEmbedding.forward)	NNr   FFNr   r	   F©NN)Ú__name__Ú
__module__Ú__qualname__Ú__doc__r   r;   Ú__classcell__r)   r)   r'   r*   r      s    ô!r   c                       s>   e Zd ZdZd‡ fdd„	Zdd„ Z					dd	d
„Z‡  ZS )ÚQuantActap  
    Quantizes the given activation.

    Args:
        activation_bit (`int`):
            Bitwidth for the quantized activation.
        act_range_momentum (`float`, *optional*, defaults to `0.95`):
            Momentum for updating the activation quantization range.
        per_channel (`bool`, *optional*, defaults to `False`):
            Whether to or not use channel-wise quantization.
        channel_len (`int`, *optional*):
            Specify the channel length when set the *per_channel* True.
        quant_mode (`bool`, *optional*, defaults to `False`):
            Whether or not the layer is quantized.
    r	   FNc                    s”   t ƒ  ¡  || _|| _|| _|| _d| _tj| _	| jsF|  
dt d¡¡ |  
dt d¡¡ |  
dt d¡¡ |  jd8  _|  jd7  _d S tdƒ‚)NFÚx_minr   Úx_maxÚact_scaling_factorgñhãˆµøä>ú;per-channel mode is not currently supported for activation.)r   r   Úactivation_bitÚact_range_momentumr   Úper_channelÚ
percentiler    r!   Úact_functionr   r   r   rC   rD   ÚNotImplementedError)r#   rG   rH   rI   Úchannel_lenr   r'   r)   r*   r   ƒ   s   
zQuantAct.__init__c              
   C   s:   | j j› d| j› d| j› d| j ¡ d›d| j ¡ d›d
S )Nz(activation_bit=z, quant_mode: z, Act_min: z.2fz, Act_max: ú))r(   r=   rG   r   rC   ÚitemrD   )r#   r)   r)   r*   Ú__repr__–   s   ÿ
ÿ
þÿzQuantAct.__repr__c                 C   s¦  |d u r|n|| }| j r†| jrJ dƒ‚| jrJ dƒ‚|j ¡ }|j ¡ }	|	 ¡  ¡ dkr5| ¡  ¡ dks9J dƒ‚| j ¡ dkrT| j	 ¡ dk rT| j| | _| j	|	 | _	n2| j
dkrjt | j|¡| _t | j	|	¡| _	n| j| j
 |d| j
   | _| j	| j
 |	d| j
   | _	| js|d fS |d u r”| jn|}|d u r| j	n|}	t| j||	| jd	| _|d u rº|  || j| j| j¡}
nt ||| j| j||¡}
| j d¡}|
| | jfS )
Nz:percentile mode is not currently supported for activation.rF   r   z5NaN detected when computing min/max of the activationg¢&ú|”ç¾g¢&ú|”ç>éÿÿÿÿr   )rI   )ÚtrainingrJ   rI   r-   r/   r1   ÚisnanÚsumrC   rD   rH   r   r   r2   rG   rE   rK   ÚFixedPointMulr!   Úview)r#   r3   Úpre_act_scaling_factorÚidentityÚidentity_scaling_factorÚspecified_minÚspecified_maxÚx_actrC   rD   Úquant_act_intÚcorrect_output_scaler)   r)   r*   r;      sH   	

"ÿ
ÿú	zQuantAct.forward)r	   FNF)NNNNN©r=   r>   r?   r@   r   rP   r;   rA   r)   r)   r'   r*   rB   r   s    
ùrB   c                       s:   e Zd ZdZ	d‡ fdd„	Z‡ fdd	„Zddd„Z‡  ZS )ÚQuantLineara8  
    Quantized version of `torch.nn.Linear`. Adds quantization-specific arguments on top of `torch.nn.Linear`.

    Args:
        weight_bit (`int`, *optional*, defaults to `8`):
            Bitwidth for the quantized weight.
        bias_bit (`int`, *optional*, defaults to `32`):
            Bitwidth for the quantized bias.
        per_channel (`bool`, *optional*, defaults to `False`):
            Whether or not to use channel-wise quantization.
        quant_mode (`bool`, *optional*, defaults to `False`):
            Whether or not the layer is quantized.
    Tr   é    Fc                    s®   t ƒ  ¡  || _|| _t t ||g¡¡| _|  	dt 
| j¡¡ |  	dt | j¡¡ |r?t t |¡¡| _|  	dt 
| j¡¡ || _|| _|| _|| _|| _d| _tj| _d S )Nr   Úfc_scaling_factorÚbias_integerF)r   r   Úin_featuresÚout_featuresr   r   r   r   r   r   r   Úbiasr   r   rI   Úbias_bitr   r    r!   r"   )r#   rd   re   rf   r   rg   rI   r   r'   r)   r*   r   ë   s    
zQuantLinear.__init__c                    s*   t ƒ  ¡ }d|› d| j› d| j› d}|S )Nú(z weight_bit=z, quant_mode=rN   )r   rP   r   r   )r#   Úsr'   r)   r*   rP     s   
zQuantLinear.__repr__Nc           
      C   s   | j stjj|| j| jdd fS |d ur|jdksJ dƒ‚| j}|j ¡ }| j	r=t
j|dd d\}}t
j|dd d\}}n| ¡  d¡}| ¡  d¡}t| j||| j	ƒ| _|  | j| j| j| j¡| _| j| }| jd urw|  | j| jd|¡| _| dd¡}|| }	tjj|	| j| jd| |fS )N)r   rf   )r   z«Input activation to the QuantLinear layer should be globally (non-channel-wise) quantized. Please add a QuantAct layer with `per_channel = True` before this QuantAct layerr   )r   ÚoutFrQ   )r   r   r+   Úlinearr   rf   Úshaper-   r.   rI   r   r/   r1   r0   r2   r   rb   r"   r   r   rg   rc   rV   )
r#   r3   Úprev_act_scaling_factorr6   r7   r8   Ú_r9   Úbias_scaling_factorÚx_intr)   r)   r*   r;     s0   ÿ
ÿ

þzQuantLinear.forward)Tr   ra   FF©Nr_   r)   r)   r'   r*   r`   Ü   s    ÿr`   c                       s4   e Zd ZdZd‡ fdd„	Zdd„ Zdd	d
„Z‡  ZS )ÚIntGELUa}  
    Quantized version of `torch.nn.GELU`. Adds quantization-specific arguments on top of `torch.nn.GELU`.

    Args:
        quant_mode (`bool`, *optional*, defaults to `False`):
            Whether or not the layer is quantized.
        force_dequant (`str`, *optional*, defaults to `"none"`):
            Force dequantize the layer if either "gelu" or "nonlinear" is given.
    TÚnonec                    sj   t ƒ  ¡  || _|dv rt d¡ d| _| jst ¡ | _d| _d| _	g d¢| _
| j
d  | j
d   < d S )	N)Ú	nonlinearÚgeluzForce dequantize geluFgà- ö?é   )g]mÅþ²{Ò¿gçû©ñÒMü¿r   é   r   )r   r   r   ÚloggerÚinfor   ÚGELUÚactivation_fnÚkÚconstÚcoeff)r#   r   Úforce_dequantr'   r)   r*   r   7  s   



zIntGELU.__init__c                 C   sš   t  | jd | ¡}t  | jd |d  ¡}t  |¡}t  t  |¡| ¡}||| d |  }|d | jd  }t |d| j  ¡}|d| j  }||fS ©Nr   rw   r   )	r   Úfloorr~   Úsignr/   ÚabsÚ	floor_ster!   r}   )r#   rp   Úscaling_factorÚb_intÚc_intr‚   Úabs_intÚy_intr)   r)   r*   Úint_erfG  s   
zIntGELU.int_erfNc                 C   s^   | j s
|  |¡d fS || }|  ||| j ¡\}}d| }|||  }|| d }|| |fS )Nç      ð?rw   )r   r{   rŠ   r|   )r#   r3   r…   rp   Úsigmoid_intÚsigmoid_scaling_factorÚ	shift_intr)   r)   r*   r;   V  s   zIntGELU.forward)Trs   rq   )r=   r>   r?   r@   r   rŠ   r;   rA   r)   r)   r'   r*   rr   ,  s
    
rr   c                       s:   e Zd ZdZd‡ fdd„	Zdd„ Zdd	„ Zd
d„ Z‡  ZS )Ú
IntSoftmaxaØ  
    Quantized version of `torch.nn.Softmax`. Adds quantization-specific arguments on top of `torch.nn.Softmax`.

    Args:
        output_bit (`int`):
            Bitwidth for the layer output activation.
        quant_mode (`bool`, *optional*, defaults to `False`):
            Whether or not the layer is quantized.
        force_dequant (`str`, *optional*, defaults to `"none"`):
            Force dequantize the layer if either "softmax" or "nonlinear" is given.
    Frs   c                    sŽ   t ƒ  ¡  || _d| _|| _|dv rt d¡ d| _td| jd| _d| _	d| _
g d	¢| _| jd
  | jd   < | jd  | jd   < d S )Nra   )rt   ÚsoftmaxzForce dequantize softmaxFé   ©r   gvqà-æ¿é   )gN„ª$ôëÖ?g¾Ã'|:ï?r‹   r   r   rw   )r   r   Ú
output_bitÚmax_bitr   rx   ry   rB   ÚactÚx0r}   Úcoef)r#   r”   r   r   r'   r)   r*   r   r  s   


zIntSoftmax.__init__c                 C   s~   t  ¡  t  | jd | ¡}t  | jd |d  ¡}W d   ƒ n1 s%w   Y  || | | }| jd |d  }||fS r€   )r   Úno_gradr   r˜   )r#   rp   r…   r†   r‡   Úzr)   r)   r*   Úint_polynomialƒ  s   
þzIntSoftmax.int_polynomialc                 C   s¬   t  ¡  t  | j| ¡}W d   ƒ n1 sw   Y  t  || j| ¡}t || ¡}|||  }|  ||¡\}}t j	t |d| j|   ¡dd}|d| j  }||fS )Nrw   r   ©r/   )
r   r™   r   r—   r1   r}   r„   r!   r›   Úclamp)r#   rp   r…   Úx0_intÚqÚrÚexp_intÚexp_scaling_factorr)   r)   r*   Úint_exp‹  s   
ÿ"zIntSoftmax.int_expc                 C   s¾   | j stjj|ddd fS || }|jddd\}}|| }|  ||¡\}}|  ||¡\}}|| }|jddd}	t 	d| j
 |	 ¡}
t 	||
 d| j
| j   ¡}dd| j  }|| |fS )NrQ   ©r   T)r   Úkeepdimrw   r   )r   r   r+   r   r1   r£   r–   rT   r„   r!   r•   r”   )r#   r3   r…   rp   Ú	x_int_maxrn   r¡   r¢   ÚexpÚexp_int_sumÚfactorr)   r)   r*   r;   —  s   zIntSoftmax.forward)Frs   )	r=   r>   r?   r@   r   r›   r£   r;   rA   r)   r)   r'   r*   r   e  s    r   c                       s<   e Zd ZdZd‡ fdd„	Zdd„ Zd	d
„ Zddd„Z‡  ZS )ÚIntLayerNormaû  
    Quantized version of `torch.nn.LayerNorm`. Adds quantization-specific arguments on top of `torch.nn.LayerNorm`.

    Args:
        output_bit (`int`, *optional*, defaults to `8`):
            Bitwidth for the layer output activation.
        quant_mode (`bool`, *optional*, defaults to `False`):
            Whether or not the layer is quantized.
        force_dequant (`str`, *optional*, defaults to `"none"`):
            Force dequantize the layer if either "layernorm" or "nonlinear" is given.
    r   Frs   c                    s’   t ƒ  ¡  || _|| _t t |¡¡| _t t |¡¡| _	|| _
|dv r,t d¡ d| _
|  dt d¡¡ || _d| _d | _t| j| j
d| _d S )N)rt   Ú	layernormzForce dequantize layernormFÚshiftr   ra   r’   )r   r   Únormalized_shapeÚepsr   r   r   r   r   rf   r   rx   ry   r   r”   r•   Údim_sqrtrB   Ú
activation)r#   r­   r®   r”   r   r   r'   r)   r*   r   ¹  s   

zIntLayerNorm.__init__c                 C   sž   t  ¡ A |d }t j|ddd}t  t  |d| j  ¡¡ ¡  ¡ }| j}t  | j|¡| _t	 
dt|ƒ› dt| jƒ› ¡ W d   ƒ d S 1 sHw   Y  d S )Nrw   T©Úaxisr¥   zDynamic shift adjustment: z -> )r   r™   rT   Úlog2Úsqrtr•   Úceilr1   r¬   rx   ry   Úint)r#   r‰   Úy_sq_intÚvar_intr¬   Ú	shift_oldr)   r)   r*   Ú	set_shiftÌ  s   
"""úzIntLayerNorm.set_shiftc                 C   s:   |   |¡ t |d| j  ¡}|d }tj|ddd}|S )z±
        This fallback function is called when overflow is detected during training time, and adjusts the `self.shift`
        to avoid overflow in the subsequent runs.
        rw   Tr±   )rº   r„   r!   r¬   r   rT   )r#   r‰   Úy_int_shiftedr·   r¸   r)   r)   r*   Úoverflow_fallbackÕ  s
   
zIntLayerNorm.overflow_fallbackNc                 C   s¬  | j s.|jddd}|| }tj|d ddd}|t | j| ¡ }|| j | j }|d fS | jd u rHtj|j	d tj
d}t |¡ |j¡| _|| }t |jddd¡}|| }	t |	d| j  ¡}
|
d }tj|ddd}| jr| ¡ d| j kr|  |	¡}| ¡ d| j d k sJ dƒ‚t t |¡¡d| j  }t d| ¡}t |	| d ¡}	| jd }| jj ¡ | jj ¡  }t || ¡}|	| }	|| j }|	| }||fS )	Nrw   Tr±   )Údtypegš™™™™™¹?zfError detected in overflow handling: `var_int` exceeds `self.max_bit` (the maximum possible bit width)l        i   @)r   Úmeanr   r´   r®   r   rf   r¯   Útensorrl   ÚfloatÚtoÚdeviceÚ	round_ster!   r„   r¬   rT   rR   r1   r•   r¼   r-   r.   )r#   r3   r…   r¾   ÚyÚvarÚnrp   Úmean_intr‰   r»   r·   r¸   Ústd_intr©   rf   Úbias_intr)   r)   r*   r;   à  s@   

ÿ

zIntLayerNorm.forward)r   Frs   rq   )	r=   r>   r?   r@   r   rº   r¼   r;   rA   r)   r)   r'   r*   rª   ¬  s    	rª   Fc           	      C   s€   | j d }t|d|d   ƒ}t|| d ƒ}tj| |dj}|dkr(|d }n
tj|  |dj }|s<| ¡ }| ¡ }||fS )aÆ  
    Calculate the percentile max and min values in a given tensor

    Args:
        input (`torch.Tensor`):
            The target tensor to calculate percentile max and min.
        lower_percentile (`float`):
            If 0.1, means we return the value of the smallest 0.1% value in the tensor as percentile min.
        upper_percentile (`float`):
            If 99.9, means we return the value of the largest 0.1% value in the tensor as percentile max.
        output_tensor (`bool`, *optional*, defaults to `False`):
            If True, this function returns tensors, otherwise it returns values.

    Returns:
        `Tuple(torch.Tensor, torch.Tensor)`: Percentile min and max value of *input*
    r   r   g{®Gáz„?)r|   )rl   Úroundr   ÚkthvalueÚvaluesrO   )	ÚinputÚlower_percentileÚupper_percentileÚoutput_tensorÚinput_lengthÚlower_indexÚupper_indexÚupper_boundÚlower_boundr)   r)   r*   Úget_percentile_min_max  s   

rÖ   c                 C   s¢   t | jƒdkr| dddd¡}| dddd¡}nt | jƒdkr,| dd¡}| dd¡}n
| d¡}| d¡}|rF|  d| ¡ |¡ ¡  | S t d| |  | ¡S )a?  
    Quantize single-precision input tensor to integers with the given scaling factor and zeropoint.

    Args:
        input (`torch.Tensor`):
            Single-precision input tensor to be quantized.
        scale (`torch.Tensor`):
            Scaling factor for quantization.
        zero_pint (`torch.Tensor`):
            Shift for quantization.
        inplace (`bool`, *optional*, defaults to `False`):
            Whether to compute inplace or not.

    Returns:
        `torch.Tensor`: Linearly quantized value of *input* according to *scale* and *zero_point*.
    é   rQ   r   rw   r‹   )Úlenrl   rV   Úmul_Úadd_Úround_r   rÊ   )rÍ   ÚscaleÚ
zero_pointÚinplacer)   r)   r*   Úlinear_quantize5  s   

rß   c                 C   s²   t  ¡ K d| d  d }|r-t jt j| ¡ | ¡ gdddd\}}t j|dd| }nt| ¡ | ¡ ƒ}t j|dd| }W d  ƒ |S W d  ƒ |S 1 sRw   Y  |S )a/  
    Compute the scaling factor with the given quantization range for symmetric quantization.

    Args:
        saturation_min (`torch.Tensor`):
            Lower bound for quantization range.
        saturation_max (`torch.Tensor`):
            Upper bound for quantization range.
        per_channel (`bool`, *optional*, defaults to `False`):
            Whether to or not use channel-wise quantization.

    Returns:
        `torch.Tensor`: Scaling factor that linearly quantizes the given range between *saturation_min* and
        *saturation_max*.
    rw   r   r¤   g:Œ0âŽyE>rœ   N)r   r™   r1   Ústackrƒ   r   )Únum_bitsÚsaturation_minÚsaturation_maxrI   rÆ   rÜ   rn   r)   r)   r*   r2   X  s   
(
÷ú
ûõr2   c                   @   ó(   e Zd ZdZedd„ ƒZedd„ ƒZdS )r    zw
    Class to quantize the given floating-point values using symmetric quantization with given range and bitwidth.
    c                 C   sP   t  d¡ |j¡}d|d  d }t|||dd}t  || |d ¡}|| _|S )a6  
        Args:
            x (`torch.Tensor`):
                Floating point tensor to be quantized.
            k (`int`):
                Quantization bitwidth.
            percentile_mode (`bool`):
                Whether or not to use percentile calibration.
            scale (`torch.Tensor`):
                Pre-calculated scaling factor for *x*. Note that the current implementation of SymmetricQuantFunction
                requires pre-calculated scaling factor.

        Returns:
            `torch.Tensor`: Symmetric-quantized value of *input*.
        g        rw   r   F)rÞ   )r   r¿   rÁ   rÂ   rß   r   rÜ   )Úctxr3   r|   r   rÜ   rÝ   rÆ   Únew_quant_xr)   r)   r*   r;   }  s   zSymmetricQuantFunction.forwardc                 C   sb   | j }t|jƒdkr| dddd¡}nt|jƒdkr!| dd¡}n| d¡}| ¡ | d d d d fS )Nr×   rQ   r   rw   )rÜ   rØ   rl   rV   Úclone)rå   Úgrad_outputrÜ   r)   r)   r*   Úbackward—  s   
zSymmetricQuantFunction.backwardN©r=   r>   r?   r@   Ústaticmethodr;   ré   r)   r)   r)   r*   r    x  s    
r    c                   @   rä   )r„   z;
    Straight-through Estimator(STE) for torch.floor()
    c                 C   ó
   t  |¡S rq   )r   r   ©rå   r3   r)   r)   r*   r;   ª  ó   
zfloor_ste.forwardc                 C   ó   |  ¡ S rq   ©rç   ©rå   rè   r)   r)   r*   ré   ®  ó   zfloor_ste.backwardNrê   r)   r)   r)   r*   r„   ¥  ó    
r„   c                   @   rä   )rÃ   z;
    Straight-through Estimator(STE) for torch.round()
    c                 C   rì   rq   )r   rÊ   rí   r)   r)   r*   r;   ¸  rî   zround_ste.forwardc                 C   rï   rq   rð   rñ   r)   r)   r*   ré   ¼  rò   zround_ste.backwardNrê   r)   r)   r)   r*   rÃ   ³  ró   rÃ   é   c                 C   s®   |   ¡ }|  d¡} t |  ¡  ¡ ¡\}}g }|D ]}tt |d|  ¡j	t d¡tj
dƒ}| |¡ qt |¡}t|ƒ| }t |¡ | j¡ |¡t |¡ | j¡ |¡fS )zü
    Decompose the scaling factor into mantissa and twos exponent.

    Args:
        scaling_factor (`torch.Tensor`):
            Target scaling factor to decompose.

    Returns:
        ``Tuple(torch.Tensor, torch.Tensor)`: mantisa and exponent
    rQ   rw   Ú1)Úrounding)ÚsizerV   ÚnpÚfrexpÚcpuÚnumpyr¶   ÚdecimalÚDecimalÚquantizeÚROUND_HALF_UPÚappendÚarrayrÀ   r   Ú
from_numpyrÁ   rÂ   )Úinputsr•   Úshape_of_inputÚoutput_mÚoutput_eÚtmp_mÚmÚint_m_shiftedr)   r)   r*   Úbatch_frexpÁ  s   
"ÿ
þr
  c                   @   s.   e Zd ZdZe		ddd„ƒZedd„ ƒZdS )rU   aQ  
    Function to perform fixed-point arithmetic that can match integer arithmetic on hardware.

    Args:
        pre_act (`torch.Tensor`):
            Input tensor.
        pre_act_scaling_factor (`torch.Tensor`):
            Scaling factor of the input tensor *pre_act*.
        bit_num (`int`):
            Quantization bitwidth.
        z_scaling_factor (`torch.Tensor`):
            Scaling factor of the output tensor.
        identity (`torch.Tensor`, *optional*):
            Identity tensor, if exists.
        identity_scaling_factor (`torch.Tensor`, *optional*):
            Scaling factor of the identity tensor *identity*, if exists.

    Returns:
        `torch.Tensor`: Output tensor(*pre_act* if *identity* is not given, otherwise the addition of *pre_act* and
        *identity*), whose scale is rescaled to *z_scaling_factor*.
    Nc                 C   s”  t |jƒdkrdd„ }ndd„ }|| _d|d  d }t ¡ ¡ ||ƒ}|d ur,||ƒ}|| _t || ¡}	| tj¡}
| tj	¡ tj¡}|
| }||ƒ}t
|ƒ\}}|	 tj¡| tj¡ }t |d|  ¡}|d ur«t || ¡}| tj¡}
| tj	¡ tj¡}|
| }||ƒ}t
|ƒ\}}| tj¡| tj¡ }t |d|  ¡}|| }t | tj	¡| d |¡W  d   ƒ S 1 sÃw   Y  d S )Nr   c                 S   s   | S rq   r)   ©r3   r)   r)   r*   Ú<lambda>  s    z'FixedPointMul.forward.<locals>.<lambda>c                 S   s   |   ddd¡S )Nr   rQ   )rV   r  r)   r)   r*   r    s    rw   r   r   )rØ   rl   rX   r   r™   Úz_scaling_factorrÊ   ÚtypeÚdoublerÀ   r
  r   )rå   Úpre_actrW   Úbit_numr  rX   rY   ÚreshaperÆ   Úz_intÚ_AÚ_BÚ	new_scaler  ÚeÚoutputÚwx_intÚm1Úe1Úoutput1r)   r)   r*   r;   ú  s<   


$ßzFixedPointMul.forwardc                 C   s8   d }| j d ur| ¡ | j }| ¡ | j d d d d |d fS rq   )rX   rç   r  )rå   rè   Úidentity_gradr)   r)   r*   ré   /  s   
zFixedPointMul.backwardr<   rê   r)   r)   r)   r*   rU   ã  s    ù4rU   )F)rô   )rü   rû   rø   r   r   Útorch.autogradr   Úutilsr   Ú
get_loggerr=   rx   ÚModuler   rB   r`   rr   r   rª   rÖ   rß   r2   r    r„   rÃ   r
  rU   r)   r)   r)   r*   Ú<module>   s*   
SjP9G
e
$
# -
"