o
    h4                     @  s
  d Z ddlmZ ddlZddlZddlZddlmZmZ ddl	Z	ddl	m
Z
 ddlmZ ddlmZmZmZmZmZmZ ddlmZ dd	lmZmZmZ g d
ZejejddZ dd Z!e de"de#dddej$dddZ%e dej$dddZ&e de#ddej$dd d!Z'e d"e#ddej$dd#d$Z(e d%ej$dd&d'Z)e d(e"de#dd)dej$dd*d+Z*e d,ej$	-ddd.d/Z+e d0e#dd)ej$dd1d2Z,e d3e!d4d5d6gd7e d8e!d9d:d6gd7e d;e!d<d=d6gd7e d>e!d?d5d@gd7e dAe!dBd:d@gd7e dCe!dDd=d@gd7e dEe!dFd:dGgd7ej$ddMdNZ-e dOe"dd-d-d-d-d-d-ej$ddPdQZ.e dRe#dd)ddej$dddSdTZ/e dUe#dd)ddej$ddVdWZ0e dXe#dd)dYej$dddZd[Z1e d\ej$dd]d^Z2e d_ej$dd`daZ3e dbej$ddcddZ4e deej$ddfdgZ5e dhej$ddidjZ6e dkej$ddldmZ7e dnej$dddodpZ8e dqej$ddrdsZ9e dtej$ddudvZ:e dwej$ddxdyZ;e dze"dej$dd{d|Z<e d}ej$dd~dZ=e de#dd)d)d)ej$dddZ>e de#dd)d)d)d)ej$dddZ?e de#ddd)d)d)dYej$ddddZ@e de#dd)d)dYej$ddddZAe de#dd)d)dYej$ddddZBe de#dd)ej$ddddZCe dej$dddZDe de#ddd)d)ej$ddddZEe de#ddd)d)ej$ddddZFe de#dd)d)ej$ddddZGej$dddZHe dej$ddddZIe de de dej$dddZJe de de dej$dddZKe dej$dddZLe dej$dddZMe dej$dddZNe dej$dddZOe de#dd)ej$dddZPe dej"dd-dej$ddddZQe dÃej$ddddńZRe dƃej$dddȄZSe dɃej$ddd˄ZTe d̃ej$ddd΄ZUe dσej$dddфZVe d҃ej$dddԄZWe dՃej$dddׄZXe d؃ej$dddڄZYej$ddd܄ZZej$dddބZ[ej$dddZ\e de#dddddej$dddZ]e dej$dddZ^e de"dd-d-e#dd)d)ej$dddZ_e de#dddddej$dddZ`e de#dddd)d)d)dd)d)	ej$dddZae de#ddddej$dddZbe dej$dddZce dej$						ddddZde dej$dd dZee dej$dddZfe dej$dddZge dej$dd	d
Zhe dej$dddZie dej$dddZjdS (  z(This file exports ONNX ops for opset 11.    )annotationsN)OptionalSequence)_C)_onnx)_type_utilserrorssymbolic_helpersymbolic_opset10symbolic_opset9utils)GLOBALS)	_beartype	jit_utilsregistration)9addappendarangeargsort
atleast_1d
atleast_2d
atleast_3dcatchunk	clamp_max	clamp_minclampconstant_pad_ndcumsumDeleteembedding_bagembedding_renormflattengatherhardtanhhstackim2col
index_fillindex
index_copy	index_putinsert
linalg_detlinalg_vector_normlogdetmasked_scattermasked_selectmmnarrownormalpadpixel_shufflepopprim_constant_chunkreflection_padrelu6	remainderreplication_padroundscatterselectsizesortsplit_with_sizessplitsqueezestacktopkunbind
unique_dim	unsqueezevstack   )opsetc                    s    fdd}|S )z_Returns a decorator that calls the decorated (higher-order) function with the given parameters.c                   s   |  i S N )fnargskwargsrM   Q/var/www/html/ai/venv/lib/python3.10/site-packages/torch/onnx/symbolic_opset11.py_apply\   s   z_apply_params.<locals>._applyrM   )rP   rQ   rS   rM   rO   rR   _apply_paramsY   s   rT   zaten::hardtanhTvfgjit_utils.GraphContextself_C.Valuemin_valfloatmax_valc                 C  s`   t j|t jj}| jdtj|| dd}| jdtj|| dd}tj	| d|||ddS )NConstantdtypevalue_tClip   opset_before)
r   JitScalarType
from_valueFLOAToptorchtensorr`   opset9_op_with_optional_float_cast)rW   rY   r[   r]   scalar_typerM   rM   rR   r$   b   s   r$   zaten::clampc                   s   t j fdd}tj|tjj}|tjjkr"|||}|||}t|r-t ||S t|r8t	 ||S t
|dkrQt
|dkrQtj d|||ddS t t	 |||S )Nc                   s*   | d urt | s jd| | dS | S )NCastto_i)r	   _is_nonerj   	onnx_type)rl   r`   rW   rM   rR   _cast_if_not_nonez   s   z clamp.<locals>._cast_if_not_noner   rc   rd   re   )r   beartyper   rg   rh   	UNDEFINEDr	   rs   r   r   _get_tensor_rankrm   rn   )rW   rY   minmaxrv   ro   rM   ru   rR   r   w   s$   




r   zaten::clamp_minc                 C  s^   | j d|tj| d}t|dkr%t| }tj	| d|||ddS tj	| d||ddS )Nrp   rq   r   rc   rd   re   Max
rj   r   rg   rh   rt   r	   ry   rm   unusedrn   )rW   rY   rz   r{   rM   rM   rR   r         
r   zaten::clamp_maxc                 C  s^   | j d|tj| d}t|dkr%t| }tj	| d|||ddS tj	| d||ddS )Nrp   rq   r   rc   rd   re   Minr}   )rW   rY   r{   rz   rM   rM   rR   r      r   r   zaten::relu6c                 C  sX   t j|t jj}| jdtjd| dd}| jdtjd| dd}t| |||S )Nr^   r   r_   ra      )	r   rg   rh   ri   rj   rk   rl   r`   r   )rW   inputro   r[   r]   rM   rM   rR   r9      s   r9   zaten::selectic                 C  s   | j d|||dS )NGatheraxis_irj   )rW   rY   dimr(   rM   rM   rR   r>      s   r>   zaten::index_putFc                   s  t |rt |}n|g}t  r$|g| ||g }jdg|R  S t |d}t|dkr2|S t|dkrtt|D ]}t || rQ	d|| ||< q>|d }|dd  D ]	}	t
||	}q\	d|  fdd|D }j	d	g|R d
di}nW|d }|}
t |
rt |}|d ur|dkrt
||
|S t |
}t |}|d ur|d ur||krt |
tt||}
t||
|S 	d| t |dg}t j	d|dgt|gtjgd}j	d	 |dd}t |}|d ur|dkrt
||d }t ||}tj|tjj}|tjjkr>tj|tjj}||kr=j	d|| d}n	|rGtd||rnj	d	d|tjdg| dd}	d|||}t||}|S 	d|||}|S )Nr*   br      NonZeroShapec                   s(   g | ]}t t| d dgqS )N)r	   _unsqueeze_helperrm   expand).0indbroadcast_index_shaperW   rM   rR   
<listcomp>   s    zindex_put.<locals>.<listcomp>Concatr   r   axesstartsendsr   rp   rq   z'self does not have a valid scalar type.ConstantOfShaper_   ra   	ScatterND) r	   _is_packed_list_unpack_listis_caffe2_aten_fallbackat
_parse_arglenrange_is_boolrj   rm   r   ry   masked_fillr   listr/   _slice_helpersysmaxsizer   _reshape_helperr   rg   rh   rx   rt   r   SymbolicValueErrorrk   rl   r`   )rW   rY   indices_list_valuevalues
accumulateindices_listrP   idx_r(   r   bool_inprank	mask_rank	self_ranksub_data_shapevalues_shapeself_scalar_typevalues_scalar_typezerosresultrM   r   rR   r*      s   
(






r*   zaten::pixel_shufflec                 C  s8   t |}|d ur|dkrt ddS | jd||ddS )N   r5   zonly support 4d inputDepthToSpaceCRD)blocksize_imode_s)r	   ry   _unimplementedrj   )rW   rY   upscale_factorr   rM   rM   rR   r5   S  s   
r5   zaten::upsample_nearest1dupsample_nearest1d   nearest)decoratezaten::upsample_nearest2dupsample_nearest2dr   zaten::upsample_nearest3dupsample_nearest3d   zaten::upsample_linear1dupsample_linear1dlinearzaten::upsample_bilinear2dupsample_bilinear2dzaten::upsample_trilinear3dupsample_trilinear3dzaten::upsample_bicubic2dupsample_bicubic2dcubicnamestrr   intinterpolate_modec                 C  s   t | ||S rL   )r	   _interpolate_helper)r   r   r   rM   rM   rR   _interpolate]  s   r   zaten::__interpolatec              	   C  s   t | ||||||S rL   )r	   __interpolate_helper)rW   r   r?   scale_factormodealign_cornersrecompute_scale_factor	antialiasrM   rM   rR   __interpolate~  s   r   zaten::gatherc                 C  sD   t |drt ddS t  r| d||||S | jd|||dS )Nr   r#   zsparse_grad == TrueGatherElementsr   )r	   _maybe_get_constr   r   r   rj   )rW   rY   r   r(   sparse_gradrM   rM   rR   r#     s
   r#   zaten::scatterc              	   C  s   t  r| jd||||ddS tj|}t |}t |r)| jd||||dS tj||kr?| jd|tj|	 d}| jd||t
| |||dS )Nr=   srcoverload_nameScatterElementsr   rp   rq   )r	   r   r   r   rg   rh   _maybe_get_scalar	_is_valuerj   rt   rm   	expand_as)rW   rY   r   r(   r   src_typerM   rM   rR   r=     s   

r=   zaten::cumsumnonec                 C  sn   | j dtj|tjdd}|r,|  dkr,t|dd}| j d|t	|
 d}n|}|  d	||}|S )
Nr^   r_   ra   zprim::Constantr   r`   rp   rq   CumSum)rj   rk   rl   r   nodekindr	   
_get_constr   rg   rt   )rW   rY   r   r`   
dim_tensorparsed_dtypecastcsumrM   rM   rR   r     s   r   zaten::masked_selectc                 C  s$   t | t | ||}| d||S )NGatherND)rm   nonzeror   rj   )rW   rY   maskr(   rM   rM   rR   r0     s   r0   zaten::masked_scatterc                 C  sr   t | t | ||}t| |tdg}tj| |tdgtdgt | |tdgd}| 	d|||S )Nr   r   r   r   )
rm   r   r   r	   r   rk   
LongTensorr   r?   rj   )rW   rY   r   sourcer(   rM   rM   rR   r/     s   

r/   z	aten::lenc                 C  sT   t |s|  dkr| d|S t| || jdtdgd}t | |dgS )Nzonnx::SplitToSequenceSequenceLengthr^   r   ra   )	r	   _is_tensor_listr   r   rj   r?   rk   r   _squeeze_helper)rW   rY   sz_0rM   rM   rR   _len  s   r   zaten::__getitem_c                 C  s0   t |r| d||S ddlm} || ||S )N
SequenceAtr   )
__getitem_)r	   r   rj   torch.onnx.symbolic_opset9r   )rW   rY   r   getitemrM   rM   rR   r     s   
r   zaten::_set_itemc                 C  s   |  d||}|  d|||S )NSequenceEraseSequenceInsertr   )rW   tensor_listr   rU   rM   rM   rR   	_set_item  s   r   zaten::appendc                 C     |  d||S Nr   r   )rW   rY   rl   rM   rM   rR   r        r   z	aten::addc                 C  sn   t |r/t |r/| }| dkrt ddS t |}|}|D ]	}| d||}q#|S t	| |||S )Nzprim::ListConstructr   z6does not support adding dynamic tensor list to anotherr   )
r	   r   r   r   r   r   r   rj   rm   r   )rW   rY   otheralphatensor_list_nodetensorsltrM   rM   rR   r     s   
r   zaten::insertc                 C  s   |  d|||S r  r   )rW   rY   posrl   rM   rM   rR   r+     s   r+   z	aten::popc                 C  r  Nr   r   rW   r   r   rM   rM   rR   r6     r  r6   zaten::Deletec                 C  r  r  r   r  rM   rM   rR   r     r  r   z	aten::catc                 C  s6   t |rt| ||S t |dd}| jd||dS )Nr   r   ConcatFromSequencer   )r	   r   rm   r   r   rj   r  rM   rM   rR   r   %  s   
r   zaten::stackc                 C  s8   t |rt| ||S t |dd}| jd||ddS )Nr   r   r  r   r   
new_axis_i)r	   r   rm   rD   r   rj   r  rM   rM   rR   rD   0  s   
rD   zaten::_unique2c           	      C  s$   | j d||dd\}}}}|||fS )NUniquer   )sorted_ioutputsr   )	rW   rY   sortedreturn_inversereturn_countsuindicesinverse_indicescountsrM   rM   rR   _unique2:  s   
r  zaten::unique_dimc           
      C  s&   | j d|||dd\}}}}	|||	fS )Nr  r   )r   r  r  r   )
rW   rY   r   r  r  r  r  r  r  r  rM   rM   rR   rG   D  s   

rG   z
aten::topkc              	   C  s   t j| ||||||dS )N)largestr  out)r	   _topk_helper)rW   rY   kr   r  r  r  rM   rM   rR   rE   P  s   rE   z
aten::sortc                 C  s   t j| ||||dS N)	decendingr  r	   _sort_helper)rW   rY   r   r   r  rM   rM   rR   r@   Y  s   r@   zaten::argsortc                 C  s   t j| ||||d\}}|S r  r!  )rW   rY   r   r   r  _r  rM   rM   rR   r   `  s   

