Merge pull request #26 from hykilpikonna/main
[+] Automatically download model weights
This commit is contained in:
@@ -18,11 +18,6 @@ pip install msclap
|
||||
pip install git+https://github.com/microsoft/CLAP.git
|
||||
```
|
||||
|
||||
## NEW CLAP weights
|
||||
Download CLAP weights: versions _2022_, _2023_, and _clapcap_: [Zenodo](https://zenodo.org/record/8378278) or [HuggingFace](https://huggingface.co/microsoft/msclap)
|
||||
|
||||
_clapcap_ is the audio captioning model that uses the 2023 encoders.
|
||||
|
||||
## Usage
|
||||
|
||||
- Zero-Shot Classification and Retrieval
|
||||
@@ -30,7 +25,8 @@ _clapcap_ is the audio captioning model that uses the 2023 encoders.
|
||||
from msclap import CLAP
|
||||
|
||||
# Load model (Choose between versions '2022' or '2023')
|
||||
clap_model = CLAP("<PATH TO WEIGHTS>", version = '2023', use_cuda=False)
|
||||
# The model weight will be downloaded automatically if `model_fp` is not specified
|
||||
clap_model = CLAP(version = '2023', use_cuda=False)
|
||||
|
||||
# Extract text embeddings
|
||||
text_embeddings = clap_model.get_text_embeddings(class_labels: List[str])
|
||||
@@ -47,7 +43,7 @@ similarities = clap_model.compute_similarity(audio_embeddings, text_embeddings)
|
||||
from msclap import CLAP
|
||||
|
||||
# Load model (Choose version 'clapcap')
|
||||
clap_model = CLAP("<PATH TO WEIGHTS>", version = 'clapcap', use_cuda=False)
|
||||
clap_model = CLAP(version = 'clapcap', use_cuda=False)
|
||||
|
||||
# Generate audio captions
|
||||
captions = clap_model.generate_caption(file_paths: List[str])
|
||||
|
||||
+22
-10
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
@@ -16,6 +18,7 @@ import torch
|
||||
import argparse
|
||||
import yaml
|
||||
import sys
|
||||
from huggingface_hub.file_download import hf_hub_download
|
||||
logging.set_verbosity_error()
|
||||
|
||||
|
||||
@@ -23,15 +26,30 @@ class CLAPWrapper():
|
||||
"""
|
||||
A class for interfacing CLAP model.
|
||||
"""
|
||||
model_repo = "microsoft/msclap"
|
||||
model_name = {
|
||||
'2022': 'CLAP_weights_2022.pth',
|
||||
'2023': 'CLAP_weights_2023.pth',
|
||||
'clapcap': 'clapcap_weights_2023.pth'
|
||||
}
|
||||
|
||||
def __init__(self, model_fp: Path | str | None = None, version: str = '2023', use_cuda=False):
|
||||
# Check if version is supported
|
||||
self.supported_versions = self.model_name.keys()
|
||||
if version not in self.supported_versions:
|
||||
raise ValueError(f"The version {version} is not supported. The supported versions are {str(self.supported_versions)}")
|
||||
|
||||
def __init__(self, model_fp, version, use_cuda=False):
|
||||
self.supported_versions = ['2022', '2023', 'clapcap']
|
||||
self.np_str_obj_array_pattern = re.compile(r'[SaUO]')
|
||||
self.file_path = os.path.realpath(__file__)
|
||||
self.default_collate_err_msg_format = (
|
||||
"default_collate: batch must contain tensors, numpy arrays, numbers, "
|
||||
"dicts or lists; found {}")
|
||||
self.config_as_str = self.get_config_path(version)
|
||||
self.config_as_str = (Path(__file__).parent / f"configs/config_{version}.yml").read_text()
|
||||
|
||||
# Automatically download model if not provided
|
||||
if not model_fp:
|
||||
model_fp = hf_hub_download(self.model_repo, self.model_name[version])
|
||||
|
||||
self.model_fp = model_fp
|
||||
self.use_cuda = use_cuda
|
||||
if 'clapcap' in version:
|
||||
@@ -39,12 +57,6 @@ class CLAPWrapper():
|
||||
else:
|
||||
self.clap, self.tokenizer, self.args = self.load_clap()
|
||||
|
||||
def get_config_path(self, version):
|
||||
if version in self.supported_versions:
|
||||
return (Path(__file__).parent / f"configs/config_{version}.yml").read_text()
|
||||
else:
|
||||
raise ValueError(f"The specific version is not supported. The supported versions are {str(self.supported_versions)}")
|
||||
|
||||
def read_config_as_args(self,config_path,args=None,is_config_str=False):
|
||||
return_dict = {}
|
||||
|
||||
@@ -143,7 +155,7 @@ class CLAPWrapper():
|
||||
args.num_layers, args.normalize_prefix, args.mapping_type, True, True)
|
||||
|
||||
model_state_dict = torch.load(self.model_fp, map_location=torch.device('cpu'))['model']
|
||||
clapcap.load_state_dict(model_state_dict)
|
||||
clapcap.load_state_dict(model_state_dict, strict=False)
|
||||
|
||||
clapcap.eval() # set clap in eval mode
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.text_model)
|
||||
|
||||
+2
-2
@@ -1,8 +1,8 @@
|
||||
[tool.poetry]
|
||||
name = "msclap"
|
||||
version = "1.3.2"
|
||||
version = "1.3.3"
|
||||
description = "CLAP (Contrastive Language-Audio Pretraining) is a model that learns acoustic concepts from natural language supervision and enables “Zero-Shot” inference. The model has been extensively evaluated in 26 audio downstream tasks achieving SoTA in several of them including classification, retrieval, and captioning."
|
||||
authors = ["Benjamin Elizalde and Soham Deshmukh and Huaming Wang"]
|
||||
authors = ["Benjamin Elizalde", "Soham Deshmukh", "Huaming Wang"]
|
||||
license = "MIT"
|
||||
readme = "README.md"
|
||||
packages = [
|
||||
|
||||
Reference in New Issue
Block a user