from pathlib import Path
from deep_t2i.trainer_GAN import AnimeHeadsTrainer

from deep_t2i.model import Anime_Export
from deep_t2i.inference_anime_heads import pred_and_show
data_dir = Path('/root/data/anime_heads')
model_dir = Path('../models/large_anime_heads/gan')
result_dir = Path('../result_jpgs/large_anime_heads/gan')
pretrained_damsm_path = Path('../models/large_anime_heads/damsm/damsm_export.pt')
data_dir, model_dir, result_dir, pretrained_damsm_path
(Path('/root/data/anime_heads'),
 Path('../models/large_anime_heads/gan'),
 Path('../result_jpgs/large_anime_heads/gan'),
 Path('../models/large_anime_heads/damsm/damsm_export.pt'))

Train

bs = 48
data_pct = 1
g_lr = 2e-4
d_lr = 2e-4

smooth_lambda = 2.0
noise_sz = 512
trainer = AnimeHeadsTrainer(
    data_dir, 
    bs=bs, 
    data_pct=data_pct,
    g_lr=g_lr, 
    d_lr=d_lr,
    device='cuda', 
    pretrained_damsm_path=pretrained_damsm_path,
    smooth_lambda=smooth_lambda,
    noise_sz=noise_sz,
)
len(trainer.dls.train)
765
ema_decay = 0.999
n_gradient_acc = 1
step_per_epoch = len(trainer.dls.train) // (n_gradient_acc*1)
step_per_epoch
765
trainer.train(
    n_step=step_per_epoch*18, 
    step_per_epoch=step_per_epoch, 
    savejpg_every=step_per_epoch, 
    jpg_path=str(result_dir/'0'), 
    is_jpg_ema=False,
    saveck_every=step_per_epoch*6, 
    ck_path=str(model_dir/'0'), 
    n_gradient_acc=n_gradient_acc,
    ema_decay=ema_decay, 
)
1, time: 737.1s, g_loss: 14.8828, d_loss: 1.9382
2, time: 737.3s, g_loss: 12.0587, d_loss: 1.6061
3, time: 737.3s, g_loss: 12.6486, d_loss: 1.3755
4, time: 737.8s, g_loss: 12.4419, d_loss: 1.3225
5, time: 736.8s, g_loss: 12.6360, d_loss: 1.2616
6, time: 739.6s, g_loss: 12.0522, d_loss: 1.3167
7, time: 738.3s, g_loss: 12.1108, d_loss: 1.2779
8, time: 738.7s, g_loss: 12.2648, d_loss: 1.1812
9, time: 738.4s, g_loss: 12.0356, d_loss: 1.2146
10, time: 738.2s, g_loss: 12.4260, d_loss: 1.0909
11, time: 738.1s, g_loss: 12.3649, d_loss: 1.1146
12, time: 741.6s, g_loss: 12.4896, d_loss: 1.0713
13, time: 738.2s, g_loss: 12.6237, d_loss: 1.0304
14, time: 737.8s, g_loss: 12.5373, d_loss: 1.0180
15, time: 737.5s, g_loss: 12.7237, d_loss: 1.0175
16, time: 738.0s, g_loss: 13.1592, d_loss: 0.9190
17, time: 737.3s, g_loss: 12.9565, d_loss: 0.9196
18, time: 740.0s, g_loss: 13.6411, d_loss: 0.8448

