animeheads_data_dir = Path('../data/tiny_data/anime_heads')
birds_data_dir = Path('../data/tiny_data/birds')

Loss

compute_sent_loss[source]

compute_sent_loss(cnn_code, sent_emb, gamma3=10.0, eps=1e-08)

cnn_code: (bs, emb_sz), sent_emb: (bs, emb_sz)

cnn_code = torch.randn(2, 24)
sent_emb = torch.randn(2, 24)
sent_loss_0, sent_loss_1 = compute_sent_loss(cnn_code, sent_emb)
sent_loss_0, sent_loss_1
(tensor(1.7161), tensor(1.2727))

compute_word_loss[source]

compute_word_loss(img_features, words_emb, cap_len, gamma1=4.0, gamma2=5.0, gamma3=10.0)

img_features(context): batch x emb_sz x 17 x 17 words_emb(query): batch x seq_len x emb_sz

word_features = torch.randn(2, 24, 17, 17)
word_emb = torch.randn(2, 20, 24)
cap_len = torch.tensor([4, 6])
word_loss_0, word_loss_1, attn_maps = compute_word_loss(word_features, word_emb, cap_len, gamma2=5.0, gamma3=10.0)
word_loss_0, word_loss_1, attn_maps[0].shape, attn_maps[1].shape
(tensor(1.9601),
 tensor(1.3081),
 torch.Size([4, 17, 17]),
 torch.Size([6, 17, 17]))

AnimeHeadsTrainer

class AnimeHeadsTrainer[source]

AnimeHeadsTrainer(data_dir, bs, data_pct=1, lr=0.003, lr_decay=0.98, device='cpu', emb_sz=24, rnn_layers=2, rnn_drop_p=0.5, gamma1=4.0, gamma2=5.0, gamma3=10.0)

trainer = AnimeHeadsTrainer(animeheads_data_dir, 2)

cap, cap_len, img = iter(trainer.dls.train).next()
cap, cap_len, img = to_device([cap, cap_len, img], trainer.device)
cap, cap_len, img = trainer.after_batch_tfm(cap, cap_len, img)
test_eq(cap.shape, (2, trainer.max_seq_len))
test_eq(cap_len.shape, (2,))
test_eq(img.shape, (2, 3, 229, 229))

_ , axes = plt.subplots(1, 2, figsize=(8, 8))
for i in range(2):
    tmp = trainer.normalizer.decode(img[i][None]).permute(0, 2, 3, 1)[0]
    axes[i].imshow(tmp.cpu())

class BirdsTrainer[source]

BirdsTrainer(data_dir, bs, data_pct=1, lr=0.003, lr_decay=0.98, device='cpu', emb_sz=256, rnn_layers=2, rnn_drop_p=0.5, gamma1=4.0, gamma2=5.0, gamma3=10.0)

trainer = BirdsTrainer(birds_data_dir, 2)

cap, cap_len, img = iter(trainer.dls.train).next()
cap, cap_len, img = to_device([cap, cap_len, img], trainer.device)
cap, cap_len, img = trainer.after_batch_tfm(cap, cap_len, img)
test_eq(cap.shape, (2, trainer.max_seq_len))
test_eq(cap_len.shape, (2,))
test_eq(img.shape, (2, 3, 229, 229))

_ , axes = plt.subplots(1, 2, figsize=(8, 8))
for i in range(2):
    tmp = trainer.normalizer.decode(img[i][None]).permute(0, 2, 3, 1)[0]
    axes[i].imshow(tmp.cpu())

Patch

trainer = AnimeHeadsTrainer(animeheads_data_dir, 1)
trainer.show()
trainer = BirdsTrainer(birds_data_dir, 1)
trainer.show()
# trainer = AnimeHeadsTrainer(animeheads_data_dir, 4, device='cuda')
# trainer.train(8, step_per_epoch=2)
1, time: 1.1s, loss: 11.2442
2, time: 1.1s, loss: 7.8520
3, time: 1.1s, loss: 10.0829
4, time: 0.7s, loss: 7.6322

total_time: 0.1min
# trainer = BirdsTrainer(birds_data_dir, 4, device='cuda')
# trainer.train(4, step_per_epoch=2)
1, time: 1.7s, loss: 7.8086
2, time: 1.3s, loss: 9.5855

total_time: 0.1min