o
    h                     @   sf  d Z ddlZddlZddlmZmZmZ ddlZddl	Z	ddl
Z	ddl	mZ ddlmZ ddlmZ ddlmZ dd	lmZmZmZ dd
lmZ ddlmZmZmZmZ ddlmZ ee Z!dZ"dZ#dZ$g dZ%dZ&dZ'dZ(dZ)dZ*g dZ+		dDdee,e,f de-de,dee	j. de,dej/fddZ0G dd  d ej1Z2G d!d" d"ej1Z3G d#d$ d$ej1Z4G d%d& d&ej1Z5G d'd( d(ej1Z6G d)d* d*ej1Z7G d+d, d,ej1Z8G d-d. d.e8Z9G d/d0 d0ej1Z:G d1d2 d2ej1Z;G d3d4 d4ej1Z<G d5d6 d6ej1Z=G d7d8 d8eZ>d9Z?d:Z@ed;e?G d<d= d=e>ZAed>e?G d?d@ d@e>ZBedAe?G dBdC dCe>ZCdS )Ez PyTorch SEW model.    N)OptionalTupleUnion)nn)CrossEntropyLoss   )ACT2FN)is_deepspeed_zero3_enabled)BaseModelOutputCausalLMOutputSequenceClassifierOutput)PreTrainedModel)add_code_sample_docstringsadd_start_docstrings%add_start_docstrings_to_model_forwardlogging   )	SEWConfigr   zasapp/sew-tiny-100k-ft-ls100h)r   i$  i   z_'MISTER QUILTER IS THE APPOSTILE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPOLLE'gzG?z(anton-l/sew-mid-100k-ft-keyword-spottingz'_unknown_'g
ףp=
#@)zasapp/sew-tiny-100kzasapp/sew-small-100kzasapp/sew-mid-100kshape	mask_probmask_lengthattention_mask	min_masksreturnc                    s  | \}dk rt dkrt d d dtjd   fdd}|dur:|d	  n
fd
dt|D }tj	|ft
d}g }	|}
|
dkrZ|S |D ];}||}tjjt|d  |dd}t|dkr}d }n|d }t|tj|
| tjd| g}|	| q\t|	}	t|	dddddf ||
f}	|	||
 }	tddddf }t|||
