exported_model_path = Path('../models/large_birds/gan/gan_export.pt')
tensor_img = torch.tanh(torch.randn(2, 3, 64, 64))
img = decode_img(tensor_img)
# test decoded img is in range 0~255
test_eq((img>=0).long() + (img<=255).long(), torch.ones(2, 64, 64, 3).long()*2 )
model = Birds_Export(tokenizer.vocab_sz, 24, tokenizer.pad_id)
cap = torch.randint(0, 4, (2, 25))
cap_len = torch.tensor([2, 3])
_ = model(cap, cap_len)
attn_w = get_attn_w(model)
test_eq(attn_w.shape, (2, 25, 128, 128))
model = Birds_Export(tokenizer.vocab_sz, 24, tokenizer.pad_id)
# model = Birds_Export.from_pretrained(exported_model_path)
cap = 'a bird with a grey body and yellow feathers on its belly with a longish beak'
img, attn_w = model.predict(cap)
test_eq(img.shape, (256, 256, 3))
test_eq(attn_w.shape, (tokenizer.max_seq_len, 128, 128))
show_pred(img, cap)
show_pred_withattn(img, attn_w, cap)
model = Birds_Export(tokenizer.vocab_sz, 24, tokenizer.pad_id)
# model = Birds_Export.from_pretrained(exported_model_path)
cap = 'a bird with a grey body and yellow feathers on its belly with a longish beak'
model.pred_and_show(cap)