Language Model or Decoder with Generate function

The default version of TensorFlow in Colab will soon switch to TensorFlow 2.x.
We recommend you upgrade now or ensure your notebook will continue to use TensorFlow 1.x via the %tensorflow_version 1.x magic: more info.

FakeDecoder, FakeLM for Test

class FakeDecoder(nn.Module):
    ''' with memory not support past'''
    def __init__(self, tgt_vocab_size):
        super().__init__()
        self.tgt_vocab_size = tgt_vocab_size
    def forward(self, tgt, memory, **kwargs):
        ''' 
        inputs: (tgt, memory)
            tgt: (b, tgt_seq_len)
            memory: (b, src_seq_len, embed_dim)
        returns: logits, others
            logits: (b, tgt_seq_len, tgt_vocab_size)
        '''
        assert tgt.shape[0] == memory.shape[0], (tgt.shape[0], memory.shape[0])
        logits = torch.randn((*tgt.shape, self.tgt_vocab_size))
        return logits, None
    
class FakeLM(nn.Module):
    ''' without memory support past '''
    def __init__(self, tgt_vocab_size):
        super().__init__()
        self.tgt_vocab_size = tgt_vocab_size
    def forward(self, tgt, past=None, **kwargs):
        ''' 
        if past==None:
            inputs: (tgt)
                tgt: (b, tgt_seq_len)
            returns: logits, presents, others
                logits: (b, tgt_seq_len, tgt_vocab_size)
                presents: List of (2, b, ...)
        else:
            inputs: (tgt, past)
                tgt: (b, 1)   
                past: List of (2, bs, num_heads, tgt_seq_len-1, ..)
            returns: logits, presents, others
                logits: (b, tgt_seq_len, tgt_vocab_size)
                presents: List of (2, bs, num_heads, tgt_seq_len, ..)
        '''
        if past is None:
            b = tgt.shape[0]
            tgt_seq_len = tgt.shape[1]
            logits = torch.randn((b, tgt_seq_len, self.tgt_vocab_size))
            presents = [torch.randn((2, b, 12, tgt_seq_len, 16))] * 6
        else:
            b = tgt.shape[0]
            tgt_seq_len = past[0].shape[3]+1
            logits = torch.randn((b, tgt_seq_len, self.tgt_vocab_size))
            presents = [torch.randn((2, b, 12, tgt_seq_len, 16))] * 6
        return logits, presents, None
bs = 3
tgt_seq_len = 10
tgt_vocab_size = 20
memory = torch.randn((bs, 9, 9))
past = [torch.randn((2, bs, 12, tgt_seq_len-1, 16))] * 6
pad_token_id=0
eos_token_id=tgt_vocab_size-1
bos_token_id=tgt_vocab_size-2
decoder = FakeDecoder(tgt_vocab_size)
lm = FakeLM(tgt_vocab_size)

tgt = torch.randint(0, tgt_vocab_size-2, (bs, tgt_seq_len))
test_eq(decoder(tgt, memory)[0].shape, (bs, tgt_seq_len, tgt_vocab_size))
test_eq(lm(tgt)[0].shape, (bs, tgt_seq_len, tgt_vocab_size))
test_eq(lm(tgt)[1][0].shape, (2, bs, 12, tgt_seq_len, 16))

tgt = torch.randint(0, tgt_vocab_size-2, (bs, tgt_seq_len))
test_eq(lm(tgt, past)[0].shape, (bs, tgt_seq_len, tgt_vocab_size))
test_eq(lm(tgt, past)[1][0].shape, (2, bs, 12, tgt_seq_len, 16))

GeneratedLM

class GeneratedLM[source]

GeneratedLM(lm, vocab_size, pad_token_id, eos_token_ids, support_past=False)

max_length=20
generate_args = dict(   
    max_length=max_length,
    do_sample=True,
    temperature=0.1,
    top_k=3,
    top_p=0.5,
    repetition_penalty=1.2,
)

# With memory, Without past
generated_decoder = GeneratedLM(decoder, tgt_vocab_size, pad_token_id, [eos_token_id], support_past=False)
tgt = torch.zeros((bs, 1), dtype=torch.long).fill_(bos_token_id)
result = generated_decoder._generate_no_beam_search(tgt, **generate_args, model_otherargs=[memory])
test_eq(result.shape[0], bs)
assert result.shape[1] <= max_length and result.shape[1] >= 1

