修改加载latest model的方式,修改global_step计算,增加preserved参数,增加train_with_pretrained_model参数

This commit is contained in:
Emberstar
2023-06-14 10:18:19 +08:00
parent e97b185188
commit 576424fe58
2 changed files with 46 additions and 14 deletions
+22 -11
View File
@@ -100,18 +100,26 @@ def run(rank, n_gpus, hps):
# load existing model
if hps.cont:
try:
_, _, _, epoch_str = utils.load_checkpoint("./OUTPUT_MODEL/G_latest.pth", net_g, None)
_, _, _, epoch_str = utils.load_checkpoint("./OUTPUT_MODEL/D_latest.pth", net_d, None)
global_step = epoch_str * hps.train.batch_size
_, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_[0-9]*.pth"), net_g, None)
_, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "D_[0-9]*.pth"), net_d, None)
global_step = (epoch_str - 1) * len(train_loader)
except:
print("Failed to find latest checkpoint, loading G_0.pth...")
_, _, _, epoch_str = utils.load_checkpoint("./pretrained_models/G_0.pth", net_g, None)
_, _, _, epoch_str = utils.load_checkpoint("./pretrained_models/D_0.pth", net_d, None)
if hps.train_with_pretrained_model:
print("Train with pretrained model...")
_, _, _, epoch_str = utils.load_checkpoint("./pretrained_models/G_0.pth", net_g, None)
_, _, _, epoch_str = utils.load_checkpoint("./pretrained_models/D_0.pth", net_d, None)
else:
print("Train without pretrained model...")
epoch_str = 1
global_step = 0
else:
_, _, _, epoch_str = utils.load_checkpoint("./pretrained_models/G_0.pth", net_g, None)
_, _, _, epoch_str = utils.load_checkpoint("./pretrained_models/D_0.pth", net_d, None)
if hps.train_with_pretrained_model:
print("Train with pretrained model...")
_, _, _, epoch_str = utils.load_checkpoint("./pretrained_models/G_0.pth", net_g, None)
_, _, _, epoch_str = utils.load_checkpoint("./pretrained_models/D_0.pth", net_d, None)
else:
print("Train without pretrained model...")
epoch_str = 1
global_step = 0
# freeze all other layers except speaker embedding
@@ -256,13 +264,16 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
utils.save_checkpoint(net_g, None, hps.train.learning_rate, epoch,
os.path.join(hps.model_dir, "G_latest.pth".format(global_step)))
utils.save_checkpoint(net_d, None, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "D_{}.pth".format(global_step)))
utils.save_checkpoint(net_d, None, hps.train.learning_rate, epoch,
os.path.join(hps.model_dir, "D_latest.pth".format(global_step)))
old_g=os.path.join(hps.model_dir, "G_{}.pth".format(global_step-4000))
old_d=os.path.join(hps.model_dir, "D_{}.pth".format(global_step-4000))
# utils.save_checkpoint(net_d, None, hps.train.learning_rate, epoch,
# os.path.join(hps.model_dir, "D_latest.pth".format(global_step)))
old_g = utils.oldest_checkpoint_path(hps.model_dir, "G_[0-9]*.pth",
preserved=hps.preserved) # Preserve 4 (default) historical checkpoints.
old_d = utils.oldest_checkpoint_path(hps.model_dir, "D_[0-9]*.pth", preserved=hps.preserved)
if os.path.exists(old_g):
print(f"remove {old_g}")
os.remove(old_g)
if os.path.exists(old_d):
print(f"remove {old_d}")
os.remove(old_d)
global_step += 1
if epoch > hps.max_epochs: