from pathlib import Path
from deep_t2i.trainer_DAMSM import AnimeHeadsTrainer
data_dir = Path('/root/data/anime_heads')
model_dir = Path('../models/large_anime_heads/damsm')
data_dir, model_dir
(Path('/root/data/anime_heads'), Path('../models/large_anime_heads/damsm'))

Train

bs = 48
data_pct = 1
lr = 2e-3
lr_decay = 0.98
emb_sz = 10
rnn_layers = 2
rnn_drop_p = 0.5
gamma1 = 4.0
gamma2 = 5.0
gamma3 = 10.0
trainer = AnimeHeadsTrainer(
    data_dir, 
    bs=bs, 
    data_pct=data_pct,
    lr=lr, 
    lr_decay=lr_decay,
    device='cuda', 
    emb_sz=emb_sz,
    rnn_layers=rnn_layers,
    rnn_drop_p=rnn_drop_p,
    gamma1=gamma1,
    gamma2=gamma2,
    gamma3=gamma3,
)
step_per_epoch = len(trainer.dls.train)
step_per_epoch
765
trainer.train(
    step_per_epoch*30, 
    step_per_epoch=step_per_epoch,
    saveck_every=step_per_epoch*10,
    ck_path=str(model_dir/'0'),
)
1, time: 165.1s, loss: 12.0607
2, time: 165.2s, loss: 9.8466
3, time: 165.2s, loss: 9.2070
4, time: 165.1s, loss: 8.9492
5, time: 165.3s, loss: 8.7227
6, time: 165.2s, loss: 8.5763
7, time: 165.1s, loss: 8.4551
8, time: 165.3s, loss: 8.3147
9, time: 165.3s, loss: 8.2661
10, time: 165.6s, loss: 8.1435
11, time: 165.6s, loss: 8.1110
12, time: 165.2s, loss: 8.0559
13, time: 166.0s, loss: 8.0147
14, time: 165.5s, loss: 7.9487
15, time: 165.3s, loss: 7.8928
16, time: 165.5s, loss: 7.8572
17, time: 165.5s, loss: 7.8176
18, time: 165.7s, loss: 7.8143
19, time: 165.8s, loss: 7.7685
20, time: 166.2s, loss: 7.7427
21, time: 165.2s, loss: 7.7462
22, time: 165.4s, loss: 7.6875
23, time: 165.2s, loss: 7.6961
24, time: 165.1s, loss: 7.6605
25, time: 165.3s, loss: 7.6083
26, time: 165.2s, loss: 7.5887
27, time: 165.4s, loss: 7.5522
28, time: 165.1s, loss: 7.5350
29, time: 165.3s, loss: 7.5247
30, time: 165.1s, loss: 7.5056

total_time: 82.7min
trainer.load_checkpoint(model_dir/'0-3.pt')
trainer.show()
trainer.export(model_dir/'damsm_export.pt')
# from deep_t2i.model import get_pretrained_DAMSM
# rnn_encoder, cnn_encoder, gamma1, gamma2, gamma3 = get_pretrained_DAMSM(model_dir/'damsm_export.pt', device='cuda')
# trainer.rnn_encoder = rnn_encoder
# trainer.cnn_encoder = cnn_encoder