f||
 }|	| }	|	 d krd |	|	d k< t||	dd	 |S )af  
    Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
    ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
    CPU as part of the preprocessing during training.

    Args:
        shape: The shape for which to compute masks. This should be of a tuple of size 2 where
               the first element is the batch size and the second element is the length of the axis to span.
        mask_prob:  The percentage of the whole axis (between 0 and 1) which will be masked. The number of
                    independently generated mask spans of length `mask_length` is computed by
                    `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
                    actual percentage will be smaller.
        mask_length: size of the mask
        min_masks: minimum number of masked spans
        attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
                        each batch dimension.
    r   z&`mask_length` has to be bigger than 0.zO`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: z and `sequence_length`: `c                    sX   t |     }t|}| kr }| d  |k r*t| d  d}|S )z;Given input length, compute how many spans should be maskedr   r   )intmax)input_lengthnum_masked_spanepsilonr   r   r   sequence_length Z/var/www/html/ai/venv/lib/python3.10/site-packages/transformers/models/sew/modeling_sew.pycompute_num_masked_spani   s   
z6_compute_mask_indices.<locals>.compute_num_masked_spanNc                    s   g | ]} qS r"   r"   .0_)r!   r"   r#   
<listcomp>|   s    z)_compute_mask_indices.<locals>.<listcomp>dtyper   F)replace)
ValueErrornprandomranditemsumdetachtolistrangezerosboolchoicearangelenconcatenateonesint32appendarraybroadcast_toreshaper   put_along_axis)r   r   r   r   r   
batch_sizer$   input_lengthsspec_aug_maskspec_aug_mask_idxsmax_num_masked_spanr   r   spec_aug_mask_idxdummy_mask_idxoffsetsr"   r   r#   _compute_mask_indicesC   s\   

rK   c                       &   e Zd Zd fdd	Zdd Z  ZS )SEWNoLayerNormConvLayerr   c                    sj   t    |dkr|j|d  nd| _|j| | _tj| j| j|j| |j| |j	d| _
t|j | _d S )Nr   r   kernel_sizestridebias)super__init__conv_dimin_conv_dimout_conv_dimr   Conv1dconv_kernelconv_stride	conv_biasconvr   feat_extract_activation
activationselfconfiglayer_id	__class__r"   r#   rS      s   
z SEWNoLayerNormConvLayer.__init__c                 C   s   |  |}| |}|S N)r[   r]   r_   hidden_statesr"   r"   r#   forward   s   

zSEWNoLayerNormConvLayer.forwardr   __name__
__module____qualname__rS   rg   __classcell__r"   r"   rb   r#   rM      s    rM   c                       rL   )SEWLayerNormConvLayerr   c                    s|   t    |dkr|j|d  nd| _|j| | _tj| j| j|j| |j| |j	d| _
tj| jdd| _t|j | _d S )Nr   r   rN   T)elementwise_affine)rR   rS   rT   rU   rV   r   rW   rX   rY   rZ   r[   	LayerNorm
layer_normr   r\   r]   r^   rb   r"   r#   rS      s   
zSEWLayerNormConvLayer.__init__c                 C   s:   |  |}|dd}| |}|dd}| |}|S )Nr%   )r[   	transposerq   r]   re   r"   r"   r#   rg      s   


zSEWLayerNormConvLayer.forwardrh   ri   r"   r"   rb   r#   rn      s    rn   c                       rL   )SEWGroupNormConvLayerr   c                    s   t    |dkr|j|d  nd| _|j| | _tj| j| j|j| |j| |j	d| _
t|j | _tj| j| jdd| _d S )Nr   r   rN   T)
num_groupsnum_channelsaffine)rR   rS   rT   rU   rV   r   rW   rX   rY   rZ   r[   r   r\   r]   	GroupNormrq   r^   rb   r"   r#   rS      s   
zSEWGroupNormConvLayer.__init__c                 C   "   |  |}| |}| |}|S rd   )r[   rq   r]   re   r"   r"   r#   rg      s   


zSEWGroupNormConvLayer.forwardrh   ri   r"   r"   rb   r#   rt      s    rt   c                       $   e Zd Z fddZdd Z  ZS )SEWPositionalConvEmbeddingc                    s   t    tj|j|j|j|jd |j|jd| _t	 rXdd l
}|jj| jjdd tjj| jddd| _W d    n1 s@w   Y  |j| | jj |j| | jj ntjj| jddd| _t|j| _t|j | _d S )N   )rO   paddinggroupsrP   r   modifier_rankweight)namedim)rR   rS   r   rW   hidden_sizenum_conv_pos_embeddingsnum_conv_pos_embedding_groupssqueeze_factorr[   r	   	deepspeedzeroGatheredParametersr   utilsweight_normregister_external_parameterweight_vweight_gSEWSamePadLayerr}   r   r\   r]   )r_   r`   r   rb   r"   r#   rS     s&   
	z#SEWPositionalConvEmbedding.__init__c                 C   ry   rd   )r[   r}   r]   re   r"   r"   r#   rg     s   


z"SEWPositionalConvEmbedding.forwardri   r"   r"   rb   r#   r{     s    r{   c                       rz   )r   c                    s*   t    |d dkrd| _d S d| _d S )Nr|   r   r   )rR   rS   num_pad_remove)r_   r   rb   r"   r#   rS   (  s   
 zSEWSamePadLayer.__init__c                 C   s,   | j dkr|d d d d d | j  f }|S Nr   )r   re   r"   r"   r#   rg   ,  s   
