diff --git a/finetune_speaker_v2.py b/finetune_speaker_v2.py index 75baf7e..5fc4fd8 100644 --- a/finetune_speaker_v2.py +++ b/finetune_speaker_v2.py @@ -100,8 +100,8 @@ def run(rank, n_gpus, hps): # load existing model if hps.cont: try: - _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_[0-9]*.pth"), net_g, None) - _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "D_[0-9]*.pth"), net_d, None) + _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_latest.pth"), net_g, None) + _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "D_latest.pth"), net_d, None) global_step = (epoch_str - 1) * len(train_loader) except: print("Failed to find latest checkpoint, loading G_0.pth...") @@ -260,21 +260,26 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade if global_step % hps.train.eval_interval == 0: evaluate(hps, net_g, eval_loader, writer_eval) - utils.save_checkpoint(net_g, None, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "G_{}.pth".format(global_step))) + utils.save_checkpoint(net_g, None, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "G_latest.pth".format(global_step))) - utils.save_checkpoint(net_d, None, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "D_{}.pth".format(global_step))) - # utils.save_checkpoint(net_d, None, hps.train.learning_rate, epoch, - # os.path.join(hps.model_dir, "D_latest.pth".format(global_step))) - old_g = utils.oldest_checkpoint_path(hps.model_dir, "G_[0-9]*.pth", - preserved=hps.preserved) # Preserve 4 (default) historical checkpoints. - old_d = utils.oldest_checkpoint_path(hps.model_dir, "D_[0-9]*.pth", preserved=hps.preserved) - if os.path.exists(old_g): - print(f"remove {old_g}") - os.remove(old_g) - if os.path.exists(old_d): - print(f"remove {old_d}") - os.remove(old_d) + + utils.save_checkpoint(net_d, None, hps.train.learning_rate, epoch, + os.path.join(hps.model_dir, "D_latest.pth".format(global_step))) + if hps.preserved > 0: + utils.save_checkpoint(net_g, None, hps.train.learning_rate, epoch, + os.path.join(hps.model_dir, "G_{}.pth".format(global_step))) + utils.save_checkpoint(net_d, None, hps.train.learning_rate, epoch, + os.path.join(hps.model_dir, "D_{}.pth".format(global_step))) + old_g = utils.oldest_checkpoint_path(hps.model_dir, "G_[0-9]*.pth", + preserved=hps.preserved) # Preserve 4 (default) historical checkpoints. + old_d = utils.oldest_checkpoint_path(hps.model_dir, "D_[0-9]*.pth", preserved=hps.preserved) + if os.path.exists(old_g): + print(f"remove {old_g}") + os.remove(old_g) + if os.path.exists(old_d): + print(f"remove {old_d}") + os.remove(old_d) global_step += 1 if epoch > hps.max_epochs: print("Maximum epoch reached, closing training...")