from pathlib import Path
from deep_t2i.trainer_GAN import BirdsTrainer
from deep_t2i.model import Birds_Export
from deep_t2i.inference_birds import pred_and_show
data_dir = Path('/root/data/birds')
model_dir = Path('../models/large_birds/gan')
result_dir = Path('../result_jpgs/large_birds/gan')
pretrained_damsm_path = Path('../models/large_birds/damsm/damsm_export.pt')
data_dir, model_dir, result_dir, pretrained_damsm_path
bs = 24
data_pct = 1
g_lr = 2e-4
d_lr = 2e-4
smooth_lambda = 2.0
noise_sz = 512
trainer = BirdsTrainer(
data_dir,
bs=bs,
data_pct=data_pct,
g_lr=g_lr,
d_lr=d_lr,
device='cuda',
pretrained_damsm_path=pretrained_damsm_path,
smooth_lambda=smooth_lambda,
noise_sz=noise_sz,
)
len(trainer.dls.train)
ema_decay = 0.999
n_gradient_acc = 1
step_per_epoch = len(trainer.dls.train) // (n_gradient_acc*1)
step_per_epoch
trainer.train(
n_step=step_per_epoch*52,
step_per_epoch=step_per_epoch,
savejpg_every=step_per_epoch,
jpg_path=str(result_dir/'0'),
is_jpg_ema=False,
saveck_every=step_per_epoch*4,
ck_path=str(model_dir/'0'),
n_gradient_acc=n_gradient_acc,
ema_decay=ema_decay,
)
trainer.load_checkpoint(model_dir/'0-13.pt')
trainer.train(
n_step=step_per_epoch*49,
step_per_epoch=step_per_epoch,
savejpg_every=step_per_epoch,
jpg_path=str(result_dir/'1'),
is_jpg_ema=False,
saveck_every=step_per_epoch*7,
ck_path=str(model_dir/'1'),
n_gradient_acc=n_gradient_acc,
ema_decay=ema_decay,
)
trainer.load_checkpoint(model_dir/'1-7.pt')
trainer.train(
n_step=step_per_epoch*52,
step_per_epoch=step_per_epoch,
savejpg_every=step_per_epoch,
jpg_path=str(result_dir/'2'),
is_jpg_ema=False,
saveck_every=step_per_epoch*4,
ck_path=str(model_dir/'2'),
n_gradient_acc=n_gradient_acc,
ema_decay=ema_decay,
)
trainer.load_checkpoint(model_dir/'2-13.pt')
trainer.train(
n_step=step_per_epoch*49,
step_per_epoch=step_per_epoch,
savejpg_every=step_per_epoch,
jpg_path=str(result_dir/'3'),
is_jpg_ema=False,
saveck_every=step_per_epoch*7,
ck_path=str(model_dir/'3'),
n_gradient_acc=n_gradient_acc,
ema_decay=ema_decay,
)
trainer.load_checkpoint(model_dir/'3-7.pt')
trainer.train(
n_step=step_per_epoch*52,
step_per_epoch=step_per_epoch,
savejpg_every=step_per_epoch,
jpg_path=str(result_dir/'4'),
is_jpg_ema=False,
saveck_every=step_per_epoch*4,
ck_path=str(model_dir/'4'),
n_gradient_acc=n_gradient_acc,
ema_decay=ema_decay,
)
trainer.load_checkpoint(model_dir/'4-13.pt')
trainer.train(
n_step=step_per_epoch*50,
step_per_epoch=step_per_epoch,
savejpg_every=step_per_epoch,
jpg_path=str(result_dir/'5'),
is_jpg_ema=False,
saveck_every=step_per_epoch*5,
ck_path=str(model_dir/'5'),
n_gradient_acc=n_gradient_acc,
ema_decay=ema_decay,
)
trainer.load_checkpoint(model_dir/'5-10.pt')
trainer.train(
n_step=step_per_epoch*50,
step_per_epoch=step_per_epoch,
savejpg_every=step_per_epoch,
jpg_path=str(result_dir/'6'),
is_jpg_ema=False,
saveck_every=step_per_epoch*5,
ck_path=str(model_dir/'6'),
n_gradient_acc=n_gradient_acc,
ema_decay=ema_decay,
)
trainer.load_checkpoint(model_dir/'6-10.pt')
trainer.train(
n_step=step_per_epoch*50,
step_per_epoch=step_per_epoch,
savejpg_every=step_per_epoch,
jpg_path=str(result_dir/'7'),
is_jpg_ema=False,
saveck_every=step_per_epoch*5,
ck_path=str(model_dir/'7'),
n_gradient_acc=n_gradient_acc,
ema_decay=ema_decay,
)
trainer.load_checkpoint(model_dir/'7-10.pt')
trainer.train(
n_step=step_per_epoch*52,
step_per_epoch=step_per_epoch,
savejpg_every=step_per_epoch,
jpg_path=str(result_dir/'8'),
is_jpg_ema=False,
saveck_every=step_per_epoch*4,
ck_path=str(model_dir/'8'),
n_gradient_acc=n_gradient_acc,
ema_decay=ema_decay,
)
trainer.load_checkpoint(model_dir/'8-13.pt')
trainer.train(
n_step=step_per_epoch*50,
step_per_epoch=step_per_epoch,
savejpg_every=step_per_epoch,
jpg_path=str(result_dir/'9'),
is_jpg_ema=False,
saveck_every=step_per_epoch*5,
ck_path=str(model_dir/'9'),
n_gradient_acc=n_gradient_acc,
ema_decay=ema_decay,
)
trainer.load_checkpoint(model_dir/'9-10.pt')
trainer.train(
n_step=step_per_epoch*50,
step_per_epoch=step_per_epoch,
savejpg_every=step_per_epoch,
jpg_path=str(result_dir/'10'),
is_jpg_ema=False,
saveck_every=step_per_epoch*5,
ck_path=str(model_dir/'10'),
n_gradient_acc=n_gradient_acc,
ema_decay=ema_decay,
)
trainer.load_checkpoint(model_dir/'10-10.pt')
trainer.train(
n_step=step_per_epoch*52,
step_per_epoch=step_per_epoch,
savejpg_every=step_per_epoch,
jpg_path=str(result_dir/'11'),
is_jpg_ema=False,
saveck_every=step_per_epoch*4,
ck_path=str(model_dir/'11'),
n_gradient_acc=n_gradient_acc,
ema_decay=ema_decay,
)
trainer.load_checkpoint(model_dir/'11-13.pt')
trainer.train(
n_step=step_per_epoch*52,
step_per_epoch=step_per_epoch,
savejpg_every=step_per_epoch,
jpg_path=str(result_dir/'12'),
is_jpg_ema=False,
saveck_every=step_per_epoch*4,
ck_path=str(model_dir/'12'),
n_gradient_acc=n_gradient_acc,
ema_decay=ema_decay,
)
trainer.load_checkpoint(model_dir/'12-13.pt')
trainer.train(
n_step=step_per_epoch*50,
step_per_epoch=step_per_epoch,
savejpg_every=step_per_epoch,
jpg_path=str(result_dir/'13'),
is_jpg_ema=False,
saveck_every=step_per_epoch*5,
ck_path=str(model_dir/'13'),
n_gradient_acc=n_gradient_acc,
ema_decay=ema_decay,
)
trainer.load_checkpoint(model_dir/'13-10.pt')
# 662, 52, 49
# 690, 50,
trainer.show(is_ema=False)
trainer.show(is_ema=True)
check = trainer.check_d(is_ema=True)
check[1]
simple_caps = [
'a small red bird',
'a small orange bird',
'a small blue bird',
'a small yellow bird',
'a small black bird',
]
medium_caps = [
'a small white bird with orange bill',
'a large red bird with black beak',
'a small black bird with a yellow head',
'a large yellow bird with long black beak',
'this bird has a green crown, black wings and a yellow belly',
]
complex_caps = [
'this bird has a blue crown green primaries and a red belly',
]
trainer.export(model_dir/'gan_export.pt', is_ema=True)
model = Birds_Export.from_pretrained(model_dir/'gan_export.pt')
cap = simple_caps[0]
pred_and_show(model, cap)
cap = medium_caps[4]
pred_and_show(model, cap)
cap = complex_caps[0]
pred_and_show(model, cap)
cap = 'this bird is red with white and has a very short beak'
pred_and_show(model, cap)
cap = 'the bird has a yellow crown and a black eyering that is round'
pred_and_show(model, cap)