zSEWSamePadLayer.forwardri   r"   r"   rb   r#   r   '  s    r   c                       rz   )SEWUpsamplingc                    s:   t    t|j|j|j | _t|j | _	|j| _d S rd   )
rR   rS   r   Linearr   r   
projectionr   r\   r]   r_   r`   rb   r"   r#   rS   3  s   
zSEWUpsampling.__init__c                 C   sd   |  |}| |}| jdkr0| \}}}|| j }|| j }|||| j|}||||}|S )Nr   )r   r]   r   sizerA   )r_   rf   bszsrc_lensrc_embed_dimtgt_lentgt_embed_dimr"   r"   r#   rg   9  s   




zSEWUpsampling.forwardri   r"   r"   rb   r#   r   2  s    r   c                       s0   e Zd ZdZ fddZdd Zdd Z  ZS )SEWFeatureEncoderz.Construct the features from raw audio waveformc                    s   t     jdkr t ddg fddt jd D  }n jdkr2 fddt jD }n	td	 j d
t|| _	d| _
d| _d S )Ngroupr   ra   c                    s   g | ]
}t  |d  dqS )r   r   )rM   r'   ir`   r"   r#   r)   P  s    z.SEWFeatureEncoder.__init__.<locals>.<listcomp>r   layerc                    s   g | ]}t  |d qS )r   )rn   r   r   r"   r#   r)   T  s    z`config.feat_extract_norm` is z), but has to be one of ['group', 'layer']FT)rR   rS   feat_extract_normrt   r5   num_feat_extract_layersr-   r   
ModuleListconv_layersgradient_checkpointing_requires_grad)r_   r`   r   rb   r   r#   rS   L  s   




zSEWFeatureEncoder.__init__c                 C   s   |   D ]}d|_qd| _d S NF)
parametersrequires_gradr   r_   paramr"   r"   r#   _freeze_parameters]  s   
z$SEWFeatureEncoder._freeze_parametersc                 C   s\   |d d d f }| j r| jrd|_| jD ]}| j r'| jr'| jr'| |j|}q||}q|S )NT)r   trainingr   r   r   _gradient_checkpointing_func__call__)r_   input_valuesrf   
conv_layerr"   r"   r#   rg   b  s   

zSEWFeatureEncoder.forward)rj   rk   rl   __doc__rS   r   rg   rm   r"   r"   rb   r#   r   I  s
    r   c                       s   e Zd Z fddZ  ZS )SEWFeatureExtractorc                    s8   t  | td| jj d| jjd j dt d S )NzThe class `zD` has been depreciated and will be removed in Transformers v5. Use `r   z
` instead.)rR   rS   warningswarnrc   rj   	__bases__FutureWarningr   rb   r"   r#   rS   v  s   zSEWFeatureExtractor.__init__)rj   rk   rl   rS   rm   r"   r"   rb   r#   r   u  s    r   c                       s   e Zd ZdZ					ddededed	ed
ededee f fddZ	de
jdedefddZ					dde
jdee
j deee
j  dee
j dee
j dedee
jee
j eee
j  f fddZ  ZS )SEWAttentionz=Multi-headed attention from 'Attention Is All You Need' paper        FTN	embed_dim	num_headsdropout
is_decoderrQ   	is_causalr`   c                    s   t    || _|| _|| _|| | _|| _| j| | jkr*td| j d| d| jd | _|| _	|| _
tj|||d| _tj|||d| _tj|||d| _tj|||d| _d S )Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: z).g      )rQ   )rR   rS   r   r   r   head_dimr`   r-   scalingr   r   r   r   k_projv_projq_projout_proj)r_   r   r   r   r   rQ   r   r`   rb   r"   r#   rS     s&   



zSEWAttention.__init__tensorseq_lenr   c                 C   s    | ||| j| jdd S )Nr   r|   )viewr   r   rs   
contiguous)r_   r   r   r   r"   r"   r#   _shape  s    zSEWAttention._shaperf   key_value_statespast_key_valuer   layer_head_maskoutput_attentionsr   c                 C   sr  |du}|  \}}	}
| || j }|r.|dur.|d jd |jd kr.|d }|d }nZ|rE| | |d|}| | |d|}nC|durt| | |d|}| | |d|}tj|d |gdd}tj|d |gdd}n| | |d|}| | |d|}| j	r||f}|| j
 d| jf}| ||	|j| }|j| }|j| }| d}t||dd}|  || j
 |	|fkrtd|| j
 |	|f d|   |dur|  |d|	|fkrtd	|d|	|f d|   ||| j
|	|| }||| j
 |	|}tjj|dd}|durL|  | j
fkr1td
| j
f d|   |dddd||| j
|	| }||| j
 |	|}|rc||| j
|	|}||| j
 |	|}nd}tjj|| j| jd}t||}|  || j
 |	| jfkrtd|| j
 |	| jf d|   ||| j