r   zaten::roundc                 C  sz   t |s|S |dkr| d|S | d|| jdttd|d}| d|}| d|| jdttdd| dS )Nr   RoundMulr^   
   ra   r   )r	   _is_fprj   rk   rl   pow)rW   rY   decimalsmulr<   rM   rM   rR   r<   j  s   
$ r<   zaten::remainderc                 C  s4   t |s
t |rt| ||S | jd||ddS )NModr   )fmod_i)r	   r'  rm   r:   rj   )rW   r   r  rM   rM   rR   r:   y  s   r:   zaten::splitc              
     s  t ||sy jd|||d|d u rS t |rmtt ||krm fddt |D } jdtjdgtjdd} jdtj|gtjdd}g }t	|D ]}	 d	|||	 }
|
 d
|||
| |
}qQ|S  fddt	|D S t ||||S )NSplitToSequencer   c                   s   g | ]
}t  |d gqS r   )r	   r   )r   rU   ru   rM   rR   r     s    zsplit.<locals>.<listcomp>r^   r   r_   ra   AddSlicec                   s2   g | ]}  d  j dtj|gtjddqS )r   r^   r_   ra   )rj   rk   rl   long)r   r   rW   	split_outrM   rR   r     s    )r	   _is_split_staticrj   r   r   r   rk   rl   r1  r   r   rm   rB   )rW   rY   split_size_or_sizesr   _outputssplit_sizesstartaxisresr   endrM   r2  rR   rB     s0   

	rB   zaten::split_with_sizesc                 C  s   t | ||||S rL   )rB   )rW   rY   r7  r   r6  rM   rM   rR   rA     s   rA   zaten::unbindc              	   C  sB   |d u r| j d|| j dtjdtjdd|ddS t| |||S )Nr-  r^   r   r_   ra   r   )r   
