Checkpoints will be saved to google drive during training
This commit is contained in:
+23
-2
@@ -262,10 +262,17 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
|
||||
evaluate(hps, net_g, eval_loader, writer_eval)
|
||||
|
||||
utils.save_checkpoint(net_g, None, hps.train.learning_rate, epoch,
|
||||
os.path.join(hps.model_dir, "G_latest.pth".format(global_step)))
|
||||
os.path.join(hps.model_dir, "G_latest.pth"))
|
||||
|
||||
utils.save_checkpoint(net_d, None, hps.train.learning_rate, epoch,
|
||||
os.path.join(hps.model_dir, "D_latest.pth".format(global_step)))
|
||||
os.path.join(hps.model_dir, "D_latest.pth"))
|
||||
# save to google drive
|
||||
if os.path.exists("/content/drive/MyDrive/"):
|
||||
utils.save_checkpoint(net_g, None, hps.train.learning_rate, epoch,
|
||||
os.path.join("/content/drive/MyDrive/", "G_latest.pth"))
|
||||
|
||||
utils.save_checkpoint(net_d, None, hps.train.learning_rate, epoch,
|
||||
os.path.join("/content/drive/MyDrive/", "D_latest.pth"))
|
||||
if hps.preserved > 0:
|
||||
utils.save_checkpoint(net_g, None, hps.train.learning_rate, epoch,
|
||||
os.path.join(hps.model_dir, "G_{}.pth".format(global_step)))
|
||||
@@ -280,6 +287,20 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
|
||||
if os.path.exists(old_d):
|
||||
print(f"remove {old_d}")
|
||||
os.remove(old_d)
|
||||
if os.path.exists("/content/drive/MyDrive/"):
|
||||
utils.save_checkpoint(net_g, None, hps.train.learning_rate, epoch,
|
||||
os.path.join("/content/drive/MyDrive/", "G_{}.pth".format(global_step)))
|
||||
utils.save_checkpoint(net_d, None, hps.train.learning_rate, epoch,
|
||||
os.path.join("/content/drive/MyDrive/", "D_{}.pth".format(global_step)))
|
||||
old_g = utils.oldest_checkpoint_path("/content/drive/MyDrive/", "G_[0-9]*.pth",
|
||||
preserved=hps.preserved) # Preserve 4 (default) historical checkpoints.
|
||||
old_d = utils.oldest_checkpoint_path("/content/drive/MyDrive/", "D_[0-9]*.pth", preserved=hps.preserved)
|
||||
if os.path.exists(old_g):
|
||||
print(f"remove {old_g}")
|
||||
os.remove(old_g)
|
||||
if os.path.exists(old_d):
|
||||
print(f"remove {old_d}")
|
||||
os.remove(old_d)
|
||||
global_step += 1
|
||||
if epoch > hps.max_epochs:
|
||||
print("Maximum epoch reached, closing training...")
|
||||
|
||||
Reference in New Issue
Block a user