Some useful utils to extend pytorch functions
t = torch.tensor([[12, 11, 0, 0],
[9, 1, 5, 0]])
mask = isin(t, [0, 1])
test_eq(mask, torch.tensor([[0, 0, 1, 1],
[0, 1, 0, 1]]).bool())
cap_len = torch.tensor([2, 1, 3])
max_seq_len = 5
src_mask = get_src_mask(cap_len, max_seq_len)
test_eq(src_mask, torch.tensor([[False, False, True, True, True],
[False, True, True, True, True],
[False, False, False, True, True]]))
normalizer = Normalizer()
img = torch.randint(0, 255, (2, 3, 16, 16))
img_encoded = normalizer.encode(img)
img_decoded = normalizer.decode(img_encoded)
test_close(img, img_decoded, eps=2)
# test encoded img is in range -1~1
test_eq((img_encoded>=-1).long() + (img_encoded<=1).long(), torch.ones(2, 3, 16, 16).long()*2 )
# test decoded img is in range 0~255
test_eq((img_decoded>=0).long() + (img_decoded<=255).long(), torch.ones(2, 3, 16, 16).long()*2 )
noise = noise_gen.sample((2, 100))
test_eq(noise.shape, (2, 100))