|	| j}|dd}|||	| j}| |}|||fS )z#Input shape: Batch x Time x ChannelNr   r|   r   r%   r   z$Attention weights should be of size z	, but is z!Attention mask should be of size z/Head mask for a single layer should be of size )pr   z `attn_output` should be of size )r   r   r   r   r   r   r   torchcatr   r   r   r   rA   bmmrs   r-   r   
functionalsoftmaxr   r   r   r   )r_   rf   r   r   r   r   r   is_cross_attentionr   r   r(   query_states
key_statesvalue_states
proj_shaper   attn_weightsattn_weights_reshaped
attn_probsattn_outputr"   r"   r#   rg     s   





"

zSEWAttention.forward)r   FTFN)NNNNF)rj   rk   rl   r   r   floatr7   r   r   rS   r   Tensorr   r   rg   rm   r"   r"   rb   r#   r     sV    r   c                       rz   )SEWFeedForwardc                    sp   t    t|j| _t|j|j| _	t
|jtr"t|j | _n|j| _t|j|j| _t|j| _d S rd   )rR   rS   r   Dropoutactivation_dropoutintermediate_dropoutr   r   intermediate_sizeintermediate_dense
isinstance
hidden_actstrr   intermediate_act_fnoutput_densehidden_dropoutoutput_dropoutr   rb   r"   r#   rS   !  s   
zSEWFeedForward.__init__c                 C   s6   |  |}| |}| |}| |}| |}|S rd   )r   r   r   r   r   re   r"   r"   r#   rg   .  s   




zSEWFeedForward.forwardri   r"   r"   rb   r#   r      s    r   c                       s&   e Zd Z fddZdddZ  ZS )SEWEncoderLayerc                    sf   t    t|j|j|jdd| _t|j	| _
tj|j|jd| _t|| _tj|j|jd| _d S )NF)r   r   r   r   eps)rR   rS   r   r   num_attention_headsattention_dropout	attentionr   r   r   r   rp   layer_norm_epsrq   r   feed_forwardfinal_layer_normr   rb   r"   r#   rS   :  s   

zSEWEncoderLayer.__init__NFc                 C   sf   |}| j |||d\}}}| |}|| }| |}|| | }| |}|f}|r1||f7 }|S )Nr   r   )r   r   rq   r   r   )r_   rf   r   r   attn_residualr   r(   outputsr"   r"   r#   rg   G  s   



zSEWEncoderLayer.forwardr   ri   r"   r"   rb   r#   r   9  s    r   c                       s.   e Zd Z fddZ				dddZ  ZS )	
SEWEncoderc                    s   t     | _t | _t j j| _tj	 j
 jd| _t j| _t fddt jD | _t | _d| _d S )Nr   c                    s   g | ]}t  qS r"   )r   r&   r   r"   r#   r)   c  s    z'SEWEncoder.__init__.<locals>.<listcomp>F)rR   rS   r`   r{   pos_conv_embedr   	AvgPool1dr   poolrp   r   r   rq   r   r   r   r   r5   num_hidden_layerslayersr   upsampler   r   rb   r   r#   rS   \  s   

 

zSEWEncoder.__init__NFTc              	   C   s  |rdnd }|r
dnd }|d urvd|| < |  d}|| jj }	|jd | jj }
tjd|
|	jddd	|	jd d}||	ddk   }d|d d d d d d f j
|jd }|t|jj }|	|jd d|jd |jd }|jd }|dd	}| |}| |}t|d|d}|d
d |f |d
d |f  }|dd	}| |}| |}t }| jD ]H}|r||f }tg }| jr|| jjk rdnd}|r|r| jr| jr| |j|||}n||||d}|d }|rd}|r||d f }q|r||f }| |}|jd |k r*tj|ddd||jd  f}|s9t dd |||fD S t!|||dS )Nr"   r   r%   r   r   device      ?r*   r|   .TFr   NNc                 s   s    | ]	}|d ur|V  qd S rd   r"   )r'   vr"   r"   r#   	<genexpr>  s    z%SEWEncoder.forward.<locals>.<genexpr>last_hidden_staterf   
attentions)"longr2   r`   r   r   r   r9   r  r   expandtor+   finfominrs   r   r   r   rq   r   r	   r  r0   r   	layerdropr   r   r   r  r   r   padtupler
   )r_   rf   r   r   output_hidden_statesreturn_dictall_hidden_statesall_self_attentionsrD   output_lengthsmax_encoder_lengthattention_idsn_input_timestepsposition_embeddingspooled_hidden_states
