diff --git a/src/CLAPWrapper.py b/src/CLAPWrapper.py index 654e775..06b3f67 100644 --- a/src/CLAPWrapper.py +++ b/src/CLAPWrapper.py @@ -98,7 +98,7 @@ class CLAPWrapper(): # We unwrap the DDP model and save. If the model is not unwrapped and saved, then the model needs to unwrapped before `load_state_dict`: # Reference link: https://discuss.pytorch.org/t/how-to-load-dataparallel-model-which-trained-using-multiple-gpus/146005 - clap.load_state_dict(model_state_dict) + clap.load_state_dict(model_state_dict, strict=False) clap.eval() # set clap in eval mode tokenizer = AutoTokenizer.from_pretrained(args.text_model)