keepdims_i)rj   rk   rl   r1  rm   rF   )rW   rY   r   r6  rM   rM   rR   rF     s   rF   c                 C  sz  t |st |rt |r| jd|ddd}t| || jdtdgd}t 	|}|du r<| d| d	|}n| jdtj|tj
d
d}| d| d|| jdtjdtj
d
d|}| jd|tjjd}| jd|| jd|tjdgtj
d
ddd}t | || jdtddgd}| jdt| |dgddgd}t | || jdtdgd}| jd|tjjd}|S )a!  Generate paddings in ONNX order based on pad in pytorch.

    Args:
        input: the input tensor.
        pad: the paddings in pytorch.
            The order is dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, ..., dim_m_begin, dim_m_end,
            where m is in range [0, n].
    r  r   r   r  r^   ra   NSizer   r_   Subr%     rp   rq   r   r   r   r   	Transposeperm_i)r	   r   _is_list_is_scalar_listrj   rm   r?   rk   rl   ry   int64_C_onnxTensorProtoDataTypeINT64r   opset10flip)rW   r   r4   pad_lenr   	extensionpaddings	padding_crM   rM   rR   _prepare_onnx_paddings  sF    
" rO  zaten::constant_pad_ndc                 C  s:   d}t |}t ||}t| ||}| jd||||dS )NconstantPadr   )r	   r   _if_scalar_type_asrO  rj   )rW   r   paddingvaluer   r4   rM   rM   rR   r     s
   
