44 Commits

Author SHA1 Message Date
azalea bd44e422f5 [U] Bump version and deploy to pypi 2025-05-23 03:57:32 -04:00
Benjamin Elizalde e8a6467b87 Merge pull request #41 from hykilpikonna/main
Remove torchvision dependency and solve pytorch vulnerability
2024-09-18 14:40:17 -07:00
azalea 2614a64a7f [-] Remove requests as it is not part of the library 2024-09-05 19:45:03 -04:00
azalea 8dbe06ee63 [-] Remove unused torchvision and numba 2024-09-05 19:34:12 -04:00
Benjamin Elizalde 59bc8446e3 Merge pull request #30 from amihalik/patch-1
bugfix: _batch_inference needs to always `yield`
2024-04-04 10:58:47 -07:00
Aaron Mihalik d8b816ef43 bugfix: _batch_generator needs to always yield 2024-03-10 16:44:30 -04:00
Benjamin Elizalde 1af3d710ee Merge pull request #28 from microsoft/bmartin1_update_examples
Removed weights_path parameter
2023-10-20 11:33:18 -07:00
Benjamin Elizalde d32b134d27 Remove weights_path parameter 2023-10-20 11:25:32 -07:00
Benjamin Elizalde 8e87029c0e Update README.md 2023-10-17 15:39:45 -07:00
Benjamin Elizalde 936db19552 Merge pull request #26 from hykilpikonna/main
[+] Automatically download model weights
2023-10-17 15:31:50 -07:00
azalea 4a30b77b48 [O] Enable torch backward compatibility 2023-10-14 17:31:25 -04:00
azalea 8e6c8e632c Merge branch 'main' of https://github.com/hykilpikonna/CLAP 2023-10-14 17:28:10 -04:00
azalea f3d8c311d3 [U] Update readme instructions 2023-10-14 17:27:02 -04:00
azalea 286156725f [U] Bump version 2023-10-14 17:25:40 -04:00
azalea ac2f4ba44a [+] Automatically download models 2023-10-14 17:25:34 -04:00
Benjamin Elizalde 29bf721dbb Update README.md
Added HuggingFace link for weights.
2023-10-14 13:05:40 -07:00
Benjamin Elizalde f5fb6c6b6a Update to clap2 image 2023-10-12 13:20:24 -07:00
Benjamin Elizalde e2187633b8 Add files via upload 2023-10-12 13:19:39 -07:00
Benjamin Elizalde 8702caf400 Edited src to examples 2023-10-12 13:16:41 -07:00
Benjamin Elizalde 1ec382a713 Removed deprecated line. 2023-10-12 13:15:44 -07:00
Benjamin Elizalde 0e37184ad1 Merge pull request #25 from hykilpikonna/main
[O] Fix python 3.11 compatibility & optimize ease-of-use
2023-10-12 13:15:08 -07:00
azalea 3f6ef7382f [O] Loosen dependency requirements 2023-10-12 00:14:59 -04:00
azalea 549ef40479 [-] Remove unrelated file in pypi package 2023-10-11 20:54:54 -04:00
azalea 765f9f8864 [U] Publish pypi 2023-10-11 20:06:24 -04:00
azalea 9f902c9029 [O] Recover image 2023-10-11 19:59:32 -04:00
azalea bb3bfb4e5b [F] Fix examples 2023-10-11 19:55:08 -04:00
azalea b41935ff3c [M] Move examples to /examples 2023-10-11 19:26:05 -04:00
azalea 3788d4e225 [M] Symlink isn't handled, rename directory instead 2023-10-11 19:24:31 -04:00
azalea f177c8203e [O] Ignore build files and macOS files 2023-10-11 19:23:08 -04:00
azalea 23519ea1e6 [U] Update usage instructions 2023-10-11 18:56:30 -04:00
azalea eeaa2a3a34 [O] Use relative imports, remove unused imports 2023-10-11 18:56:18 -04:00
azalea ea5629de26 [O] Allow backward-compatibility for torch state dict 2023-10-11 18:54:51 -04:00
azalea d47289931d [-] importlib-resources is not needed 2023-10-11 18:54:27 -04:00
azalea 21c7b38cd4 [+] Importing magic 2023-10-11 18:41:25 -04:00
azalea 81cc62ef32 [M] Migrate to poetry for modern dependency management 2023-10-11 18:39:53 -04:00
azalea 583928827d [-] Remove deprecated torch._six 2023-10-11 18:26:10 -04:00
azalea 83a344004b [-] Remove unnecessary dependencies 2023-10-11 18:25:54 -04:00
Benjamin Elizalde 03703c0e91 Update README.md 2023-10-03 14:58:38 -07:00
Soham cf0e497e03 Resampler update 2023-10-01 10:29:41 -07:00
Benjamin Elizalde 8054343040 Update README.md 2023-09-26 15:04:48 -07:00
Benjamin Elizalde 94b9565724 Update README.md 2023-09-26 15:03:44 -07:00
Benjamin Elizalde 79de824822 Update README.md 2023-09-26 14:40:08 -07:00
Benjamin Elizalde 1ea0905552 Update to 2023 2023-09-26 14:37:20 -07:00
Soham 9d4006b773 Update README.md 2023-09-18 15:10:26 -07:00
28 changed files with 4062 additions and 401 deletions
+6
View File
@@ -348,3 +348,9 @@ MigrationBackup/
# Ionide (cross platform F# VS Code tools) working folder
.ionide/
dist/
.DS_Store
._*
venv/
root_path/
+61 -34
View File
@@ -1,71 +1,98 @@
###### [Overview](#CLAP) | [Setup](#Setup) | [CLAP weights](#CLAP-weights) | [Usage](#Usage) | [Examples](#Examples) | [Citation](#Citation)
# CLAP
CLAP (Contrastive Language-Audio Pretraining) is a neural network model that learns acoustic concepts from natural language supervision. It achieves SoTA in “Zero-Shot” classification, Audio-Text & Text-Audio Retrieval, and in some datasets when finetuned.
CLAP (Contrastive Language-Audio Pretraining) is a model that learns acoustic concepts from natural language supervision and enables “Zero-Shot” inference. The model has been extensively evaluated in 26 audio downstream tasks achieving SoTA in several of them including classification, retrieval, and captioning.
<img width="832" alt="clap_diagram_v3" src="https://user-images.githubusercontent.com/26778834/199842089-39ef6a2e-8abb-4338-bdfe-680abab70f53.png">
<img width="832" alt="clap_diagrams" src="docs/clap2_diagram.png">
## Setup
You are required to just install the dependencies: `pip install -r requirements.txt` using Python 3 to get started.
If you have [conda](https://www.anaconda.com) installed, you can run the following:
First, install python 3.8 or higher (3.11 recommended). Then, install CLAP using either of the following:
```shell
git clone https://github.com/microsoft/CLAP.git && \
cd CLAP && \
conda create -n clap python=3.8 && \
conda activate clap && \
pip install -r requirements.txt
# Install pypi pacakge
pip install msclap
# Or Install latest (unstable) git source
pip install git+https://github.com/microsoft/CLAP.git
```
## CLAP weights
Request CLAP weights: [Pretrained Model \[Zenodo\]](https://zenodo.org/record/7312125#.Y22vecvMIQ9)
CLAP weights are downloaded automatically (choose between versions _2022_, _2023_, and _clapcap_), but are also available at: [Zenodo](https://zenodo.org/record/8378278) or [HuggingFace](https://huggingface.co/microsoft/msclap)
_clapcap_ is the audio captioning model that uses the 2023 encoders.
## Usage
Please take a look at `src/examples` for usage examples.
- Load model
- Zero-Shot Classification and Retrieval
```python
from src import CLAP
from msclap import CLAP
clap_model = CLAP("<PATH TO WEIGHTS>", use_cuda=False)
```
# Load model (Choose between versions '2022' or '2023')
# The model weight will be downloaded automatically if `model_fp` is not specified
clap_model = CLAP(version = '2023', use_cuda=False)
- Extract text embeddings
```python
# Extract text embeddings
text_embeddings = clap_model.get_text_embeddings(class_labels: List[str])
```
- Extract audio embeddings
```python
# Extract audio embeddings
audio_embeddings = clap_model.get_audio_embeddings(file_paths: List[str])
# Compute similarity between audio and text embeddings
similarities = clap_model.compute_similarity(audio_embeddings, text_embeddings)
```
- Compute similarity
- Audio Captioning
```python
sim = clap_model.compute_similarity(audio_embeddings, text_embeddings)
from msclap import CLAP
# Load model (Choose version 'clapcap')
clap_model = CLAP(version = 'clapcap', use_cuda=False)
# Generate audio captions
captions = clap_model.generate_caption(file_paths: List[str])
```
## Examples
To run zero-shot evaluation on the ESC50 dataset or a single audio file from ESC50, check `CLAP\src\`. For zero-shot evaluation on the ESC50 dataset:
Take a look at [examples](./examples/) for usage examples.
To run Zero-Shot Classification on the ESC50 dataset try the following:
```bash
> cd src && python zero_shot_classification.py
> cd examples && python zero_shot_classification.py
```
Output
Output (version 2023)
```bash
ESC50 Accuracy: 82.6%
ESC50 Accuracy: 93.9%
```
## Citation
https://arxiv.org/pdf/2206.04769.pdf
Kindly cite our work if you find it useful.
[CLAP: Learning Audio Concepts from Natural Language Supervision](https://ieeexplore.ieee.org/abstract/document/10095889)
```
@article{elizalde2022clap,
title={Clap: Learning audio concepts from natural language supervision},
author={Elizalde, Benjamin and Deshmukh, Soham and Ismail, Mahmoud Al and Wang, Huaming},
journal={arXiv preprint arXiv:2206.04769},
year={2022}
@inproceedings{CLAP2022,
title={Clap learning audio concepts from natural language supervision},
author={Elizalde, Benjamin and Deshmukh, Soham and Al Ismail, Mahmoud and Wang, Huaming},
booktitle={ICASSP 2023-2023 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
pages={1--5},
year={2023},
organization={IEEE}
}
```
[Natural Language Supervision for General-Purpose Audio Representations](https://arxiv.org/abs/2309.05767)
```
@misc{CLAP2023,
title={Natural Language Supervision for General-Purpose Audio Representations},
author={Benjamin Elizalde and Soham Deshmukh and Huaming Wang},
year={2023},
eprint={2309.05767},
archivePrefix={arXiv},
primaryClass={cs.SD},
url={https://arxiv.org/abs/2309.05767}
}
```
Binary file not shown.

After

Width:  |  Height:  |  Size: 277 KiB

BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 81 KiB

+24
View File
@@ -0,0 +1,24 @@
"""
This is an example using CLAPCAP for audio captioning.
"""
from msclap import CLAP
# Load and initialize CLAP
clap_model = CLAP(version = 'clapcap', use_cuda=False)
#Load audio files
audio_files = ['audio_file']
# Generate captions for the recording
captions = clap_model.generate_caption(audio_files, resample=True, beam_size=5, entry_length=67, temperature=0.01)
# Print the result
for i in range(len(audio_files)):
print(f"Audio file: {audio_files[i]} \n")
print(f"Generated caption: {captions[i]} \n")
"""
The output (the exact caption may vary):
The birds are singing in the trees.
"""
@@ -1,6 +1,6 @@
from torch.utils.data import Dataset
from torchvision.datasets.utils import download_url
from tqdm import tqdm
from pathlib import Path
import pandas as pd
import os
import torch.nn as nn
@@ -24,7 +24,7 @@ class AudioDataset(Dataset):
class ESC50(AudioDataset):
base_folder = 'ESC-50-master'
url = "https://github.com/karolpiczak/ESC-50/archive/refs/heads/master.zip"
url = "https://github.com/karoldvl/ESC-50/archive/master.zip"
filename = "ESC-50-master.zip"
num_files_in_dir = 2000
audio_dir = 'audio'
@@ -74,9 +74,29 @@ class ESC50(AudioDataset):
return len(self.audio_paths)
def download(self):
download_url(self.url, self.root, self.filename)
# Download file using requests
import requests
file = Path(self.root) / self.filename
if file.is_file():
return
r = requests.get(self.url, stream=True)
# extract file
# To prevent partial downloads, download to a temp file first
tmp = file.with_suffix('.tmp')
tmp.parent.mkdir(parents=True, exist_ok=True)
with open(tmp, 'wb') as f:
pbar = tqdm(unit=" MB", bar_format=f'{file.name}: {{rate_noinv_fmt}}')
for chunk in r.iter_content(chunk_size=1024):
if chunk:
pbar.update(len(chunk) / 1024 / 1024)
f.write(chunk)
# move temp file to correct location
tmp.rename(file)
# # extract file
from zipfile import ZipFile
with ZipFile(os.path.join(self.root, self.filename), 'r') as zip:
zip.extractall(path=self.root)
zip.extractall(path=self.root)
@@ -1,9 +1,9 @@
"""
This is an example using CLAP to perform zeroshot
This is an example using CLAP to perform zeroshot
classification on ESC50 (https://github.com/karolpiczak/ESC-50).
"""
from CLAPWrapper import CLAPWrapper
from msclap import CLAP
from esc50_dataset import ESC50
import torch.nn.functional as F
import numpy as np
@@ -11,15 +11,13 @@ from tqdm import tqdm
from sklearn.metrics import accuracy_score
# Load dataset
dataset = ESC50(root="data_path", download=False)
prompt = 'this is a sound of '
root_path = "root_path" # Folder with ESC-50-master/
dataset = ESC50(root=root_path, download=True) #If download=False code assumes base_folder='ESC-50-master' in esc50_dataset.py
prompt = 'this is the sound of '
y = [prompt + x for x in dataset.classes]
# Load and initialize CLAP
weights_path = "weights_path"
clap_model = CLAPWrapper(weights_path, use_cuda=False)
clap_model = CLAP(version = '2023', use_cuda=False)
# Computing text embeddings
text_embeddings = clap_model.get_text_embeddings(y)
@@ -34,6 +32,7 @@ for i in tqdm(range(len(dataset))):
y_preds.append(y_pred)
y_labels.append(one_hot_target.detach().cpu().numpy())
y_labels, y_preds = np.concatenate(y_labels, axis=0), np.concatenate(y_preds, axis=0)
acc = accuracy_score(np.argmax(y_labels, axis=1), np.argmax(y_preds, axis=1))
print('ESC50 Accuracy {}'.format(acc))
@@ -41,6 +40,6 @@ print('ESC50 Accuracy {}'.format(acc))
"""
The output:
ESC50 Accuracy: 82.6%
ESC50 Accuracy: 93.9%
"""
+50
View File
@@ -0,0 +1,50 @@
"""
This is an example using CLAP for zero-shot inference.
"""
from msclap import CLAP
import torch.nn.functional as F
# Define classes for zero-shot
# Should be in lower case and can be more than one word
classes = ['coughing','sneezing','drinking sipping', 'breathing', 'brushing teeth']
ground_truth = ['coughing']
# Add prompt
prompt = 'this is a sound of '
class_prompts = [prompt + x for x in classes]
#Load audio files
audio_files = ['audio_file']
# Load and initialize CLAP
# Setting use_cuda = True will load the model on a GPU using CUDA
clap_model = CLAP(version = '2023', use_cuda=False)
# compute text embeddings from natural text
text_embeddings = clap_model.get_text_embeddings(class_prompts)
# compute the audio embeddings from an audio file
audio_embeddings = clap_model.get_audio_embeddings(audio_files, resample=True)
# compute the similarity between audio_embeddings and text_embeddings
similarity = clap_model.compute_similarity(audio_embeddings, text_embeddings)
similarity = F.softmax(similarity, dim=1)
values, indices = similarity[0].topk(5)
# Print the results
print("Ground Truth: {}".format(ground_truth))
print("Top predictions:\n")
for value, index in zip(values, indices):
print(f"{classes[index]:>16s}: {100 * value.item():.2f}%")
"""
The output (the exact numbers may vary):
Ground Truth: coughing
Top predictions:
coughing: 98.55%
sneezing: 1.24%
drinking sipping: 0.15%
breathing: 0.02%
brushing teeth: 0.01%
"""
+419
View File
@@ -0,0 +1,419 @@
from __future__ import annotations
from pathlib import Path
import warnings
warnings.filterwarnings("ignore")
import random
import torchaudio
import collections
import re
import numpy as np
from transformers import AutoTokenizer, logging
from .models.clap import CLAP
from .models.mapper import get_clapcap
import math
import torchaudio.transforms as T
import os
import torch
import argparse
import yaml
import sys
from huggingface_hub.file_download import hf_hub_download
logging.set_verbosity_error()
class CLAPWrapper():
"""
A class for interfacing CLAP model.
"""
model_repo = "microsoft/msclap"
model_name = {
'2022': 'CLAP_weights_2022.pth',
'2023': 'CLAP_weights_2023.pth',
'clapcap': 'clapcap_weights_2023.pth'
}
def __init__(self, model_fp: Path | str | None = None, version: str = '2023', use_cuda=False):
# Check if version is supported
self.supported_versions = self.model_name.keys()
if version not in self.supported_versions:
raise ValueError(f"The version {version} is not supported. The supported versions are {str(self.supported_versions)}")
self.np_str_obj_array_pattern = re.compile(r'[SaUO]')
self.file_path = os.path.realpath(__file__)
self.default_collate_err_msg_format = (
"default_collate: batch must contain tensors, numpy arrays, numbers, "
"dicts or lists; found {}")
self.config_as_str = (Path(__file__).parent / f"configs/config_{version}.yml").read_text()
# Automatically download model if not provided
if not model_fp:
model_fp = hf_hub_download(self.model_repo, self.model_name[version])
self.model_fp = model_fp
self.use_cuda = use_cuda
if 'clapcap' in version:
self.clapcap, self.tokenizer, self.args = self.load_clapcap()
else:
self.clap, self.tokenizer, self.args = self.load_clap()
def read_config_as_args(self,config_path,args=None,is_config_str=False):
return_dict = {}
if config_path is not None:
if is_config_str:
yml_config = yaml.load(config_path, Loader=yaml.FullLoader)
else:
with open(config_path, "r") as f:
yml_config = yaml.load(f, Loader=yaml.FullLoader)
if args != None:
for k, v in yml_config.items():
if k in args.__dict__:
args.__dict__[k] = v
else:
sys.stderr.write("Ignored unknown parameter {} in yaml.\n".format(k))
else:
for k, v in yml_config.items():
return_dict[k] = v
args = args if args != None else return_dict
return argparse.Namespace(**args)
def load_clap(self):
r"""Load CLAP model with args from config file"""
args = self.read_config_as_args(self.config_as_str, is_config_str=True)
if 'roberta' in args.text_model or 'clip' in args.text_model or 'gpt' in args.text_model:
self.token_keys = ['input_ids', 'attention_mask']
elif 'bert' in args.text_model:
self.token_keys = ['input_ids', 'token_type_ids', 'attention_mask']
clap = CLAP(
audioenc_name=args.audioenc_name,
sample_rate=args.sampling_rate,
window_size=args.window_size,
hop_size=args.hop_size,
mel_bins=args.mel_bins,
fmin=args.fmin,
fmax=args.fmax,
classes_num=args.num_classes,
out_emb=args.out_emb,
text_model=args.text_model,
transformer_embed_dim=args.transformer_embed_dim,
d_proj=args.d_proj
)
# Load pretrained weights for model
model_state_dict = torch.load(self.model_fp, map_location=torch.device('cpu'))['model']
# We unwrap the DDP model and save. If the model is not unwrapped and saved, then the model needs to unwrapped before `load_state_dict`:
# Reference link: https://discuss.pytorch.org/t/how-to-load-dataparallel-model-which-trained-using-multiple-gpus/146005
clap.load_state_dict(model_state_dict, strict=False)
clap.eval() # set clap in eval mode
tokenizer = AutoTokenizer.from_pretrained(args.text_model)
if 'gpt' in args.text_model:
tokenizer.add_special_tokens({'pad_token': '!'})
if self.use_cuda and torch.cuda.is_available():
clap = clap.cuda()
return clap, tokenizer, args
def load_clapcap(self):
r"""Load CLAP model with args from config file"""
args = self.read_config_as_args(self.config_as_str, is_config_str=True)
args.prefix_dim = args.d_proj
text_model = args.text_model
args.text_model = args.text_decoder
args.cross_attention = True if 'cross' in args.clapcap_model.lower() else False
if 'roberta' in args.text_model or 'clip' in args.text_model or 'gpt' in args.text_model:
self.token_keys = ['input_ids', 'attention_mask']
elif 'bert' in args.text_model:
self.token_keys = ['input_ids', 'token_type_ids', 'attention_mask']
clap = CLAP(
audioenc_name=args.audioenc_name,
sample_rate=args.sampling_rate,
window_size=args.window_size,
hop_size=args.hop_size,
mel_bins=args.mel_bins,
fmin=args.fmin,
fmax=args.fmax,
classes_num=args.num_classes,
out_emb=args.out_emb,
text_model=text_model,
transformer_embed_dim=args.transformer_embed_dim,
d_proj=args.d_proj
)
clapcap = get_clapcap(args.clapcap_model)(clap, args.text_decoder, args.prefix_length, args.prefix_length_clip, args.prefix_dim,
args.num_layers, args.normalize_prefix, args.mapping_type, True, True)
model_state_dict = torch.load(self.model_fp, map_location=torch.device('cpu'))['model']
clapcap.load_state_dict(model_state_dict, strict=False)
clapcap.eval() # set clap in eval mode
tokenizer = AutoTokenizer.from_pretrained(args.text_model)
if 'gpt' in args.text_model:
tokenizer.add_special_tokens({'pad_token': '!'})
if self.use_cuda and torch.cuda.is_available():
clapcap = clapcap.cuda()
return clapcap, tokenizer, args
def default_collate(self, batch):
r"""Puts each data field into a tensor with outer dimension batch size"""
elem = batch[0]
elem_type = type(elem)
if isinstance(elem, torch.Tensor):
out = None
if torch.utils.data.get_worker_info() is not None:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum([x.numel() for x in batch])
storage = elem.storage()._new_shared(numel)
out = elem.new(storage)
return torch.stack(batch, 0, out=out)
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
# array of string classes and object
if self.np_str_obj_array_pattern.search(elem.dtype.str) is not None:
raise TypeError(
self.default_collate_err_msg_format.format(elem.dtype))
return self.default_collate([torch.as_tensor(b) for b in batch])
elif elem.shape == (): # scalars
return torch.as_tensor(batch)
elif isinstance(elem, float):
return torch.tensor(batch, dtype=torch.float64)
elif isinstance(elem, int):
return torch.tensor(batch)
elif isinstance(elem, str):
return batch
elif isinstance(elem, collections.abc.Mapping):
return {key: self.default_collate([d[key] for d in batch]) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
return elem_type(*(self.default_collate(samples) for samples in zip(*batch)))
elif isinstance(elem, collections.abc.Sequence):
# check to make sure that the elements in batch have consistent size
it = iter(batch)
elem_size = len(next(it))
if not all(len(elem) == elem_size for elem in it):
raise RuntimeError(
'each element in list of batch should be of equal size')
transposed = zip(*batch)
return [self.default_collate(samples) for samples in transposed]
raise TypeError(self.default_collate_err_msg_format.format(elem_type))
def read_audio(self, audio_path, resample=True):
r"""Loads audio file or array and returns a torch tensor"""
# Randomly sample a segment of audio_duration from the clip or pad to match duration
audio_time_series, sample_rate = torchaudio.load(audio_path)
resample_rate = self.args.sampling_rate
if resample and resample_rate != sample_rate:
resampler = T.Resample(sample_rate, resample_rate)
audio_time_series = resampler(audio_time_series)
return audio_time_series, resample_rate
def load_audio_into_tensor(self, audio_path, audio_duration, resample=False):
r"""Loads audio file and returns raw audio."""
# Randomly sample a segment of audio_duration from the clip or pad to match duration
audio_time_series, sample_rate = self.read_audio(audio_path, resample=resample)
audio_time_series = audio_time_series.reshape(-1)
# audio_time_series is shorter than predefined audio duration,
# so audio_time_series is extended
if audio_duration*sample_rate >= audio_time_series.shape[0]:
repeat_factor = int(np.ceil((audio_duration*sample_rate) /
audio_time_series.shape[0]))
# Repeat audio_time_series by repeat_factor to match audio_duration
audio_time_series = audio_time_series.repeat(repeat_factor)
# remove excess part of audio_time_series
audio_time_series = audio_time_series[0:audio_duration*sample_rate]
else:
# audio_time_series is longer than predefined audio duration,
# so audio_time_series is trimmed
start_index = random.randrange(
audio_time_series.shape[0] - audio_duration*sample_rate)
audio_time_series = audio_time_series[start_index:start_index +
audio_duration*sample_rate]
return torch.FloatTensor(audio_time_series)
def preprocess_audio(self, audio_files, resample):
r"""Load list of audio files and return raw audio"""
audio_tensors = []
for audio_file in audio_files:
audio_tensor = self.load_audio_into_tensor(
audio_file, self.args.duration, resample)
audio_tensor = audio_tensor.reshape(
1, -1).cuda() if self.use_cuda and torch.cuda.is_available() else audio_tensor.reshape(1, -1)
audio_tensors.append(audio_tensor)
return self.default_collate(audio_tensors)
def preprocess_text(self, text_queries):
r"""Load list of class labels and return tokenized text"""
tokenized_texts = []
for ttext in text_queries:
if 'gpt' in self.args.text_model:
ttext = ttext + ' <|endoftext|>'
tok = self.tokenizer.encode_plus(
text=ttext, add_special_tokens=True, max_length=self.args.text_len, padding='max_length', return_tensors="pt")
for key in self.token_keys:
tok[key] = tok[key].reshape(-1).cuda() if self.use_cuda and torch.cuda.is_available() else tok[key].reshape(-1)
tokenized_texts.append(tok)
return self.default_collate(tokenized_texts)
def get_text_embeddings(self, class_labels):
r"""Load list of class labels and return text embeddings"""
preprocessed_text = self.preprocess_text(class_labels)
return self._get_text_embeddings(preprocessed_text)
def get_audio_embeddings(self, audio_files, resample=True):
r"""Load list of audio files and return a audio embeddings"""
preprocessed_audio = self.preprocess_audio(audio_files, resample)
return self._get_audio_embeddings(preprocessed_audio)
def _get_text_embeddings(self, preprocessed_text):
r"""Load preprocessed text and return text embeddings"""
with torch.no_grad():
return self.clap.caption_encoder(preprocessed_text)
def _get_audio_embeddings(self, preprocessed_audio):
r"""Load preprocessed audio and return a audio embeddings"""
with torch.no_grad():
preprocessed_audio = preprocessed_audio.reshape(
preprocessed_audio.shape[0], preprocessed_audio.shape[2])
#Append [0] the audio emebdding, [1] has output class probabilities
return self.clap.audio_encoder(preprocessed_audio)[0]
def _generic_batch_inference(self, func, *args):
r"""Process audio and/or text per batch"""
input_tmp = args[0]
batch_size = args[-1]
# args[0] has audio_files, args[1] has class_labels
inputs = [args[0], args[1]] if len(args) == 3 else [args[0]]
args0_len = len(args[0])
# compute text_embeddings once for all the audio_files batches
if len(inputs) == 2:
text_embeddings = self.get_text_embeddings(args[1])
inputs = [args[0], args[1], text_embeddings]
dataset_idx = 0
for _ in range(math.ceil(args0_len/batch_size)):
next_batch_idx = dataset_idx + batch_size
# batch size is bigger than available audio/text items
if next_batch_idx >= args0_len:
inputs[0] = input_tmp[dataset_idx:]
yield func(*tuple(inputs))
else:
inputs[0] = input_tmp[dataset_idx:next_batch_idx]
yield func(*tuple(inputs))
dataset_idx = next_batch_idx
def get_audio_embeddings_per_batch(self, audio_files, batch_size):
r"""Load preprocessed audio and return a audio embeddings per batch"""
return self._generic_batch_inference(self.get_audio_embeddings, audio_files, batch_size)
def get_text_embeddings_per_batch(self, class_labels, batch_size):
r"""Load preprocessed text and return text embeddings per batch"""
return self._generic_batch_inference(self.get_text_embeddings, class_labels, batch_size)
def compute_similarity(self, audio_embeddings, text_embeddings):
r"""Compute similarity between text and audio embeddings"""
audio_embeddings = audio_embeddings/torch.norm(audio_embeddings, dim=-1, keepdim=True)
text_embeddings = text_embeddings/torch.norm(text_embeddings, dim=-1, keepdim=True)
logit_scale = self.clap.logit_scale.exp()
similarity = logit_scale*text_embeddings @ audio_embeddings.T
return similarity.T
def classify_audio_files_per_batch(self, audio_files, class_labels, batch_size):
r"""Compute classification probabilities for each audio recording in a batch and each class label"""
return self._generic_batch_inference(self.classify_audio_files, audio_files, class_labels, batch_size)
def generate_caption(self, audio_files, resample=True, beam_size: int = 5, entry_length=67, temperature=1.):
r"""Generate audio captions for each audio recording in a batch"""
captions = []
audio_tensors = self.preprocess_audio(audio_files, resample)
with torch.no_grad():
prefix = self.clapcap.clap(audio_tensors.squeeze(1))[0]
if self.args.normalize_prefix:
prefix = prefix / prefix.norm(2, -1).reshape(-1,1)
prefix_embed = self.clapcap.clap_project(prefix).view(-1, self.args.prefix_length, self.clapcap.gpt.transformer.wte.weight.shape[1])
for i in range(len(audio_tensors)):
gen_caption = self._generate_beam(embed=prefix_embed[i].unsqueeze(0),\
beam_size=beam_size,\
entry_length=entry_length,\
temperature=temperature)[0]
captions.append(gen_caption.capitalize())
return captions
def _generate_beam(self, beam_size: int = 5, prompt=None, embed=None,
entry_length=67, temperature=1., stop_token: str = ' <|endoftext|>'):
r"""Generate captions by beam search decoding"""
self.clapcap.eval()
stop_token_index = self.tokenizer.encode(stop_token)[0]
tokens = None
scores = None
device = next(self.clapcap.parameters()).device
seq_lengths = torch.ones(beam_size, device=device)
is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool)
with torch.no_grad():
if embed is not None:
generated = embed
else:
if tokens is None:
tokens = torch.tensor(self.tokenizer.encode(prompt))
tokens = tokens.unsqueeze(0).to(device)
generated = self.clapcap.gpt.transformer.wte(tokens)
for i in range(entry_length):
outputs = self.clapcap.gpt(inputs_embeds=generated)
logits = outputs.logits
logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
logits = logits.softmax(-1).log()
if scores is None:
scores, next_tokens = logits.topk(beam_size, -1)
generated = generated.expand(beam_size, *generated.shape[1:])
next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
if tokens is None:
tokens = next_tokens
else:
tokens = tokens.expand(beam_size, *tokens.shape[1:])
tokens = torch.cat((tokens, next_tokens), dim=1)
else:
logits[is_stopped] = -float(np.inf)
logits[is_stopped, 0] = 0
scores_sum = scores[:, None] + logits
seq_lengths[~is_stopped] += 1
scores_sum_average = scores_sum / seq_lengths[:, None]
scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1)
next_tokens_source = next_tokens // scores_sum.shape[1]
seq_lengths = seq_lengths[next_tokens_source]
next_tokens = next_tokens % scores_sum.shape[1]
next_tokens = next_tokens.unsqueeze(1)
tokens = tokens[next_tokens_source]
tokens = torch.cat((tokens, next_tokens), dim=1)
generated = generated[next_tokens_source]
scores = scores_sum_average * seq_lengths
is_stopped = is_stopped[next_tokens_source]
next_token_embed = self.clapcap.gpt.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1)
generated = torch.cat((generated, next_token_embed), dim=1)
is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze()
if is_stopped.all():
break
scores = scores / seq_lengths
output_list = tokens.cpu().numpy()
output_texts = [self.tokenizer.decode(output[:int(length)]) for output, length in zip(output_list, seq_lengths)]
order = scores.argsort(descending=True)
output_texts = [output_texts[i] for i in order]
return output_texts
+1
View File
@@ -0,0 +1 @@
from .CLAPWrapper import CLAPWrapper as CLAP
+26
View File
@@ -0,0 +1,26 @@
# TEXT ENCODER CONFIG
text_model: 'gpt2'
text_len: 77
transformer_embed_dim: 768
freeze_text_encoder_weights: True
# AUDIO ENCODER CONFIG
audioenc_name: 'HTSAT'
out_emb: 768
sampling_rate: 44100
duration: 7
fmin: 50
fmax: 8000 #14000
n_fft: 1024 # 1028
hop_size: 320
mel_bins: 64
window_size: 1024
# PROJECTION SPACE CONFIG
d_proj: 1024
temperature: 0.003
# TRAINING AND EVALUATION CONFIG
num_classes: 527
batch_size: 1024
demo: False
+34
View File
@@ -0,0 +1,34 @@
# TEXT ENCODER CONFIG
text_model: 'gpt2'
transformer_embed_dim: 768
freeze_text_encoder_weights: True
# AUDIO ENCODER CONFIG
audioenc_name: 'HTSAT'
out_emb: 768
sampling_rate: 44100
duration: 7
fmin: 50
fmax: 8000
n_fft: 1024
hop_size: 320
mel_bins: 64
window_size: 1024
# PROJECTION SPACE CONFIG
d_proj: 1024
temperature: 0.003
# TRAINING AND EVALUATION CONFIG
batch_size: 128
num_classes: 527
# CLAPCAP CONFIG
clapcap_model: 'ClapCaption'
text_decoder: 'gpt2'
prefix_length: 40
prefix_length_clip: 40
mapping_type: 'transformer'
num_layers: 8
normalize_prefix: True
freeze_gpt_weights: True
+6
View File
@@ -0,0 +1,6 @@
from . import clap
from . import audio
from . import htsat
from . import config
from . import pytorch_utils
from . import htsat
@@ -2,10 +2,13 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from torchlibrosa.stft import Spectrogram, LogmelFilterBank
from .htsat import HTSATWrapper
def get_audio_encoder(name: str):
if name == "Cnn14":
return Cnn14
elif name == "HTSAT":
return HTSATWrapper
else:
raise Exception('The audio encoder name {} is incorrect or not supported'.format(name))
+21 -2
View File
@@ -42,14 +42,33 @@ class AudioEncoder(nn.Module):
class TextEncoder(nn.Module):
def __init__(self, d_out: int, text_model: str, transformer_embed_dim: int) -> None:
super().__init__()
self.text_model = text_model
self.base = AutoModel.from_pretrained(text_model)
if 'clip' in text_model:
self.clip_text_projection = self.base.text_projection
self.base = self.base.text_model
if 'base' in text_model:
transformer_embed_dim = 512
self.projection = Projection(transformer_embed_dim, d_out)
def forward(self, x):
out = self.base(**x)[0]
out = out[:, 0, :] # get CLS token output
if 'clip' in self.text_model:
pooled_output = self.base(**x)[1] # get pooled output
out = self.clip_text_projection(pooled_output) # get CLS token output
elif 'gpt' in self.text_model:
batch_size = x['input_ids'].shape[0]
hidden_states = self.base(**x)[0] # (batch_size=4, seq_len, 768)
sequence_lengths = torch.ne(x['input_ids'], 0).sum(-1) - 1 # tensor([13, 14, 18, 17])
out = hidden_states[torch.arange(batch_size, device=hidden_states.device), sequence_lengths] # [batch_size, 768] = [4, 768]
else:
out = self.base(**x)[0]
out = out[:, 0, :] # get CLS token output
projected_vec = self.projection(out)
return projected_vec
class CLAP(nn.Module):
+128
View File
@@ -0,0 +1,128 @@
# Ke Chen
# knutchen@ucsd.edu
# HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION
# The configuration for training the model
exp_name = "exp_htsat_pretrain" # the saved ckpt prefix name of the model
workspace = "/home/kechen/Research/HTSAT" # the folder of your code
dataset_path = "/home/Research/audioset" # the dataset path
desed_folder = "/home/Research/DESED" # the desed file
dataset_type = "audioset" # "audioset" "esc-50" "scv2"
index_type = "full_train" # only works for audioset
balanced_data = True # only works for audioset
loss_type = "clip_bce" #
# AudioSet & SCV2: "clip_bce" | ESC-50: "clip_ce"
# trained from a checkpoint, or evaluate a single model
resume_checkpoint = None
# "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_1.ckpt"
esc_fold = 0 # just for esc dataset, select the fold you need for evaluation and (+1) validation
debug = False
random_seed = 970131 # 19970318 970131 12412 127777 1009 34047
batch_size = 32 * 4 # batch size per GPU x GPU number , default is 32 x 4 = 128
learning_rate = 1e-3 # 1e-4 also workable
max_epoch = 100
num_workers = 3
lr_scheduler_epoch = [10,20,30]
lr_rate = [0.02, 0.05, 0.1]
# these data preparation optimizations do not bring many improvements, so deprecated
enable_token_label = False # token label
class_map_path = "class_hier_map.npy"
class_filter = None
retrieval_index = [15382, 9202, 130, 17618, 17157, 17516, 16356, 6165, 13992, 9238, 5550, 5733, 1914, 1600, 3450, 13735, 11108, 3762,
9840, 11318, 8131, 4429, 16748, 4992, 16783, 12691, 4945, 8779, 2805, 9418, 2797, 14357, 5603, 212, 3852, 12666, 1338, 10269, 2388, 8260, 4293, 14454, 7677, 11253, 5060, 14938, 8840, 4542, 2627, 16336, 8992, 15496, 11140, 446, 6126, 10691, 8624, 10127, 9068, 16710, 10155, 14358, 7567, 5695, 2354, 8057, 17635, 133, 16183, 14535, 7248, 4560, 14429, 2463, 10773, 113, 2462, 9223, 4929, 14274, 4716, 17307, 4617, 2132, 11083, 1039, 1403, 9621, 13936, 2229, 2875, 17840, 9359, 13311, 9790, 13288, 4750, 17052, 8260, 14900]
token_label_range = [0.2,0.6]
enable_time_shift = False # shift time
enable_label_enhance = False # enhance hierarchical label
enable_repeat_mode = False # repeat the spectrogram / reshape the spectrogram
# for model's design
enable_tscam = True # enbale the token-semantic layer
# for signal processing
sample_rate = 32000 # 16000 for scv2, 32000 for audioset and esc-50
clip_samples = sample_rate * 10 # audio_set 10-sec clip
window_size = 1024
hop_size = 320 # 160 for scv2, 320 for audioset and esc-50
mel_bins = 64
fmin = 50
fmax = 14000
shift_max = int(clip_samples * 0.5)
# for data collection
classes_num = 527 # esc: 50 | audioset: 527 | scv2: 35
patch_size = (25, 4) # deprecated
crop_size = None # int(clip_samples * 0.5) deprecated
# for htsat hyperparamater
htsat_window_size = 8
htsat_spec_size = 256
htsat_patch_size = 4
htsat_stride = (4, 4)
htsat_num_head = [4,8,16,32]
htsat_dim = 96
htsat_depth = [2,2,6,2]
swin_pretrain_path = None
# "/home/Research/model_backup/pretrain/swin_tiny_c24_patch4_window8_256.pth"
# Some Deprecated Optimization in the model design, check the model code for details
htsat_attn_heatmap = False
htsat_hier_output = False
htsat_use_max = False
# for ensemble test
ensemble_checkpoints = []
ensemble_strides = []
# weight average folder
wa_folder = "/home/version_0/checkpoints/"
# weight average output filename
wa_model_path = "HTSAT_AudioSet_Saved_x.ckpt"
esm_model_pathes = [
"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_1.ckpt",
"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_2.ckpt",
"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_3.ckpt",
"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_4.ckpt",
"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_5.ckpt",
"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_6.ckpt"
]
# for framewise localization
heatmap_dir = "/home/Research/heatmap_output"
test_file = "htsat-test-ensemble"
fl_local = False # indicate if we need to use this dataset for the framewise detection
fl_dataset = "/home/Research/desed/desed_eval.npy"
fl_class_num = [
"Speech", "Frying", "Dishes", "Running_water",
"Blender", "Electric_shaver_toothbrush", "Alarm_bell_ringing",
"Cat", "Dog", "Vacuum_cleaner"
]
# map 527 classes into 10 classes
fl_audioset_mapping = [
[0,1,2,3,4,5,6,7],
[366, 367, 368],
[364],
[288, 289, 290, 291, 292, 293, 294, 295, 296, 297],
[369],
[382],
[310, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402],
[81, 82, 83, 84, 85],
[74, 75, 76, 77, 78, 79],
[377]
]
+942
View File
@@ -0,0 +1,942 @@
# Ke Chen
# knutchen@ucsd.edu
# HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION
# Model Core
# below codes are based and referred from https://github.com/microsoft/Swin-Transformer
# Swin Transformer for Computer Vision: https://arxiv.org/pdf/2103.14030.pdf
import math
import random
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from torchlibrosa.stft import Spectrogram, LogmelFilterBank
from torchlibrosa.augmentation import SpecAugmentation
from itertools import repeat
from .pytorch_utils import do_mixup, interpolate
from . import config
import collections.abc
import warnings
from torch.nn.init import _calculate_fan_in_and_fan_out
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))
return parse
to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = _ntuple
def drop_path(x, drop_prob: float = 0., training: bool = False):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
class PatchEmbed(nn.Module):
""" 2D Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, patch_stride = 16):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patch_stride = to_2tuple(patch_stride)
self.img_size = img_size
self.patch_size = patch_size
self.patch_stride = patch_stride
self.grid_size = (img_size[0] // patch_stride[0], img_size[1] // patch_stride[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.in_chans = in_chans
self.embed_dim = embed_dim
padding = ((patch_size[0] - patch_stride[0]) // 2, (patch_size[1] - patch_stride[1]) // 2)
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_stride, padding=padding)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
x = self.norm(x)
return x
class Mlp(nn.Module):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2.
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2)
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
# type: (Tensor, float, float, float, float) -> Tensor
r"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
"""
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
if mode == 'fan_in':
denom = fan_in
elif mode == 'fan_out':
denom = fan_out
elif mode == 'fan_avg':
denom = (fan_in + fan_out) / 2
variance = scale / denom
if distribution == "truncated_normal":
# constant is stddev of standard normal truncated to (-2, 2)
trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978)
elif distribution == "normal":
tensor.normal_(std=math.sqrt(variance))
elif distribution == "uniform":
bound = math.sqrt(3 * variance)
tensor.uniform_(-bound, bound)
else:
raise ValueError(f"invalid distribution {distribution}")
def lecun_normal_(tensor):
variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal')
# below codes are based and referred from https://github.com/microsoft/Swin-Transformer
# Swin Transformer for Computer Vision: https://arxiv.org/pdf/2103.14030.pdf
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x, attn
def extra_repr(self):
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
# We use the model based on Swintransformer Block, therefore we can use the swin-transformer pretrained model
class SwinTransformerBlock(nn.Module):
r""" Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm, norm_before_mlp='ln'):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
self.norm_before_mlp = norm_before_mlp
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
if self.norm_before_mlp == 'ln':
self.norm2 = nn.LayerNorm(dim)
elif self.norm_before_mlp == 'bn':
self.norm2 = lambda x: nn.BatchNorm1d(dim)(x.transpose(1, 2)).transpose(1, 2)
else:
raise NotImplementedError
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
if self.shift_size > 0:
# calculate attention mask for SW-MSA
H, W = self.input_resolution
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
def forward(self, x):
# pdb.set_trace()
H, W = self.input_resolution
# print("H: ", H)
# print("W: ", W)
# pdb.set_trace()
B, L, C = x.shape
# assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows, attn = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
x = x.view(B, H * W, C)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x, attn
def extra_repr(self):
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
class PatchMerging(nn.Module):
r""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
def extra_repr(self):
return f"input_resolution={self.input_resolution}, dim={self.dim}"
class BasicLayer(nn.Module):
""" A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
norm_before_mlp='ln'):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.ModuleList([
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
num_heads=num_heads, window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop, attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer, norm_before_mlp=norm_before_mlp)
for i in range(depth)])
# patch merging layer
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x):
attns = []
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x)
else:
x, attn = blk(x)
if not self.training:
attns.append(attn.unsqueeze(0))
if self.downsample is not None:
x = self.downsample(x)
if not self.training:
attn = torch.cat(attns, dim = 0)
attn = torch.mean(attn, dim = 0)
return x, attn
def extra_repr(self):
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
# The Core of HTSAT
class HTSAT_Swin_Transformer(nn.Module):
r"""HTSAT based on the Swin Transformer
Args:
spec_size (int | tuple(int)): Input Spectrogram size. Default 256
patch_size (int | tuple(int)): Patch size. Default: 4
path_stride (iot | tuple(int)): Patch Stride for Frequency and Time Axis. Default: 4
in_chans (int): Number of input image channels. Default: 1 (mono)
num_classes (int): Number of classes for classification head. Default: 527
embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each HTSAT-Swin Transformer layer.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size. Default: 8
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
drop_rate (float): Dropout rate. Default: 0
attn_drop_rate (float): Attention dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
patch_norm (bool): If True, add normalization after patch embedding. Default: True
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
config (module): The configuration Module from config.py
"""
def __init__(self, spec_size=256, patch_size=4, patch_stride=(4,4),
in_chans=1, num_classes=527,
embed_dim=96, depths=[2, 2, 6, 2], num_heads=[4, 8, 16, 32],
window_size=8, mlp_ratio=4., qkv_bias=True, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm,
ape=False, patch_norm=True,
use_checkpoint=False, norm_before_mlp='ln', config = None, **kwargs):
super(HTSAT_Swin_Transformer, self).__init__()
self.config = config
self.spec_size = spec_size
self.patch_stride = patch_stride
self.patch_size = patch_size
self.window_size = window_size
self.embed_dim = embed_dim
self.depths = depths
self.ape = ape
self.in_chans = in_chans
self.num_classes = num_classes
self.num_heads = num_heads
self.num_layers = len(self.depths)
self.num_features = int(self.embed_dim * 2 ** (self.num_layers - 1))
self.drop_rate = drop_rate
self.attn_drop_rate = attn_drop_rate
self.drop_path_rate = drop_path_rate
self.qkv_bias = qkv_bias
self.qk_scale = None
self.patch_norm = patch_norm
self.norm_layer = norm_layer if self.patch_norm else None
self.norm_before_mlp = norm_before_mlp
self.mlp_ratio = mlp_ratio
self.use_checkpoint = use_checkpoint
# process mel-spec ; used only once
self.freq_ratio = self.spec_size // self.config.mel_bins
window = 'hann'
center = True
pad_mode = 'reflect'
ref = 1.0
amin = 1e-10
top_db = None
self.interpolate_ratio = 32 # Downsampled ratio
# Spectrogram extractor
self.spectrogram_extractor = Spectrogram(n_fft=config.window_size, hop_length=config.hop_size,
win_length=config.window_size, window=window, center=center, pad_mode=pad_mode,
freeze_parameters=True)
# Logmel feature extractor
self.logmel_extractor = LogmelFilterBank(sr=config.sample_rate, n_fft=config.window_size,
n_mels=config.mel_bins, fmin=config.fmin, fmax=config.fmax, ref=ref, amin=amin, top_db=top_db,
freeze_parameters=True)
# Spec augmenter
self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2,
freq_drop_width=8, freq_stripes_num=2) # 2 2
self.bn0 = nn.BatchNorm2d(self.config.mel_bins)
# split spctrogram into non-overlapping patches
self.patch_embed = PatchEmbed(
img_size=self.spec_size, patch_size=self.patch_size, in_chans=self.in_chans,
embed_dim=self.embed_dim, norm_layer=self.norm_layer, patch_stride = patch_stride)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.grid_size
self.patches_resolution = patches_resolution
# absolute position embedding
if self.ape:
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, self.embed_dim))
trunc_normal_(self.absolute_pos_embed, std=.02)
self.pos_drop = nn.Dropout(p=self.drop_rate)
# stochastic depth
dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, sum(self.depths))] # stochastic depth decay rule
# build layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = BasicLayer(dim=int(self.embed_dim * 2 ** i_layer),
input_resolution=(patches_resolution[0] // (2 ** i_layer),
patches_resolution[1] // (2 ** i_layer)),
depth=self.depths[i_layer],
num_heads=self.num_heads[i_layer],
window_size=self.window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=self.qkv_bias, qk_scale=self.qk_scale,
drop=self.drop_rate, attn_drop=self.attn_drop_rate,
drop_path=dpr[sum(self.depths[:i_layer]):sum(self.depths[:i_layer + 1])],
norm_layer=self.norm_layer,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
use_checkpoint=use_checkpoint,
norm_before_mlp=self.norm_before_mlp)
self.layers.append(layer)
self.norm = self.norm_layer(self.num_features)
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.maxpool = nn.AdaptiveMaxPool1d(1)
if self.config.enable_tscam:
SF = self.spec_size // (2 ** (len(self.depths) - 1)) // self.patch_stride[0] // self.freq_ratio
self.tscam_conv = nn.Conv2d(
in_channels = self.num_features,
out_channels = self.num_classes,
kernel_size = (SF,3),
padding = (0,1)
)
self.head = nn.Linear(num_classes, num_classes)
else:
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'absolute_pos_embed'}
@torch.jit.ignore
def no_weight_decay_keywords(self):
return {'relative_position_bias_table'}
def forward_features(self, x):
frames_num = x.shape[2]
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for i, layer in enumerate(self.layers):
x, attn = layer(x)
if self.config.enable_tscam:
# for x
x = self.norm(x)
B, N, C = x.shape
SF = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[0]
ST = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[1]
x = x.permute(0,2,1).contiguous().reshape(B, C, SF, ST)
B, C, F, T = x.shape
# group 2D CNN
c_freq_bin = F // self.freq_ratio
x = x.reshape(B, C, F // c_freq_bin, c_freq_bin, T)
x = x.permute(0,1,3,2,4).contiguous().reshape(B, C, c_freq_bin, -1)
# get latent_output
latent_output = self.avgpool(torch.flatten(x,2))
latent_output = torch.flatten(latent_output, 1)
# display the attention map, if needed
if self.config.htsat_attn_heatmap:
# for attn
attn = torch.mean(attn, dim = 1)
attn = torch.mean(attn, dim = 1)
attn = attn.reshape(B, SF, ST)
c_freq_bin = SF // self.freq_ratio
attn = attn.reshape(B, SF // c_freq_bin, c_freq_bin, ST)
attn = attn.permute(0,2,1,3).contiguous().reshape(B, c_freq_bin, -1)
attn = attn.mean(dim = 1)
attn_max = torch.max(attn, dim = 1, keepdim = True)[0]
attn_min = torch.min(attn, dim = 1, keepdim = True)[0]
attn = ((attn * 0.15) + (attn_max * 0.85 - attn_min)) / (attn_max - attn_min)
attn = attn.unsqueeze(dim = 2)
x = self.tscam_conv(x)
x = torch.flatten(x, 2) # B, C, T
if self.config.htsat_attn_heatmap:
fpx = interpolate(torch.sigmoid(x).permute(0,2,1).contiguous() * attn, 8 * self.patch_stride[1])
else:
fpx = interpolate(torch.sigmoid(x).permute(0,2,1).contiguous(), 8 * self.patch_stride[1])
x = self.avgpool(x)
x = torch.flatten(x, 1)
if self.config.loss_type == "clip_ce":
output_dict = {
'framewise_output': fpx, # already sigmoided
'clipwise_output': x,
'latent_output': latent_output
}
else:
output_dict = {
'framewise_output': fpx, # already sigmoided
'clipwise_output': torch.sigmoid(x),
'latent_output': latent_output
}
else:
x = self.norm(x) # B N C
B, N, C = x.shape
fpx = x.permute(0,2,1).contiguous().reshape(B, C, frames_num // (2 ** (len(self.depths) + 1)), frames_num // (2 ** (len(self.depths) + 1)) )
B, C, F, T = fpx.shape
c_freq_bin = F // self.freq_ratio
fpx = fpx.reshape(B, C, F // c_freq_bin, c_freq_bin, T)
fpx = fpx.permute(0,1,3,2,4).contiguous().reshape(B, C, c_freq_bin, -1)
fpx = torch.sum(fpx, dim = 2)
fpx = interpolate(fpx.permute(0,2,1).contiguous(), 8 * self.patch_stride[1])
x = self.avgpool(x.transpose(1, 2)) # B C 1
x = torch.flatten(x, 1)
if self.num_classes > 0:
x = self.head(x)
fpx = self.head(fpx)
output_dict = {'framewise_output': torch.sigmoid(fpx),
'clipwise_output': torch.sigmoid(x)}
return output_dict
def crop_wav(self, x, crop_size, spe_pos = None):
time_steps = x.shape[2]
tx = torch.zeros(x.shape[0], x.shape[1], crop_size, x.shape[3]).to(x.device)
for i in range(len(x)):
if spe_pos is None:
crop_pos = random.randint(0, time_steps - crop_size - 1)
else:
crop_pos = spe_pos
tx[i][0] = x[i, 0, crop_pos:crop_pos + crop_size,:]
return tx
# Reshape the wavform to a img size, if you want to use the pretrained swin transformer model
def reshape_wav2img(self, x):
B, C, T, F = x.shape
target_T = int(self.spec_size * self.freq_ratio)
target_F = self.spec_size // self.freq_ratio
assert T <= target_T and F <= target_F, "the wav size should less than or equal to the swin input size"
# to avoid bicubic zero error
if T < target_T:
x = nn.functional.interpolate(x, (target_T, x.shape[3]), mode="bicubic", align_corners=True)
if F < target_F:
x = nn.functional.interpolate(x, (x.shape[2], target_F), mode="bicubic", align_corners=True)
x = x.permute(0,1,3,2).contiguous()
x = x.reshape(x.shape[0], x.shape[1], x.shape[2], self.freq_ratio, x.shape[3] // self.freq_ratio)
# print(x.shape)
x = x.permute(0,1,3,2,4).contiguous()
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3], x.shape[4])
return x
# Repeat the wavform to a img size, if you want to use the pretrained swin transformer model
def repeat_wat2img(self, x, cur_pos):
B, C, T, F = x.shape
target_T = int(self.spec_size * self.freq_ratio)
target_F = self.spec_size // self.freq_ratio
assert T <= target_T and F <= target_F, "the wav size should less than or equal to the swin input size"
# to avoid bicubic zero error
if T < target_T:
x = nn.functional.interpolate(x, (target_T, x.shape[3]), mode="bicubic", align_corners=True)
if F < target_F:
x = nn.functional.interpolate(x, (x.shape[2], target_F), mode="bicubic", align_corners=True)
x = x.permute(0,1,3,2).contiguous() # B C F T
x = x[:,:,:,cur_pos:cur_pos + self.spec_size]
x = x.repeat(repeats = (1,1,4,1))
return x
def forward(self, x: torch.Tensor, mixup_lambda = None, infer_mode = False):# out_feat_keys: List[str] = None):
x = self.spectrogram_extractor(x) # (batch_size, 1, time_steps, freq_bins)
x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
x = x.transpose(1, 3)
x = self.bn0(x)
x = x.transpose(1, 3)
if self.training:
x = self.spec_augmenter(x)
if self.training and mixup_lambda is not None:
x = do_mixup(x, mixup_lambda)
if infer_mode:
# in infer mode. we need to handle different length audio input
frame_num = x.shape[2]
target_T = int(self.spec_size * self.freq_ratio)
repeat_ratio = math.floor(target_T / frame_num)
x = x.repeat(repeats=(1,1,repeat_ratio,1))
x = self.reshape_wav2img(x)
output_dict = self.forward_features(x)
elif self.config.enable_repeat_mode:
if self.training:
cur_pos = random.randint(0, (self.freq_ratio - 1) * self.spec_size - 1)
x = self.repeat_wat2img(x, cur_pos)
output_dict = self.forward_features(x)
else:
output_dicts = []
for cur_pos in range(0, (self.freq_ratio - 1) * self.spec_size + 1, self.spec_size):
tx = x.clone()
tx = self.repeat_wat2img(tx, cur_pos)
output_dicts.append(self.forward_features(tx))
clipwise_output = torch.zeros_like(output_dicts[0]["clipwise_output"]).float().to(x.device)
framewise_output = torch.zeros_like(output_dicts[0]["framewise_output"]).float().to(x.device)
for d in output_dicts:
clipwise_output += d["clipwise_output"]
framewise_output += d["framewise_output"]
clipwise_output = clipwise_output / len(output_dicts)
framewise_output = framewise_output / len(output_dicts)
output_dict = {
'framewise_output': framewise_output,
'clipwise_output': clipwise_output
}
else:
if x.shape[2] > self.freq_ratio * self.spec_size:
if self.training:
x = self.crop_wav(x, crop_size=self.freq_ratio * self.spec_size)
x = self.reshape_wav2img(x)
output_dict = self.forward_features(x)
else:
# Change: Hard code here
overlap_size = 344 #(x.shape[2] - 1) // 4
output_dicts = []
crop_size = 689 #(x.shape[2] - 1) // 2
for cur_pos in range(0, x.shape[2] - crop_size - 1, overlap_size):
tx = self.crop_wav(x, crop_size = crop_size, spe_pos = cur_pos)
tx = self.reshape_wav2img(tx)
output_dicts.append(self.forward_features(tx))
clipwise_output = torch.zeros_like(output_dicts[0]["clipwise_output"]).float().to(x.device)
framewise_output = torch.zeros_like(output_dicts[0]["framewise_output"]).float().to(x.device)
latent_output = torch.zeros_like(output_dicts[0]["latent_output"]).float().to(x.device)
for d in output_dicts:
clipwise_output += d["clipwise_output"]
framewise_output += d["framewise_output"]
latent_output += d["latent_output"]
clipwise_output = clipwise_output / len(output_dicts)
framewise_output = framewise_output / len(output_dicts)
latent_output = latent_output / len(output_dicts)
output_dict = {
'framewise_output': framewise_output,
'clipwise_output': clipwise_output,
'latent_output': latent_output,
}
else: # this part is typically used, and most easy one
x = self.reshape_wav2img(x)
output_dict = self.forward_features(x)
# x = self.head(x)
return output_dict
class HTSATWrapper(nn.Module):
def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin,
fmax, classes_num, out_emb):
super().__init__()
# print("parameters are being overidden when using HTSAT")
# print("HTSAT only support loading a pretrained model on AudioSet")
# @TODO later look at what parameters are same and can be merged
self.htsat = HTSAT_Swin_Transformer(config=config)
def forward(self, x):
out_dict = self.htsat(x)
out_dict['embedding'] = out_dict['latent_output']
return out_dict
+199
View File
@@ -0,0 +1,199 @@
import torch
import torch.nn as nn
from torch.nn import functional as nnf
from enum import Enum
from transformers import GPT2LMHeadModel
from typing import Tuple, Optional
def get_clapcap(name: str):
if name == "ClapCaption":
return ClapCaptionModel
else:
raise Exception('The ClapCap model {} is incorrect or not supported'.format(name))
class MappingType(Enum):
MLP = 'mlp'
Transformer = 'transformer'
class MLP(nn.Module):
def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
super(MLP, self).__init__()
layers = []
for i in range(len(sizes) - 1):
layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
if i < len(sizes) - 2:
layers.append(act())
self.model = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.model(x)
class MlpTransformer(nn.Module):
def __init__(self, in_dim, h_dim, out_d: Optional[int] = None, act=nnf.relu, dropout=0.):
super().__init__()
out_d = out_d if out_d is not None else in_dim
self.fc1 = nn.Linear(in_dim, h_dim)
self.act = act
self.fc2 = nn.Linear(h_dim, out_d)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return x
class MultiHeadAttention(nn.Module):
def __init__(self, dim_self, dim_ref, num_heads, bias=True, dropout=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim_self // num_heads
self.scale = head_dim ** -0.5
self.to_queries = nn.Linear(dim_self, dim_self, bias=bias)
self.to_keys_values = nn.Linear(dim_ref, dim_self * 2, bias=bias)
self.project = nn.Linear(dim_self, dim_self)
self.dropout = nn.Dropout(dropout)
def forward(self, x, y=None, mask=None):
y = y if y is not None else x
b, n, c = x.shape
_, m, d = y.shape
# b n h dh
queries = self.to_queries(x).reshape(b, n, self.num_heads, c // self.num_heads)
# b m 2 h dh
keys_values = self.to_keys_values(y).reshape(b, m, 2, self.num_heads, c // self.num_heads)
keys, values = keys_values[:, :, 0], keys_values[:, :, 1]
attention = torch.einsum('bnhd,bmhd->bnmh', queries, keys) * self.scale
if mask is not None:
if mask.dim() == 2:
mask = mask.unsqueeze(1)
attention = attention.masked_fill(mask.unsqueeze(3), float("-inf"))
attention = attention.softmax(dim=2)
out = torch.einsum('bnmh,bmhd->bnhd', attention, values).reshape(b, n, c)
out = self.project(out)
return out, attention
class TransformerLayer(nn.Module):
def forward_with_attention(self, x, y=None, mask=None):
x_, attention = self.attn(self.norm1(x), y, mask)
x = x + x_
x = x + self.mlp(self.norm2(x))
return x, attention
def forward(self, x, y=None, mask=None):
x = x + self.attn(self.norm1(x), y, mask)[0]
x = x + self.mlp(self.norm2(x))
return x
def __init__(self, dim_self, dim_ref, num_heads, mlp_ratio=4., bias=False, dropout=0., act=nnf.relu,
norm_layer: nn.Module = nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim_self)
self.attn = MultiHeadAttention(dim_self, dim_ref, num_heads, bias=bias, dropout=dropout)
self.norm2 = norm_layer(dim_self)
self.mlp = MlpTransformer(dim_self, int(dim_self * mlp_ratio), act=act, dropout=dropout)
class Transformer(nn.Module):
def __init__(self, dim_self: int, num_heads: int, num_layers: int, dim_ref: Optional[int] = None,
mlp_ratio: float = 2., act=nnf.relu, norm_layer: nn.Module = nn.LayerNorm, enc_dec: bool = False):
super(Transformer, self).__init__()
dim_ref = dim_ref if dim_ref is not None else dim_self
self.enc_dec = enc_dec
if enc_dec:
num_layers = num_layers * 2
layers = []
for i in range(num_layers):
if i % 2 == 0 and enc_dec: # cross
layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
elif enc_dec: # self
layers.append(TransformerLayer(dim_self, dim_self, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
else: # self or cross
layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
self.layers = nn.ModuleList(layers)
def forward_with_attention(self, x, y=None, mask=None):
attentions = []
for layer in self.layers:
x, att = layer.forward_with_attention(x, y, mask)
attentions.append(att)
return x, attentions
def forward(self, x, y=None, mask=None):
for i, layer in enumerate(self.layers):
if i % 2 == 0 and self.enc_dec: # cross
x = layer(x, y)
elif self.enc_dec: # self
x = layer(x, x, mask)
else: # self or cross
x = layer(x, y, mask)
return x
class TransformerMapper(nn.Module):
def __init__(self, dim_clip: int, dim_embedding: int, prefix_length: int, clip_length: int, num_layers: int = 8):
super(TransformerMapper, self).__init__()
self.clip_length = clip_length
self.transformer = Transformer(dim_embedding, 8, num_layers)
self.linear = nn.Linear(dim_clip, clip_length * dim_embedding)
self.prefix_const = nn.Parameter(torch.randn(prefix_length, dim_embedding), requires_grad=True)
def forward(self, x):
x = self.linear(x).view(x.shape[0], self.clip_length, -1)
prefix = self.prefix_const.unsqueeze(0).expand(x.shape[0], *self.prefix_const.shape)
prefix = torch.cat((x, prefix), dim=1)
out = self.transformer(prefix)[:, self.clip_length:]
return out
class ClapCaptionModel(nn.Module):
def __init__(self, clap, text_decoder: str, prefix_length: int, clip_length: Optional[int] = None, prefix_size: int = 512,
num_layers: int = 8, normalize_prefix: bool = True, mapping_type: str = None,\
freeze_audio_encoder_weights: bool = True, freeze_gpt_weights: bool = True):
super(ClapCaptionModel, self).__init__()
self.clap = clap.audio_encoder
self.prefix_length = prefix_length
self.normalize_prefix = normalize_prefix
self.gpt = GPT2LMHeadModel.from_pretrained(text_decoder)
self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
if mapping_type == 'mlp':
self.clap_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2,
self.gpt_embedding_size * prefix_length))
else:
self.clap_project = TransformerMapper(prefix_size, self.gpt_embedding_size, prefix_length,
clip_length, num_layers)
# Freeze all CLAP parameters
if freeze_audio_encoder_weights:
for p in self.clap.parameters():
p.requires_grad = False
if freeze_gpt_weights:
for p in self.gpt.parameters():
p.requires_grad = False
def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor:
return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)
def forward(self, audios: torch.Tensor, tokens: torch.Tensor, mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None):
# get audio embeddings
prefix, _ = self.clap(audios)
# normalize prefix (audio embedding)
if self.normalize_prefix:
prefix = prefix / prefix.norm(2, -1).reshape(-1,1)
embedding_text = self.gpt.transformer.wte(tokens['input_ids'])
prefix_projections = self.clap_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size)
embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
if labels is not None:
dummy_token = self.get_dummy_token(tokens['input_ids'].shape[0], tokens['input_ids'].device)
labels = torch.cat((dummy_token, tokens), dim=1)
out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
return out
+182
View File
@@ -0,0 +1,182 @@
import torch
import torch.nn as nn
def move_data_to_device(x, device):
if 'float' in str(x.dtype):
x = torch.Tensor(x)
elif 'int' in str(x.dtype):
x = torch.LongTensor(x)
else:
return x
return x.to(device)
def do_mixup(x, mixup_lambda):
"""Mixup x of even indexes (0, 2, 4, ...) with x of odd indexes
(1, 3, 5, ...).
Args:
x: (batch_size * 2, ...)
mixup_lambda: (batch_size * 2,)
Returns:
out: (batch_size, ...)
"""
out = (x[0 :: 2].transpose(0, -1) * mixup_lambda[0 :: 2] + \
x[1 :: 2].transpose(0, -1) * mixup_lambda[1 :: 2]).transpose(0, -1)
return out
def append_to_dict(dict, key, value):
if key in dict.keys():
dict[key].append(value)
else:
dict[key] = [value]
def interpolate(x, ratio):
"""Interpolate data in time domain. This is used to compensate the
resolution reduction in downsampling of a CNN.
Args:
x: (batch_size, time_steps, classes_num)
ratio: int, ratio to interpolate
Returns:
upsampled: (batch_size, time_steps * ratio, classes_num)
"""
(batch_size, time_steps, classes_num) = x.shape
upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1)
upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num)
return upsampled
def pad_framewise_output(framewise_output, frames_num):
"""Pad framewise_output to the same length as input frames. The pad value
is the same as the value of the last frame.
Args:
framewise_output: (batch_size, frames_num, classes_num)
frames_num: int, number of frames to pad
Outputs:
output: (batch_size, frames_num, classes_num)
"""
pad = framewise_output[:, -1 :, :].repeat(1, frames_num - framewise_output.shape[1], 1)
"""tensor for padding"""
output = torch.cat((framewise_output, pad), dim=1)
"""(batch_size, frames_num, classes_num)"""
return output
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def count_flops(model, audio_length):
"""Count flops. Code modified from others' implementation.
"""
multiply_adds = True
list_conv2d=[]
def conv2d_hook(self, input, output):
batch_size, input_channels, input_height, input_width = input[0].size()
output_channels, output_height, output_width = output[0].size()
kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups) * (2 if multiply_adds else 1)
bias_ops = 1 if self.bias is not None else 0
params = output_channels * (kernel_ops + bias_ops)
flops = batch_size * params * output_height * output_width
list_conv2d.append(flops)
list_conv1d=[]
def conv1d_hook(self, input, output):
batch_size, input_channels, input_length = input[0].size()
output_channels, output_length = output[0].size()
kernel_ops = self.kernel_size[0] * (self.in_channels / self.groups) * (2 if multiply_adds else 1)
bias_ops = 1 if self.bias is not None else 0
params = output_channels * (kernel_ops + bias_ops)
flops = batch_size * params * output_length
list_conv1d.append(flops)
list_linear=[]
def linear_hook(self, input, output):
batch_size = input[0].size(0) if input[0].dim() == 2 else 1
weight_ops = self.weight.nelement() * (2 if multiply_adds else 1)
bias_ops = self.bias.nelement()
flops = batch_size * (weight_ops + bias_ops)
list_linear.append(flops)
list_bn=[]
def bn_hook(self, input, output):
list_bn.append(input[0].nelement() * 2)
list_relu=[]
def relu_hook(self, input, output):
list_relu.append(input[0].nelement() * 2)
list_pooling2d=[]
def pooling2d_hook(self, input, output):
batch_size, input_channels, input_height, input_width = input[0].size()
output_channels, output_height, output_width = output[0].size()
kernel_ops = self.kernel_size * self.kernel_size
bias_ops = 0
params = output_channels * (kernel_ops + bias_ops)
flops = batch_size * params * output_height * output_width
list_pooling2d.append(flops)
list_pooling1d=[]
def pooling1d_hook(self, input, output):
batch_size, input_channels, input_length = input[0].size()
output_channels, output_length = output[0].size()
kernel_ops = self.kernel_size[0]
bias_ops = 0
params = output_channels * (kernel_ops + bias_ops)
flops = batch_size * params * output_length
list_pooling2d.append(flops)
def foo(net):
childrens = list(net.children())
if not childrens:
if isinstance(net, nn.Conv2d):
net.register_forward_hook(conv2d_hook)
elif isinstance(net, nn.Conv1d):
net.register_forward_hook(conv1d_hook)
elif isinstance(net, nn.Linear):
net.register_forward_hook(linear_hook)
elif isinstance(net, nn.BatchNorm2d) or isinstance(net, nn.BatchNorm1d):
net.register_forward_hook(bn_hook)
elif isinstance(net, nn.ReLU):
net.register_forward_hook(relu_hook)
elif isinstance(net, nn.AvgPool2d) or isinstance(net, nn.MaxPool2d):
net.register_forward_hook(pooling2d_hook)
elif isinstance(net, nn.AvgPool1d) or isinstance(net, nn.MaxPool1d):
net.register_forward_hook(pooling1d_hook)
else:
print('Warning: flop of module {} is not counted!'.format(net))
return
for c in childrens:
foo(c)
# Register hook
foo(model)
device = device = next(model.parameters()).device
input = torch.rand(1, audio_length).to(device)
out = model(input)
total_flops = sum(list_conv2d) + sum(list_conv1d) + sum(list_linear) + \
sum(list_bn) + sum(list_relu) + sum(list_pooling2d) + sum(list_pooling1d)
return total_flops
Generated
+1899
View File
File diff suppressed because it is too large Load Diff
+28
View File
@@ -0,0 +1,28 @@
[tool.poetry]
name = "msclap"
version = "1.3.4"
description = "CLAP (Contrastive Language-Audio Pretraining) is a model that learns acoustic concepts from natural language supervision and enables “Zero-Shot” inference. The model has been extensively evaluated in 26 audio downstream tasks achieving SoTA in several of them including classification, retrieval, and captioning."
authors = ["Benjamin Elizalde", "Soham Deshmukh", "Huaming Wang"]
license = "MIT"
readme = "README.md"
packages = [
{ include = "msclap" },
]
[tool.poetry.dependencies]
python = "^3.8"
librosa = "^0.10.1"
numpy = "^1.23.0"
pandas = "^2.0.0"
torch = "^2.1.0"
torchaudio = "^2.1.0"
torchlibrosa = "^0.1.0"
tqdm = "^4.66.1"
transformers = "^4.34.0"
pyyaml = "^6.0.1"
scikit-learn = "^1.3.1"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
-59
View File
@@ -1,59 +0,0 @@
appdirs==1.4.4
audioread==2.1.9
certifi==2020.12.5
cffi==1.14.5
chardet==4.0.0
click==7.1.2
configparser==5.0.2
cycler==0.10.0
decorator==5.0.7
docker-pycreds==0.4.0
filelock==3.0.12
gitdb==4.0.7
GitPython==3.1.14
h5py==3.2.1
idna==2.10
joblib==1.2.0
kiwisolver==1.3.1
librosa==0.8.0
llvmlite==0.36.0
matplotlib==3.4.1
numba==0.53.1
numpy==1.22.0
packaging==20.9
pandas==1.2.4
pathtools==0.1.2
Pillow==9.0.1
pooch==1.3.0
promise==2.3
protobuf==3.18.3
psutil==5.8.0
pycparser==2.20
pyparsing==2.4.7
python-dateutil==2.8.1
pytz==2021.1
PyYAML==5.4.1
regex==2021.4.4
requests==2.25.1
resampy==0.2.2
sacremoses==0.0.45
scikit-learn==0.24.2
scipy==1.6.3
sentry-sdk==1.0.0
shortuuid==1.0.1
six==1.15.0
smmap==4.0.0
SoundFile==0.10.3.post1
subprocess32==3.5.4
threadpoolctl==2.1.0
tokenizers==0.10.2
--find-links https://download.pytorch.org/whl/torch_stable.html
torch==1.8.1+cu111
torchaudio==0.8.1
torchlibrosa==0.0.9
torchvision==0.9.1+cu111
tqdm==4.60.0
transformers==4.5.1
typing-extensions==3.10.0.0
urllib3==1.26.5
importlib-resources==5.10.0
-237
View File
@@ -1,237 +0,0 @@
import random
import torchaudio
from torch._six import string_classes
import collections
import re
import torch.nn.functional as F
import numpy as np
from transformers import AutoTokenizer
from models.utils import read_config_as_args
from models.clap import CLAP
import math
import torchaudio.transforms as T
import os
import torch
from importlib_resources import files
class CLAPWrapper():
"""
A class for interfacing CLAP model.
"""
def __init__(self, model_fp, use_cuda=False):
self.np_str_obj_array_pattern = re.compile(r'[SaUO]')
self.file_path = os.path.realpath(__file__)
self.default_collate_err_msg_format = (
"default_collate: batch must contain tensors, numpy arrays, numbers, "
"dicts or lists; found {}")
self.config_as_str = files('configs').joinpath('config.yml').read_text()
self.model_fp = model_fp
self.use_cuda = use_cuda
self.clap, self.tokenizer, self.args = self.load_clap()
def load_clap(self):
r"""Load CLAP model with args from config file"""
args = read_config_as_args(self.config_as_str, is_config_str=True)
if 'bert' in args.text_model:
self.token_keys = ['input_ids', 'token_type_ids', 'attention_mask']
else:
self.token_keys = ['input_ids', 'attention_mask']
clap = CLAP(
audioenc_name=args.audioenc_name,
sample_rate=args.sampling_rate,
window_size=args.window_size,
hop_size=args.hop_size,
mel_bins=args.mel_bins,
fmin=args.fmin,
fmax=args.fmax,
classes_num=args.num_classes,
out_emb=args.out_emb,
text_model=args.text_model,
transformer_embed_dim=args.transformer_embed_dim,
d_proj=args.d_proj
)
# Load pretrained weights for model
model_state_dict = torch.load(self.model_fp, map_location=torch.device('cpu'))['model']
clap.load_state_dict(model_state_dict)
clap.eval() # set clap in eval mode
tokenizer = AutoTokenizer.from_pretrained(args.text_model)
if self.use_cuda and torch.cuda.is_available():
clap = clap.cuda()
return clap, tokenizer, args
def default_collate(self, batch):
r"""Puts each data field into a tensor with outer dimension batch size"""
elem = batch[0]
elem_type = type(elem)
if isinstance(elem, torch.Tensor):
out = None
if torch.utils.data.get_worker_info() is not None:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum([x.numel() for x in batch])
storage = elem.storage()._new_shared(numel)
out = elem.new(storage)
return torch.stack(batch, 0, out=out)
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
# array of string classes and object
if self.np_str_obj_array_pattern.search(elem.dtype.str) is not None:
raise TypeError(
self.default_collate_err_msg_format.format(elem.dtype))
return self.default_collate([torch.as_tensor(b) for b in batch])
elif elem.shape == (): # scalars
return torch.as_tensor(batch)
elif isinstance(elem, float):
return torch.tensor(batch, dtype=torch.float64)
elif isinstance(elem, int):
return torch.tensor(batch)
elif isinstance(elem, string_classes):
return batch
elif isinstance(elem, collections.abc.Mapping):
return {key: self.default_collate([d[key] for d in batch]) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
return elem_type(*(self.default_collate(samples) for samples in zip(*batch)))
elif isinstance(elem, collections.abc.Sequence):
# check to make sure that the elements in batch have consistent size
it = iter(batch)
elem_size = len(next(it))
if not all(len(elem) == elem_size for elem in it):
raise RuntimeError(
'each element in list of batch should be of equal size')
transposed = zip(*batch)
return [self.default_collate(samples) for samples in transposed]
raise TypeError(self.default_collate_err_msg_format.format(elem_type))
def load_audio_into_tensor(self, audio_path, audio_duration, resample=False):
r"""Loads audio file and returns raw audio."""
# Randomly sample a segment of audio_duration from the clip or pad to match duration
audio_time_series, sample_rate = torchaudio.load(audio_path)
resample_rate = self.args.sampling_rate
if resample:
resampler = T.Resample(sample_rate, resample_rate)
audio_time_series = resampler(audio_time_series)
audio_time_series = audio_time_series.reshape(-1)
# audio_time_series is shorter than predefined audio duration,
# so audio_time_series is extended
if audio_duration*sample_rate >= audio_time_series.shape[0]:
repeat_factor = int(np.ceil((audio_duration*sample_rate) /
audio_time_series.shape[0]))
# Repeat audio_time_series by repeat_factor to match audio_duration
audio_time_series = audio_time_series.repeat(repeat_factor)
# remove excess part of audio_time_series
audio_time_series = audio_time_series[0:audio_duration*sample_rate]
else:
# audio_time_series is longer than predefined audio duration,
# so audio_time_series is trimmed
start_index = random.randrange(
audio_time_series.shape[0] - audio_duration*sample_rate)
audio_time_series = audio_time_series[start_index:start_index +
audio_duration*sample_rate]
return torch.FloatTensor(audio_time_series)
def preprocess_audio(self, audio_files, resample):
r"""Load list of audio files and return raw audio"""
audio_tensors = []
for audio_file in audio_files:
audio_tensor = self.load_audio_into_tensor(
audio_file, self.args.duration, resample)
audio_tensor = audio_tensor.reshape(
1, -1).cuda() if self.use_cuda and torch.cuda.is_available() else audio_tensor.reshape(1, -1)
audio_tensors.append(audio_tensor)
return self.default_collate(audio_tensors)
def preprocess_text(self, text_queries):
r"""Load list of class labels and return tokenized text"""
tokenized_texts = []
for ttext in text_queries:
tok = self.tokenizer.encode_plus(
text=ttext, add_special_tokens=True, max_length=self.args.text_len, pad_to_max_length=True, return_tensors="pt")
for key in self.token_keys:
tok[key] = tok[key].reshape(-1).cuda() if self.use_cuda and torch.cuda.is_available() else tok[key].reshape(-1)
tokenized_texts.append(tok)
return self.default_collate(tokenized_texts)
def get_text_embeddings(self, class_labels):
r"""Load list of class labels and return text embeddings"""
preprocessed_text = self.preprocess_text(class_labels)
text_embeddings = self._get_text_embeddings(preprocessed_text)
text_embeddings = text_embeddings/torch.norm(text_embeddings, dim=-1, keepdim=True)
return text_embeddings
def get_audio_embeddings(self, audio_files, resample):
r"""Load list of audio files and return a audio embeddings"""
preprocessed_audio = self.preprocess_audio(audio_files, resample)
audio_embeddings = self._get_audio_embeddings(preprocessed_audio)
audio_embeddings = audio_embeddings/torch.norm(audio_embeddings, dim=-1, keepdim=True)
return audio_embeddings
def _get_text_embeddings(self, preprocessed_text):
r"""Load preprocessed text and return text embeddings"""
with torch.no_grad():
text_embeddings = self.clap.caption_encoder(preprocessed_text)
text_embeddings = text_embeddings/torch.norm(text_embeddings, dim=-1, keepdim=True)
return text_embeddings
def _get_audio_embeddings(self, preprocessed_audio):
r"""Load preprocessed audio and return a audio embeddings"""
with torch.no_grad():
preprocessed_audio = preprocessed_audio.reshape(
preprocessed_audio.shape[0], preprocessed_audio.shape[2])
#Append [0] the audio emebdding, [1] has output class probabilities
audio_embeddings = self.clap.audio_encoder(preprocessed_audio)[0]
audio_embeddings = audio_embeddings/torch.norm(audio_embeddings, dim=-1, keepdim=True)
return audio_embeddings
def compute_similarity(self, audio_embeddings, text_embeddings):
r"""Compute similarity between text and audio embeddings"""
logit_scale = self.clap.logit_scale.exp()
similarity = logit_scale*text_embeddings @ audio_embeddings.T
return similarity.T
def _generic_batch_inference(self, func, *args):
r"""Process audio and/or text per batch"""
input_tmp = args[0]
batch_size = args[-1]
# args[0] has audio_files, args[1] has class_labels
inputs = [args[0], args[1]] if len(args) == 3 else [args[0]]
args0_len = len(args[0])
# compute text_embeddings once for all the audio_files batches
if len(inputs) == 2:
text_embeddings = self.get_text_embeddings(args[1])
inputs = [args[0], args[1], text_embeddings]
dataset_idx = 0
for _ in range(math.ceil(args0_len/batch_size)):
next_batch_idx = dataset_idx + batch_size
# batch size is bigger than available audio/text items
if next_batch_idx >= args0_len:
inputs[0] = input_tmp[dataset_idx:]
return func(*tuple(inputs))
else:
inputs[0] = input_tmp[dataset_idx:next_batch_idx]
yield func(*tuple(inputs))
dataset_idx = next_batch_idx
def get_audio_embeddings_per_batch(self, audio_files, batch_size):
r"""Load preprocessed audio and return a audio embeddings per batch"""
return self._generic_batch_inference(self.get_audio_embeddings, audio_files, batch_size)
def get_text_embeddings_per_batch(self, class_labels, batch_size):
r"""Load preprocessed text and return text embeddings per batch"""
return self._generic_batch_inference(self.get_text_embeddings, class_labels, batch_size)
def classify_audio_files_per_batch(self, audio_files, class_labels, batch_size):
r"""Compute classification probabilities for each audio recording in a batch and each class label"""
return self._generic_batch_inference(self.classify_audio_files, audio_files, class_labels, batch_size)
View File
-3
View File
@@ -1,3 +0,0 @@
from . import clap
from . import audio
from . import utils
-52
View File
@@ -1,52 +0,0 @@
"""
This is an example using CLAP for zero-shot
inference using ESC50 (https://github.com/karolpiczak/ESC-50).
"""
from CLAPWrapper import CLAPWrapper
from esc50_dataset import ESC50
import torch.nn.functional as F
# Load ESC50 dataset
dataset = ESC50(root="data_path", download=True) # set download=True when dataset is not downloaded
audio_file, target, one_hot_target = dataset[1000]
audio_file = [audio_file]
prompt = 'this is a sound of '
y = [prompt + x for x in dataset.classes]
# Load and initialize CLAP
weights_path = "weights_path"
# Setting use_cuda = True will load the model on a GPU using CUDA
clap_model = CLAPWrapper(weights_path, use_cuda=False)
# compute text embeddings from natural text
text_embeddings = clap_model.get_text_embeddings(y)
# compute the audio embeddings from an audio file
audio_embeddings = clap_model.get_audio_embeddings(audio_file, resample=True)
# compute the similarity between audio_embeddings and text_embeddings
similarity = clap_model.compute_similarity(audio_embeddings, text_embeddings)
similarity = F.softmax(similarity, dim=1)
values, indices = similarity[0].topk(5)
# view the results
print("Ground Truth: {}".format(target))
print("Top predictions:\n")
for value, index in zip(values, indices):
print(f"{dataset.classes[index]:>16s}: {100 * value.item():.2f}%")
"""
The output (the exact numbers may vary):
Ground Truth: coughing
Top predictions:
coughing: 86.34%
sneezing: 9.30%
drinking sipping: 1.31%
laughing: 1.20%
glass breaking: 0.81%
"""