Checkpoints will be saved to google drive during training
This commit is contained in:
+23
-2
@@ -262,10 +262,17 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
|
|||||||
evaluate(hps, net_g, eval_loader, writer_eval)
|
evaluate(hps, net_g, eval_loader, writer_eval)
|
||||||
|
|
||||||
utils.save_checkpoint(net_g, None, hps.train.learning_rate, epoch,
|
utils.save_checkpoint(net_g, None, hps.train.learning_rate, epoch,
|
||||||
os.path.join(hps.model_dir, "G_latest.pth".format(global_step)))
|
os.path.join(hps.model_dir, "G_latest.pth"))
|
||||||
|
|
||||||
utils.save_checkpoint(net_d, None, hps.train.learning_rate, epoch,
|
utils.save_checkpoint(net_d, None, hps.train.learning_rate, epoch,
|
||||||
os.path.join(hps.model_dir, "D_latest.pth".format(global_step)))
|
os.path.join(hps.model_dir, "D_latest.pth"))
|
||||||
|
# save to google drive
|
||||||
|
if os.path.exists("/content/drive/MyDrive/"):
|
||||||
|
utils.save_checkpoint(net_g, None, hps.train.learning_rate, epoch,
|
||||||
|
os.path.join("/content/drive/MyDrive/", "G_latest.pth"))
|
||||||
|
|
||||||
|
utils.save_checkpoint(net_d, None, hps.train.learning_rate, epoch,
|
||||||
|
os.path.join("/content/drive/MyDrive/", "D_latest.pth"))
|
||||||
if hps.preserved > 0:
|
if hps.preserved > 0:
|
||||||
utils.save_checkpoint(net_g, None, hps.train.learning_rate, epoch,
|
utils.save_checkpoint(net_g, None, hps.train.learning_rate, epoch,
|
||||||
os.path.join(hps.model_dir, "G_{}.pth".format(global_step)))
|
os.path.join(hps.model_dir, "G_{}.pth".format(global_step)))
|
||||||
@@ -280,6 +287,20 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
|
|||||||
if os.path.exists(old_d):
|
if os.path.exists(old_d):
|
||||||
print(f"remove {old_d}")
|
print(f"remove {old_d}")
|
||||||
os.remove(old_d)
|
os.remove(old_d)
|
||||||
|
if os.path.exists("/content/drive/MyDrive/"):
|
||||||
|
utils.save_checkpoint(net_g, None, hps.train.learning_rate, epoch,
|
||||||
|
os.path.join("/content/drive/MyDrive/", "G_{}.pth".format(global_step)))
|
||||||
|
utils.save_checkpoint(net_d, None, hps.train.learning_rate, epoch,
|
||||||
|
os.path.join("/content/drive/MyDrive/", "D_{}.pth".format(global_step)))
|
||||||
|
old_g = utils.oldest_checkpoint_path("/content/drive/MyDrive/", "G_[0-9]*.pth",
|
||||||
|
preserved=hps.preserved) # Preserve 4 (default) historical checkpoints.
|
||||||
|
old_d = utils.oldest_checkpoint_path("/content/drive/MyDrive/", "D_[0-9]*.pth", preserved=hps.preserved)
|
||||||
|
if os.path.exists(old_g):
|
||||||
|
print(f"remove {old_g}")
|
||||||
|
os.remove(old_g)
|
||||||
|
if os.path.exists(old_d):
|
||||||
|
print(f"remove {old_d}")
|
||||||
|
os.remove(old_d)
|
||||||
global_step += 1
|
global_step += 1
|
||||||
if epoch > hps.max_epochs:
|
if epoch > hps.max_epochs:
|
||||||
print("Maximum epoch reached, closing training...")
|
print("Maximum epoch reached, closing training...")
|
||||||
|
|||||||
Reference in New Issue
Block a user