adding examples and minor fixes

This commit is contained in:
Soham Deshmukh
2022-11-08 10:55:31 -08:00
parent c69a553665
commit 0c99d4d4cd
5 changed files with 119 additions and 40 deletions
+25 -21
View File
@@ -36,6 +36,11 @@ class CLAPWrapper():
args = read_config_as_args(self.config_as_str, is_config_str=True)
if 'bert' in args.text_model:
self.token_keys = ['input_ids', 'token_type_ids', 'attention_mask']
else:
self.token_keys = ['input_ids', 'attention_mask']
clap = CLAP(
audioenc_name=args.audioenc_name,
sample_rate=args.sampling_rate,
@@ -154,31 +159,31 @@ class CLAPWrapper():
for ttext in text_queries:
tok = self.tokenizer.encode_plus(
text=ttext, add_special_tokens=True, max_length=self.args.text_len, pad_to_max_length=True, return_tensors="pt")
tok['input_ids'] = tok['input_ids'].reshape(-1).cuda(
) if self.use_cuda and torch.cuda.is_available() else tok['input_ids'].reshape(-1)
tok['token_type_ids'] = tok['token_type_ids'].reshape(-1).cuda(
) if self.use_cuda and torch.cuda.is_available() else tok['token_type_ids'].reshape(-1)
tok['token_type_ids'] = tok['token_type_ids'].reshape(-1).cuda(
) if self.use_cuda and torch.cuda.is_available() else tok['token_type_ids'].reshape(-1)
tok['attention_mask'] = tok['attention_mask'].reshape(-1).cuda(
) if self.use_cuda and torch.cuda.is_available() else tok['attention_mask'].reshape(-1)
for key in self.token_keys:
tok[key] = tok[key].reshape(-1).cuda() if self.use_cuda and torch.cuda.is_available() else tok[key].reshape(-1)
tokenized_texts.append(tok)
return self.default_collate(tokenized_texts)
def get_text_embeddings(self, class_labels):
r"""Load list of class labels and return text embeddings"""
preprocessed_text = self.preprocess_text(class_labels)
return self._get_text_embeddings(preprocessed_text)
text_embeddings = self._get_text_embeddings(preprocessed_text)
text_embeddings = text_embeddings/torch.norm(text_embeddings, dim=-1, keepdim=True)
return text_embeddings
def get_audio_embeddings(self, audio_files, resample):
r"""Load list of audio files and return a audio embeddings"""
preprocessed_audio = self.preprocess_audio(audio_files, resample)
return self._get_audio_embeddings(preprocessed_audio)
audio_embeddings = self._get_audio_embeddings(preprocessed_audio)
audio_embeddings = audio_embeddings/torch.norm(audio_embeddings, dim=-1, keepdim=True)
return audio_embeddings
def _get_text_embeddings(self, preprocessed_text):
r"""Load preprocessed text and return text embeddings"""
with torch.no_grad():
return self.clap.caption_encoder(preprocessed_text)
text_embeddings = self.clap.caption_encoder(preprocessed_text)
text_embeddings = text_embeddings/torch.norm(text_embeddings, dim=-1, keepdim=True)
return text_embeddings
def _get_audio_embeddings(self, preprocessed_audio):
r"""Load preprocessed audio and return a audio embeddings"""
@@ -186,7 +191,15 @@ class CLAPWrapper():
preprocessed_audio = preprocessed_audio.reshape(
preprocessed_audio.shape[0], preprocessed_audio.shape[2])
#Append [0] the audio emebdding, [1] has output class probabilities
return self.clap.audio_encoder(preprocessed_audio)[0]
audio_embeddings = self.clap.audio_encoder(preprocessed_audio)[0]
audio_embeddings = audio_embeddings/torch.norm(audio_embeddings, dim=-1, keepdim=True)
return audio_embeddings
def compute_similarity(self, audio_embeddings, text_embeddings):
r"""Compute similarity between text and audio embeddings"""
logit_scale = self.clap.logit_scale.exp()
similarity = logit_scale*text_embeddings @ audio_embeddings.T
return similarity.T
def _generic_batch_inference(self, func, *args):
r"""Process audio and/or text per batch"""
@@ -219,15 +232,6 @@ class CLAPWrapper():
r"""Load preprocessed text and return text embeddings per batch"""
return self._generic_batch_inference(self.get_text_embeddings, class_labels, batch_size)
def compute_similarity(self, audio_embeddings, text_embeddings):
r"""Compute similarity between text and audio embeddings"""
audio_embeddings = audio_embeddings/torch.norm(audio_embeddings, dim=-1, keepdim=True)
text_embeddings = text_embeddings/torch.norm(text_embeddings, dim=-1, keepdim=True)
logit_scale = self.clap.logit_scale.exp()
similarity = logit_scale*text_embeddings @ audio_embeddings.T
return similarity.T
def classify_audio_files_per_batch(self, audio_files, class_labels, batch_size):
r"""Compute classification probabilities for each audio recording in a batch and each class label"""
return self._generic_batch_inference(self.classify_audio_files, audio_files, class_labels, batch_size)