Added capability of continue training from previous checkpoints
This commit is contained in:
+17
-6
@@ -98,8 +98,17 @@ def run(rank, n_gpus, hps):
|
||||
net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank)
|
||||
|
||||
# load existing model
|
||||
_, _, _, _ = utils.load_checkpoint("./pretrained_models/G_0.pth", net_g, None, drop_speaker_emb=hps.drop_speaker_embed)
|
||||
_, _, _, _ = utils.load_checkpoint("./pretrained_models/D_0.pth", net_d, None)
|
||||
G_ckpt = "./pretrained_models/G_latest.pth" if hps.cont else "./pretrained_models/G_0.pth"
|
||||
D_ckpt = "./pretrained_models/D_latest.pth" if hps.cont else "./pretrained_models/D_0.pth"
|
||||
try:
|
||||
_, _, _, _ = utils.load_checkpoint(G_ckpt, net_g, None,
|
||||
drop_speaker_emb=hps.drop_speaker_embed)
|
||||
except Exception:
|
||||
_, _, _, _ = utils.load_checkpoint("./pretrained_models/G_0.pth", net_g, None, drop_speaker_emb=hps.drop_speaker_embed)
|
||||
try:
|
||||
_, _, _, _ = utils.load_checkpoint(D_ckpt, net_d, None)
|
||||
except Exception:
|
||||
_, _, _, _ = utils.load_checkpoint("./pretrained_models/D_0.pth", net_d, None)
|
||||
epoch_str = 1
|
||||
global_step = 0
|
||||
# freeze all other layers except speaker embedding
|
||||
@@ -243,13 +252,15 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
|
||||
utils.save_checkpoint(net_g, None, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "G_{}.pth".format(global_step)))
|
||||
utils.save_checkpoint(net_g, None, hps.train.learning_rate, epoch,
|
||||
os.path.join(hps.model_dir, "G_latest.pth".format(global_step)))
|
||||
# utils.save_checkpoint(net_d, optim_d, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "D_{}.pth".format(global_step)))
|
||||
utils.save_checkpoint(net_d, None, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "D_{}.pth".format(global_step)))
|
||||
utils.save_checkpoint(net_d, None, hps.train.learning_rate, epoch,
|
||||
os.path.join(hps.model_dir, "D_latest.pth".format(global_step)))
|
||||
old_g=os.path.join(hps.model_dir, "G_{}.pth".format(global_step-4000))
|
||||
# old_d=os.path.join(hps.model_dir, "D_{}.pth".format(global_step-400))
|
||||
old_d=os.path.join(hps.model_dir, "D_{}.pth".format(global_step-4000))
|
||||
if os.path.exists(old_g):
|
||||
os.remove(old_g)
|
||||
# if os.path.exists(old_d):
|
||||
# os.remove(old_d)
|
||||
if os.path.exists(old_d):
|
||||
os.remove(old_d)
|
||||
global_step += 1
|
||||
if epoch > hps.max_epochs:
|
||||
print("Maximum epoch reached, closing training...")
|
||||
|
||||
Reference in New Issue
Block a user