diff --git a/msclap/CLAPWrapper.py b/msclap/CLAPWrapper.py index c899be3..b75eb37 100644 --- a/msclap/CLAPWrapper.py +++ b/msclap/CLAPWrapper.py @@ -155,7 +155,7 @@ class CLAPWrapper(): args.num_layers, args.normalize_prefix, args.mapping_type, True, True) model_state_dict = torch.load(self.model_fp, map_location=torch.device('cpu'))['model'] - clapcap.load_state_dict(model_state_dict) + clapcap.load_state_dict(model_state_dict, strict=False) clapcap.eval() # set clap in eval mode tokenizer = AutoTokenizer.from_pretrained(args.text_model)