virtex.modules.embedding
- class virtex.modules.embedding.WordAndPositionalEmbedding(vocab_size: int, hidden_size: int, dropout: float = 0.0, max_caption_length: int = 30, padding_idx: int = 0)[source]
Bases:
torch.nn.modules.module.ModuleA
Modulefor learned word embeddings and position embeddings for input tokens. Each token is mapped to a fixed dimensional word embedding; and corresponding positional embedding based on its index. These are summed together followed by layer normalization and an optional dropout.- Parameters
vocab_size – Size of token vocabulary.
hidden_size – Size of token embedding vectors.
dropout – Probability for final dropout applied after layer normalization.
max_caption_length – Maximum length of input captions; this is used to create a fixed positional embedding lookup table.
padding_idx – Token index of
[PAD]token, word embedding for these tokens will be a vector of zeroes (and not trainable).
- forward(tokens: torch.Tensor) torch.Tensor[source]
Get combined word and positional embeddings for input tokens.
- Parameters
tokens – A tensor of shape
(batch_size, max_caption_length)containing a batch of caption tokens, values in[0, vocab_size).- Returns
A tensor of shape
(batch_size, max_caption_length, hidden_size)containing corresponding token embeddings.