Compare commits
38 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| bd44e422f5 | |||
| e8a6467b87 | |||
| 2614a64a7f | |||
| 8dbe06ee63 | |||
| 59bc8446e3 | |||
| d8b816ef43 | |||
| 1af3d710ee | |||
| d32b134d27 | |||
| 8e87029c0e | |||
| 936db19552 | |||
| 4a30b77b48 | |||
| 8e6c8e632c | |||
| f3d8c311d3 | |||
| 286156725f | |||
| ac2f4ba44a | |||
| 29bf721dbb | |||
| f5fb6c6b6a | |||
| e2187633b8 | |||
| 8702caf400 | |||
| 1ec382a713 | |||
| 0e37184ad1 | |||
| 3f6ef7382f | |||
| 549ef40479 | |||
| 765f9f8864 | |||
| 9f902c9029 | |||
| bb3bfb4e5b | |||
| b41935ff3c | |||
| 3788d4e225 | |||
| f177c8203e | |||
| 23519ea1e6 | |||
| eeaa2a3a34 | |||
| ea5629de26 | |||
| d47289931d | |||
| 21c7b38cd4 | |||
| 81cc62ef32 | |||
| 583928827d | |||
| 83a344004b | |||
| 03703c0e91 |
@@ -348,3 +348,9 @@ MigrationBackup/
|
||||
|
||||
# Ionide (cross platform F# VS Code tools) working folder
|
||||
.ionide/
|
||||
dist/
|
||||
.DS_Store
|
||||
._*
|
||||
|
||||
venv/
|
||||
root_path/
|
||||
|
||||
@@ -4,24 +4,22 @@
|
||||
|
||||
CLAP (Contrastive Language-Audio Pretraining) is a model that learns acoustic concepts from natural language supervision and enables “Zero-Shot” inference. The model has been extensively evaluated in 26 audio downstream tasks achieving SoTA in several of them including classification, retrieval, and captioning.
|
||||
|
||||
<img width="832" alt="clap_diagrams" src="https://github.com/bmartin1/CLAP/assets/26778834/c5340a09-cc0c-4e41-ad5a-61546eaa824c">
|
||||
<img width="832" alt="clap_diagrams" src="docs/clap2_diagram.png">
|
||||
|
||||
## Setup
|
||||
|
||||
Install the dependencies: `pip install -r requirements.txt` using Python 3 to get started.
|
||||
|
||||
If you have [conda](https://www.anaconda.com) installed, you can run the following:
|
||||
First, install python 3.8 or higher (3.11 recommended). Then, install CLAP using either of the following:
|
||||
|
||||
```shell
|
||||
git clone https://github.com/microsoft/CLAP.git && \
|
||||
cd CLAP && \
|
||||
conda create -n clap python=3.10 && \
|
||||
conda activate clap && \
|
||||
pip install -r requirements.txt
|
||||
# Install pypi pacakge
|
||||
pip install msclap
|
||||
|
||||
# Or Install latest (unstable) git source
|
||||
pip install git+https://github.com/microsoft/CLAP.git
|
||||
```
|
||||
|
||||
## NEW CLAP weights
|
||||
Download CLAP weights: versions _2022_, _2023_, and _clapcap_: [Pretrained Model \[Zenodo\]](https://zenodo.org/record/8378278)
|
||||
## CLAP weights
|
||||
CLAP weights are downloaded automatically (choose between versions _2022_, _2023_, and _clapcap_), but are also available at: [Zenodo](https://zenodo.org/record/8378278) or [HuggingFace](https://huggingface.co/microsoft/msclap)
|
||||
|
||||
_clapcap_ is the audio captioning model that uses the 2023 encoders.
|
||||
|
||||
@@ -29,10 +27,11 @@ _clapcap_ is the audio captioning model that uses the 2023 encoders.
|
||||
|
||||
- Zero-Shot Classification and Retrieval
|
||||
```python
|
||||
# Load model (Choose between versions '2022' or '2023')
|
||||
from src import CLAP
|
||||
from msclap import CLAP
|
||||
|
||||
clap_model = CLAP("<PATH TO WEIGHTS>", version = '2023', use_cuda=False)
|
||||
# Load model (Choose between versions '2022' or '2023')
|
||||
# The model weight will be downloaded automatically if `model_fp` is not specified
|
||||
clap_model = CLAP(version = '2023', use_cuda=False)
|
||||
|
||||
# Extract text embeddings
|
||||
text_embeddings = clap_model.get_text_embeddings(class_labels: List[str])
|
||||
@@ -46,22 +45,22 @@ similarities = clap_model.compute_similarity(audio_embeddings, text_embeddings)
|
||||
|
||||
- Audio Captioning
|
||||
```python
|
||||
# Load model (Choose version 'clapcap')
|
||||
from src import CLAP
|
||||
from msclap import CLAP
|
||||
|
||||
clap_model = CLAP("<PATH TO WEIGHTS>", version = 'clapcap', use_cuda=False)
|
||||
# Load model (Choose version 'clapcap')
|
||||
clap_model = CLAP(version = 'clapcap', use_cuda=False)
|
||||
|
||||
# Generate audio captions
|
||||
captions = clap_model.generate_caption(file_paths: List[str])
|
||||
```
|
||||
|
||||
## Examples
|
||||
Take a look at `CLAP\src\` for usage examples.
|
||||
Take a look at [examples](./examples/) for usage examples.
|
||||
|
||||
To run Zero-Shot Classification on the ESC50 dataset try the following:
|
||||
|
||||
```bash
|
||||
> cd src && python zero_shot_classification.py
|
||||
> cd examples && python zero_shot_classification.py
|
||||
```
|
||||
Output (version 2023)
|
||||
```bash
|
||||
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 277 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 81 KiB |
@@ -1,11 +1,10 @@
|
||||
"""
|
||||
This is an example using CLAPCAP for audio captioning.
|
||||
"""
|
||||
from CLAPWrapper import CLAPWrapper
|
||||
from msclap import CLAP
|
||||
|
||||
# Load and initialize CLAP
|
||||
weights_path = "weights_path"
|
||||
clap_model = CLAPWrapper(weights_path, version = 'clapcap', use_cuda=False)
|
||||
clap_model = CLAP(version = 'clapcap', use_cuda=False)
|
||||
|
||||
#Load audio files
|
||||
audio_files = ['audio_file']
|
||||
@@ -1,6 +1,6 @@
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision.datasets.utils import download_url
|
||||
from tqdm import tqdm
|
||||
from pathlib import Path
|
||||
import pandas as pd
|
||||
import os
|
||||
import torch.nn as nn
|
||||
@@ -74,9 +74,29 @@ class ESC50(AudioDataset):
|
||||
return len(self.audio_paths)
|
||||
|
||||
def download(self):
|
||||
download_url(self.url, self.root, self.filename)
|
||||
# Download file using requests
|
||||
import requests
|
||||
file = Path(self.root) / self.filename
|
||||
if file.is_file():
|
||||
return
|
||||
|
||||
r = requests.get(self.url, stream=True)
|
||||
|
||||
# extract file
|
||||
# To prevent partial downloads, download to a temp file first
|
||||
tmp = file.with_suffix('.tmp')
|
||||
tmp.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(tmp, 'wb') as f:
|
||||
pbar = tqdm(unit=" MB", bar_format=f'{file.name}: {{rate_noinv_fmt}}')
|
||||
|
||||
for chunk in r.iter_content(chunk_size=1024):
|
||||
if chunk:
|
||||
pbar.update(len(chunk) / 1024 / 1024)
|
||||
f.write(chunk)
|
||||
|
||||
# move temp file to correct location
|
||||
tmp.rename(file)
|
||||
|
||||
# # extract file
|
||||
from zipfile import ZipFile
|
||||
with ZipFile(os.path.join(self.root, self.filename), 'r') as zip:
|
||||
zip.extractall(path=self.root)
|
||||
@@ -3,7 +3,7 @@ This is an example using CLAP to perform zeroshot
|
||||
classification on ESC50 (https://github.com/karolpiczak/ESC-50).
|
||||
"""
|
||||
|
||||
from CLAPWrapper import CLAPWrapper
|
||||
from msclap import CLAP
|
||||
from esc50_dataset import ESC50
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
@@ -17,8 +17,7 @@ prompt = 'this is the sound of '
|
||||
y = [prompt + x for x in dataset.classes]
|
||||
|
||||
# Load and initialize CLAP
|
||||
weights_path = "weights_path"
|
||||
clap_model = CLAPWrapper(weights_path, version = '2023', use_cuda=False)
|
||||
clap_model = CLAP(version = '2023', use_cuda=False)
|
||||
|
||||
# Computing text embeddings
|
||||
text_embeddings = clap_model.get_text_embeddings(y)
|
||||
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
This is an example using CLAP for zero-shot inference.
|
||||
"""
|
||||
from CLAPWrapper import CLAPWrapper
|
||||
from msclap import CLAP
|
||||
import torch.nn.functional as F
|
||||
|
||||
# Define classes for zero-shot
|
||||
@@ -15,9 +15,8 @@ class_prompts = [prompt + x for x in classes]
|
||||
audio_files = ['audio_file']
|
||||
|
||||
# Load and initialize CLAP
|
||||
weights_path = "weights_path"
|
||||
# Setting use_cuda = True will load the model on a GPU using CUDA
|
||||
clap_model = CLAPWrapper(weights_path, version = '2023', use_cuda=False)
|
||||
clap_model = CLAP(version = '2023', use_cuda=False)
|
||||
|
||||
# compute text embeddings from natural text
|
||||
text_embeddings = clap_model.get_text_embeddings(class_prompts)
|
||||
@@ -1,22 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
import random
|
||||
import torchaudio
|
||||
from torch._six import string_classes
|
||||
import collections
|
||||
import re
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer, logging
|
||||
from models.clap import CLAP
|
||||
from models.mapper import get_clapcap
|
||||
from .models.clap import CLAP
|
||||
from .models.mapper import get_clapcap
|
||||
import math
|
||||
import torchaudio.transforms as T
|
||||
import os
|
||||
import torch
|
||||
from importlib_resources import files
|
||||
import argparse
|
||||
import yaml
|
||||
import sys
|
||||
from huggingface_hub.file_download import hf_hub_download
|
||||
logging.set_verbosity_error()
|
||||
|
||||
|
||||
@@ -24,15 +26,30 @@ class CLAPWrapper():
|
||||
"""
|
||||
A class for interfacing CLAP model.
|
||||
"""
|
||||
model_repo = "microsoft/msclap"
|
||||
model_name = {
|
||||
'2022': 'CLAP_weights_2022.pth',
|
||||
'2023': 'CLAP_weights_2023.pth',
|
||||
'clapcap': 'clapcap_weights_2023.pth'
|
||||
}
|
||||
|
||||
def __init__(self, model_fp: Path | str | None = None, version: str = '2023', use_cuda=False):
|
||||
# Check if version is supported
|
||||
self.supported_versions = self.model_name.keys()
|
||||
if version not in self.supported_versions:
|
||||
raise ValueError(f"The version {version} is not supported. The supported versions are {str(self.supported_versions)}")
|
||||
|
||||
def __init__(self, model_fp, version, use_cuda=False):
|
||||
self.supported_versions = ['2022', '2023', 'clapcap']
|
||||
self.np_str_obj_array_pattern = re.compile(r'[SaUO]')
|
||||
self.file_path = os.path.realpath(__file__)
|
||||
self.default_collate_err_msg_format = (
|
||||
"default_collate: batch must contain tensors, numpy arrays, numbers, "
|
||||
"dicts or lists; found {}")
|
||||
self.config_as_str = self.get_config_path(version)
|
||||
self.config_as_str = (Path(__file__).parent / f"configs/config_{version}.yml").read_text()
|
||||
|
||||
# Automatically download model if not provided
|
||||
if not model_fp:
|
||||
model_fp = hf_hub_download(self.model_repo, self.model_name[version])
|
||||
|
||||
self.model_fp = model_fp
|
||||
self.use_cuda = use_cuda
|
||||
if 'clapcap' in version:
|
||||
@@ -40,12 +57,6 @@ class CLAPWrapper():
|
||||
else:
|
||||
self.clap, self.tokenizer, self.args = self.load_clap()
|
||||
|
||||
def get_config_path(self, version):
|
||||
if version in self.supported_versions:
|
||||
return files('configs').joinpath(f"config_{version}.yml").read_text()
|
||||
else:
|
||||
raise ValueError(f"The specific version is not supported. The supported versions are {str(self.supported_versions)}")
|
||||
|
||||
def read_config_as_args(self,config_path,args=None,is_config_str=False):
|
||||
return_dict = {}
|
||||
|
||||
@@ -99,7 +110,7 @@ class CLAPWrapper():
|
||||
|
||||
# We unwrap the DDP model and save. If the model is not unwrapped and saved, then the model needs to unwrapped before `load_state_dict`:
|
||||
# Reference link: https://discuss.pytorch.org/t/how-to-load-dataparallel-model-which-trained-using-multiple-gpus/146005
|
||||
clap.load_state_dict(model_state_dict)
|
||||
clap.load_state_dict(model_state_dict, strict=False)
|
||||
|
||||
clap.eval() # set clap in eval mode
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.text_model)
|
||||
@@ -144,7 +155,7 @@ class CLAPWrapper():
|
||||
args.num_layers, args.normalize_prefix, args.mapping_type, True, True)
|
||||
|
||||
model_state_dict = torch.load(self.model_fp, map_location=torch.device('cpu'))['model']
|
||||
clapcap.load_state_dict(model_state_dict)
|
||||
clapcap.load_state_dict(model_state_dict, strict=False)
|
||||
|
||||
clapcap.eval() # set clap in eval mode
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.text_model)
|
||||
@@ -184,7 +195,7 @@ class CLAPWrapper():
|
||||
return torch.tensor(batch, dtype=torch.float64)
|
||||
elif isinstance(elem, int):
|
||||
return torch.tensor(batch)
|
||||
elif isinstance(elem, string_classes):
|
||||
elif isinstance(elem, str):
|
||||
return batch
|
||||
elif isinstance(elem, collections.abc.Mapping):
|
||||
return {key: self.default_collate([d[key] for d in batch]) for key in elem}
|
||||
@@ -301,7 +312,7 @@ class CLAPWrapper():
|
||||
# batch size is bigger than available audio/text items
|
||||
if next_batch_idx >= args0_len:
|
||||
inputs[0] = input_tmp[dataset_idx:]
|
||||
return func(*tuple(inputs))
|
||||
yield func(*tuple(inputs))
|
||||
else:
|
||||
inputs[0] = input_tmp[dataset_idx:next_batch_idx]
|
||||
yield func(*tuple(inputs))
|
||||
@@ -0,0 +1 @@
|
||||
from .CLAPWrapper import CLAPWrapper as CLAP
|
||||
@@ -2,7 +2,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torchlibrosa.stft import Spectrogram, LogmelFilterBank
|
||||
from models.htsat import HTSATWrapper
|
||||
from .htsat import HTSATWrapper
|
||||
|
||||
def get_audio_encoder(name: str):
|
||||
if name == "Cnn14":
|
||||
@@ -6,11 +6,8 @@
|
||||
# Swin Transformer for Computer Vision: https://arxiv.org/pdf/2103.14030.pdf
|
||||
|
||||
|
||||
import logging
|
||||
import pdb
|
||||
import math
|
||||
import random
|
||||
from numpy.core.fromnumeric import clip, reshape
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
@@ -19,15 +16,10 @@ from torchlibrosa.stft import Spectrogram, LogmelFilterBank
|
||||
from torchlibrosa.augmentation import SpecAugmentation
|
||||
|
||||
from itertools import repeat
|
||||
from typing import List
|
||||
try:
|
||||
from models.pytorch_utils import do_mixup, interpolate
|
||||
import models.config as config
|
||||
except:
|
||||
from CLAP_API.models.pytorch_utils import do_mixup, interpolate
|
||||
from CLAP_API.models import config
|
||||
|
||||
import torch.nn.functional as F
|
||||
from .pytorch_utils import do_mixup, interpolate
|
||||
from . import config
|
||||
|
||||
import collections.abc
|
||||
import warnings
|
||||
|
||||
@@ -2,10 +2,9 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as nnf
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from enum import Enum
|
||||
from transformers import GPT2LMHeadModel
|
||||
from typing import Tuple, Optional, Union
|
||||
from typing import Tuple, Optional
|
||||
|
||||
def get_clapcap(name: str):
|
||||
if name == "ClapCaption":
|
||||
@@ -1,5 +1,3 @@
|
||||
import numpy as np
|
||||
import time
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
Generated
+1899
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,28 @@
|
||||
[tool.poetry]
|
||||
name = "msclap"
|
||||
version = "1.3.4"
|
||||
description = "CLAP (Contrastive Language-Audio Pretraining) is a model that learns acoustic concepts from natural language supervision and enables “Zero-Shot” inference. The model has been extensively evaluated in 26 audio downstream tasks achieving SoTA in several of them including classification, retrieval, and captioning."
|
||||
authors = ["Benjamin Elizalde", "Soham Deshmukh", "Huaming Wang"]
|
||||
license = "MIT"
|
||||
readme = "README.md"
|
||||
packages = [
|
||||
{ include = "msclap" },
|
||||
]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.8"
|
||||
librosa = "^0.10.1"
|
||||
numpy = "^1.23.0"
|
||||
pandas = "^2.0.0"
|
||||
torch = "^2.1.0"
|
||||
torchaudio = "^2.1.0"
|
||||
torchlibrosa = "^0.1.0"
|
||||
tqdm = "^4.66.1"
|
||||
transformers = "^4.34.0"
|
||||
pyyaml = "^6.0.1"
|
||||
scikit-learn = "^1.3.1"
|
||||
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
@@ -1,50 +0,0 @@
|
||||
appdirs==1.4.4
|
||||
audioread==3.0.0
|
||||
certifi==2022.12.7
|
||||
cffi==1.15.1
|
||||
charset-normalizer==3.0.1
|
||||
colorama==0.4.6
|
||||
decorator==5.1.1
|
||||
filelock==3.9.0
|
||||
flit_core==3.6.0
|
||||
huggingface-hub==0.12.1
|
||||
idna==3.4
|
||||
importlib-metadata==6.0.0
|
||||
importlib-resources==5.12.0
|
||||
jaraco.classes==3.2.3
|
||||
joblib==1.2.0
|
||||
lazy_loader==0.1
|
||||
librosa==0.10.0
|
||||
llvmlite==0.39.1
|
||||
mkl-service==2.4.0
|
||||
more-itertools==9.0.0
|
||||
msgpack==1.0.4
|
||||
numba==0.56.4
|
||||
numpy==1.23.5
|
||||
packaging==23.0
|
||||
pandas==1.4.2
|
||||
pooch==1.6.0
|
||||
pycparser==2.21
|
||||
pywin32-ctypes==0.2.0
|
||||
PyYAML==6.0
|
||||
regex==2022.10.31
|
||||
requests==2.28.2
|
||||
scikit-learn==1.2.1
|
||||
scipy==1.10.1
|
||||
setuptools==65.6.3
|
||||
six==1.16.0
|
||||
soundfile==0.12.1
|
||||
soxr==0.3.3
|
||||
threadpoolctl==3.1.0
|
||||
tokenizers==0.13.2
|
||||
torch==1.13.1
|
||||
torchaudio==0.13.1
|
||||
torchlibrosa==0.1.0
|
||||
torchvision==0.14.1
|
||||
tqdm==4.64.1
|
||||
transformers==4.26.1
|
||||
typing_extensions==4.4.0
|
||||
urllib3==1.26.14
|
||||
wheel==0.38.4
|
||||
wincertstore==0.2
|
||||
zipp==3.14.0
|
||||
Reference in New Issue
Block a user