r   zaten::reflection_pad1dzaten::reflection_pad2dzaten::reflection_pad3dc                 C  "   d}t | ||}| jd|||dS )NreflectrQ  rR  rO  rj   rW   r   rT  r   rM  rM   rM   rR   r8        r8   zaten::replication_pad1dzaten::replication_pad2dzaten::replication_pad3dc                 C  rV  )NedgerQ  rR  rX  rY  rM   rM   rR   r;     rZ  r;   z	aten::padr   r4   r   rU  c                 C  sr   t |d}|dkrt| ||S |dkrt| ||S |dkr%t| |||S |dkr0t| ||S td| |)Ns	replicaterW  rP  circularzUnrecognized padding mode )	r	   r   r;   r8   r   rm   _pad_circularr   r   )rW   r   r4   r   rU  rM   rM   rR   r4     s   	zaten::linalg_detc                 C  s   |  d|S )NDetr   )rW   rY   rM   rM   rR   r,   -  s   r,   zaten::logdetc                 C  s   t | t| |S rL   )rm   logr,   )rW   r   rM   rM   rR   r.   3  s   r.   aten::arangec                 G  s  dd }t |dkrFtdd |D rFtj}| jdtj|d |dd	}| jdtj|d
 |dd	}| jdtjd
|dd	}| d|||S t |dksRt |dkrt |dkr[d }n||d
 }tj| |d |d\}}}}| jdtjd| dd	}	| jdtjd
