Checkpoints will be saved to google drive during training

This commit is contained in:
Plachta
2023-07-13 20:33:47 +08:00
parent cec3206028
commit a5a0fed4e1
+23 -2
View File
@@ -262,10 +262,17 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
evaluate(hps, net_g, eval_loader, writer_eval) evaluate(hps, net_g, eval_loader, writer_eval)
utils.save_checkpoint(net_g, None, hps.train.learning_rate, epoch, utils.save_checkpoint(net_g, None, hps.train.learning_rate, epoch,
os.path.join(hps.model_dir, "G_latest.pth".format(global_step))) os.path.join(hps.model_dir, "G_latest.pth"))
utils.save_checkpoint(net_d, None, hps.train.learning_rate, epoch, utils.save_checkpoint(net_d, None, hps.train.learning_rate, epoch,
os.path.join(hps.model_dir, "D_latest.pth".format(global_step))) os.path.join(hps.model_dir, "D_latest.pth"))
# save to google drive
if os.path.exists("/content/drive/MyDrive/"):
utils.save_checkpoint(net_g, None, hps.train.learning_rate, epoch,
os.path.join("/content/drive/MyDrive/", "G_latest.pth"))
utils.save_checkpoint(net_d, None, hps.train.learning_rate, epoch,
os.path.join("/content/drive/MyDrive/", "D_latest.pth"))
if hps.preserved > 0: if hps.preserved > 0:
utils.save_checkpoint(net_g, None, hps.train.learning_rate, epoch, utils.save_checkpoint(net_g, None, hps.train.learning_rate, epoch,
os.path.join(hps.model_dir, "G_{}.pth".format(global_step))) os.path.join(hps.model_dir, "G_{}.pth".format(global_step)))
@@ -280,6 +287,20 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
if os.path.exists(old_d): if os.path.exists(old_d):
print(f"remove {old_d}") print(f"remove {old_d}")
os.remove(old_d) os.remove(old_d)
if os.path.exists("/content/drive/MyDrive/"):
utils.save_checkpoint(net_g, None, hps.train.learning_rate, epoch,
os.path.join("/content/drive/MyDrive/", "G_{}.pth".format(global_step)))
utils.save_checkpoint(net_d, None, hps.train.learning_rate, epoch,
os.path.join("/content/drive/MyDrive/", "D_{}.pth".format(global_step)))
old_g = utils.oldest_checkpoint_path("/content/drive/MyDrive/", "G_[0-9]*.pth",
preserved=hps.preserved) # Preserve 4 (default) historical checkpoints.
old_d = utils.oldest_checkpoint_path("/content/drive/MyDrive/", "D_[0-9]*.pth", preserved=hps.preserved)
if os.path.exists(old_g):
print(f"remove {old_g}")
os.remove(old_g)
if os.path.exists(old_d):
print(f"remove {old_d}")
os.remove(old_d)
global_step += 1 global_step += 1
if epoch > hps.max_epochs: if epoch > hps.max_epochs:
print("Maximum epoch reached, closing training...") print("Maximum epoch reached, closing training...")