update: 可选是否保存旧模型

This commit is contained in:
Emberstar
2023-06-14 15:33:00 +08:00
parent 6d65db1f76
commit a9b43a8afc
+20 -15
View File
@@ -100,8 +100,8 @@ def run(rank, n_gpus, hps):
# load existing model
if hps.cont:
try:
_, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_[0-9]*.pth"), net_g, None)
_, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "D_[0-9]*.pth"), net_d, None)
_, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_latest.pth"), net_g, None)
_, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "D_latest.pth"), net_d, None)
global_step = (epoch_str - 1) * len(train_loader)
except:
print("Failed to find latest checkpoint, loading G_0.pth...")
@@ -260,21 +260,26 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
if global_step % hps.train.eval_interval == 0:
evaluate(hps, net_g, eval_loader, writer_eval)
utils.save_checkpoint(net_g, None, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "G_{}.pth".format(global_step)))
utils.save_checkpoint(net_g, None, hps.train.learning_rate, epoch,
os.path.join(hps.model_dir, "G_latest.pth".format(global_step)))
utils.save_checkpoint(net_d, None, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "D_{}.pth".format(global_step)))
# utils.save_checkpoint(net_d, None, hps.train.learning_rate, epoch,
# os.path.join(hps.model_dir, "D_latest.pth".format(global_step)))
old_g = utils.oldest_checkpoint_path(hps.model_dir, "G_[0-9]*.pth",
preserved=hps.preserved) # Preserve 4 (default) historical checkpoints.
old_d = utils.oldest_checkpoint_path(hps.model_dir, "D_[0-9]*.pth", preserved=hps.preserved)
if os.path.exists(old_g):
print(f"remove {old_g}")
os.remove(old_g)
if os.path.exists(old_d):
print(f"remove {old_d}")
os.remove(old_d)
utils.save_checkpoint(net_d, None, hps.train.learning_rate, epoch,
os.path.join(hps.model_dir, "D_latest.pth".format(global_step)))
if hps.preserved > 0:
utils.save_checkpoint(net_g, None, hps.train.learning_rate, epoch,
os.path.join(hps.model_dir, "G_{}.pth".format(global_step)))
utils.save_checkpoint(net_d, None, hps.train.learning_rate, epoch,
os.path.join(hps.model_dir, "D_{}.pth".format(global_step)))
old_g = utils.oldest_checkpoint_path(hps.model_dir, "G_[0-9]*.pth",
preserved=hps.preserved) # Preserve 4 (default) historical checkpoints.
old_d = utils.oldest_checkpoint_path(hps.model_dir, "D_[0-9]*.pth", preserved=hps.preserved)
if os.path.exists(old_g):
print(f"remove {old_g}")
os.remove(old_g)
if os.path.exists(old_d):
print(f"remove {old_d}")
os.remove(old_d)
global_step += 1
if epoch > hps.max_epochs:
print("Maximum epoch reached, closing training...")