adding examples and minor fixes
This commit is contained in:
+25
-21
@@ -36,6 +36,11 @@ class CLAPWrapper():
|
||||
|
||||
args = read_config_as_args(self.config_as_str, is_config_str=True)
|
||||
|
||||
if 'bert' in args.text_model:
|
||||
self.token_keys = ['input_ids', 'token_type_ids', 'attention_mask']
|
||||
else:
|
||||
self.token_keys = ['input_ids', 'attention_mask']
|
||||
|
||||
clap = CLAP(
|
||||
audioenc_name=args.audioenc_name,
|
||||
sample_rate=args.sampling_rate,
|
||||
@@ -154,31 +159,31 @@ class CLAPWrapper():
|
||||
for ttext in text_queries:
|
||||
tok = self.tokenizer.encode_plus(
|
||||
text=ttext, add_special_tokens=True, max_length=self.args.text_len, pad_to_max_length=True, return_tensors="pt")
|
||||
tok['input_ids'] = tok['input_ids'].reshape(-1).cuda(
|
||||
) if self.use_cuda and torch.cuda.is_available() else tok['input_ids'].reshape(-1)
|
||||
tok['token_type_ids'] = tok['token_type_ids'].reshape(-1).cuda(
|
||||
) if self.use_cuda and torch.cuda.is_available() else tok['token_type_ids'].reshape(-1)
|
||||
tok['token_type_ids'] = tok['token_type_ids'].reshape(-1).cuda(
|
||||
) if self.use_cuda and torch.cuda.is_available() else tok['token_type_ids'].reshape(-1)
|
||||
tok['attention_mask'] = tok['attention_mask'].reshape(-1).cuda(
|
||||
) if self.use_cuda and torch.cuda.is_available() else tok['attention_mask'].reshape(-1)
|
||||
for key in self.token_keys:
|
||||
tok[key] = tok[key].reshape(-1).cuda() if self.use_cuda and torch.cuda.is_available() else tok[key].reshape(-1)
|
||||
tokenized_texts.append(tok)
|
||||
return self.default_collate(tokenized_texts)
|
||||
|
||||
def get_text_embeddings(self, class_labels):
|
||||
r"""Load list of class labels and return text embeddings"""
|
||||
preprocessed_text = self.preprocess_text(class_labels)
|
||||
return self._get_text_embeddings(preprocessed_text)
|
||||
text_embeddings = self._get_text_embeddings(preprocessed_text)
|
||||
text_embeddings = text_embeddings/torch.norm(text_embeddings, dim=-1, keepdim=True)
|
||||
return text_embeddings
|
||||
|
||||
def get_audio_embeddings(self, audio_files, resample):
|
||||
r"""Load list of audio files and return a audio embeddings"""
|
||||
preprocessed_audio = self.preprocess_audio(audio_files, resample)
|
||||
return self._get_audio_embeddings(preprocessed_audio)
|
||||
audio_embeddings = self._get_audio_embeddings(preprocessed_audio)
|
||||
audio_embeddings = audio_embeddings/torch.norm(audio_embeddings, dim=-1, keepdim=True)
|
||||
return audio_embeddings
|
||||
|
||||
def _get_text_embeddings(self, preprocessed_text):
|
||||
r"""Load preprocessed text and return text embeddings"""
|
||||
with torch.no_grad():
|
||||
return self.clap.caption_encoder(preprocessed_text)
|
||||
text_embeddings = self.clap.caption_encoder(preprocessed_text)
|
||||
text_embeddings = text_embeddings/torch.norm(text_embeddings, dim=-1, keepdim=True)
|
||||
return text_embeddings
|
||||
|
||||
def _get_audio_embeddings(self, preprocessed_audio):
|
||||
r"""Load preprocessed audio and return a audio embeddings"""
|
||||
@@ -186,7 +191,15 @@ class CLAPWrapper():
|
||||
preprocessed_audio = preprocessed_audio.reshape(
|
||||
preprocessed_audio.shape[0], preprocessed_audio.shape[2])
|
||||
#Append [0] the audio emebdding, [1] has output class probabilities
|
||||
return self.clap.audio_encoder(preprocessed_audio)[0]
|
||||
audio_embeddings = self.clap.audio_encoder(preprocessed_audio)[0]
|
||||
audio_embeddings = audio_embeddings/torch.norm(audio_embeddings, dim=-1, keepdim=True)
|
||||
return audio_embeddings
|
||||
|
||||
def compute_similarity(self, audio_embeddings, text_embeddings):
|
||||
r"""Compute similarity between text and audio embeddings"""
|
||||
logit_scale = self.clap.logit_scale.exp()
|
||||
similarity = logit_scale*text_embeddings @ audio_embeddings.T
|
||||
return similarity.T
|
||||
|
||||
def _generic_batch_inference(self, func, *args):
|
||||
r"""Process audio and/or text per batch"""
|
||||
@@ -219,15 +232,6 @@ class CLAPWrapper():
|
||||
r"""Load preprocessed text and return text embeddings per batch"""
|
||||
return self._generic_batch_inference(self.get_text_embeddings, class_labels, batch_size)
|
||||
|
||||
def compute_similarity(self, audio_embeddings, text_embeddings):
|
||||
r"""Compute similarity between text and audio embeddings"""
|
||||
audio_embeddings = audio_embeddings/torch.norm(audio_embeddings, dim=-1, keepdim=True)
|
||||
text_embeddings = text_embeddings/torch.norm(text_embeddings, dim=-1, keepdim=True)
|
||||
|
||||
logit_scale = self.clap.logit_scale.exp()
|
||||
similarity = logit_scale*text_embeddings @ audio_embeddings.T
|
||||
return similarity.T
|
||||
|
||||
def classify_audio_files_per_batch(self, audio_files, class_labels, batch_size):
|
||||
r"""Compute classification probabilities for each audio recording in a batch and each class label"""
|
||||
return self._generic_batch_inference(self.classify_audio_files, audio_files, class_labels, batch_size)
|
||||
|
||||
@@ -15,10 +15,8 @@ https://arxiv.org/pdf/2206.04769.pdf
|
||||
}
|
||||
```
|
||||
|
||||
## Request CLAP weights:
|
||||
```
|
||||
https://forms.office.com/r/ULb4k9GL1F
|
||||
```
|
||||
## CLAP weights:
|
||||
Request CLAP weights by filling this form: [link](https://forms.office.com/r/ULb4k9GL1F)
|
||||
|
||||
|
||||
### Usage
|
||||
@@ -31,27 +29,71 @@ clap_model = CLAP("<PATH TO WEIGHTS>", use_cuda=False)
|
||||
|
||||
- Extract text embeddings
|
||||
```python
|
||||
|
||||
text_embeddings = clap_model.get_text_embeddings(class_labels: List[str])
|
||||
```
|
||||
|
||||
- Extract audio embeddings
|
||||
```python
|
||||
|
||||
audio_embeddings = clap_model.get_audio_embeddings(file_paths: List[str])
|
||||
```
|
||||
|
||||
- Compute similarity
|
||||
```python
|
||||
# For using the below function, DO NOT normalize the text and audio embeddings
|
||||
sim = clap_model.compute_similarity(audio_embeddings, text_embeddings)
|
||||
```
|
||||
|
||||
### Zero-Shot inference on an audio file from [ESC50 dataset](https://github.com/karolpiczak/ESC-50)
|
||||
|
||||
```python
|
||||
from CLAP_API import CLAP
|
||||
from esc50_dataset import ESC50
|
||||
import time
|
||||
import torch.nn.functional as F
|
||||
|
||||
# Load CLAP
|
||||
weights_path = 'best.pth' # Add weight path here
|
||||
clap_model = CLAP(weights_path, use_cuda=False)
|
||||
|
||||
# Load dataset
|
||||
dataset = ESC50(root='data', download=True) # set download=True when dataset is not downloaded
|
||||
audio_file, target, one_hot_target = dataset[1000]
|
||||
audio_file = [audio_file]
|
||||
prompt = 'this is a sound of '
|
||||
y = [prompt + x for x in dataset.classes]
|
||||
|
||||
print('Computing text embeddings')
|
||||
text_embeddings = clap_model.get_text_embeddings(y)
|
||||
print('Computing audio embeddings')
|
||||
audio_embeddings = clap_model.get_audio_embeddings(audio_file, resample=True)
|
||||
similarity = clap_model.compute_similarity(audio_embeddings, text_embeddings)
|
||||
|
||||
similarity = F.softmax(similarity, dim=1)
|
||||
values, indices = similarity[0].topk(5)
|
||||
# Print the result
|
||||
print("Ground Truth: {}".format(target))
|
||||
print("Top predictions:\n")
|
||||
for value, index in zip(values, indices):
|
||||
print(f"{dataset.classes[index]:>16s}: {100 * value.item():.2f}%")
|
||||
```
|
||||
|
||||
The output (the exact numbers may vary):
|
||||
|
||||
```
|
||||
Ground Truth: coughing
|
||||
Top predictions:
|
||||
|
||||
coughing: 86.34%
|
||||
sneezing: 9.30%
|
||||
drinking sipping: 1.31%
|
||||
laughing: 1.20%
|
||||
glass breaking: 0.81%
|
||||
```
|
||||
|
||||
### Zero-Shot Classification of [ESC50 dataset](https://github.com/karolpiczak/ESC-50)
|
||||
|
||||
```python
|
||||
from CLAP_API import CLAP
|
||||
from esc50 import ESC50
|
||||
from esc50_dataset import ESC50
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
@@ -62,7 +104,7 @@ weights_path = # Add weight path here
|
||||
clap_model = CLAP(weights_path, use_cuda=False)
|
||||
|
||||
# Load dataset
|
||||
dataset = ESC50(root='path/ESC-50-master', download=False)
|
||||
dataset = ESC50(root='data', download=False)
|
||||
prompt = 'this is a sound of '
|
||||
Y = [prompt + x for x in dataset.classes]
|
||||
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
from CLAP_API import CLAP
|
||||
from esc50_dataset import ESC50
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from sklearn.metrics import accuracy_score
|
||||
|
||||
# Load CLAP
|
||||
weights_path = 'C:\\Users\\sdeshmukh\\Desktop\\CLAP_package\\model\\new\\best.pth' # Add weight path here
|
||||
clap_model = CLAP(weights_path, use_cuda=False)
|
||||
|
||||
# Load dataset
|
||||
dataset = ESC50(root='data', download=True)
|
||||
prompt = 'this is a sound of '
|
||||
Y = [prompt + x for x in dataset.classes]
|
||||
|
||||
# Computing text embeddings
|
||||
text_embeddings = clap_model.get_text_embeddings(Y)
|
||||
|
||||
# Computing audio embeddings
|
||||
y_preds, y_labels = [], []
|
||||
for i in tqdm(range(len(dataset))):
|
||||
x, _, one_hot_target = dataset.__getitem__(i)
|
||||
audio_embeddings = clap_model.get_audio_embeddings([x], resample=True)
|
||||
similarity = clap_model.compute_similarity(audio_embeddings, text_embeddings)
|
||||
y_pred = F.softmax(similarity.detach().cpu(), dim=1).numpy()
|
||||
y_preds.append(y_pred)
|
||||
y_labels.append(one_hot_target.detach().cpu().numpy())
|
||||
|
||||
y_labels, y_preds = np.concatenate(y_labels, axis=0), np.concatenate(y_preds, axis=0)
|
||||
acc = accuracy_score(np.argmax(y_labels, axis=1), np.argmax(y_preds, axis=1))
|
||||
print('ESC50 Accuracy {}'.format(acc))
|
||||
+11
-10
@@ -1,22 +1,23 @@
|
||||
from CLAP_API import CLAP
|
||||
from esc50 import ESC50
|
||||
from esc50_dataset import ESC50
|
||||
import time
|
||||
import torch.nn.functional as F
|
||||
|
||||
start_time = time.time()
|
||||
weights_path = 'C:\\Users\\sdeshmukh\\Desktop\\CLAP_package\\model\\new\\best.pth'
|
||||
# Load CLAP
|
||||
weights_path = 'C:\\Users\\sdeshmukh\\Desktop\\CLAP_package\\model\\new\\best.pth' # Add weight path here
|
||||
clap_model = CLAP(weights_path, use_cuda=False)
|
||||
print("Finished loading CLAP. Total time: {}".format(time.time() - start_time))
|
||||
|
||||
esc50_dataset = ESC50(root='data', download=False)
|
||||
x, target = esc50_dataset[1000]
|
||||
x = [x]
|
||||
y = esc50_dataset.classes
|
||||
# Load dataset
|
||||
dataset = ESC50(root='data', download=True) # set download=True when dataset is not downloaded
|
||||
audio_file, target, one_hot_target = dataset[1000]
|
||||
audio_file = [audio_file]
|
||||
prompt = 'this is a sound of '
|
||||
y = [prompt + x for x in dataset.classes]
|
||||
|
||||
print('Computing text embeddings')
|
||||
text_embeddings = clap_model.get_text_embeddings(y)
|
||||
print('Computing audio embeddings')
|
||||
audio_embeddings = clap_model.get_audio_embeddings(x, resample=True)
|
||||
audio_embeddings = clap_model.get_audio_embeddings(audio_file, resample=True)
|
||||
similarity = clap_model.compute_similarity(audio_embeddings, text_embeddings)
|
||||
|
||||
similarity = F.softmax(similarity, dim=1)
|
||||
@@ -25,4 +26,4 @@ values, indices = similarity[0].topk(5)
|
||||
print("Ground Truth: {}".format(target))
|
||||
print("Top predictions:\n")
|
||||
for value, index in zip(values, indices):
|
||||
print(f"{y[index]:>16s}: {100 * value.item():.2f}%")
|
||||
print(f"{dataset.classes[index]:>16s}: {100 * value.item():.2f}%")
|
||||
Reference in New Issue
Block a user