data_dir = Path('../data/tiny_data/anime_heads')
train_items, valid_items = get_items(data_dir)
test_eq(len(train_items), 240)
test_eq(len(valid_items), 60)
train_items[:5]
tokenizer = Tokenizer()
ori_cap = 'aqua hair aqua eyes'
tags, tag_len = tokenizer.encode(ori_cap)
test_eq(tags, [3, 18])
test_eq(tag_len, 2)
out_cap = tokenizer.decode(tags)
test_eq(out_cap, ori_cap)
ds = AnimeHeadsDataset(train_items, data_dir)
tag, tag_len, img64 = ds[0]
test_eq(tag.shape, (2,))
test_eq(tag_len.shape, ())
test_eq(img64.shape, (64, 64, 3))
print(tag, tag_len)
plt.imshow(img64)
dsets = Datasets(data_dir)
test_eq(len(dsets.train), 240)
test_eq(len(dsets.valid), 60)
dls = DataLoaders(dsets, bs=16)
for tag, tag_len, img in dls.train:
test_eq(tag.shape, (16, 2))
test_eq(tag_len.shape, (16,))
test_eq(img.shape, (16, 64, 64, 3))
break