min_lengthdeepspeed_zero3_is_enabledr   dropout_probabilityskip_the_layerlayer_outputsr"   r"   r#   rg   g  sz   
&


 






 zSEWEncoder.forward)NFFTri   r"   r"   rb   r#   r   [  s    r   c                   @   sT   e Zd ZdZeZdZdZdZdd Z	de
ejef fdd	Zd
edejfddZdS )SEWPreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    sewr   Tc              	   C   s  t |tr)tjj|jjddtd|jj	d |jj
   d tj|jjd nt |tjr;|jjjd| jjd n}t |tjtjfrR|jj  |jjd nft |tjrt rddl}t|drt|d	r|jj|j|jgdd
 tj|jj W d   n1 sw   Y  n*|jj|jdd
 tj|jj W d   n1 sw   Y  ntj|jj t |tjtjfr|jdur|jj  dS dS dS )zInitialize the weightsr   r|   r   )meanstdr   r  Nr   r   r   )r   r{   r   initnormal_r[   r   mathsqrtrO   in_channels	constant_rQ   r   datar`   initializer_rangerp   rx   zero_fill_rW   r	   r   hasattrr   r   r   r   kaiming_normal_)r_   moduler   r"   r"   r#   _init_weights  s8   
 z SEWPreTrainedModel._init_weightsrD   c                 C   s4   dd }t | jj| jjD ]
\}}||||}q|S )zH
        Computes the output length of the convolutional layers
        c                 S   s   t j| | |ddd S )Nfloor)rounding_moder   )r   div)r   rO   rP   r"   r"   r#   _conv_out_length  s   zMSEWPreTrainedModel._get_feat_extract_output_lengths.<locals>._conv_out_length)zipr`   rX   rY   )r_   rD   r8  rO   rP   r"   r"   r#    _get_feat_extract_output_lengths  s   z3SEWPreTrainedModel._get_feat_extract_output_lengthsfeature_vector_lengthr   c                 C   s~   |  |dtj}|jd }tj||f|j|jd}d|tj	|jd |jd|d f< |
dgd
dg }|S )Nr%   r   )r+   r  r   r  )r:  r2   r  r   r  r   r6   r+   r  r9   flipcumsumr7   )r_   r;  r   r  rC   r"   r"   r#   "_get_feature_vector_attention_mask  s   
"z5SEWPreTrainedModel._get_feature_vector_attention_maskN)rj   rk   rl   r   r   config_classbase_model_prefixmain_input_namesupports_gradient_checkpointingr4  r   r   
LongTensorr   r:  r>  r"   r"   r"   r#   r#    s     r#  a  
    SEW was proposed in [Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech
    Recognition](https://arxiv.org/abs/2109.06870) by Felix Wu, Kwangyoun Kim, Jing Pan, Kyu Han, Kilian Q. Weinberger,
    Yoav Artzi.

    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving etc.).

    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
    it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.

    Parameters:
        config ([`SEWConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
a  
    Args:
        input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
            Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
            into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
            soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
            conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details.
        attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
            1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)

        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
z]The bare SEW Model transformer outputting raw hidden-states without any specific head on top.c                       s   e Zd Zdef fddZ		ddejdeej deej fdd	Z	e
eeeeed
ed					ddeej deej deej dee dee dee deeef fddZ  ZS )SEWModelr`   c                    s   t  | || _t|| _tj|jd |jd| _	|jd |j
k| _| jr1t|jd |j
| _t|j| _|jdksB|jdkrNtt|j
 | _t|| _|   d S )Nr%   r   r   )rR   rS   r`   r   feature_extractorr   rp   rT   r   rq   r   project_featuresr   feature_projectionr   feat_proj_dropoutfeature_dropoutmask_time_probmask_feature_prob	Parameterr   FloatTensoruniform_masked_spec_embedr   encoder	post_initr   rb   r"   r#   rS   =  s   

zSEWModel.__init__Nrf   mask_time_indicesr   c                 C   s  t | jdds	|S | \}}}|dur| j|j||< n-| jjdkrK| jrKt||f| jj| jj	|| jj
d}tj||jtjd}| j|j||< | jjdkr| jrt||f| jj| jj| jjd}tj||jtjd}|dddf d|d}d||< |S )	z
        Masks extracted features along time axis and/or along feature axis according to
        [SpecAugment](https://arxiv.org/abs/1904.08779).
        apply_spec_augmentTNr   )r   r   r   r   )r  r+   )r   r   r   r%   )getattrr`   r   rO  r  r+   rJ  r   rK   mask_time_lengthmask_time_min_masksr   r   r  r7   rK  mask_feature_lengthmask_feature_min_masksr  )r_   rf   rR  r   rC   r!   r   mask_feature_indicesr"   r"   r#   _mask_hidden_statesQ  s4   zSEWModel._mask_hidden_statesaudio)