| dd	}| d|	||S t |dkst |dkrt |dkrd }n||d }tj| |d |d
 |d |d\}
}}}| d|||S t |dkr||d }tj| |d |d
 |d\}}}}| jdtjd
| dd	}| d|||S t	ddt | dS )Nc                 S  s   t | d} | S )Nr   )r	   r   r_   rM   rM   rR   _get_arange_dtype<  s   z!arange.<locals>._get_arange_dtyper?  c                 s  s    | ]}t |tV  qd S rL   )
isinstancer   )r   valrM   rM   rR   	<genexpr>@  s    zarange.<locals>.<genexpr>r^   r   r_   ra   r   Ranger   )r;  r`   r      r   )r8  r;  stepr`   r   )r8  r;  r`   rb  zwith z
 arguments)
r   allrk   rE  rj   rl   r	   _arange_cast_helperr`   r   )rW   rP   rc  r`   r8  r;  delta_defaulttype_ri  start_defaultr#  rM   rM   rR   r   9  sj   
r   zaten::_dim_arangec                 C  sT   |  d|}| j d|| j dt|ddd}t r!|  d|S t| |dd d d S )	Nr   r   r^   ra   r   r   z_caffe2::Ranger   )rj   rk   rl   r	   r   r   )rW   liker   
like_shapestoprM   rM   rR   _dim_arange  s   rr  z
aten::size)quantize_outputc                 C  s"   |d u r
|  d|S t| ||S )Nr   )rj   r	   _size_helperrW   rY   r   rM   rM   rR   r?     s   r?   zaten::squeezec                 C  sx  |d u r
|  d|S t|st| ||gS t|dd}t|}|}|d ur1|dk r1||7 }t||}|dk r?|d u sC|d u r| j dt|gd}t	| ||}| j dtj
dtjdd}|  d	||}	tj| d
|	dd\}
\}}}t|||g}t|j| | d|}t|j| |
S |}|dkrtdt| d d t| d d d  |S t| ||gS )NSqueezer   r   r   r^   ra   r   r_   EqualIfr?  n_blocksIdentityz5This model contains a squeeze operation on dimension z. The size of z%this dimension in the given input is z. The model will zWbe exported without the squeeze node. If the model is intended to be used with dynamic z7input shapes, please export with dynamic_axes argument.)rj   r	   _is_constantr   r   ry   _get_tensor_dim_sizerk   rl   rt  onesrE  r   add_op_with_blocksr   _add_output_to_blockblockwarningswarnr   )rW   rY   r   
input_rankadjusted_dimdim_sizedim_constantr?   	const_onecondif_op
if_contextelse_contextr#  squeeze_	identity_rM   rM   rR   rC     sX   


rC   zaten::unsqueezec                 C  s(   t |rt |dd}t | ||gS )Nr   r   )r	   r|  r   r   ru  rM   rM   rR   rH     s   
rH   zaten::mmc                 C  s   | j d||dddS )NGemmg        g      ?)beta_falpha_fr   )rW   rY   r  rM   rM   rR   r1     s   r1   zaten::indexc                 C  s   t  r| jd||ddS t |rt |}n|g}t|dkrF|d }t |sFt |s9tj	
|tj	jkrFt| |}| d||S t| ||S )Nr(   Tensorr   r   r   r   )r	   r   r   r   r   r   rs   r   r   rg   rh   UINT8rm   r   rj   r(   )rW   rY   r(   r  rM   rM   rR   r(     s    


r(   zaten::index_fillc           	      C  st   t |d}t  r| jd|||d|dS t | |||\}}t |}t ||}t| ||d }t	| ||||S )Nr   r'   
int_Scalar)r   dim_i)
r	   r   r   r   _index_fill_reshape_helperr   rS  rm   r   r=   )	rW   rY   r   r(   rU  	dim_valueexpanded_index_shapeexpanded_indexexpanded_valuerM   rM   rR   r'     s"   	
r'   zaten::index_copyc                 C  sL   t |d}t  r| jd||||dS t | |||\}}t| ||||S )Nr   r)   )r  )r	   r   r   r   r  r=   )rW   rY   r   r(   r   r  r  r  rM   rM   rR   r)   
  s   r)   zaten::__rshift_c                 C     t j|t jjt j|kr| jd|t j| d}t j|t jjt jjkr3| jd||ddS | jdtjdtj	dd	}t
|sO| jd|tjjd}| d
||}| jd|t j| d}| d||}|S )Nrp   rq   BitShiftRIGHTdirection_sr^   r?  r_   ra   PowDivr   rg   rh   rx   rj   rt   r  rk   rl   float32r	   r'  rF  rG  ri   )rW   rY   r  twotwo_powrshiftrM   rM   rR   	__rshift_  2   

r  zaten::__lshift_c                 C  r  )Nrp   rq   r  LEFTr  r^   r?  r_   ra   r  r%  r  )rW   rY   r  r  r  lshiftrM   rM   rR   	__lshift_8  r  r  c                 C  s   |  d|| j dt|d d}|  d|| j dt||d  d}|  d| j dtdd|| j dt|d}td|| |}| j d|dd}t| |dg}t| || j dtd	dgd}	|  d||	}
|
S )
Nr/  r^   r?  ra   r>  r   rg  r   r   )rj   rk   rl   r   rH   r	   r   r   )rW   input_dkernel_size_d
dilation_d	padding_dstride_dblocks_dblocks_d_indiceskernel_gridkernel_mask
block_maskrM   rM   rR   _get_im2col_indices_along_dimZ  s0   
r  c                 C  s.   | j dtdd||gd d}|  d||S )Nr^   r   r?  ra   rQ  )rj   rk   r   )rW   r   	padding_h	padding_wr4   rM   rM   rR   _get_im2col_padded_input  s    r  c              
   C  s   t | || jdtdd}t | || jdtdd}| d|| jdt|| d}| jdt| |dgt| |dg| jdtdgdddS )	Nr^   r   ra   r   r%  r   r   r   )r?   rj   rk   rl   r	   r   )rW   r   kernel_hkernel_w	batch_dimchannel_dimchannel_unfoldedrM   rM   rR   _get_im2col_output_shape  s   r  zaten::im2colisc                 C  s  t | || jdtdd}t | || jdtdd}|d |d }}	|d |d }
}|d |d }}|d |d }}t| ||||
|}t| |||||	}t| |||}t| ||
|}| jd||dd}| jd||d	d}| jd
|g dd}t| ||S )Nr^   r?  ra   r   r   r   r   r   r   r@  )r   r   r?  r   r   r   rA  )	r?   rj   rk   rl   r  r  r  r	   r   )rW   r   kernel_sizedilationrT  strideinput_hinput_wstride_hstride_wr  r  
dilation_h
dilation_wr  r  blocks_row_indicesblocks_col_indicesoutput_shapepadded_inputoutputrM   rM   rR   r&     s$   r&   zaten::narrowc                 C  s"   |  d||}tj| ||||dS )Nr/  r   )rj   r	   r   )rW   r   r   r8  lengthr;  rM   rM   rR   r2     s   r2   zaten::flattenc                 C  s   t |}|dkr|S |dkr&|dks|d ur%||d kr%| jd||dS n|dkrB|dks8|d urB||d krB| jd||d dS |d u rLt dd	S |dk rT|| }t | ||||S )
Nr   r   Flattenr   r   r?  r   zfONNX and PyTorch use different strategies to split the input. Input rank must be known at export time.)r	   ry   rj   r   _flatten_helper)rW   r   	start_dimend_dimr   rM   rM   rR   r"     s$   
r"   zaten::linalg_vector_normr   Optional[Sequence[int]]keepdimboolc                 C  s   |dkrH|d u rt | || jdtjdgtjdd}d}| d| d|| jdtdgd}| jd	|tj	|
 d
}t j| |||dS t| |||||S )Nr   r^   r   r_   ra   FNotrw  rp   rq   axes_ir<  )r	   r   rj   rk   rl   rE  r   r   rg   rh   rt   _reducesum_helperrm   r-   )rW   rY   ordr   r  r`   cond_oprM   rM   rR   r-     s$    r-   zaten::embedding_bagc
                 C  s  |r
