diff --git a/CLAP_API/CLAPWrapper.py b/CLAP_API/CLAPWrapper.py new file mode 100644 index 0000000..23e29fb --- /dev/null +++ b/CLAP_API/CLAPWrapper.py @@ -0,0 +1,233 @@ +import random +import torchaudio +from torch._six import string_classes +import collections +import re +import torch.nn.functional as F +import numpy as np +from transformers import AutoTokenizer +from .models.utils import read_config_as_args +from .models.clap import CLAP +import math +import torchaudio.transforms as T +import os +import torch +from importlib_resources import files, as_file + + +class CLAPWrapper(): + """ + A class for interfacing CLAP model. + """ + + def __init__(self, model_fp, use_cuda=False): + self.np_str_obj_array_pattern = re.compile(r'[SaUO]') + self.file_path = os.path.realpath(__file__) + self.default_collate_err_msg_format = ( + "default_collate: batch must contain tensors, numpy arrays, numbers, " + "dicts or lists; found {}") + self.config_as_str = files('CLAP_API.configs').joinpath('config.yml').read_text() + self.model_fp = model_fp + self.use_cuda = use_cuda + self.clap, self.tokenizer, self.args = self.load_clap() + + def load_clap(self): + r"""Load CLAP model with args from config file""" + + args = read_config_as_args(self.config_as_str, is_config_str=True) + + clap = CLAP( + audioenc_name=args.audioenc_name, + sample_rate=args.sampling_rate, + window_size=args.window_size, + hop_size=args.hop_size, + mel_bins=args.mel_bins, + fmin=args.fmin, + fmax=args.fmax, + classes_num=args.num_classes, + out_emb=args.out_emb, + text_model=args.text_model, + transformer_embed_dim=args.transformer_embed_dim, + d_proj=args.d_proj + ) + + # Load pretrained weights for model + model_state_dict = torch.load(self.model_fp, map_location=torch.device('cpu'))['model'] + clap.load_state_dict(model_state_dict) + + clap.eval() # set clap in eval mode + tokenizer = AutoTokenizer.from_pretrained(args.text_model) + + if self.use_cuda and torch.cuda.is_available(): + clap = clap.cuda() + + return clap, tokenizer, args + + def default_collate(self, batch): + r"""Puts each data field into a tensor with outer dimension batch size""" + elem = batch[0] + elem_type = type(elem) + if isinstance(elem, torch.Tensor): + out = None + if torch.utils.data.get_worker_info() is not None: + # If we're in a background process, concatenate directly into a + # shared memory tensor to avoid an extra copy + numel = sum([x.numel() for x in batch]) + storage = elem.storage()._new_shared(numel) + out = elem.new(storage) + return torch.stack(batch, 0, out=out) + elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ + and elem_type.__name__ != 'string_': + if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap': + # array of string classes and object + if self.np_str_obj_array_pattern.search(elem.dtype.str) is not None: + raise TypeError( + self.default_collate_err_msg_format.format(elem.dtype)) + + return self.default_collate([torch.as_tensor(b) for b in batch]) + elif elem.shape == (): # scalars + return torch.as_tensor(batch) + elif isinstance(elem, float): + return torch.tensor(batch, dtype=torch.float64) + elif isinstance(elem, int): + return torch.tensor(batch) + elif isinstance(elem, string_classes): + return batch + elif isinstance(elem, collections.abc.Mapping): + return {key: self.default_collate([d[key] for d in batch]) for key in elem} + elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple + return elem_type(*(self.default_collate(samples) for samples in zip(*batch))) + elif isinstance(elem, collections.abc.Sequence): + # check to make sure that the elements in batch have consistent size + it = iter(batch) + elem_size = len(next(it)) + if not all(len(elem) == elem_size for elem in it): + raise RuntimeError( + 'each element in list of batch should be of equal size') + transposed = zip(*batch) + return [self.default_collate(samples) for samples in transposed] + + raise TypeError(self.default_collate_err_msg_format.format(elem_type)) + + def load_audio_into_tensor(self, audio_path, audio_duration, resample=False): + r"""Loads audio file and returns raw audio.""" + # Randomly sample a segment of audio_duration from the clip or pad to match duration + audio_time_series, sample_rate = torchaudio.load(audio_path) + resample_rate = self.args.sampling_rate + if resample: + resampler = T.Resample(sample_rate, resample_rate) + audio_time_series = resampler(audio_time_series) + audio_time_series = audio_time_series.reshape(-1) + + # audio_time_series is shorter than predefined audio duration, + # so audio_time_series is extended + if audio_duration*sample_rate >= audio_time_series.shape[0]: + repeat_factor = int(np.ceil((audio_duration*sample_rate) / + audio_time_series.shape[0])) + # Repeat audio_time_series by repeat_factor to match audio_duration + audio_time_series = audio_time_series.repeat(repeat_factor) + # remove excess part of audio_time_series + audio_time_series = audio_time_series[0:audio_duration*sample_rate] + else: + # audio_time_series is longer than predefined audio duration, + # so audio_time_series is trimmed + start_index = random.randrange( + audio_time_series.shape[0] - audio_duration*sample_rate) + audio_time_series = audio_time_series[start_index:start_index + + audio_duration*sample_rate] + return torch.FloatTensor(audio_time_series) + + def preprocess_audio(self, audio_files, resample): + r"""Load list of audio files and return raw audio""" + audio_tensors = [] + for audio_file in audio_files: + audio_tensor = self.load_audio_into_tensor( + audio_file, self.args.duration, resample) + audio_tensor = audio_tensor.reshape( + 1, -1).cuda if self.use_cuda and torch.cuda.is_available() else audio_tensor.reshape(1, -1) + audio_tensors.append(audio_tensor) + return self.default_collate(audio_tensors) + + def preprocess_text(self, text_queries): + r"""Load list of class labels and return tokenized text""" + tokenized_texts = [] + for ttext in text_queries: + tok = self.tokenizer.encode_plus( + text=ttext, add_special_tokens=True, max_length=self.args.text_len, pad_to_max_length=True, return_tensors="pt") + tok['input_ids'] = tok['input_ids'].reshape(-1).cuda( + ) if self.use_cuda and torch.cuda.is_available() else tok['input_ids'].reshape(-1) + tok['token_type_ids'] = tok['token_type_ids'].reshape(-1).cuda( + ) if self.use_cuda and torch.cuda.is_available() else tok['token_type_ids'].reshape(-1) + tok['token_type_ids'] = tok['token_type_ids'].reshape(-1).cuda( + ) if self.use_cuda and torch.cuda.is_available() else tok['token_type_ids'].reshape(-1) + tok['attention_mask'] = tok['attention_mask'].reshape(-1).cuda( + ) if self.use_cuda and torch.cuda.is_available() else tok['attention_mask'].reshape(-1) + tokenized_texts.append(tok) + return self.default_collate(tokenized_texts) + + def get_text_embeddings(self, class_labels): + r"""Load list of class labels and return text embeddings""" + preprocessed_text = self.preprocess_text(class_labels) + return self._get_text_embeddings(preprocessed_text) + + def get_audio_embeddings(self, audio_files, resample): + r"""Load list of audio files and return a audio embeddings""" + preprocessed_audio = self.preprocess_audio(audio_files, resample) + return self._get_audio_embeddings(preprocessed_audio) + + def _get_text_embeddings(self, preprocessed_text): + r"""Load preprocessed text and return text embeddings""" + with torch.no_grad(): + return self.clap.caption_encoder(preprocessed_text) + + def _get_audio_embeddings(self, preprocessed_audio): + r"""Load preprocessed audio and return a audio embeddings""" + with torch.no_grad(): + preprocessed_audio = preprocessed_audio.reshape( + preprocessed_audio.shape[0], preprocessed_audio.shape[2]) + #Append [0] the audio emebdding, [1] has output class probabilities + return self.clap.audio_encoder(preprocessed_audio)[0] + + def _generic_batch_inference(self, func, *args): + r"""Process audio and/or text per batch""" + input_tmp = args[0] + batch_size = args[-1] + # args[0] has audio_files, args[1] has class_labels + inputs = [args[0], args[1]] if len(args) == 3 else [args[0]] + args0_len = len(args[0]) + # compute text_embeddings once for all the audio_files batches + if len(inputs) == 2: + text_embeddings = self.get_text_embeddings(args[1]) + inputs = [args[0], args[1], text_embeddings] + dataset_idx = 0 + for _ in range(math.ceil(args0_len/batch_size)): + next_batch_idx = dataset_idx + batch_size + # batch size is bigger than available audio/text items + if next_batch_idx >= args0_len: + inputs[0] = input_tmp[dataset_idx:] + return func(*tuple(inputs)) + else: + inputs[0] = input_tmp[dataset_idx:next_batch_idx] + yield func(*tuple(inputs)) + dataset_idx = next_batch_idx + + def get_audio_embeddings_per_batch(self, audio_files, batch_size): + r"""Load preprocessed audio and return a audio embeddings per batch""" + return self._generic_batch_inference(self.get_audio_embeddings, audio_files, batch_size) + + def get_text_embeddings_per_batch(self, class_labels, batch_size): + r"""Load preprocessed text and return text embeddings per batch""" + return self._generic_batch_inference(self.get_text_embeddings, class_labels, batch_size) + + def compute_similarity(self, audio_embeddings, text_embeddings): + r"""Compute similarity between text and audio embeddings""" + audio_embeddings = audio_embeddings/torch.norm(audio_embeddings, dim=-1, keepdim=True) + text_embeddings = text_embeddings/torch.norm(text_embeddings, dim=-1, keepdim=True) + + logit_scale = self.clap.logit_scale.exp() + similarity = logit_scale*text_embeddings @ audio_embeddings.T + return similarity.T + + def classify_audio_files_per_batch(self, audio_files, class_labels, batch_size): + r"""Compute classification probabilities for each audio recording in a batch and each class label""" + return self._generic_batch_inference(self.classify_audio_files, audio_files, class_labels, batch_size) diff --git a/CLAP_API/__init__.py b/CLAP_API/__init__.py new file mode 100644 index 0000000..bd3b45b --- /dev/null +++ b/CLAP_API/__init__.py @@ -0,0 +1 @@ +from .CLAPWrapper import CLAPWrapper as CLAP \ No newline at end of file diff --git a/CLAP_API b/CLAP_API/configs/__init__.py similarity index 100% rename from CLAP_API rename to CLAP_API/configs/__init__.py diff --git a/CLAP_API/configs/config.yml b/CLAP_API/configs/config.yml new file mode 100644 index 0000000..1b01939 --- /dev/null +++ b/CLAP_API/configs/config.yml @@ -0,0 +1,26 @@ +# TEXT ENCODER CONFIG +text_model: 'bert-base-uncased' +text_len: 100 +transformer_embed_dim: 768 +freeze_text_encoder_weights: True + +# AUDIO ENCODER CONFIG +audioenc_name: 'Cnn14' +out_emb: 2048 +sampling_rate: 44100 +duration: 5 +fmin: 50 +fmax: 14000 +n_fft: 1028 +hop_size: 320 +mel_bins: 64 +window_size: 1024 + +# PROJECTION SPACE CONFIG +d_proj: 1024 +temperature: 0.003 + +# TRAINING AND EVALUATION CONFIG +num_classes: 527 +batch_size: 1024 +demo: False diff --git a/CLAP_API/models/__init__.py b/CLAP_API/models/__init__.py new file mode 100644 index 0000000..aadad97 --- /dev/null +++ b/CLAP_API/models/__init__.py @@ -0,0 +1,3 @@ +from . import clap +from . import audio +from . import utils \ No newline at end of file diff --git a/CLAP_API/models/audio.py b/CLAP_API/models/audio.py new file mode 100644 index 0000000..0980d72 --- /dev/null +++ b/CLAP_API/models/audio.py @@ -0,0 +1,179 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchlibrosa.stft import Spectrogram, LogmelFilterBank + +def get_audio_encoder(name: str): + if name == "Cnn14": + return Cnn14 + else: + raise Exception('The audio encoder name {} is incorrect or not supported'.format(name)) + + +class ConvBlock(nn.Module): + def __init__(self, in_channels, out_channels): + + super(ConvBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=(3, 3), stride=(1, 1), + padding=(1, 1), bias=False) + + self.conv2 = nn.Conv2d(in_channels=out_channels, + out_channels=out_channels, + kernel_size=(3, 3), stride=(1, 1), + padding=(1, 1), bias=False) + + self.bn1 = nn.BatchNorm2d(out_channels) + self.bn2 = nn.BatchNorm2d(out_channels) + + + def forward(self, input, pool_size=(2, 2), pool_type='avg'): + + x = input + x = F.relu_(self.bn1(self.conv1(x))) + x = F.relu_(self.bn2(self.conv2(x))) + if pool_type == 'max': + x = F.max_pool2d(x, kernel_size=pool_size) + elif pool_type == 'avg': + x = F.avg_pool2d(x, kernel_size=pool_size) + elif pool_type == 'avg+max': + x1 = F.avg_pool2d(x, kernel_size=pool_size) + x2 = F.max_pool2d(x, kernel_size=pool_size) + x = x1 + x2 + else: + raise Exception('Incorrect argument!') + + return x + + +class ConvBlock5x5(nn.Module): + def __init__(self, in_channels, out_channels): + + super(ConvBlock5x5, self).__init__() + + self.conv1 = nn.Conv2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=(5, 5), stride=(1, 1), + padding=(2, 2), bias=False) + + self.bn1 = nn.BatchNorm2d(out_channels) + + + def forward(self, input, pool_size=(2, 2), pool_type='avg'): + + x = input + x = F.relu_(self.bn1(self.conv1(x))) + if pool_type == 'max': + x = F.max_pool2d(x, kernel_size=pool_size) + elif pool_type == 'avg': + x = F.avg_pool2d(x, kernel_size=pool_size) + elif pool_type == 'avg+max': + x1 = F.avg_pool2d(x, kernel_size=pool_size) + x2 = F.max_pool2d(x, kernel_size=pool_size) + x = x1 + x2 + else: + raise Exception('Incorrect argument!') + + return x + + +class AttBlock(nn.Module): + def __init__(self, n_in, n_out, activation='linear', temperature=1.): + super(AttBlock, self).__init__() + + self.activation = activation + self.temperature = temperature + self.att = nn.Conv1d(in_channels=n_in, out_channels=n_out, kernel_size=1, stride=1, padding=0, bias=True) + self.cla = nn.Conv1d(in_channels=n_in, out_channels=n_out, kernel_size=1, stride=1, padding=0, bias=True) + + self.bn_att = nn.BatchNorm1d(n_out) + + def forward(self, x): + # x: (n_samples, n_in, n_time) + norm_att = torch.softmax(torch.clamp(self.att(x), -10, 10), dim=-1) + cla = self.nonlinear_transform(self.cla(x)) + x = torch.sum(norm_att * cla, dim=2) + return x, norm_att, cla + + def nonlinear_transform(self, x): + if self.activation == 'linear': + return x + elif self.activation == 'sigmoid': + return torch.sigmoid(x) + + +class Cnn14(nn.Module): + def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, + fmax, classes_num, out_emb): + + super(Cnn14, self).__init__() + + window = 'hann' + center = True + pad_mode = 'reflect' + ref = 1.0 + amin = 1e-10 + top_db = None + + # Spectrogram extractor + self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size, + win_length=window_size, window=window, center=center, pad_mode=pad_mode, + freeze_parameters=True) + + # Logmel feature extractor + self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size, + n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db, + freeze_parameters=True) + + self.bn0 = nn.BatchNorm2d(64) + + self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) + self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) + self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) + self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) + self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) + self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) + + # out_emb is 2048 for best Cnn14 + self.fc1 = nn.Linear(2048, out_emb, bias=True) + self.fc_audioset = nn.Linear(out_emb, classes_num, bias=True) + + def forward(self, input, mixup_lambda=None): + """ + Input: (batch_size, data_length) + """ + + x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) + x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) + + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + + x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg') + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg') + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg') + x = F.dropout(x, p=0.2, training=self.training) + x = torch.mean(x, dim=3) + + (x1, _) = torch.max(x, dim=2) + x2 = torch.mean(x, dim=2) + x = x1 + x2 + x = F.dropout(x, p=0.5, training=self.training) + x = F.relu_(self.fc1(x)) + embedding = F.dropout(x, p=0.5, training=self.training) + clipwise_output = torch.sigmoid(self.fc_audioset(x)) + + output_dict = {'clipwise_output': clipwise_output, 'embedding': embedding} + + return output_dict \ No newline at end of file diff --git a/CLAP_API/models/clap.py b/CLAP_API/models/clap.py new file mode 100644 index 0000000..03027d3 --- /dev/null +++ b/CLAP_API/models/clap.py @@ -0,0 +1,90 @@ +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from transformers import AutoModel +from .audio import get_audio_encoder + +class Projection(nn.Module): + def __init__(self, d_in: int, d_out: int, p: float=0.5) -> None: + super().__init__() + self.linear1 = nn.Linear(d_in, d_out, bias=False) + self.linear2 = nn.Linear(d_out, d_out, bias=False) + self.layer_norm = nn.LayerNorm(d_out) + self.drop = nn.Dropout(p) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + embed1 = self.linear1(x) + embed2 = self.drop(self.linear2(F.gelu(embed1))) + embeds = self.layer_norm(embed1 + embed2) + return embeds + +class AudioEncoder(nn.Module): + def __init__(self, audioenc_name:str, d_in: int, d_out: int, sample_rate: int, window_size: int, + hop_size: int, mel_bins: int, fmin: int, fmax: int, classes_num: int) -> None: + super().__init__() + + audio_encoder = get_audio_encoder(audioenc_name) + + self.base = audio_encoder( + sample_rate, window_size, + hop_size, mel_bins, fmin, fmax, + classes_num, d_in) + + self.projection = Projection(d_in, d_out) + + def forward(self, x): + out_dict = self.base(x) + audio_features, audio_classification_output = out_dict['embedding'], out_dict['clipwise_output'] + projected_vec = self.projection(audio_features) + return projected_vec, audio_classification_output + +class TextEncoder(nn.Module): + def __init__(self, d_out: int, text_model: str, transformer_embed_dim: int) -> None: + super().__init__() + self.base = AutoModel.from_pretrained(text_model) + + self.projection = Projection(transformer_embed_dim, d_out) + + def forward(self, x): + out = self.base(**x)[0] + out = out[:, 0, :] # get CLS token output + projected_vec = self.projection(out) + return projected_vec + +class CLAP(nn.Module): + def __init__(self, + # audio + audioenc_name: str, + sample_rate: int, + window_size: int, + hop_size: int, + mel_bins: int, + fmin: int, + fmax: int, + classes_num: int, + out_emb: int, + # text + text_model: str, + transformer_embed_dim: int, + # common + d_proj: int, + ): + super().__init__() + + + self.audio_encoder = AudioEncoder( + audioenc_name, out_emb, d_proj, + sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num) + + self.caption_encoder = TextEncoder( + d_proj, text_model, transformer_embed_dim + ) + + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + def forward(self, audio, text): + audio_embed, _ = self.audio_encoder(audio) + caption_embed = self.caption_encoder(text) + + return caption_embed, audio_embed, self.logit_scale.exp() \ No newline at end of file diff --git a/CLAP_API/models/utils.py b/CLAP_API/models/utils.py new file mode 100644 index 0000000..f95931f --- /dev/null +++ b/CLAP_API/models/utils.py @@ -0,0 +1,26 @@ +import argparse +import yaml +import sys + +def read_config_as_args(config_path,args=None,is_config_str=False): + return_dict = {} + + if config_path is not None: + if is_config_str: + yml_config = yaml.load(config_path, Loader=yaml.FullLoader) + else: + with open(config_path, "r") as f: + yml_config = yaml.load(f, Loader=yaml.FullLoader) + + if args != None: + for k, v in yml_config.items(): + if k in args.__dict__: + args.__dict__[k] = v + else: + sys.stderr.write("Ignored unknown parameter {} in yaml.\n".format(k)) + else: + for k, v in yml_config.items(): + return_dict[k] = v + + args = args if args != None else return_dict + return argparse.Namespace(**args)