mamkit.modules package#
Submodules#
mamkit.modules.rnn module#
- class mamkit.modules.rnn.LSTMStack(input_size, lstm_weigths, return_hidden=True)#
Bases:
Module
- forward(x)#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool#
mamkit.modules.transformer module#
- class mamkit.modules.transformer.CustomEncoder(d_model, ffn_hidden, n_head, n_layers, drop_prob)#
Bases:
Module
Encoder Class
- forward(embedding, text_mask, audio_mask)#
- Parameters:
embedding – input tensor
text_mask – mask for text sequence
audio_mask – mask for audio sequence
- training: bool#
- class mamkit.modules.transformer.CustomEncoderLayer(d_model, ffn_hidden, n_head, drop_prob)#
Bases:
Module
Encoder Layer Class
- forward(x, text_mask, audio_mask)#
- Parameters:
x – input tensor
text_mask – mask for text sequence
audio_mask – mask for audio sequence
- training: bool#
- class mamkit.modules.transformer.CustomMultiHeadAttention(d_model, n_head)#
Bases:
Module
Multi Head Attention Class for Transformer
- concat(tensor)#
inverse function of self.split(tensor : torch.Tensor) :param tensor: [batch_size, head, length, d_tensor]
- forward(q, k, v, text_mask, audio_mask)#
- Parameters:
q – query (decoder)
k – key (encoder)
v – value (encoder)
text_mask – mask for text sequence
audio_mask – mask for audio sequence
- split(tensor)#
split tensor by number of head
- Parameters:
tensor – [batch_size, length, d_model]
- training: bool#
- class mamkit.modules.transformer.CustomScaleDotProductAttention#
Bases:
Module
compute scale dot product attention
Query : given sentence that we focused on (decoder) Key : every sentence to check relationship with Qeury(encoder) Value : every sentence same with Key (encoder)
- forward(q, k, v, text_mask, audio_mask, e=1e-12)#
- Parameters:
q – query (decoder)
k – key (encoder)
v – value (encoder)
text_mask – mask for text sequence
audio_mask – mask for audio sequence
e – epsilon value for masking
- training: bool#
- class mamkit.modules.transformer.LayerNorm(d_model, eps=1e-12)#
Bases:
Module
Layer Normalization Class
- forward(x)#
- Parameters:
x – input tensor
- training: bool#
- class mamkit.modules.transformer.MulTA_CrossAttentionBlock(embedding_dim, d_ffn, num_heads=4, dropout_prob=0.1)#
Bases:
Module
Class for the cross modal attention block
- forward(elem_a, elem_b, attn_mask)#
Forward pass of the model :param elem_a: elements of the modality A :param elem_b: elements of the modality B :param attn_mask: attention mask to use
- training: bool#
- class mamkit.modules.transformer.PositionalEncoding(d_model, dual_modality=False, dropout=0.1, max_len=5000)#
Bases:
Module
Positional Encoding for Transformer
- forward(x, is_first=True)#
- Parameters:
x – input tensor (bs, sqlen, emb)
is_first – True if the first modality, False if the second modality
- training: bool#