t jr
tdS |	d ur|	dkrtd| jdtdd}
| jd|
tj	j
d}
| jdtdgd}t| t| || jdtdddg}|s\||g}| jd	g|R d
di}tj| |dgdgtjgdgd}tj| |dgdgtjgdgd}t| || jdtdd}tj| d||
dd\}\}}|j}t|}t|}|jd||dd}|jd||dd}t||dg}t||dg}|d||||}|jd||dd}t|s|d||||}t||dg}|d||}|dkrtj||dgdd}n|dkr|jd|dgdd}n
|jd|dgdd}|jd|
tj	j
d}t|| t|| |  d d d fS )Nz7embedding_bag with scale_grad_by_freq for training moder   zembedding_bag with padding_idxr^   r   ra   rp   rq   r   r   )r   r   r   stepsLoopry  r   r   r0  r%  r  
ReduceMean	ReduceMax)r   export_trainingr	   _onnx_unsupportedRuntimeErrorrj   rk   rl   rF  rG  BOOLr   rt  r   r   r   r   r  r  r   _add_input_to_blockrs   r  r  r   r  )rW   embedding_matrixr  offsetsscale_grad_by_freqr   sparseper_sample_weightsinclude_last_offsetpadding_idxloop_conditionzeroindices_lenoffsets_startsoffsets_endsloop_lenlooploop_contextr#  
loop_blockblock_input_iterr  indices_startindices_endindices_row
embeddingsper_sample_weights_rowcond_outrM   rM   rR   r      s~   









