update: 可选是否保存旧模型
This commit is contained in:
+20
-15
@@ -100,8 +100,8 @@ def run(rank, n_gpus, hps):
|
||||
# load existing model
|
||||
if hps.cont:
|
||||
try:
|
||||
_, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_[0-9]*.pth"), net_g, None)
|
||||
_, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "D_[0-9]*.pth"), net_d, None)
|
||||
_, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_latest.pth"), net_g, None)
|
||||
_, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "D_latest.pth"), net_d, None)
|
||||
global_step = (epoch_str - 1) * len(train_loader)
|
||||
except:
|
||||
print("Failed to find latest checkpoint, loading G_0.pth...")
|
||||
@@ -260,21 +260,26 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
|
||||
|
||||
if global_step % hps.train.eval_interval == 0:
|
||||
evaluate(hps, net_g, eval_loader, writer_eval)
|
||||
utils.save_checkpoint(net_g, None, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "G_{}.pth".format(global_step)))
|
||||
|
||||
utils.save_checkpoint(net_g, None, hps.train.learning_rate, epoch,
|
||||
os.path.join(hps.model_dir, "G_latest.pth".format(global_step)))
|
||||
utils.save_checkpoint(net_d, None, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "D_{}.pth".format(global_step)))
|
||||
# utils.save_checkpoint(net_d, None, hps.train.learning_rate, epoch,
|
||||
# os.path.join(hps.model_dir, "D_latest.pth".format(global_step)))
|
||||
old_g = utils.oldest_checkpoint_path(hps.model_dir, "G_[0-9]*.pth",
|
||||
preserved=hps.preserved) # Preserve 4 (default) historical checkpoints.
|
||||
old_d = utils.oldest_checkpoint_path(hps.model_dir, "D_[0-9]*.pth", preserved=hps.preserved)
|
||||
if os.path.exists(old_g):
|
||||
print(f"remove {old_g}")
|
||||
os.remove(old_g)
|
||||
if os.path.exists(old_d):
|
||||
print(f"remove {old_d}")
|
||||
os.remove(old_d)
|
||||
|
||||
utils.save_checkpoint(net_d, None, hps.train.learning_rate, epoch,
|
||||
os.path.join(hps.model_dir, "D_latest.pth".format(global_step)))
|
||||
if hps.preserved > 0:
|
||||
utils.save_checkpoint(net_g, None, hps.train.learning_rate, epoch,
|
||||
os.path.join(hps.model_dir, "G_{}.pth".format(global_step)))
|
||||
utils.save_checkpoint(net_d, None, hps.train.learning_rate, epoch,
|
||||
os.path.join(hps.model_dir, "D_{}.pth".format(global_step)))
|
||||
old_g = utils.oldest_checkpoint_path(hps.model_dir, "G_[0-9]*.pth",
|
||||
preserved=hps.preserved) # Preserve 4 (default) historical checkpoints.
|
||||
old_d = utils.oldest_checkpoint_path(hps.model_dir, "D_[0-9]*.pth", preserved=hps.preserved)
|
||||
if os.path.exists(old_g):
|
||||
print(f"remove {old_g}")
|
||||
os.remove(old_g)
|
||||
if os.path.exists(old_d):
|
||||
print(f"remove {old_d}")
|
||||
os.remove(old_d)
|
||||
global_step += 1
|
||||
if epoch > hps.max_epochs:
|
||||
print("Maximum epoch reached, closing training...")
|
||||
|
||||
Reference in New Issue
Block a user