data_dir = Path('../data/tiny_data/anime_heads')

Items

train_items, valid_items = get_items(data_dir)
test_eq(len(train_items), 240)
test_eq(len(valid_items), 60)
train_items[:5]
id cap
293 293 aqua hair black eyes
40 40 aqua hair aqua eyes
84 84 aqua hair aqua eyes
47 47 aqua hair aqua eyes
288 288 aqua hair black eyes

Datasets

class Tokenizer[source]

Tokenizer()

tokenizer = Tokenizer()
ori_cap = 'aqua hair aqua eyes'
tags, tag_len = tokenizer.encode(ori_cap)
test_eq(tags, [3, 18])
test_eq(tag_len, 2)
out_cap = tokenizer.decode(tags)
test_eq(out_cap, ori_cap)
ds = AnimeHeadsDataset(train_items, data_dir)
tag, tag_len, img64 = ds[0]
test_eq(tag.shape, (2,))
test_eq(tag_len.shape, ())
test_eq(img64.shape, (64, 64, 3))

print(tag, tag_len)
plt.imshow(img64)
tensor([ 3, 13]) tensor(2)
<matplotlib.image.AxesImage at 0x7f9287e94748>

class Datasets[source]

Datasets(data_dir, pct=1, valid_pct=0.2)

dsets = Datasets(data_dir)
test_eq(len(dsets.train), 240)
test_eq(len(dsets.valid), 60)

DataLoaders

class DataLoaders[source]

DataLoaders(dsets, bs=64)

dls = DataLoaders(dsets, bs=16)
for tag, tag_len, img in dls.train:
    test_eq(tag.shape, (16, 2))
    test_eq(tag_len.shape, (16,))
    test_eq(img.shape, (16, 64, 64, 3))
    break