Merge pull request #26 from hykilpikonna/main

[+] Automatically download model weights
This commit is contained in:
Benjamin Elizalde
2023-10-17 15:31:50 -07:00
committed by GitHub
3 changed files with 27 additions and 19 deletions
+3 -7
View File
@@ -18,11 +18,6 @@ pip install msclap
pip install git+https://github.com/microsoft/CLAP.git
```
## NEW CLAP weights
Download CLAP weights: versions _2022_, _2023_, and _clapcap_: [Zenodo](https://zenodo.org/record/8378278) or [HuggingFace](https://huggingface.co/microsoft/msclap)
_clapcap_ is the audio captioning model that uses the 2023 encoders.
## Usage
- Zero-Shot Classification and Retrieval
@@ -30,7 +25,8 @@ _clapcap_ is the audio captioning model that uses the 2023 encoders.
from msclap import CLAP
# Load model (Choose between versions '2022' or '2023')
clap_model = CLAP("<PATH TO WEIGHTS>", version = '2023', use_cuda=False)
# The model weight will be downloaded automatically if `model_fp` is not specified
clap_model = CLAP(version = '2023', use_cuda=False)
# Extract text embeddings
text_embeddings = clap_model.get_text_embeddings(class_labels: List[str])
@@ -47,7 +43,7 @@ similarities = clap_model.compute_similarity(audio_embeddings, text_embeddings)
from msclap import CLAP
# Load model (Choose version 'clapcap')
clap_model = CLAP("<PATH TO WEIGHTS>", version = 'clapcap', use_cuda=False)
clap_model = CLAP(version = 'clapcap', use_cuda=False)
# Generate audio captions
captions = clap_model.generate_caption(file_paths: List[str])
+22 -10
View File
@@ -1,3 +1,5 @@
from __future__ import annotations
from pathlib import Path
import warnings
warnings.filterwarnings("ignore")
@@ -16,6 +18,7 @@ import torch
import argparse
import yaml
import sys
from huggingface_hub.file_download import hf_hub_download
logging.set_verbosity_error()
@@ -23,15 +26,30 @@ class CLAPWrapper():
"""
A class for interfacing CLAP model.
"""
model_repo = "microsoft/msclap"
model_name = {
'2022': 'CLAP_weights_2022.pth',
'2023': 'CLAP_weights_2023.pth',
'clapcap': 'clapcap_weights_2023.pth'
}
def __init__(self, model_fp: Path | str | None = None, version: str = '2023', use_cuda=False):
# Check if version is supported
self.supported_versions = self.model_name.keys()
if version not in self.supported_versions:
raise ValueError(f"The version {version} is not supported. The supported versions are {str(self.supported_versions)}")
def __init__(self, model_fp, version, use_cuda=False):
self.supported_versions = ['2022', '2023', 'clapcap']
self.np_str_obj_array_pattern = re.compile(r'[SaUO]')
self.file_path = os.path.realpath(__file__)
self.default_collate_err_msg_format = (
"default_collate: batch must contain tensors, numpy arrays, numbers, "
"dicts or lists; found {}")
self.config_as_str = self.get_config_path(version)
self.config_as_str = (Path(__file__).parent / f"configs/config_{version}.yml").read_text()
# Automatically download model if not provided
if not model_fp:
model_fp = hf_hub_download(self.model_repo, self.model_name[version])
self.model_fp = model_fp
self.use_cuda = use_cuda
if 'clapcap' in version:
@@ -39,12 +57,6 @@ class CLAPWrapper():
else:
self.clap, self.tokenizer, self.args = self.load_clap()
def get_config_path(self, version):
if version in self.supported_versions:
return (Path(__file__).parent / f"configs/config_{version}.yml").read_text()
else:
raise ValueError(f"The specific version is not supported. The supported versions are {str(self.supported_versions)}")
def read_config_as_args(self,config_path,args=None,is_config_str=False):
return_dict = {}
@@ -143,7 +155,7 @@ class CLAPWrapper():
args.num_layers, args.normalize_prefix, args.mapping_type, True, True)
model_state_dict = torch.load(self.model_fp, map_location=torch.device('cpu'))['model']
clapcap.load_state_dict(model_state_dict)
clapcap.load_state_dict(model_state_dict, strict=False)
clapcap.eval() # set clap in eval mode
tokenizer = AutoTokenizer.from_pretrained(args.text_model)
+2 -2
View File
@@ -1,8 +1,8 @@
[tool.poetry]
name = "msclap"
version = "1.3.2"
version = "1.3.3"
description = "CLAP (Contrastive Language-Audio Pretraining) is a model that learns acoustic concepts from natural language supervision and enables “Zero-Shot” inference. The model has been extensively evaluated in 26 audio downstream tasks achieving SoTA in several of them including classification, retrieval, and captioning."
authors = ["Benjamin Elizalde and Soham Deshmukh and Huaming Wang"]
authors = ["Benjamin Elizalde", "Soham Deshmukh", "Huaming Wang"]
license = "MIT"
readme = "README.md"
packages = [