[O] Allow backward-compatibility for torch state dict

This commit is contained in:
2023-10-11 18:54:51 -04:00
parent d47289931d
commit ea5629de26
+1 -1
View File
@@ -98,7 +98,7 @@ class CLAPWrapper():
# We unwrap the DDP model and save. If the model is not unwrapped and saved, then the model needs to unwrapped before `load_state_dict`:
# Reference link: https://discuss.pytorch.org/t/how-to-load-dataparallel-model-which-trained-using-multiple-gpus/146005
clap.load_state_dict(model_state_dict)
clap.load_state_dict(model_state_dict, strict=False)
clap.eval() # set clap in eval mode
tokenizer = AutoTokenizer.from_pretrained(args.text_model)