r    zaten::embedding_renormc              	   C  s   |  d|}|  d||}t|}|dkrd}n|dkrd}n
td| d|| j ||dgdd	}|  d
|| j dtdd}	t|}|  d||	}
|  d||
}|  d|  d||||}|  d|t| |dg|S )Nr  r   r   ReduceL1r?  ReduceL2z8Unsupported: ONNX export of embedding_renorm with norm: z. Only 1. and 2. are supported.r  r/  r^   gHz>ra   r  r%  WhereGreaterr   )rj   r   r   r   rk   rl   r	   r   )rW   weightr  max_norm	norm_typeunique_indicespartial_weightnorm_ipartial_weight_normpartial_weight_norm_scalespartial_weight_renormrM   rM   rR   r!   {  s<   

r!   zaten::chunkc              
   C  s   | j d|  d||dd}|  d|| j dtjdgtjdd	}|  d
|  d|||}t| ||d |  d||  d||g}| j dg|R ddi}t| |||S )Nr   r   r   r   r>  r^   r   r_   ra   r  r/  r%  r   r   )rj   rk   rl   r1  rm   r   rB   )rW   rY   chunksr   r  chunk_size_s
chunk_size	chunk_vecrM   rM   rR   r     s   r   zaten::normalc	           
      C  sD   |d urt |st| ||d }t| || d|}	t| |	|S )NRandomNormalLike)r	   rs   rm   r   r*  rj   r   )
