from pathlib import Path import warnings warnings.filterwarnings("ignore") import random import torchaudio import collections import re import numpy as np from transformers import AutoTokenizer, logging from .models.clap import CLAP from .models.mapper import get_clapcap import math import torchaudio.transforms as T import os import torch import argparse import yaml import sys logging.set_verbosity_error() class CLAPWrapper(): """ A class for interfacing CLAP model. """ def __init__(self, model_fp, version, use_cuda=False): self.supported_versions = ['2022', '2023', 'clapcap'] self.np_str_obj_array_pattern = re.compile(r'[SaUO]') self.file_path = os.path.realpath(__file__) self.default_collate_err_msg_format = ( "default_collate: batch must contain tensors, numpy arrays, numbers, " "dicts or lists; found {}") self.config_as_str = self.get_config_path(version) self.model_fp = model_fp self.use_cuda = use_cuda if 'clapcap' in version: self.clapcap, self.tokenizer, self.args = self.load_clapcap() else: self.clap, self.tokenizer, self.args = self.load_clap() def get_config_path(self, version): if version in self.supported_versions: return (Path(__file__).parent / f"configs/config_{version}.yml").read_text() else: raise ValueError(f"The specific version is not supported. The supported versions are {str(self.supported_versions)}") def read_config_as_args(self,config_path,args=None,is_config_str=False): return_dict = {} if config_path is not None: if is_config_str: yml_config = yaml.load(config_path, Loader=yaml.FullLoader) else: with open(config_path, "r") as f: yml_config = yaml.load(f, Loader=yaml.FullLoader) if args != None: for k, v in yml_config.items(): if k in args.__dict__: args.__dict__[k] = v else: sys.stderr.write("Ignored unknown parameter {} in yaml.\n".format(k)) else: for k, v in yml_config.items(): return_dict[k] = v args = args if args != None else return_dict return argparse.Namespace(**args) def load_clap(self): r"""Load CLAP model with args from config file""" args = self.read_config_as_args(self.config_as_str, is_config_str=True) if 'roberta' in args.text_model or 'clip' in args.text_model or 'gpt' in args.text_model: self.token_keys = ['input_ids', 'attention_mask'] elif 'bert' in args.text_model: self.token_keys = ['input_ids', 'token_type_ids', 'attention_mask'] clap = CLAP( audioenc_name=args.audioenc_name, sample_rate=args.sampling_rate, window_size=args.window_size, hop_size=args.hop_size, mel_bins=args.mel_bins, fmin=args.fmin, fmax=args.fmax, classes_num=args.num_classes, out_emb=args.out_emb, text_model=args.text_model, transformer_embed_dim=args.transformer_embed_dim, d_proj=args.d_proj ) # Load pretrained weights for model model_state_dict = torch.load(self.model_fp, map_location=torch.device('cpu'))['model'] # We unwrap the DDP model and save. If the model is not unwrapped and saved, then the model needs to unwrapped before `load_state_dict`: # Reference link: https://discuss.pytorch.org/t/how-to-load-dataparallel-model-which-trained-using-multiple-gpus/146005 clap.load_state_dict(model_state_dict, strict=False) clap.eval() # set clap in eval mode tokenizer = AutoTokenizer.from_pretrained(args.text_model) if 'gpt' in args.text_model: tokenizer.add_special_tokens({'pad_token': '!'}) if self.use_cuda and torch.cuda.is_available(): clap = clap.cuda() return clap, tokenizer, args def load_clapcap(self): r"""Load CLAP model with args from config file""" args = self.read_config_as_args(self.config_as_str, is_config_str=True) args.prefix_dim = args.d_proj text_model = args.text_model args.text_model = args.text_decoder args.cross_attention = True if 'cross' in args.clapcap_model.lower() else False if 'roberta' in args.text_model or 'clip' in args.text_model or 'gpt' in args.text_model: self.token_keys = ['input_ids', 'attention_mask'] elif 'bert' in args.text_model: self.token_keys = ['input_ids', 'token_type_ids', 'attention_mask'] clap = CLAP( audioenc_name=args.audioenc_name, sample_rate=args.sampling_rate, window_size=args.window_size, hop_size=args.hop_size, mel_bins=args.mel_bins, fmin=args.fmin, fmax=args.fmax, classes_num=args.num_classes, out_emb=args.out_emb, text_model=text_model, transformer_embed_dim=args.transformer_embed_dim, d_proj=args.d_proj ) clapcap = get_clapcap(args.clapcap_model)(clap, args.text_decoder, args.prefix_length, args.prefix_length_clip, args.prefix_dim, args.num_layers, args.normalize_prefix, args.mapping_type, True, True) model_state_dict = torch.load(self.model_fp, map_location=torch.device('cpu'))['model'] clapcap.load_state_dict(model_state_dict) clapcap.eval() # set clap in eval mode tokenizer = AutoTokenizer.from_pretrained(args.text_model) if 'gpt' in args.text_model: tokenizer.add_special_tokens({'pad_token': '!'}) if self.use_cuda and torch.cuda.is_available(): clapcap = clapcap.cuda() return clapcap, tokenizer, args def default_collate(self, batch): r"""Puts each data field into a tensor with outer dimension batch size""" elem = batch[0] elem_type = type(elem) if isinstance(elem, torch.Tensor): out = None if torch.utils.data.get_worker_info() is not None: # If we're in a background process, concatenate directly into a # shared memory tensor to avoid an extra copy numel = sum([x.numel() for x in batch]) storage = elem.storage()._new_shared(numel) out = elem.new(storage) return torch.stack(batch, 0, out=out) elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ and elem_type.__name__ != 'string_': if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap': # array of string classes and object if self.np_str_obj_array_pattern.search(elem.dtype.str) is not None: raise TypeError( self.default_collate_err_msg_format.format(elem.dtype)) return self.default_collate([torch.as_tensor(b) for b in batch]) elif elem.shape == (): # scalars return torch.as_tensor(batch) elif isinstance(elem, float): return torch.tensor(batch, dtype=torch.float64) elif isinstance(elem, int): return torch.tensor(batch) elif isinstance(elem, str): return batch elif isinstance(elem, collections.abc.Mapping): return {key: self.default_collate([d[key] for d in batch]) for key in elem} elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple return elem_type(*(self.default_collate(samples) for samples in zip(*batch))) elif isinstance(elem, collections.abc.Sequence): # check to make sure that the elements in batch have consistent size it = iter(batch) elem_size = len(next(it)) if not all(len(elem) == elem_size for elem in it): raise RuntimeError( 'each element in list of batch should be of equal size') transposed = zip(*batch) return [self.default_collate(samples) for samples in transposed] raise TypeError(self.default_collate_err_msg_format.format(elem_type)) def read_audio(self, audio_path, resample=True): r"""Loads audio file or array and returns a torch tensor""" # Randomly sample a segment of audio_duration from the clip or pad to match duration audio_time_series, sample_rate = torchaudio.load(audio_path) resample_rate = self.args.sampling_rate if resample and resample_rate != sample_rate: resampler = T.Resample(sample_rate, resample_rate) audio_time_series = resampler(audio_time_series) return audio_time_series, resample_rate def load_audio_into_tensor(self, audio_path, audio_duration, resample=False): r"""Loads audio file and returns raw audio.""" # Randomly sample a segment of audio_duration from the clip or pad to match duration audio_time_series, sample_rate = self.read_audio(audio_path, resample=resample) audio_time_series = audio_time_series.reshape(-1) # audio_time_series is shorter than predefined audio duration, # so audio_time_series is extended if audio_duration*sample_rate >= audio_time_series.shape[0]: repeat_factor = int(np.ceil((audio_duration*sample_rate) / audio_time_series.shape[0])) # Repeat audio_time_series by repeat_factor to match audio_duration audio_time_series = audio_time_series.repeat(repeat_factor) # remove excess part of audio_time_series audio_time_series = audio_time_series[0:audio_duration*sample_rate] else: # audio_time_series is longer than predefined audio duration, # so audio_time_series is trimmed start_index = random.randrange( audio_time_series.shape[0] - audio_duration*sample_rate) audio_time_series = audio_time_series[start_index:start_index + audio_duration*sample_rate] return torch.FloatTensor(audio_time_series) def preprocess_audio(self, audio_files, resample): r"""Load list of audio files and return raw audio""" audio_tensors = [] for audio_file in audio_files: audio_tensor = self.load_audio_into_tensor( audio_file, self.args.duration, resample) audio_tensor = audio_tensor.reshape( 1, -1).cuda() if self.use_cuda and torch.cuda.is_available() else audio_tensor.reshape(1, -1) audio_tensors.append(audio_tensor) return self.default_collate(audio_tensors) def preprocess_text(self, text_queries): r"""Load list of class labels and return tokenized text""" tokenized_texts = [] for ttext in text_queries: if 'gpt' in self.args.text_model: ttext = ttext + ' <|endoftext|>' tok = self.tokenizer.encode_plus( text=ttext, add_special_tokens=True, max_length=self.args.text_len, padding='max_length', return_tensors="pt") for key in self.token_keys: tok[key] = tok[key].reshape(-1).cuda() if self.use_cuda and torch.cuda.is_available() else tok[key].reshape(-1) tokenized_texts.append(tok) return self.default_collate(tokenized_texts) def get_text_embeddings(self, class_labels): r"""Load list of class labels and return text embeddings""" preprocessed_text = self.preprocess_text(class_labels) return self._get_text_embeddings(preprocessed_text) def get_audio_embeddings(self, audio_files, resample=True): r"""Load list of audio files and return a audio embeddings""" preprocessed_audio = self.preprocess_audio(audio_files, resample) return self._get_audio_embeddings(preprocessed_audio) def _get_text_embeddings(self, preprocessed_text): r"""Load preprocessed text and return text embeddings""" with torch.no_grad(): return self.clap.caption_encoder(preprocessed_text) def _get_audio_embeddings(self, preprocessed_audio): r"""Load preprocessed audio and return a audio embeddings""" with torch.no_grad(): preprocessed_audio = preprocessed_audio.reshape( preprocessed_audio.shape[0], preprocessed_audio.shape[2]) #Append [0] the audio emebdding, [1] has output class probabilities return self.clap.audio_encoder(preprocessed_audio)[0] def _generic_batch_inference(self, func, *args): r"""Process audio and/or text per batch""" input_tmp = args[0] batch_size = args[-1] # args[0] has audio_files, args[1] has class_labels inputs = [args[0], args[1]] if len(args) == 3 else [args[0]] args0_len = len(args[0]) # compute text_embeddings once for all the audio_files batches if len(inputs) == 2: text_embeddings = self.get_text_embeddings(args[1]) inputs = [args[0], args[1], text_embeddings] dataset_idx = 0 for _ in range(math.ceil(args0_len/batch_size)): next_batch_idx = dataset_idx + batch_size # batch size is bigger than available audio/text items if next_batch_idx >= args0_len: inputs[0] = input_tmp[dataset_idx:] return func(*tuple(inputs)) else: inputs[0] = input_tmp[dataset_idx:next_batch_idx] yield func(*tuple(inputs)) dataset_idx = next_batch_idx def get_audio_embeddings_per_batch(self, audio_files, batch_size): r"""Load preprocessed audio and return a audio embeddings per batch""" return self._generic_batch_inference(self.get_audio_embeddings, audio_files, batch_size) def get_text_embeddings_per_batch(self, class_labels, batch_size): r"""Load preprocessed text and return text embeddings per batch""" return self._generic_batch_inference(self.get_text_embeddings, class_labels, batch_size) def compute_similarity(self, audio_embeddings, text_embeddings): r"""Compute similarity between text and audio embeddings""" audio_embeddings = audio_embeddings/torch.norm(audio_embeddings, dim=-1, keepdim=True) text_embeddings = text_embeddings/torch.norm(text_embeddings, dim=-1, keepdim=True) logit_scale = self.clap.logit_scale.exp() similarity = logit_scale*text_embeddings @ audio_embeddings.T return similarity.T def classify_audio_files_per_batch(self, audio_files, class_labels, batch_size): r"""Compute classification probabilities for each audio recording in a batch and each class label""" return self._generic_batch_inference(self.classify_audio_files, audio_files, class_labels, batch_size) def generate_caption(self, audio_files, resample=True, beam_size: int = 5, entry_length=67, temperature=1.): r"""Generate audio captions for each audio recording in a batch""" captions = [] audio_tensors = self.preprocess_audio(audio_files, resample) with torch.no_grad(): prefix = self.clapcap.clap(audio_tensors.squeeze(1))[0] if self.args.normalize_prefix: prefix = prefix / prefix.norm(2, -1).reshape(-1,1) prefix_embed = self.clapcap.clap_project(prefix).view(-1, self.args.prefix_length, self.clapcap.gpt.transformer.wte.weight.shape[1]) for i in range(len(audio_tensors)): gen_caption = self._generate_beam(embed=prefix_embed[i].unsqueeze(0),\ beam_size=beam_size,\ entry_length=entry_length,\ temperature=temperature)[0] captions.append(gen_caption.capitalize()) return captions def _generate_beam(self, beam_size: int = 5, prompt=None, embed=None, entry_length=67, temperature=1., stop_token: str = ' <|endoftext|>'): r"""Generate captions by beam search decoding""" self.clapcap.eval() stop_token_index = self.tokenizer.encode(stop_token)[0] tokens = None scores = None device = next(self.clapcap.parameters()).device seq_lengths = torch.ones(beam_size, device=device) is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool) with torch.no_grad(): if embed is not None: generated = embed else: if tokens is None: tokens = torch.tensor(self.tokenizer.encode(prompt)) tokens = tokens.unsqueeze(0).to(device) generated = self.clapcap.gpt.transformer.wte(tokens) for i in range(entry_length): outputs = self.clapcap.gpt(inputs_embeds=generated) logits = outputs.logits logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0) logits = logits.softmax(-1).log() if scores is None: scores, next_tokens = logits.topk(beam_size, -1) generated = generated.expand(beam_size, *generated.shape[1:]) next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0) if tokens is None: tokens = next_tokens else: tokens = tokens.expand(beam_size, *tokens.shape[1:]) tokens = torch.cat((tokens, next_tokens), dim=1) else: logits[is_stopped] = -float(np.inf) logits[is_stopped, 0] = 0 scores_sum = scores[:, None] + logits seq_lengths[~is_stopped] += 1 scores_sum_average = scores_sum / seq_lengths[:, None] scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1) next_tokens_source = next_tokens // scores_sum.shape[1] seq_lengths = seq_lengths[next_tokens_source] next_tokens = next_tokens % scores_sum.shape[1] next_tokens = next_tokens.unsqueeze(1) tokens = tokens[next_tokens_source] tokens = torch.cat((tokens, next_tokens), dim=1) generated = generated[next_tokens_source] scores = scores_sum_average * seq_lengths is_stopped = is_stopped[next_tokens_source] next_token_embed = self.clapcap.gpt.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1) generated = torch.cat((generated, next_token_embed), dim=1) is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze() if is_stopped.all(): break scores = scores / seq_lengths output_list = tokens.cpu().numpy() output_texts = [self.tokenizer.decode(output[:int(length)]) for output, length in zip(output_list, seq_lengths)] order = scores.argsort(descending=True) output_texts = [output_texts[i] for i in order] return output_texts