encoder = RnnEncoder(vocab_sz=111, emb_sz=400, pad_id=1)
inp_ids = torch.randint(0, 100, (16, 20))
true_seqlen = torch.randint(1, 20, (16,))
sent_emb, word_emb = encoder(inp_ids, true_seqlen)
test_eq(sent_emb.shape, (16, 400))
test_eq(word_emb.shape, (16, 20, 400))
cnn_encoder = CnnEncoder(256)
x = torch.randn(2, 3, 229, 229)
features, cnn_code = cnn_encoder(x)
test_eq(features.shape, (2, 256, 17, 17))
test_eq(cnn_code.shape, (2, 256))
conv = spectral_conv2d(64, 32, 3, 1, 1)
inp = torch.randn(2, 64, 16, 16)
out = conv(inp)
test_eq(out.shape, (2, 32, 16, 16))
conv = conv_block_g(64, 32, 3, 1, 1)
inp = torch.randn(2, 64, 16, 16)
out = conv(inp)
test_eq(out.shape, (2, 32, 16, 16))
conv = conv_block_d(64, 32, 3, 1, 1)
inp = torch.randn(2, 64, 16, 16)
out = conv(inp)
test_eq(out.shape, (2, 32, 16, 16))
up = up_block()
inp = torch.randn(2, 3, 16, 16)
out = up(inp)
test_eq(out.shape, (2, 3, 32, 32))
down = down_block()
inp = torch.randn(2, 3, 32, 32)
out = down(inp)
test_eq(out.shape, (2, 3, 16, 16))
to_rbg = to_rgb_block(64)
inp = torch.randn(2, 64, 8, 8)
out = to_rbg(inp)
test_eq(out.shape, (2, 3, 8, 8))
from_rbg = from_rgb_block(64)
inp = torch.randn(2, 3, 8, 8)
out = from_rbg(inp)
test_eq(out.shape, (2, 64, 8, 8))
self_attn = SelfAttention(16)
x = torch.randn(2, 16, 4, 4)
out = self_attn(x)
test_eq(out.shape, (2, 16, 4, 4))
# # exporti
# def simple_attn2(tgt, src, src_mask=None):
# ''' tgt: (bs, tgt_seq_len, emb_sz), src: (bs, src_seq_len, emb_sz), src_mask: (bs, src_seq_len) [True will be masked]
# returns: (bs, tgt_seq_len, emb_sz), (bs, tgt_seq_len, src_seq_len) '''
# bs, tgt_seq_len, emb_sz = tgt.shape
# _, src_seq_len, _ = src.shape
# src = src.permute(0, 2, 1) # (bs, emb_sz, src_seq_len)
# attn_w = torch.bmm(tgt, src) # (bs, tgt_seq_len, src_seq_len)
# attn_w = attn_w.view(bs*tgt_seq_len, src_seq_len) # (bs*tgt_seq_len, src_seq_len)
# if src_mask is not None:
# mask = src_mask.repeat(tgt_seq_len, 1) # (bs*tgt_seq_len, src_seq_len)
# attn_w.masked_fill_(mask, -float('inf'))
# attn_w = nn.functional.softmax(attn_w, dim=1)
# attn_w = attn_w.view(bs, tgt_seq_len, src_seq_len)
# attn_w = torch.transpose(attn_w, 1, 2).contiguous() # (bs, src_seq_len, tgt_seq_len)
# attn_out = torch.bmm(src, attn_w) # (bs, emb_sz, tgt_seq_len)
# attn_out = attn_out.permute(0, 2, 1) # (bs, tgt_seq_len, emb_sz)
# attn_w = attn_w.permute(0, 2, 1) # (bs, tgt_seq_len, src_seq_len)
# return attn_out, attn_w
tgt = torch.randn(2, 5, 3)
src = torch.randn(2, 4, 3)
src_mask = torch.tensor([[True, False, False, False],
[True, True, True, False]])
attn_out, attn_w = simple_attn(tgt, src, src_mask)
test_eq(attn_out.shape, (2, 5, 3))
test_eq(attn_w.shape, (2, 5, 4))
attn_block = AttnBlock(600, 256)
x = torch.randn(16, 256, 16, 16)
word_emb = torch.randn(16, 20, 600)
src_mask = torch.ones(16, 20).bool()
out = attn_block(x, word_emb, src_mask)
test_eq(out.shape, (16, 256, 16, 16))
test_eq(attn_block.attn_w.shape, (16, 16, 16, 20))
g_init = G_Init(25, 100)
sent_emb = torch.randn(2, 25)
noise = torch.randn(2, 100)
img, code = g_init(sent_emb, noise)
test_eq(img.shape, (2, 3, 4, 4))
test_eq(code.shape, (2, CH_TABLE[0], 4, 4))
g_general = G_General(512, 256, is_self_attn=True)
code = torch.randn(2, 512, 4, 4)
img, code = g_general(code)
test_eq(img.shape, (2, 3, 8, 8))
test_eq(code.shape, (2, 256, 8, 8))
g_general_attn = G_General_Attn(25, 512, 256, is_self_attn=True)
code = torch.randn(2, 512, 4, 4)
word_emb = torch.randn(2, 2, 25)
src_mask = torch.ones(2, 2).bool()
img, code = g_general_attn(code, word_emb, src_mask)
test_eq(img.shape, (2, 3, 8, 8))
test_eq(code.shape, (2, 256, 8, 8))
anime_g = Anime_G(25, 100)
sent_emb = torch.randn(2, 25)
noise = torch.randn(2, 100)
word_emb = torch.randn(2, 2, 25)
src_mask = torch.ones(2, 2).bool()
imgs = anime_g(sent_emb, noise, word_emb, src_mask)
test_eq([img.shape for img in imgs],
[torch.Size([2, 3, 4, 4]),
torch.Size([2, 3, 8, 8]),
torch.Size([2, 3, 16, 16]),
torch.Size([2, 3, 32, 32]),
torch.Size([2, 3, 64, 64])])
birds_g = Birds_G(25, 100)
sent_emb = torch.randn(2, 25)
noise = torch.randn(2, 100)
word_emb = torch.randn(2, 2, 25)
src_mask = torch.ones(2, 2).bool()
imgs = birds_g(sent_emb, noise, word_emb, src_mask)
test_eq([img.shape for img in imgs],
[torch.Size([2, 3, 4, 4]),
torch.Size([2, 3, 8, 8]),
torch.Size([2, 3, 16, 16]),
torch.Size([2, 3, 32, 32]),
torch.Size([2, 3, 64, 64]),
torch.Size([2, 3, 128, 128]),
torch.Size([2, 3, 256, 256])])
sent_code = torch.randn(2, CH_TABLE[0]+1, 4, 4)
uncond_cls = UncondCls()
uncond_logit = uncond_cls(sent_code)
test_eq(uncond_logit.shape, (2, 1))
sent_code = torch.randn(2, CH_TABLE[0]+1, 4, 4)
sent_emb = torch.randn(2, 25)
cond_cls = CondCls(25)
cond_logit = cond_cls(sent_code, sent_emb)
test_eq(cond_logit.shape, (2, 1))
CH_TABLE
d_general = D_General(16, 32, is_self_attn=True)
img = torch.randn(2, 3, 8, 8)
code = torch.randn(2, 16, 8, 8)
code = d_general(img, code)
test_eq(code.shape, (2, 32, 4, 4))
anime_d = Anime_D(25)
imgs =[torch.randn(2, 3, 64, 64), torch.randn(2, 3, 32, 32), torch.randn(2, 3, 16, 16), torch.randn(2, 3, 8, 8), torch.randn(2, 3, 4, 4)]
sent_emb = torch.randn(2, 25)
sent_code = anime_d.get_sent_code(imgs)
test_eq(sent_code.shape, (2, CH_TABLE[0]+1, 4, 4))
uncond_logit, cond_logit = anime_d(imgs, sent_emb)
test_eq(uncond_logit.shape, (2, 1))
test_eq(cond_logit.shape, (2, 1))
birds_d = Birds_D(25)
imgs =[torch.randn(2, 3, 256, 256), torch.randn(2, 3, 128, 128), torch.randn(2, 3, 64, 64), torch.randn(2, 3, 32, 32), torch.randn(2, 3, 16, 16), torch.randn(2, 3, 8, 8), torch.randn(2, 3, 4, 4)]
sent_emb = torch.randn(2, 25)
sent_code = birds_d.get_sent_code(imgs)
test_eq(sent_code.shape, (2, CH_TABLE[0]+1, 4, 4))
uncond_logit, cond_logit = birds_d(imgs, sent_emb)
test_eq(uncond_logit.shape, (2, 1))
test_eq(cond_logit.shape, (2, 1))
anime_export = Anime_Export(5, 24, 0)
inp_ids = torch.tensor([[1, 2],
[3, 4]])
cap_len = torch.tensor([2, 2])
img = anime_export(inp_ids, cap_len)
test_eq(img.shape, (2, 3, 64, 64))
# anime_export = Anime_Export(5, 24, 0)
# inp_ids = torch.tensor([[1, 2],
# [3, 4]])
# cap_len = torch.tensor([2, 2])
# img = anime_export.small_forward(inp_ids, cap_len)
# test_eq(img.shape, (anime_export.samples[0].shape[0]+2, 3, 64, 64))
birds_export = Birds_Export(5, 24, 0)
inp_ids = torch.randint(0, 4, (2, 25))
cap_len = torch.tensor([2, 2])
img = birds_export(inp_ids, cap_len)
test_eq(img.shape, (2, 3, 256, 256))
# birds_export = Birds_Export(5, 24, 0)
# inp_ids = torch.randint(0, 4, (2, 25))
# cap_len = torch.tensor([2, 2])
# img = birds_export.small_forward(inp_ids, cap_len)
# test_eq(img.shape, (anime_export.samples[0].shape[0]+2, 3, 256, 256))