Language Model or Decoder with Generate function
FakeDecoder, FakeLM for Test¶
class FakeDecoder(nn.Module):
''' with memory not support past'''
def __init__(self, tgt_vocab_size):
super().__init__()
self.tgt_vocab_size = tgt_vocab_size
def forward(self, tgt, memory, **kwargs):
'''
inputs: (tgt, memory)
tgt: (b, tgt_seq_len)
memory: (b, src_seq_len, embed_dim)
returns: logits, others
logits: (b, tgt_seq_len, tgt_vocab_size)
'''
assert tgt.shape[0] == memory.shape[0], (tgt.shape[0], memory.shape[0])
logits = torch.randn((*tgt.shape, self.tgt_vocab_size))
return logits, None
class FakeLM(nn.Module):
''' without memory support past '''
def __init__(self, tgt_vocab_size):
super().__init__()
self.tgt_vocab_size = tgt_vocab_size
def forward(self, tgt, past=None, **kwargs):
'''
if past==None:
inputs: (tgt)
tgt: (b, tgt_seq_len)
returns: logits, presents, others
logits: (b, tgt_seq_len, tgt_vocab_size)
presents: List of (2, b, ...)
else:
inputs: (tgt, past)
tgt: (b, 1)
past: List of (2, bs, num_heads, tgt_seq_len-1, ..)
returns: logits, presents, others
logits: (b, tgt_seq_len, tgt_vocab_size)
presents: List of (2, bs, num_heads, tgt_seq_len, ..)
'''
if past is None:
b = tgt.shape[0]
tgt_seq_len = tgt.shape[1]
logits = torch.randn((b, tgt_seq_len, self.tgt_vocab_size))
presents = [torch.randn((2, b, 12, tgt_seq_len, 16))] * 6
else:
b = tgt.shape[0]
tgt_seq_len = past[0].shape[3]+1
logits = torch.randn((b, tgt_seq_len, self.tgt_vocab_size))
presents = [torch.randn((2, b, 12, tgt_seq_len, 16))] * 6
return logits, presents, None
bs = 3
tgt_seq_len = 10
tgt_vocab_size = 20
memory = torch.randn((bs, 9, 9))
past = [torch.randn((2, bs, 12, tgt_seq_len-1, 16))] * 6
pad_token_id=0
eos_token_id=tgt_vocab_size-1
bos_token_id=tgt_vocab_size-2
decoder = FakeDecoder(tgt_vocab_size)
lm = FakeLM(tgt_vocab_size)
tgt = torch.randint(0, tgt_vocab_size-2, (bs, tgt_seq_len))
test_eq(decoder(tgt, memory)[0].shape, (bs, tgt_seq_len, tgt_vocab_size))
test_eq(lm(tgt)[0].shape, (bs, tgt_seq_len, tgt_vocab_size))
test_eq(lm(tgt)[1][0].shape, (2, bs, 12, tgt_seq_len, 16))
tgt = torch.randint(0, tgt_vocab_size-2, (bs, tgt_seq_len))
test_eq(lm(tgt, past)[0].shape, (bs, tgt_seq_len, tgt_vocab_size))
test_eq(lm(tgt, past)[1][0].shape, (2, bs, 12, tgt_seq_len, 16))
GeneratedLM¶
_generate_no_beam_search¶
max_length=20
generate_args = dict(
max_length=max_length,
do_sample=True,
temperature=0.1,
top_k=3,
top_p=0.5,
repetition_penalty=1.2,
)
# With memory, Without past
generated_decoder = GeneratedLM(decoder, tgt_vocab_size, pad_token_id, [eos_token_id], support_past=False)
tgt = torch.zeros((bs, 1), dtype=torch.long).fill_(bos_token_id)
result = generated_decoder._generate_no_beam_search(tgt, **generate_args, model_otherargs=[memory])
test_eq(result.shape[0], bs)
assert result.shape[1] <= max_length and result.shape[1] >= 1
# Without memory, With past
generated_lm = GeneratedLM(lm, tgt_vocab_size, pad_token_id, [eos_token_id], support_past=True)
tgt = torch.randint(0, tgt_vocab_size-2, (bs, tgt_seq_len))
result = generated_lm._generate_no_beam_search(tgt, **generate_args)
test_eq(result.shape[0], bs)
assert result.shape[1] <= max_length and result.shape[1] >= tgt_seq_len
# Without memory, Without past
generated_lm = GeneratedLM(lm, tgt_vocab_size, pad_token_id, [eos_token_id], support_past=False)
tgt = torch.randint(0, tgt_vocab_size-2, (bs, tgt_seq_len))
result = generated_lm._generate_no_beam_search(tgt, **generate_args)
test_eq(result.shape[0], bs)
assert result.shape[1] <= max_length and result.shape[1] >= tgt_seq_len
build_model_otherargs_for_beam¶
model_otherargs = [torch.tensor([[1, 2, 3],
[4, 5, 6]])]
expected = [torch.tensor([[1, 2, 3],
[1, 2, 3],
[4, 5, 6],
[4, 5, 6]])]
result = build_model_otherargs_for_beam(None, model_otherargs, 2)
test_eq(result, expected)
_generate_beam_search¶
max_length=20
generate_args = dict(
max_length=max_length,
do_sample=True,
temperature=0.1,
top_k=3,
top_p=0.5,
repetition_penalty=1.2,
length_penalty=1,
num_beams=4,
vocab_size=tgt_vocab_size,
)
# With memory, Without past
generated_decoder = GeneratedLM(decoder, tgt_vocab_size, pad_token_id, [eos_token_id], support_past=False)
tgt = torch.zeros((bs, 1), dtype=torch.long).fill_(bos_token_id)
model_otherargs = generated_decoder.build_model_otherargs_for_beam([memory], generate_args['num_beams'])
result = generated_decoder._generate_beam_search(tgt, **generate_args, model_otherargs=model_otherargs)
test_eq(result.shape[0], bs)
assert result.shape[1] <= max_length and result.shape[1] >= 1
# Without memory, With past
generated_lm = GeneratedLM(lm, tgt_vocab_size, pad_token_id, [eos_token_id], support_past=True)
tgt = torch.randint(0, tgt_vocab_size-2, (bs, tgt_seq_len))
result = generated_lm._generate_beam_search(tgt, **generate_args)
test_eq(result.shape[0], bs)
assert result.shape[1] <= max_length and result.shape[1] >= tgt_seq_len
# Without memory, Without past
generated_lm = GeneratedLM(lm, tgt_vocab_size, pad_token_id, [eos_token_id], support_past=False)
tgt = torch.randint(0, tgt_vocab_size-2, (bs, tgt_seq_len))
result = generated_lm._generate_beam_search(tgt, **generate_args)
test_eq(result.shape[0], bs)
assert result.shape[1] <= max_length and result.shape[1] >= tgt_seq_len
GenerateArgs¶
generate¶
generate_args = GenerateArgs(
do_sample=True,
num_beams=1,
)
generated_decoder = GeneratedLM(decoder, tgt_vocab_size, pad_token_id, [eos_token_id], support_past=False)
tgt = torch.zeros((bs, 1), dtype=torch.long).fill_(bos_token_id)
result = generated_decoder.generate(tgt, generate_args, model_otherargs=[memory])
test_eq(result.shape[0], bs)
assert result.shape[1] <= max_length and result.shape[1] >= 1
generated_lm = GeneratedLM(lm, tgt_vocab_size, pad_token_id, [eos_token_id], support_past=True)
tgt = torch.randint(0, tgt_vocab_size-2, (bs, tgt_seq_len))
result = generated_lm.generate(tgt, generate_args)
test_eq(result.shape[0], bs)
assert result.shape[1] <= max_length and result.shape[1] >= tgt_seq_len
generated_lm = GeneratedLM(lm, tgt_vocab_size, pad_token_id, [eos_token_id], support_past=False)
tgt = torch.randint(0, tgt_vocab_size-2, (bs, tgt_seq_len))
result = generated_lm.generate(tgt, generate_args)
test_eq(result.shape[0], bs)
assert result.shape[1] <= max_length and result.shape[1] >= tgt_seq_len
generate_args = GenerateArgs(
do_sample=True,
num_beams=2,
)
# With memory, Without past
generated_decoder = GeneratedLM(decoder, tgt_vocab_size, pad_token_id, [eos_token_id], support_past=False)
tgt = torch.zeros((bs, 1), dtype=torch.long).fill_(bos_token_id)
model_otherargs = generated_decoder.build_model_otherargs_for_beam([memory], generate_args.num_beams)
result = generated_decoder.generate(tgt, generate_args, model_otherargs=model_otherargs)
test_eq(result.shape[0], bs)
assert result.shape[1] <= max_length and result.shape[1] >= 1
# Without memory, With past
generated_lm = GeneratedLM(lm, tgt_vocab_size, pad_token_id, [eos_token_id], support_past=True)
tgt = torch.randint(0, tgt_vocab_size-2, (bs, tgt_seq_len))
result = generated_lm.generate(tgt, generate_args)
test_eq(result.shape[0], bs)
assert result.shape[1] <= max_length and result.shape[1] >= tgt_seq_len
# Without memory, Without past
generated_lm = GeneratedLM(lm, tgt_vocab_size, pad_token_id, [eos_token_id], support_past=False)
tgt = torch.randint(0, tgt_vocab_size-2, (bs, tgt_seq_len))
result = generated_lm.generate(tgt, generate_args)
test_eq(result.shape[0], bs)
assert result.shape[1] <= max_length and result.shape[1] >= tgt_seq_len
Test¶
Test that with do_sample=False, GeneratedLM.generate should returns the same result as huggingface's PretrainedModel.generate
# slow
from transformers import AutoModelWithLMHead, AutoTokenizer
gpt2_lm = AutoModelWithLMHead.from_pretrained('distilgpt2')
gpt2_lm.eval()
tokenizer = AutoTokenizer.from_pretrained('distilgpt2')
from fastai_transformers_utils.all import *
# slow
sentence = 'The dog is a'
tgt = torch.tensor([tokenizer.encode(sentence)])
generate_args = GenerateArgs(
max_length=20,
do_sample=False,
num_beams=1,
temperature=1.0,
repetition_penalty=1,
length_penalty=1.0,
)
generated_lm = GeneratedLM(gpt2_lm, tokenizer.vocab_size, gpt2_lm.config.pad_token_id, [gpt2_lm.config.eos_token_ids], True)
numeric_result = generated_lm.generate(tgt, generate_args)
result = tokenizer.decode(list(numeric_result[0]))
huggingface_numeric_result = gpt2_lm.generate(tgt, **asdict(generate_args))
huggingface_result = tokenizer.decode(list(huggingface_numeric_result[0]))
test_eq(result, huggingface_result)