animeheads_data_dir = Path('../data/tiny_data/anime_heads')
birds_data_dir = Path('../data/tiny_data/birds')
d = Anime_D(24)
true_imgs = [torch.randn(2, 3, 64, 64), torch.randn(2, 3, 32, 32), torch.randn(2, 3, 16, 16), torch.randn(2, 3, 8, 8), torch.randn(2, 3, 4, 4)]
fake_imgs = [torch.randn(2, 3, 64, 64), torch.randn(2, 3, 32, 32), torch.randn(2, 3, 16, 16), torch.randn(2, 3, 8, 8), torch.randn(2, 3, 4, 4)]
sent_emb = torch.randn(2, 24)
uncond_gp, cond_gp = compute_gradient_penalty(d, true_imgs, fake_imgs, sent_emb)
test_eq(uncond_gp.shape, ())
test_eq(cond_gp.shape, ())
d = Anime_D(24)
true_imgs = [torch.randn(2, 3, 64, 64), torch.randn(2, 3, 32, 32), torch.randn(2, 3, 16, 16), torch.randn(2, 3, 8, 8), torch.randn(2, 3, 4, 4)]
fake_imgs = [torch.randn(2, 3, 64, 64), torch.randn(2, 3, 32, 32), torch.randn(2, 3, 16, 16), torch.randn(2, 3, 8, 8), torch.randn(2, 3, 4, 4)]
sent_emb = torch.randn(2, 24)
loss = compute_d_loss(d, true_imgs, fake_imgs, sent_emb, gp_lambda=5)
test_eq(loss.shape, ())
d_net = Anime_D(24)
fake_imgs = [torch.randn(2, 3, 64, 64), torch.randn(2, 3, 32, 32), torch.randn(2, 3, 16, 16), torch.randn(2, 3, 8, 8), torch.randn(2, 3, 4, 4)]
sent_emb = torch.randn(2, 24)
cnn_code = torch.randn(2, 24)
word_features = torch.randn(2, 24, 17, 17)
word_emb = torch.randn(2, 20, 24)
cap_len = torch.tensor([4, 6])
loss = compute_g_loss(d_net, fake_imgs, sent_emb, cnn_code, word_features, word_emb, cap_len)
test_eq(loss.shape, ())
src = nn.Sequential(nn.Linear(2, 2, bias=False), nn.BatchNorm1d(2))
tgt = nn.Sequential(nn.Linear(2, 2, bias=False), nn.BatchNorm1d(2))
update_average(tgt, src, decay=0)
test_eq(is_models_equal(tgt, src), True)
# Test for after_batch_tfm
trainer = AnimeHeadsTrainer(animeheads_data_dir, 1)
imgs = []
cap, cap_len, img = iter(trainer.dls.train).next()
cap, cap_len, img = to_device([cap, cap_len, img], trainer.device)
cap, cap_len, (img64, img32, img16, img8, img4) = trainer.after_batch_tfm(cap, cap_len, img)
imgs = [img64[0],img32[0], img16[0], img8[0], img4[0]]
test_eq(cap.shape, (1, 2))
test_eq(img4.shape, (1, 3, 4, 4))
test_eq(img8.shape, (1, 3, 8, 8))
test_eq(img16.shape, (1, 3, 16, 16))
test_eq(img32.shape, (1, 3, 32, 32))
test_eq(img64.shape, (1, 3, 64, 64))
_ , axes = plt.subplots(1, 5, figsize=(8, 8))
for i in range(5):
img = trainer.normalizer.decode(imgs[i][None]).permute(0, 2, 3, 1)[0]
axes[i].imshow(img.cpu())
# Test for after_batch_tfm
trainer = BirdsTrainer(birds_data_dir, 1)
imgs = []
cap, cap_len, img = iter(trainer.dls.train).next()
cap, cap_len, img = to_device([cap, cap_len, img], trainer.device)
cap, cap_len, (img256, img128, img64, img32, img16, img8, img4) = trainer.after_batch_tfm(cap, cap_len, img)
imgs = [img256[0], img128[0], img64[0], img32[0], img16[0], img8[0], img4[0]]
test_eq(cap.shape, (1, 25))
test_eq(img4.shape, (1, 3, 4, 4))
test_eq(img8.shape, (1, 3, 8, 8))
test_eq(img16.shape, (1, 3, 16, 16))
test_eq(img32.shape, (1, 3, 32, 32))
test_eq(img64.shape, (1, 3, 64, 64))
test_eq(img128.shape, (1, 3, 128, 128))
test_eq(img256.shape, (1, 3, 256, 256))
_ , axes = plt.subplots(1, 7, figsize=(8, 8))
for i in range(7):
img = trainer.normalizer.decode(imgs[i][None]).permute(0, 2, 3, 1)[0]
axes[i].imshow(img.cpu())
trainer = AnimeHeadsTrainer(animeheads_data_dir, 4)
trainer.show()
# trainer = AnimeHeadsTrainer(animeheads_data_dir, 4, device='cuda')
# trainer.train(50, step_per_epoch=10, n_gradient_acc=2)
# trainer = BirdsTrainer(birds_data_dir, 4, device='cuda')
# trainer.train(4, step_per_epoch=2, n_gradient_acc=2)