diff --git a/README.md b/README.md index 527da6d..1ad133e 100644 --- a/README.md +++ b/README.md @@ -18,11 +18,6 @@ pip install msclap pip install git+https://github.com/microsoft/CLAP.git ``` -## NEW CLAP weights -Download CLAP weights: versions _2022_, _2023_, and _clapcap_: [Zenodo](https://zenodo.org/record/8378278) or [HuggingFace](https://huggingface.co/microsoft/msclap) - -_clapcap_ is the audio captioning model that uses the 2023 encoders. - ## Usage - Zero-Shot Classification and Retrieval @@ -30,7 +25,8 @@ _clapcap_ is the audio captioning model that uses the 2023 encoders. from msclap import CLAP # Load model (Choose between versions '2022' or '2023') -clap_model = CLAP("", version = '2023', use_cuda=False) +# The model weight will be downloaded automatically if `model_fp` is not specified +clap_model = CLAP(version = '2023', use_cuda=False) # Extract text embeddings text_embeddings = clap_model.get_text_embeddings(class_labels: List[str]) @@ -47,7 +43,7 @@ similarities = clap_model.compute_similarity(audio_embeddings, text_embeddings) from msclap import CLAP # Load model (Choose version 'clapcap') -clap_model = CLAP("", version = 'clapcap', use_cuda=False) +clap_model = CLAP(version = 'clapcap', use_cuda=False) # Generate audio captions captions = clap_model.generate_caption(file_paths: List[str]) diff --git a/msclap/CLAPWrapper.py b/msclap/CLAPWrapper.py index 542ab6f..b75eb37 100644 --- a/msclap/CLAPWrapper.py +++ b/msclap/CLAPWrapper.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from pathlib import Path import warnings warnings.filterwarnings("ignore") @@ -16,6 +18,7 @@ import torch import argparse import yaml import sys +from huggingface_hub.file_download import hf_hub_download logging.set_verbosity_error() @@ -23,15 +26,30 @@ class CLAPWrapper(): """ A class for interfacing CLAP model. """ + model_repo = "microsoft/msclap" + model_name = { + '2022': 'CLAP_weights_2022.pth', + '2023': 'CLAP_weights_2023.pth', + 'clapcap': 'clapcap_weights_2023.pth' + } + + def __init__(self, model_fp: Path | str | None = None, version: str = '2023', use_cuda=False): + # Check if version is supported + self.supported_versions = self.model_name.keys() + if version not in self.supported_versions: + raise ValueError(f"The version {version} is not supported. The supported versions are {str(self.supported_versions)}") - def __init__(self, model_fp, version, use_cuda=False): - self.supported_versions = ['2022', '2023', 'clapcap'] self.np_str_obj_array_pattern = re.compile(r'[SaUO]') self.file_path = os.path.realpath(__file__) self.default_collate_err_msg_format = ( "default_collate: batch must contain tensors, numpy arrays, numbers, " "dicts or lists; found {}") - self.config_as_str = self.get_config_path(version) + self.config_as_str = (Path(__file__).parent / f"configs/config_{version}.yml").read_text() + + # Automatically download model if not provided + if not model_fp: + model_fp = hf_hub_download(self.model_repo, self.model_name[version]) + self.model_fp = model_fp self.use_cuda = use_cuda if 'clapcap' in version: @@ -39,12 +57,6 @@ class CLAPWrapper(): else: self.clap, self.tokenizer, self.args = self.load_clap() - def get_config_path(self, version): - if version in self.supported_versions: - return (Path(__file__).parent / f"configs/config_{version}.yml").read_text() - else: - raise ValueError(f"The specific version is not supported. The supported versions are {str(self.supported_versions)}") - def read_config_as_args(self,config_path,args=None,is_config_str=False): return_dict = {} @@ -143,7 +155,7 @@ class CLAPWrapper(): args.num_layers, args.normalize_prefix, args.mapping_type, True, True) model_state_dict = torch.load(self.model_fp, map_location=torch.device('cpu'))['model'] - clapcap.load_state_dict(model_state_dict) + clapcap.load_state_dict(model_state_dict, strict=False) clapcap.eval() # set clap in eval mode tokenizer = AutoTokenizer.from_pretrained(args.text_model) diff --git a/pyproject.toml b/pyproject.toml index 4efa066..5cfa2ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,8 +1,8 @@ [tool.poetry] name = "msclap" -version = "1.3.2" +version = "1.3.3" description = "CLAP (Contrastive Language-Audio Pretraining) is a model that learns acoustic concepts from natural language supervision and enables “Zero-Shot” inference. The model has been extensively evaluated in 26 audio downstream tasks achieving SoTA in several of them including classification, retrieval, and captioning." -authors = ["Benjamin Elizalde and Soham Deshmukh and Huaming Wang"] +authors = ["Benjamin Elizalde", "Soham Deshmukh", "Huaming Wang"] license = "MIT" readme = "README.md" packages = [