checkpointoutput_typer?  modalityexpected_outputr   r   r  r  r   c           
      C   s   |d ur|n| j j}|d ur|n| j j}|d ur|n| j j}| |}|dd}| |}| jr6| |}| 	|}|d urH| 
|jd |}| j||d}| j|||||d}	|	d }|sh|f|	dd   S t||	j|	jdS )Nr   r|   )rR  r   r   r  r  r   r	  )r`   r   r  use_return_dictrE  rs   rq   rF  rG  rI  r>  r   rZ  rP  r
   rf   r  )
r_   r   r   rR  r   r  r  extract_featuresrf   encoder_outputsr"   r"   r#   rg     s8   



zSEWModel.forwardr  NNNNN)rj   rk   rl   r   rS   r   rM  r   rC  rZ  r   SEW_INPUTS_DOCSTRINGr   _CHECKPOINT_FOR_DOCr
   _CONFIG_FOR_DOC_EXPECTED_OUTPUT_SHAPEr   r7   r   r   rg   rm   r"   r"   rb   r#   rD  8  sN    
.

rD  zaSEW Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).c                       s   e Zd Zddee f fddZdd Zdd Zd	d
 Zdd Z	e
eeeeeeed					ddeej deej dee dee dee deej deeef fddZ  ZS )	SEWForCTCNtarget_langc                    s~   t  | t|| _t|j| _|| _|j	d u r#t
d| j dt|dr.|jr.|jn|j}t||j	| _|   d S )NzYou are trying to instantiate z with a configuration that does not define the vocabulary size of the language model head. Please instantiate the model as follows: `SEWForCTC.from_pretrained(..., vocab_size=vocab_size)`. or define `vocab_size` of your model's configuration.add_adapter)rR   rS   rD  r$  r   r   final_dropoutr   rj  
vocab_sizer-   rc   r1  rk  output_hidden_sizer   r   lm_headrQ  )r_   r`   rj  rn  rb   r"   r#   rS     s   

zSEWForCTC.__init__c                 C   sv   | j }|durt| jdddu rtd| d|du r,t| jdddur,td dS |dur9| j|dd dS dS )a'  
        This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
        passing `target_lang=...` to `from_pretrained(...)`.

        This method is **not** supposed to be called by the user and is prone to be changed in the future.
        Nadapter_attn_dimzCannot pass `target_lang`: z- if `config.adapter_attn_dim` is not defined.z)By default `target_lang` is set to 'eng'.T)
force_load)rj  rT  r`   r-   loggerinfoload_adapter)r_   rj  r"   r"   r#   tie_weights  s   zSEWForCTC.tie_weightsc                 C      t dt |   dS )
        Calling this function will disable the gradient computation for the feature encoder so that its parameter will
        not be updated during training.
        The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. Please use the equivalent `freeze_feature_encoder` method instead.Nr   r   r   freeze_feature_encoderr_   r"   r"   r#   freeze_feature_extractor  
   z"SEWForCTC.freeze_feature_extractorc                 C      | j j  dS rw  Nr$  rE  r   r{  r"   r"   r#   rz       z SEWForCTC.freeze_feature_encoderc                 C      | j  D ]}d|_qdS z
        Calling this function will disable the gradient computation for the base model so that its parameters will not
        be updated during training. Only the classification head will be updated.
        FNr$  r   r   r   r"   r"   r#   freeze_base_model     zSEWForCTC.freeze_base_model)r\  r]  r?  r_  expected_lossr   r   r   r  r  labelsr   c              
   C   st  |dur|n| j j}| j|||||d}|d }| |}| |}	d}
|dur| | j jkr9td| j j |dur?|ntj	|tj
d}| |dtj
}|dk}|d}||}tjj|	dtjddd}tjjjd	d
 tjj||||| j j| j j| j jd}
W d   n1 sw   Y  |s|	f|td  }|
dur|
f| S |S t|
|	|j|jdS )a  
        labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
            Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
            the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
            All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
            config.vocab_size - 1]`.
        Nr`  r   z$Label values must be <= vocab_size: r*   r%   )r   r+   r   F)enabled)blank	reductionzero_infinitylosslogitsrf   r  )r`   ra  r$  r   ro  r   rm  r-   r   	ones_liker  r:  r2   r  masked_selectr   r   log_softmaxfloat32rs   backendscudnnflagsctc_losspad_token_idctc_loss_reductionctc_zero_infinity_HIDDEN_STATES_START_POSITIONr   rf   r  )r_   r   r   r   r  r  r  r   rf   r  r  rD   labels_masktarget_lengthsflattened_targets	log_probsoutputr"   r"   r#   rg     sN   



