diff --git a/requirements.txt b/requirements.txt index 08920b1..6d0a036 100644 --- a/requirements.txt +++ b/requirements.txt @@ -47,6 +47,7 @@ SoundFile==0.10.3.post1 subprocess32==3.5.4 threadpoolctl==2.1.0 tokenizers==0.10.2 +--find-links https://download.pytorch.org/whl/torch_stable.html torch==1.8.1+cu111 torchaudio==0.8.1 torchlibrosa==0.0.9 diff --git a/src/CLAPWrapper.py b/src/CLAPWrapper.py index 6900c3d..d69d3bf 100644 --- a/src/CLAPWrapper.py +++ b/src/CLAPWrapper.py @@ -6,8 +6,8 @@ import re import torch.nn.functional as F import numpy as np from transformers import AutoTokenizer -from .models.utils import read_config_as_args -from .models.clap import CLAP +from models.utils import read_config_as_args +from models.clap import CLAP import math import torchaudio.transforms as T import os @@ -26,7 +26,7 @@ class CLAPWrapper(): self.default_collate_err_msg_format = ( "default_collate: batch must contain tensors, numpy arrays, numbers, " "dicts or lists; found {}") - self.config_as_str = files('CLAP_API.configs').joinpath('config.yml').read_text() + self.config_as_str = files('configs').joinpath('config.yml').read_text() self.model_fp = model_fp self.use_cuda = use_cuda self.clap, self.tokenizer, self.args = self.load_clap() diff --git a/src/__init__.py b/src/__init__.py index bd3b45b..e69de29 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -1 +0,0 @@ -from .CLAPWrapper import CLAPWrapper as CLAP \ No newline at end of file diff --git a/src/examples/esc50_dataset.py b/src/esc50_dataset.py similarity index 100% rename from src/examples/esc50_dataset.py rename to src/esc50_dataset.py diff --git a/src/examples/zero_shot_classification.py b/src/zero_shot_classification.py similarity index 82% rename from src/examples/zero_shot_classification.py rename to src/zero_shot_classification.py index 130a1ef..098184d 100644 --- a/src/examples/zero_shot_classification.py +++ b/src/zero_shot_classification.py @@ -3,7 +3,7 @@ This is an example using CLAP to perform zeroshot classification on ESC50 (https://github.com/karolpiczak/ESC-50). """ -from src.CLAPWrapper import CLAP +from CLAPWrapper import CLAPWrapper from esc50_dataset import ESC50 import torch.nn.functional as F import numpy as np @@ -11,14 +11,14 @@ from tqdm import tqdm from sklearn.metrics import accuracy_score # Load dataset -dataset = ESC50(root='data', download=False) +dataset = ESC50(root="C:\\Users\\benjaminm\\Datasets", download=False) prompt = 'this is a sound of ' y = [prompt + x for x in dataset.classes] # Load and initialize CLAP -weights_path = '' -clap_model = CLAP(weights_path, use_cuda=False) +weights_path = "C:\\Users\\benjaminm\\OneDrive - Microsoft\\CLAP_shared\\CLAP_models\\best.pth" +clap_model = CLAPWrapper(weights_path, use_cuda=False) # Computing text embeddings diff --git a/src/examples/zero_shot_predictions.py b/src/zero_shot_predictions.py similarity index 81% rename from src/examples/zero_shot_predictions.py rename to src/zero_shot_predictions.py index 4333f51..e8297b5 100644 --- a/src/examples/zero_shot_predictions.py +++ b/src/zero_shot_predictions.py @@ -3,22 +3,22 @@ This is an example using CLAP for zero-shot inference using ESC50 (https://github.com/karolpiczak/ESC-50). """ -from src.CLAPWrapper import CLAP +from CLAPWrapper import CLAPWrapper from esc50_dataset import ESC50 import torch.nn.functional as F # Load ESC50 dataset -dataset = ESC50(root='data', download=True) # set download=True when dataset is not downloaded +dataset = ESC50(root="C:\\Users\\benjaminm\\Datasets", download=True) # set download=True when dataset is not downloaded audio_file, target, one_hot_target = dataset[1000] audio_file = [audio_file] prompt = 'this is a sound of ' y = [prompt + x for x in dataset.classes] # Load and initialize CLAP -weights_path = '' +weights_path = "C:\\Users\\benjaminm\\OneDrive - Microsoft\\CLAP_shared\\CLAP_models\\best.pth" # Setting use_cuda = True will load the model on a GPU using CUDA -clap_model = CLAP(weights_path, use_cuda=False) +clap_model = CLAPWrapper(weights_path, use_cuda=False) # compute text embeddings from natural text text_embeddings = clap_model.get_text_embeddings(y)