mamkit.modules package#
Submodules#
mamkit.modules.rnn module#
- class mamkit.modules.rnn.LSTMStack(input_size, lstm_weigths, return_hidden=True)#
Bases:
Module- forward(x)#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool#
mamkit.modules.transformer module#
- class mamkit.modules.transformer.CustomEncoder(d_model, ffn_hidden, n_head, n_layers, drop_prob)#
Bases:
ModuleEncoder Class
- forward(embedding, text_mask, audio_mask)#
- Parameters:
embedding – input tensor
text_mask – mask for text sequence
audio_mask – mask for audio sequence
- training: bool#
- class mamkit.modules.transformer.CustomEncoderLayer(d_model, ffn_hidden, n_head, drop_prob)#
Bases:
ModuleEncoder Layer Class
- forward(x, text_mask, audio_mask)#
- Parameters:
x – input tensor
text_mask – mask for text sequence
audio_mask – mask for audio sequence
- training: bool#
- class mamkit.modules.transformer.CustomMultiHeadAttention(d_model, n_head)#
Bases:
ModuleMulti Head Attention Class for Transformer
- concat(tensor)#
inverse function of self.split(tensor : torch.Tensor) :param tensor: [batch_size, head, length, d_tensor]
- forward(q, k, v, text_mask, audio_mask)#
- Parameters:
q – query (decoder)
k – key (encoder)
v – value (encoder)
text_mask – mask for text sequence
audio_mask – mask for audio sequence
- split(tensor)#
split tensor by number of head
- Parameters:
tensor – [batch_size, length, d_model]
- training: bool#
- class mamkit.modules.transformer.CustomScaleDotProductAttention#
Bases:
Modulecompute scale dot product attention
Query : given sentence that we focused on (decoder) Key : every sentence to check relationship with Qeury(encoder) Value : every sentence same with Key (encoder)
- forward(q, k, v, text_mask, audio_mask, e=1e-12)#
- Parameters:
q – query (decoder)
k – key (encoder)
v – value (encoder)
text_mask – mask for text sequence
audio_mask – mask for audio sequence
e – epsilon value for masking
- training: bool#
- class mamkit.modules.transformer.LayerNorm(d_model, eps=1e-12)#
Bases:
ModuleLayer Normalization Class
- forward(x)#
- Parameters:
x – input tensor
- training: bool#
- class mamkit.modules.transformer.MulTA_CrossAttentionBlock(embedding_dim, d_ffn, num_heads=4, dropout_prob=0.1)#
Bases:
ModuleClass for the cross modal attention block
- forward(elem_a, elem_b, attn_mask)#
Forward pass of the model :param elem_a: elements of the modality A :param elem_b: elements of the modality B :param attn_mask: attention mask to use
- training: bool#
- class mamkit.modules.transformer.PositionalEncoding(d_model, dual_modality=False, dropout=0.1, max_len=5000)#
Bases:
ModulePositional Encoding for Transformer
- forward(x, is_first=True)#
- Parameters:
x – input tensor (bs, sqlen, emb)
is_first – True if the first modality, False if the second modality
- training: bool#