zSEWForCTC.forwardrd   rd  )rj   rk   rl   r   r   rS   ru  r|  rz  r  r   re  r   rf  r   rg  _CTC_EXPECTED_OUTPUT_CTC_EXPECTED_LOSSr   r   r7   r   r   rg   rm   r"   r"   rb   r#   ri    sD    

ri  z
    SEW Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like SUPERB
    Keyword Spotting.
    c                       s   e Zd Z fddZdd Zdd Zdd Zeee	e
eed	eed
					ddeej deej dee dee dee deej deeef fddZ  ZS )SEWForSequenceClassificationc                    s   t  | t|dr|jrtdt|| _|jd }|jr*t	
t|| | _t	|j|j| _t	|j|j| _|   d S )Nrk  zZSequence classification does not support the use of SEW adapters (config.add_adapter=True)r   )rR   rS   r1  rk  r-   rD  r$  r   use_weighted_layer_sumr   rL  r   r<   layer_weightsr   r   classifier_proj_size	projector
num_labels
classifierrQ  )r_   r`   
num_layersrb   r"   r#   rS   ^  s   

z%SEWForSequenceClassification.__init__c                 C   rv  )z
        Calling this function will disable the gradient computation for the feature encoder so that its parameters will
        not be updated during training.
        rx  Nry  r{  r"   r"   r#   r|  o  r}  z5SEWForSequenceClassification.freeze_feature_extractorc                 C   r~  r  r  r{  r"   r"   r#   rz  {  r  z3SEWForSequenceClassification.freeze_feature_encoderc                 C   r  r  r  r   r"   r"   r#   r    r  z.SEWForSequenceClassification.freeze_base_modelr[  )r\  r]  r?  r^  r_  r  Nr   r   r   r  r  r  r   c                 C   s`  |dur|n| j j}| j jrdn|}| j|||||d}| j jrB|t }tj|dd}tjj	| j
dd}	||	ddd jdd}n|d }| |}|du rV|jdd}
n| |jd |}d|| < |jdd|jdddd }
| |
}d}|durt }||d| j j|d}|s|f|td  }|dur|f| S |S t|||j|jd	S )
a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        NTr`  r   r   r%   r   r   r  )r`   ra  r  r$  r  r   stackr   r   r   r  r   r2   r  r%  r>  r   r  r   r  r   rf   r  )r_   r   r   r   r  r  r  r   rf   norm_weightspooled_outputpadding_maskr  r  loss_fctr  r"   r"   r#   rg     sF   

 
z$SEWForSequenceClassification.forwardrd  )rj   rk   rl   rS   r|  rz  r  r   re  r   _SEQ_CLASS_CHECKPOINTr   rg  _SEQ_CLASS_EXPECTED_OUTPUT_SEQ_CLASS_EXPECTED_LOSSr   r   r   r7   r   r   rg   rm   r"   r"   rb   r#   r  U  sD    	
r  r   )Dr   r)  r   typingr   r   r   numpyr.   r   torch.utils.checkpointr   torch.nnr   activationsr   integrations.deepspeedr	   modeling_outputsr
   r   r   modeling_utilsr   r   r   r   r   r   configuration_sewr   
get_loggerrj   rr  r  rg  rf  rh  r  r  r  r  r  !SEW_PRETRAINED_MODEL_ARCHIVE_LISTr   r   rC  ndarrayrK   ModulerM   rn   rt   r{   r   r   r   r   r   r   r   r   r#  SEW_START_DOCSTRINGre  rD  ri  r  r"   r"   r"   r#   <module>   s   


x",  "hG| 