Added capability of continue training from previous checkpoints

This commit is contained in:
Plachta
2023-06-12 18:42:05 +08:00
parent 1d7e8fc637
commit 291d8ddf5e
3 changed files with 19 additions and 6 deletions
+17 -6
View File
@@ -98,8 +98,17 @@ def run(rank, n_gpus, hps):
net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank)
# load existing model
_, _, _, _ = utils.load_checkpoint("./pretrained_models/G_0.pth", net_g, None, drop_speaker_emb=hps.drop_speaker_embed)
_, _, _, _ = utils.load_checkpoint("./pretrained_models/D_0.pth", net_d, None)
G_ckpt = "./pretrained_models/G_latest.pth" if hps.cont else "./pretrained_models/G_0.pth"
D_ckpt = "./pretrained_models/D_latest.pth" if hps.cont else "./pretrained_models/D_0.pth"
try:
_, _, _, _ = utils.load_checkpoint(G_ckpt, net_g, None,
drop_speaker_emb=hps.drop_speaker_embed)
except Exception:
_, _, _, _ = utils.load_checkpoint("./pretrained_models/G_0.pth", net_g, None, drop_speaker_emb=hps.drop_speaker_embed)
try:
_, _, _, _ = utils.load_checkpoint(D_ckpt, net_d, None)
except Exception:
_, _, _, _ = utils.load_checkpoint("./pretrained_models/D_0.pth", net_d, None)
epoch_str = 1
global_step = 0
# freeze all other layers except speaker embedding
@@ -243,13 +252,15 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
utils.save_checkpoint(net_g, None, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "G_{}.pth".format(global_step)))
utils.save_checkpoint(net_g, None, hps.train.learning_rate, epoch,
os.path.join(hps.model_dir, "G_latest.pth".format(global_step)))
# utils.save_checkpoint(net_d, optim_d, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "D_{}.pth".format(global_step)))
utils.save_checkpoint(net_d, None, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "D_{}.pth".format(global_step)))
utils.save_checkpoint(net_d, None, hps.train.learning_rate, epoch,
os.path.join(hps.model_dir, "D_latest.pth".format(global_step)))
old_g=os.path.join(hps.model_dir, "G_{}.pth".format(global_step-4000))
# old_d=os.path.join(hps.model_dir, "D_{}.pth".format(global_step-400))
old_d=os.path.join(hps.model_dir, "D_{}.pth".format(global_step-4000))
if os.path.exists(old_g):
os.remove(old_g)
# if os.path.exists(old_d):
# os.remove(old_d)
if os.path.exists(old_d):
os.remove(old_d)
global_step += 1
if epoch > hps.max_epochs:
print("Maximum epoch reached, closing training...")