diff --git a/esc50.py b/esc50.py index e69de29..d9c08a3 100644 --- a/esc50.py +++ b/esc50.py @@ -0,0 +1,82 @@ +from torch.utils.data import Dataset +from torchvision.datasets.utils import download_url +from tqdm import tqdm +import pandas as pd +import os +import torch.nn as nn +import torch + +class AudioDataset(Dataset): + def __init__(self, root: str, download: bool = True): + self.root = os.path.expanduser(root) + if download: + self.download() + + def __getitem__(self, index): + raise NotImplementedError + + def download(self): + raise NotImplementedError + + def __len__(self): + raise NotImplementedError + + +class ESC50(AudioDataset): + base_folder = 'ESC-50-master' + url = "https://github.com/karolpiczak/ESC-50/archive/refs/heads/master.zip" + filename = "ESC-50-master.zip" + num_files_in_dir = 2000 + audio_dir = 'audio' + label_col = 'category' + file_col = 'filename' + meta = { + 'filename': os.path.join('meta','esc50.csv'), + } + + def __init__(self, root, reading_transformations: nn.Module = None, download: bool = True): + super().__init__(root) + self._load_meta() + + self.targets, self.audio_paths = [], [] + self.pre_transformations = reading_transformations + print("Loading audio files") + # self.df['filename'] = os.path.join(self.root, self.base_folder, self.audio_dir) + os.sep + self.df['filename'] + self.df['category'] = self.df['category'].str.replace('_',' ') + + for _, row in tqdm(self.df.iterrows()): + file_path = os.path.join(self.root, self.base_folder, self.audio_dir, row[self.file_col]) + self.targets.append(row[self.label_col]) + self.audio_paths.append(file_path) + + def _load_meta(self): + path = os.path.join(self.root, self.base_folder, self.meta['filename']) + + self.df = pd.read_csv(path) + self.class_to_idx = {} + self.classes = [x.replace('_',' ') for x in sorted(self.df[self.label_col].unique())] + for i, category in enumerate(self.classes): + self.class_to_idx[category] = i + + def __getitem__(self, index): + """ + Args: + index (int): Index + Returns: + tuple: (image, target) where target is index of the target class. + """ + file_path, target = self.audio_paths[index], self.targets[index] + idx = torch.tensor(self.class_to_idx[target]) + one_hot_target = torch.zeros(len(self.classes)).scatter_(0, idx, 1).reshape(1,-1) + return file_path, target, one_hot_target + + def __len__(self): + return len(self.audio_paths) + + def download(self): + download_url(self.url, self.root, self.filename) + + # extract file + from zipfile import ZipFile + with ZipFile(os.path.join(self.root, self.filename), 'r') as zip: + zip.extractall(path=self.root) \ No newline at end of file diff --git a/main.py b/main.py index e69de29..15d470f 100644 --- a/main.py +++ b/main.py @@ -0,0 +1,28 @@ +from CLAP_API import CLAP +from esc50 import ESC50 +import time +import torch.nn.functional as F + +start_time = time.time() +weights_path = 'C:\\Users\\sdeshmukh\\Desktop\\CLAP_package\\model\\new\\best.pth' +clap_model = CLAP(weights_path, use_cuda=False) +print("Finished loading CLAP. Total time: {}".format(time.time() - start_time)) + +esc50_dataset = ESC50(root='data', download=False) +x, target = esc50_dataset[1000] +x = [x] +y = esc50_dataset.classes + +print('Computing text embeddings') +text_embeddings = clap_model.get_text_embeddings(y) +print('Computing audio embeddings') +audio_embeddings = clap_model.get_audio_embeddings(x, resample=True) +similarity = clap_model.compute_similarity(audio_embeddings, text_embeddings) + +similarity = F.softmax(similarity, dim=1) +values, indices = similarity[0].topk(5) +# Print the result +print("Ground Truth: {}".format(target)) +print("Top predictions:\n") +for value, index in zip(values, indices): + print(f"{y[index]:>16s}: {100 * value.item():.2f}%") \ No newline at end of file