Added capability of continue training from previous checkpoints

This commit is contained in:
Plachta
2023-06-12 19:25:30 +08:00
parent 9bbc9e9246
commit cb1f29d1ed
+3 -2
View File
@@ -100,10 +100,11 @@ def run(rank, n_gpus, hps):
# load existing model
if hps.cont:
try:
_, _, _, epoch_str = utils.load_checkpoint("./pretrained_models/G_latest.pth", net_g, None)
_, _, _, epoch_str = utils.load_checkpoint("./pretrained_models/D_latest.pth", net_d, None)
_, _, _, epoch_str = utils.load_checkpoint("./OUTPUT_MODEL/G_latest.pth", net_g, None)
_, _, _, epoch_str = utils.load_checkpoint("./OUTPUT_MODEL/D_latest.pth", net_d, None)
global_step = epoch_str * hps.train.batch_size
except:
print("Failed to find latest checkpoint, loading G_0.pth...")
_, _, _, epoch_str = utils.load_checkpoint("./pretrained_models/G_0.pth", net_g, None)
_, _, _, epoch_str = utils.load_checkpoint("./pretrained_models/D_0.pth", net_d, None)
epoch_str = 1