exported_model_path = Path('../models/large_anime_heads/gan/gan_export.pt')
tensor_img = torch.tanh(torch.randn(2, 3, 64, 64))
img = decode_img(tensor_img)
# test decoded img is in range 0~255
test_eq((img>=0).long() + (img<=255).long(), torch.ones(2, 64, 64, 3).long()*2 )
model = Anime_Export(len(tokenizer.vocab), 24, tokenizer.pad_id)
cap = torch.randint(0, 10, (2, 20))
cap_len = torch.tensor([2, 2])
_ = model(cap, cap_len)
attn_w = get_attn_w(model)
test_eq(attn_w.shape, (2, 20, 32, 32))

Anime_Export.predict[source]

Anime_Export.predict(cap)

cap: 'white hair yellow eyes' returns: img: (64, 64, 3), attn_w: (2, 64, 64)

model = Anime_Export(len(tokenizer.vocab), 24, tokenizer.pad_id)
# model = Anime_Export.from_pretrained(exported_model_path)
cap = 'white hair yellow eyes'
img, attn_w = model.predict(cap)
test_eq(img.shape, (64, 64, 3))
test_eq(attn_w.shape, (2, 32, 32))
show_pred(img, cap)
show_pred_withattn(img, attn_w, cap)

Anime_Export.pred_and_show[source]

Anime_Export.pred_and_show(cap, with_attn=False)

cap: 'white hair yellow eyes'

model = Anime_Export(len(tokenizer.vocab), 24, tokenizer.pad_id)
# model = Anime_Export.from_pretrained(exported_model_path)
cap = 'blonde hair purple eyes'
model.pred_and_show(cap)