class SoftEmbedding(nn.Module):
random_range: float = 0.5,
initialize_from_vocab: bool = True):
"""appends learned embedding to
wte (nn.Embedding): original transformer word embedding
n_tokens (int, optional): number of tokens for task. Defaults to 10.
random_range (float, optional): range to init embedding (if not initialize from vocab). Defaults to 0.5.
initialize_from_vocab (bool, optional): initalizes from default vocab. Defaults to True.
super(SoftEmbedding, self).__init__()
self.n_tokens = n_tokens
self.learned_embedding = nn.parameter.Parameter(self.initialize_embedding(wte,
def initialize_embedding(self,
random_range: float = 0.5,
initialize_from_vocab: bool = True):
"""initializes learned embedding
torch.float: initialized using original schemes
if initialize_from_vocab:
return self.wte.weight[:n_tokens].clone().detach()
return torch.FloatTensor(n_tokens, wte.weight.size(1)).uniform_(-random_range, random_range)
def forward(self, tokens):
tokens (torch.long): input tokens before encoding
torch.float: encoding of text concatenated with learned task specifc embedding
input_embedding = self.wte(tokens[:, self.n_tokens:])
learned_embedding = self.learned_embedding.repeat(input_embedding.size(0), 1, 1)
return torch.cat([learned_embedding, input_embedding], 1)