# Without memory, With past
generated_lm = GeneratedLM(lm, tgt_vocab_size, pad_token_id, [eos_token_id], support_past=True)
tgt = torch.randint(0, tgt_vocab_size-2, (bs, tgt_seq_len))
result = generated_lm._generate_no_beam_search(tgt, **generate_args)
test_eq(result.shape[0], bs)
assert result.shape[1] <= max_length and result.shape[1] >= tgt_seq_len

# Without memory, Without past
generated_lm = GeneratedLM(lm, tgt_vocab_size, pad_token_id, [eos_token_id], support_past=False)
tgt = torch.randint(0, tgt_vocab_size-2, (bs, tgt_seq_len))
result = generated_lm._generate_no_beam_search(tgt, **generate_args)
test_eq(result.shape[0], bs)
assert result.shape[1] <= max_length and result.shape[1] >= tgt_seq_len

build_model_otherargs_for_beam

GeneratedLM.build_model_otherargs_for_beam[source]

GeneratedLM.build_model_otherargs_for_beam(model_otherargs, num_beams)

model_otherargs: List of tensor with shape (bs, ...) returns list of expanded args with shape (bs*num_beams, ...)

model_otherargs = [torch.tensor([[1, 2, 3], 
                                 [4, 5, 6]])]
expected = [torch.tensor([[1, 2, 3],
                          [1, 2, 3],
                          [4, 5, 6],
                          [4, 5, 6]])]
result = build_model_otherargs_for_beam(None, model_otherargs, 2)
test_eq(result, expected)
max_length=20
generate_args = dict(   
    max_length=max_length,
    do_sample=True,
    temperature=0.1,
    top_k=3,
    top_p=0.5,
    repetition_penalty=1.2,
    length_penalty=1,
    num_beams=4,
    vocab_size=tgt_vocab_size,
)

# With memory, Without past
generated_decoder = GeneratedLM(decoder, tgt_vocab_size, pad_token_id, [eos_token_id], support_past=False)
tgt = torch.zeros((bs, 1), dtype=torch.long).fill_(bos_token_id)
model_otherargs = generated_decoder.build_model_otherargs_for_beam([memory], generate_args['num_beams'])
result = generated_decoder._generate_beam_search(tgt, **generate_args, model_otherargs=model_otherargs)
test_eq(result.shape[0], bs)
assert result.shape[1] <= max_length and result.shape[1] >= 1

# Without memory, With past
generated_lm = GeneratedLM(lm, tgt_vocab_size, pad_token_id, [eos_token_id], support_past=True)
tgt = torch.randint(0, tgt_vocab_size-2, (bs, tgt_seq_len))
result = generated_lm._generate_beam_search(tgt, **generate_args)
test_eq(result.shape[0], bs)
assert result.shape[1] <= max_length and result.shape[1] >= tgt_seq_len

# Without memory, Without past
generated_lm = GeneratedLM(lm, tgt_vocab_size, pad_token_id, [eos_token_id], support_past=False)
tgt = torch.randint(0, tgt_vocab_size-2, (bs, tgt_seq_len))
result = generated_lm._generate_beam_search(tgt, **generate_args)
test_eq(result.shape[0], bs)
assert result.shape[1] <= max_length and result.shape[1] >= tgt_seq_len

GenerateArgs

class GenerateArgs[source]

GenerateArgs(max_length:int=20, do_sample:bool=False, num_beams:int=1, temperature:float=1.0, top_k:int=1, top_p:float=1.0, repetition_penalty:float=1.0, length_penalty:float=1.0)

GenerateArgs(max_length:int=20, do_sample:bool=False, num_beams:int=1, temperature:float=1.0, top_k:int=1, top_p:float=1.0, repetition_penalty:float=1.0, length_penalty:float=1.0)

generate

generate[source]

generate(tgt, generate_args:GenerateArgs=GenerateArgs(max_length=20, do_sample=False, num_beams=1, temperature=1.0, top_k=1, top_p=1.0, repetition_penalty=1.0, length_penalty=1.0), model_otherargs=[], model_otherkwargs={})

tgt: (b, tgt_seq_len) model_otherargs: Other positional args that your model need. Maybe momory from encoder. model_otherkwargs: Other keyword args that your model need. Maybe some masks. returns: (b, (tgt_seq_len <= ? <= max_length))

