adding examples and minor fixes
This commit is contained in:
+25
-21
@@ -36,6 +36,11 @@ class CLAPWrapper():
|
||||
|
||||
args = read_config_as_args(self.config_as_str, is_config_str=True)
|
||||
|
||||
if 'bert' in args.text_model:
|
||||
self.token_keys = ['input_ids', 'token_type_ids', 'attention_mask']
|
||||
else:
|
||||
self.token_keys = ['input_ids', 'attention_mask']
|
||||
|
||||
clap = CLAP(
|
||||
audioenc_name=args.audioenc_name,
|
||||
sample_rate=args.sampling_rate,
|
||||
@@ -154,31 +159,31 @@ class CLAPWrapper():
|
||||
for ttext in text_queries:
|
||||
tok = self.tokenizer.encode_plus(
|
||||
text=ttext, add_special_tokens=True, max_length=self.args.text_len, pad_to_max_length=True, return_tensors="pt")
|
||||
tok['input_ids'] = tok['input_ids'].reshape(-1).cuda(
|
||||
) if self.use_cuda and torch.cuda.is_available() else tok['input_ids'].reshape(-1)
|
||||
tok['token_type_ids'] = tok['token_type_ids'].reshape(-1).cuda(
|
||||
) if self.use_cuda and torch.cuda.is_available() else tok['token_type_ids'].reshape(-1)
|
||||
tok['token_type_ids'] = tok['token_type_ids'].reshape(-1).cuda(
|
||||
) if self.use_cuda and torch.cuda.is_available() else tok['token_type_ids'].reshape(-1)
|
||||
tok['attention_mask'] = tok['attention_mask'].reshape(-1).cuda(
|
||||
) if self.use_cuda and torch.cuda.is_available() else tok['attention_mask'].reshape(-1)
|
||||
for key in self.token_keys:
|
||||
tok[key] = tok[key].reshape(-1).cuda() if self.use_cuda and torch.cuda.is_available() else tok[key].reshape(-1)
|
||||
tokenized_texts.append(tok)
|
||||
return self.default_collate(tokenized_texts)
|
||||
|
||||
def get_text_embeddings(self, class_labels):
|
||||
r"""Load list of class labels and return text embeddings"""
|
||||
preprocessed_text = self.preprocess_text(class_labels)
|
||||
return self._get_text_embeddings(preprocessed_text)
|
||||
text_embeddings = self._get_text_embeddings(preprocessed_text)
|
||||
text_embeddings = text_embeddings/torch.norm(text_embeddings, dim=-1, keepdim=True)
|
||||
return text_embeddings
|
||||
|
||||
def get_audio_embeddings(self, audio_files, resample):
|
||||
r"""Load list of audio files and return a audio embeddings"""
|
||||
preprocessed_audio = self.preprocess_audio(audio_files, resample)
|
||||
return self._get_audio_embeddings(preprocessed_audio)
|
||||
audio_embeddings = self._get_audio_embeddings(preprocessed_audio)
|
||||
audio_embeddings = audio_embeddings/torch.norm(audio_embeddings, dim=-1, keepdim=True)
|
||||
return audio_embeddings
|
||||
|
||||
def _get_text_embeddings(self, preprocessed_text):
|
||||
r"""Load preprocessed text and return text embeddings"""
|
||||
with torch.no_grad():
|
||||
return self.clap.caption_encoder(preprocessed_text)
|
||||
text_embeddings = self.clap.caption_encoder(preprocessed_text)
|
||||
text_embeddings = text_embeddings/torch.norm(text_embeddings, dim=-1, keepdim=True)
|
||||
return text_embeddings
|
||||
|
||||
def _get_audio_embeddings(self, preprocessed_audio):
|
||||
r"""Load preprocessed audio and return a audio embeddings"""
|
||||
@@ -186,7 +191,15 @@ class CLAPWrapper():
|
||||
preprocessed_audio = preprocessed_audio.reshape(
|
||||
preprocessed_audio.shape[0], preprocessed_audio.shape[2])
|
||||
#Append [0] the audio emebdding, [1] has output class probabilities
|
||||
return self.clap.audio_encoder(preprocessed_audio)[0]
|
||||
audio_embeddings = self.clap.audio_encoder(preprocessed_audio)[0]
|
||||
audio_embeddings = audio_embeddings/torch.norm(audio_embeddings, dim=-1, keepdim=True)
|
||||
return audio_embeddings
|
||||
|
||||
def compute_similarity(self, audio_embeddings, text_embeddings):
|
||||
r"""Compute similarity between text and audio embeddings"""
|
||||
logit_scale = self.clap.logit_scale.exp()
|
||||
similarity = logit_scale*text_embeddings @ audio_embeddings.T
|
||||
return similarity.T
|
||||
|
||||
def _generic_batch_inference(self, func, *args):
|
||||
r"""Process audio and/or text per batch"""
|
||||
@@ -219,15 +232,6 @@ class CLAPWrapper():
|
||||
r"""Load preprocessed text and return text embeddings per batch"""
|
||||
return self._generic_batch_inference(self.get_text_embeddings, class_labels, batch_size)
|
||||
|
||||
def compute_similarity(self, audio_embeddings, text_embeddings):
|
||||
r"""Compute similarity between text and audio embeddings"""
|
||||
audio_embeddings = audio_embeddings/torch.norm(audio_embeddings, dim=-1, keepdim=True)
|
||||
text_embeddings = text_embeddings/torch.norm(text_embeddings, dim=-1, keepdim=True)
|
||||
|
||||
logit_scale = self.clap.logit_scale.exp()
|
||||
similarity = logit_scale*text_embeddings @ audio_embeddings.T
|
||||
return similarity.T
|
||||
|
||||
def classify_audio_files_per_batch(self, audio_files, class_labels, batch_size):
|
||||
r"""Compute classification probabilities for each audio recording in a batch and each class label"""
|
||||
return self._generic_batch_inference(self.classify_audio_files, audio_files, class_labels, batch_size)
|
||||
|
||||
Reference in New Issue
Block a user