from pathlib import Path
from deep_t2i.trainer_GAN import AnimeHeadsTrainer
from deep_t2i.model import Anime_Export
from deep_t2i.inference_anime_heads import pred_and_show
data_dir = Path('/root/data/anime_heads')
model_dir = Path('../models/large_anime_heads/gan')
result_dir = Path('../result_jpgs/large_anime_heads/gan')
pretrained_damsm_path = Path('../models/large_anime_heads/damsm/damsm_export.pt')
data_dir, model_dir, result_dir, pretrained_damsm_path
bs = 48
data_pct = 1
g_lr = 2e-4
d_lr = 2e-4
smooth_lambda = 2.0
noise_sz = 512
trainer = AnimeHeadsTrainer(
data_dir,
bs=bs,
data_pct=data_pct,
g_lr=g_lr,
d_lr=d_lr,
device='cuda',
pretrained_damsm_path=pretrained_damsm_path,
smooth_lambda=smooth_lambda,
noise_sz=noise_sz,
)
len(trainer.dls.train)
ema_decay = 0.999
n_gradient_acc = 1
step_per_epoch = len(trainer.dls.train) // (n_gradient_acc*1)
step_per_epoch
trainer.train(
n_step=step_per_epoch*18,
step_per_epoch=step_per_epoch,
savejpg_every=step_per_epoch,
jpg_path=str(result_dir/'0'),
is_jpg_ema=False,
saveck_every=step_per_epoch*6,
ck_path=str(model_dir/'0'),
n_gradient_acc=n_gradient_acc,
ema_decay=ema_decay,
)
trainer.load_checkpoint(model_dir/'0-3.pt')
trainer.train(
n_step=step_per_epoch*30,
step_per_epoch=step_per_epoch,
savejpg_every=step_per_epoch,
jpg_path=str(result_dir/'1'),
is_jpg_ema=False,
saveck_every=step_per_epoch*5,
ck_path=str(model_dir/'1'),
n_gradient_acc=n_gradient_acc,
ema_decay=ema_decay,
)
trainer.load_checkpoint(model_dir/'1-4.pt')
# 740
trainer.show(is_ema=False)
trainer.show(is_ema=True)
check = trainer.check_d(is_ema=True)
check[1]
hairs = ['orange hair', 'white hair', 'aqua hair', 'gray hair','green hair', 'red hair',
'purple hair', 'pink hair','blue hair', 'black hair', 'brown hair', 'blonde hair']
eyes = ['black eyes', 'orange eyes', 'purple eyes', 'pink eyes', 'yellow eyes', 'aqua eyes',
'green eyes', 'brown eyes', 'red eyes', 'blue eyes']
trainer.export(model_dir/'gan_export.pt', is_ema=True)
model = Anime_Export.from_pretrained(model_dir/'gan_export.pt')
cap = 'white hair yellow eyes'
pred_and_show(model, cap)
cap = 'white hair yellow eyes'
pred_and_show(model, cap)
cap = 'aqua hair green eyes'
pred_and_show(model, cap)
cap = 'pink hair black eyes'
pred_and_show(model, cap)
cap = 'green hair orange eyes'
pred_and_show(model, cap)
cap = 'blonde hair purple eyes'
pred_and_show(model, cap)