generate_args = GenerateArgs( 
    do_sample=True, 
    num_beams=1, 
)

generated_decoder = GeneratedLM(decoder, tgt_vocab_size, pad_token_id, [eos_token_id], support_past=False)
tgt = torch.zeros((bs, 1), dtype=torch.long).fill_(bos_token_id)
result = generated_decoder.generate(tgt, generate_args, model_otherargs=[memory])
test_eq(result.shape[0], bs)
assert result.shape[1] <= max_length and result.shape[1] >= 1

generated_lm = GeneratedLM(lm, tgt_vocab_size, pad_token_id, [eos_token_id], support_past=True)
tgt = torch.randint(0, tgt_vocab_size-2, (bs, tgt_seq_len))
result = generated_lm.generate(tgt, generate_args)
test_eq(result.shape[0], bs)
assert result.shape[1] <= max_length and result.shape[1] >= tgt_seq_len

generated_lm = GeneratedLM(lm, tgt_vocab_size, pad_token_id, [eos_token_id], support_past=False)
tgt = torch.randint(0, tgt_vocab_size-2, (bs, tgt_seq_len))
result = generated_lm.generate(tgt, generate_args)
test_eq(result.shape[0], bs)
assert result.shape[1] <= max_length and result.shape[1] >= tgt_seq_len
generate_args = GenerateArgs(
    do_sample=True, 
    num_beams=2, 
)

# With memory, Without past
generated_decoder = GeneratedLM(decoder, tgt_vocab_size, pad_token_id, [eos_token_id], support_past=False)
tgt = torch.zeros((bs, 1), dtype=torch.long).fill_(bos_token_id)
model_otherargs = generated_decoder.build_model_otherargs_for_beam([memory], generate_args.num_beams)
result = generated_decoder.generate(tgt, generate_args, model_otherargs=model_otherargs)
test_eq(result.shape[0], bs)
assert result.shape[1] <= max_length and result.shape[1] >= 1


# Without memory, With past
generated_lm = GeneratedLM(lm, tgt_vocab_size, pad_token_id, [eos_token_id], support_past=True)
tgt = torch.randint(0, tgt_vocab_size-2, (bs, tgt_seq_len))
result = generated_lm.generate(tgt, generate_args)
test_eq(result.shape[0], bs)
assert result.shape[1] <= max_length and result.shape[1] >= tgt_seq_len


# Without memory, Without past
generated_lm = GeneratedLM(lm, tgt_vocab_size, pad_token_id, [eos_token_id], support_past=False)
tgt = torch.randint(0, tgt_vocab_size-2, (bs, tgt_seq_len))
result = generated_lm.generate(tgt, generate_args)
test_eq(result.shape[0], bs)
assert result.shape[1] <= max_length and result.shape[1] >= tgt_seq_len

Test

Test that with do_sample=False, GeneratedLM.generate should returns the same result as huggingface's PretrainedModel.generate

# slow
from transformers import AutoModelWithLMHead, AutoTokenizer
gpt2_lm = AutoModelWithLMHead.from_pretrained('distilgpt2')
gpt2_lm.eval()
tokenizer = AutoTokenizer.from_pretrained('distilgpt2')
RobertaTokenizerFast has an issue when working on mask language modeling where it introduces an extra encoded space before the mask token.See https://github.com/huggingface/transformers/pull/2778 for more information.
from fastai_transformers_utils.all import *
# slow
sentence = 'The dog is a'
tgt = torch.tensor([tokenizer.encode(sentence)])

generate_args = GenerateArgs(   
    max_length=20,
    do_sample=False,
    num_beams=1,
    temperature=1.0,
    repetition_penalty=1,
    length_penalty=1.0,
)
generated_lm = GeneratedLM(gpt2_lm, tokenizer.vocab_size, gpt2_lm.config.pad_token_id, [gpt2_lm.config.eos_token_ids], True)
numeric_result = generated_lm.generate(tgt, generate_args)
result = tokenizer.decode(list(numeric_result[0]))

huggingface_numeric_result = gpt2_lm.generate(tgt, **asdict(generate_args))
huggingface_result = tokenizer.decode(list(huggingface_numeric_result[0]))

test_eq(result, huggingface_result)