animeheads_data_dir = Path('../data/tiny_data/anime_heads')
birds_data_dir = Path('../data/tiny_data/birds')
Code from https://github.com/taoxugit/AttnGAN
cnn_code = torch.randn(2, 24)
sent_emb = torch.randn(2, 24)
sent_loss_0, sent_loss_1 = compute_sent_loss(cnn_code, sent_emb)
sent_loss_0, sent_loss_1
word_features = torch.randn(2, 24, 17, 17)
word_emb = torch.randn(2, 20, 24)
cap_len = torch.tensor([4, 6])
word_loss_0, word_loss_1, attn_maps = compute_word_loss(word_features, word_emb, cap_len, gamma2=5.0, gamma3=10.0)
word_loss_0, word_loss_1, attn_maps[0].shape, attn_maps[1].shape
trainer = AnimeHeadsTrainer(animeheads_data_dir, 2)
cap, cap_len, img = iter(trainer.dls.train).next()
cap, cap_len, img = to_device([cap, cap_len, img], trainer.device)
cap, cap_len, img = trainer.after_batch_tfm(cap, cap_len, img)
test_eq(cap.shape, (2, trainer.max_seq_len))
test_eq(cap_len.shape, (2,))
test_eq(img.shape, (2, 3, 229, 229))
_ , axes = plt.subplots(1, 2, figsize=(8, 8))
for i in range(2):
tmp = trainer.normalizer.decode(img[i][None]).permute(0, 2, 3, 1)[0]
axes[i].imshow(tmp.cpu())
trainer = BirdsTrainer(birds_data_dir, 2)
cap, cap_len, img = iter(trainer.dls.train).next()
cap, cap_len, img = to_device([cap, cap_len, img], trainer.device)
cap, cap_len, img = trainer.after_batch_tfm(cap, cap_len, img)
test_eq(cap.shape, (2, trainer.max_seq_len))
test_eq(cap_len.shape, (2,))
test_eq(img.shape, (2, 3, 229, 229))
_ , axes = plt.subplots(1, 2, figsize=(8, 8))
for i in range(2):
tmp = trainer.normalizer.decode(img[i][None]).permute(0, 2, 3, 1)[0]
axes[i].imshow(tmp.cpu())
trainer = AnimeHeadsTrainer(animeheads_data_dir, 1)
trainer.show()
trainer = BirdsTrainer(birds_data_dir, 1)
trainer.show()
# trainer = AnimeHeadsTrainer(animeheads_data_dir, 4, device='cuda')
# trainer.train(8, step_per_epoch=2)
# trainer = BirdsTrainer(birds_data_dir, 4, device='cuda')
# trainer.train(4, step_per_epoch=2)