From ea5629de26ee6a9f288f28b1819612ea47e05e86 Mon Sep 17 00:00:00 2001 From: Azalea <22280294+hykilpikonna@users.noreply.github.com> Date: Wed, 11 Oct 2023 18:54:51 -0400 Subject: [PATCH] [O] Allow backward-compatibility for torch state dict --- src/CLAPWrapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/CLAPWrapper.py b/src/CLAPWrapper.py index 654e775..06b3f67 100644 --- a/src/CLAPWrapper.py +++ b/src/CLAPWrapper.py @@ -98,7 +98,7 @@ class CLAPWrapper(): # We unwrap the DDP model and save. If the model is not unwrapped and saved, then the model needs to unwrapped before `load_state_dict`: # Reference link: https://discuss.pytorch.org/t/how-to-load-dataparallel-model-which-trained-using-multiple-gpus/146005 - clap.load_state_dict(model_state_dict) + clap.load_state_dict(model_state_dict, strict=False) clap.eval() # set clap in eval mode tokenizer = AutoTokenizer.from_pretrained(args.text_model)