rW   meanstdsizes	generatorr`   layoutdevice
pin_memoryr   rM   rM   rR   r3     s   r3   zaten::atleast_1dtorch._C.Valuec              
   C  s   t |r?t |r?t |}g }|D ]"}|}t |}|dkr0t | || jdtdgd}|	| q| jdg|R  S t |}|dkrXt | || jdtdgd}|S )Nr   r^   r   ra   SequenceConstruct)
r	   r   r   r   ry   r   rj   rk   rl   r   rW   rY   r   new_tensor_listrl   
new_tensortensor_rankrM   rM   rR   r     s$   


r   zaten::atleast_2dc                 C  s   t |rNt |rNt |}g }|D ]1}|}t |}|dkr2t | || jdtddgd}n|dkr?t j	| |dgd}|
| q| jdg|R  S t |}|dkrjt | || jdtddgd}|S |dkrwt j	| |dgd}|S )Nr   r^   r   ra   r  r  r	   r   r   r   ry   r   rj   rk   rl   r   r   r  rM   rM   rR   r     s2   


r   zaten::atleast_3dc                 C  sP  t |ret |ret |}g }|D ]H}|}t |}|dkr2t | || jdtg dd}n$|dkrIt j	| |dgd}t j	| |dgd}n|dkrVt j	| |dgd}|
| q| jd	g|R  S t |}|dkrt | || jdtg dd}|S |dkrt j	| |dgd}t j	| |dgd}|S |dkrt j	| |dgd}|S )
Nr   r^   )r   r   r   ra   r   r  r   r?  r  r  r  rM   rM   rR   r   
  sH   


r   zprim::ConstantChunkc              
   C  s  |  d|}| j dtj|gtjdd}| j d||dd}| j dtjdgtjdd}| j dtj|gtjdd}| j dtj|d gtjdd}	|  d	||	}
|  d
|
|}g }t|D ]'}| j dtj|d gtjdd}|  d||}||  d|||| |}q]|S )Nr   r^   r_   ra   r   r   r   r   r/  r  r%  r0  )rj   rk   rl   r1  r   r   )rW   rY   r  r   input_shaper9  input_shape_dimr8  r  chunk_size_minus_1input_shape_dim_shift	chunk_dimr:  r   r(   r;  rM   rM   rR   r7   7  s"    r7   zaten::hstackr   c              
   C  s   t | |}| d|| jdtjdtjdd}| d|}| d|}| jdtjdtjdd}| d	||}tj| d
|ddd\}\}}	}
|jd|ddd}t|j	| |	jd|ddd}t|	j	| |
  }|S )Nr   r^   r   r_   ra   r   r=  r   rw  rx  r?  )rz  r  r  r  )r   rj   rk   rl   r1  r   r  r   r  r  r   r  )rW   r   first_tensorfirst_tensor_shapefirst_tensor_dimr  equal_to_oneif_op_greaterif_context_equalelse_context_equalr#  	result_ifresult_elser   rM   rM   rR   r%   M  s2   
r%   zaten::vstackc                 C  s   t | |}| jd|dddS )Nr  r   r  )r   rj   )rW   r   rM   rM   rR   rI   n  s   
rI   )rW   rX   rY   rZ   r[   r\   r]   r\   )rW   rX   )F)r   r   r   r   r   r   rL   r.  )r   N)
rW   rX   r   rZ   r4   rZ   r   rZ   rU  rZ   )rW   rX   r   r  r  r  )NNNNNN)rW   rX   rY   r  )rW   rX   r   rZ   )k__doc__
__future__r   	functoolsr   r  typingr   r   rk   r   torch._Cr   rF  
torch.onnxr   r   r	   r
   rI  r   rm   r   torch.onnx._globalsr   torch.onnx._internalr   r   r   __all__partialonnx_symbolic_onnx_symbolicrT   quantized_args
parse_argsrw   r$   r   r   r   r9   r>   r*   r5   r   r   r#   r=   r   r0   r/   r   r   r   r   r   r+   r6   r   r   rD   r  rG   rE   r@   r   r<   r:   rB   rA   rF   rO  r   r8   r;   r4   r,   r.   r   rr  r?   rC   rH   r1   r(   r'   r)   r  r  r  r  r  r&   r2   r"   r-   r    r!   r   r3   r   r   r   r7   r%   rI   rM   rM   rM   rR   <module>   s@    <	#




	
$9G

2
  +3^% +