total_time: 221.5min
trainer.load_checkpoint(model_dir/'0-3.pt')
trainer.train(
    n_step=step_per_epoch*30, 
    step_per_epoch=step_per_epoch, 
    savejpg_every=step_per_epoch, 
    jpg_path=str(result_dir/'1'), 
    is_jpg_ema=False,
    saveck_every=step_per_epoch*5, 
    ck_path=str(model_dir/'1'), 
    n_gradient_acc=n_gradient_acc,
    ema_decay=ema_decay, 
)
1, time: 736.1s, g_loss: 13.1634, d_loss: 0.9130
2, time: 736.2s, g_loss: 13.7612, d_loss: 0.7991
3, time: 736.0s, g_loss: 13.2819, d_loss: 0.8901
4, time: 737.3s, g_loss: 13.3199, d_loss: 0.8378
5, time: 742.5s, g_loss: 13.4411, d_loss: 0.8202
6, time: 744.2s, g_loss: 13.3130, d_loss: 0.8468
7, time: 744.1s, g_loss: 14.3841, d_loss: 0.7207
8, time: 738.3s, g_loss: 13.3168, d_loss: 0.8155
9, time: 737.5s, g_loss: 13.0568, d_loss: 0.8374
10, time: 739.4s, g_loss: 13.4755, d_loss: 0.7901
11, time: 737.9s, g_loss: 13.3714, d_loss: 0.7728
12, time: 737.2s, g_loss: 13.8357, d_loss: 0.7173
13, time: 738.1s, g_loss: 13.4536, d_loss: 0.7742
14, time: 737.5s, g_loss: 13.7362, d_loss: 0.7470
15, time: 739.8s, g_loss: 13.6774, d_loss: 0.7381
16, time: 738.5s, g_loss: 14.2344, d_loss: 0.7163
17, time: 739.5s, g_loss: 13.5687, d_loss: 0.7314
18, time: 737.3s, g_loss: 13.9669, d_loss: 0.6793
19, time: 738.0s, g_loss: 14.0334, d_loss: 0.6503
20, time: 742.4s, g_loss: 13.6093, d_loss: 0.7161
21, time: 740.5s, g_loss: 14.0966, d_loss: 0.6538
22, time: 740.1s, g_loss: 14.2859, d_loss: 0.6436
23, time: 741.0s, g_loss: 13.8886, d_loss: 0.6423
24, time: 740.2s, g_loss: 13.8942, d_loss: 0.6810
25, time: 741.8s, g_loss: 14.3715, d_loss: 0.5932
26, time: 739.1s, g_loss: 14.4019, d_loss: 0.6128
27, time: 738.8s, g_loss: 13.9150, d_loss: 0.6136
28, time: 739.1s, g_loss: 13.8992, d_loss: 0.6169
29, time: 739.8s, g_loss: 14.7981, d_loss: 0.5290
30, time: 741.0s, g_loss: 13.8405, d_loss: 0.5969

total_time: 369.7min
trainer.load_checkpoint(model_dir/'1-4.pt')
# 740

Check

trainer.show(is_ema=False)
trainer.show(is_ema=True)
check = trainer.check_d(is_ema=True)
check[1]
(tensor([[0.9831],
         [0.9946],
         [0.9785],
         [0.9790],
         [0.9949],
         [0.9846],
         [0.9546],
         [0.9926],
         [0.9618],
         [0.9224],
         [0.9843],
         [0.9511],
         [0.9236],
         [0.9681],
         [0.9716],
         [0.9479],
         [0.4937],
         [0.9786],
         [0.9805],
         [0.8621],
         [0.9743],
         [0.9611],
         [0.9965],
         [0.7368],
         [0.9877],
         [0.9964],
         [0.9656],
         [0.9902],
         [0.9941],
         [0.9782],
         [0.9018],
         [0.9418],
         [0.9683],
         [0.9871],
         [0.9881],
         [0.9712],
         [0.9734],
         [0.9680],
         [0.9938],
         [0.9654],
         [0.9699],
         [0.8825],
         [0.9835],
         [0.9792],
         [0.9990],
         [0.9798],
         [0.9460],
         [0.9368]], device='cuda:0'), tensor([[0.0974],
         [0.0408],
         [0.0308],
         [0.0107],
         [0.0056],
         [0.0281],
         [0.0262],
         [0.0485],
         [0.1032],
         [0.0140],
         [0.0170],
         [0.0352],
         [0.0077],
         [0.3231],
         [0.0984],
         [0.0035],
         [0.0110],
         [0.0371],
         [0.0157],
         [0.1257],
         [0.1245],
         [0.0289],
         [0.2470],
         [0.0115],
         [0.1673],
         [0.0036],
         [0.0391],
         [0.0097],
         [0.0059],
         [0.0484],
         [0.0352],
         [0.1146],
         [0.0714],
         [0.0359],
         [0.0192],
         [0.2873],
         [0.0196],
         [0.0375],
         [0.0039],
         [0.0715],
         [0.0519],
         [0.0372],
         [0.0949],
         [0.1770],
         [0.0023],
         [0.1026],
         [0.0283],
         [0.0694]], device='cuda:0'))

Export and Inference

hairs = ['orange hair', 'white hair', 'aqua hair', 'gray hair','green hair', 'red hair',
         'purple hair', 'pink hair','blue hair', 'black hair', 'brown hair', 'blonde hair']
eyes = ['black eyes', 'orange eyes', 'purple eyes', 'pink eyes', 'yellow eyes', 'aqua eyes', 
        'green eyes', 'brown eyes', 'red eyes', 'blue eyes']
trainer.export(model_dir/'gan_export.pt', is_ema=True)
model = Anime_Export.from_pretrained(model_dir/'gan_export.pt')
cap = 'white hair yellow eyes'
pred_and_show(model, cap)
cap = 'white hair yellow eyes'
pred_and_show(model, cap)
cap = 'aqua hair green eyes'
pred_and_show(model, cap)
cap = 'pink hair black eyes'
pred_and_show(model, cap)
cap = 'green hair orange eyes'
pred_and_show(model, cap)
cap = 'blonde hair purple eyes'
pred_and_show(model, cap)