Update to 2023

This commit is contained in:
Benjamin Elizalde
2023-09-26 14:37:20 -07:00
parent 9d4006b773
commit 1ea0905552
16 changed files with 1858 additions and 125 deletions
+48 -57
View File
@@ -1,59 +1,50 @@
appdirs==1.4.4
audioread==2.1.9
certifi==2020.12.5
cffi==1.14.5
chardet==4.0.0
click==7.1.2
configparser==5.0.2
cycler==0.10.0
decorator==5.0.7
docker-pycreds==0.4.0
filelock==3.0.12
gitdb==4.0.7
GitPython==3.1.14
h5py==3.2.1
idna==2.10
audioread==3.0.0
certifi==2022.12.7
cffi==1.15.1
charset-normalizer==3.0.1
colorama==0.4.6
decorator==5.1.1
filelock==3.9.0
flit_core==3.6.0
huggingface-hub==0.12.1
idna==3.4
importlib-metadata==6.0.0
importlib-resources==5.12.0
jaraco.classes==3.2.3
joblib==1.2.0
kiwisolver==1.3.1
librosa==0.8.0
llvmlite==0.36.0
matplotlib==3.4.1
numba==0.53.1
numpy==1.22.0
packaging==20.9
pandas==1.2.4
pathtools==0.1.2
Pillow==9.0.1
pooch==1.3.0
promise==2.3
protobuf==3.18.3
psutil==5.8.0
pycparser==2.20
pyparsing==2.4.7
python-dateutil==2.8.1
pytz==2021.1
PyYAML==5.4.1
regex==2021.4.4
requests==2.25.1
resampy==0.2.2
sacremoses==0.0.45
scikit-learn==0.24.2
scipy==1.6.3
sentry-sdk==1.0.0
shortuuid==1.0.1
six==1.15.0
smmap==4.0.0
SoundFile==0.10.3.post1
subprocess32==3.5.4
threadpoolctl==2.1.0
tokenizers==0.10.2
--find-links https://download.pytorch.org/whl/torch_stable.html
torch==1.8.1+cu111
torchaudio==0.8.1
torchlibrosa==0.0.9
torchvision==0.9.1+cu111
tqdm==4.60.0
transformers==4.5.1
typing-extensions==3.10.0.0
urllib3==1.26.5
importlib-resources==5.10.0
lazy_loader==0.1
librosa==0.10.0
llvmlite==0.39.1
mkl-service==2.4.0
more-itertools==9.0.0
msgpack==1.0.4
numba==0.56.4
numpy==1.23.5
packaging==23.0
pandas==1.4.2
pooch==1.6.0
pycparser==2.21
pywin32-ctypes==0.2.0
PyYAML==6.0
regex==2022.10.31
requests==2.28.2
scikit-learn==1.2.1
scipy==1.10.1
setuptools==65.6.3
six==1.16.0
soundfile==0.12.1
soxr==0.3.3
threadpoolctl==3.1.0
tokenizers==0.13.2
torch==1.13.1
torchaudio==0.13.1
torchlibrosa==0.1.0
torchvision==0.14.1
tqdm==4.64.1
transformers==4.26.1
typing_extensions==4.4.0
urllib3==1.26.14
wheel==0.38.4
wincertstore==0.2
zipp==3.14.0
+203 -32
View File
@@ -1,18 +1,23 @@
import warnings
warnings.filterwarnings("ignore")
import random
import torchaudio
from torch._six import string_classes
import collections
import re
import torch.nn.functional as F
import numpy as np
from transformers import AutoTokenizer
from models.utils import read_config_as_args
from transformers import AutoTokenizer, logging
from models.clap import CLAP
from models.mapper import get_clapcap
import math
import torchaudio.transforms as T
import os
import torch
from importlib_resources import files
import argparse
import yaml
import sys
logging.set_verbosity_error()
class CLAPWrapper():
@@ -20,26 +25,59 @@ class CLAPWrapper():
A class for interfacing CLAP model.
"""
def __init__(self, model_fp, use_cuda=False):
def __init__(self, model_fp, version, use_cuda=False):
self.supported_versions = ['2022', '2023', 'clapcap']
self.np_str_obj_array_pattern = re.compile(r'[SaUO]')
self.file_path = os.path.realpath(__file__)
self.default_collate_err_msg_format = (
"default_collate: batch must contain tensors, numpy arrays, numbers, "
"dicts or lists; found {}")
self.config_as_str = files('configs').joinpath('config.yml').read_text()
self.config_as_str = self.get_config_path(version)
self.model_fp = model_fp
self.use_cuda = use_cuda
self.clap, self.tokenizer, self.args = self.load_clap()
if 'clapcap' in version:
self.clapcap, self.tokenizer, self.args = self.load_clapcap()
else:
self.clap, self.tokenizer, self.args = self.load_clap()
def get_config_path(self, version):
if version in self.supported_versions:
return files('configs').joinpath(f"config_{version}.yml").read_text()
else:
raise ValueError(f"The specific version is not supported. The supported versions are {str(self.supported_versions)}")
def read_config_as_args(self,config_path,args=None,is_config_str=False):
return_dict = {}
if config_path is not None:
if is_config_str:
yml_config = yaml.load(config_path, Loader=yaml.FullLoader)
else:
with open(config_path, "r") as f:
yml_config = yaml.load(f, Loader=yaml.FullLoader)
if args != None:
for k, v in yml_config.items():
if k in args.__dict__:
args.__dict__[k] = v
else:
sys.stderr.write("Ignored unknown parameter {} in yaml.\n".format(k))
else:
for k, v in yml_config.items():
return_dict[k] = v
args = args if args != None else return_dict
return argparse.Namespace(**args)
def load_clap(self):
r"""Load CLAP model with args from config file"""
args = read_config_as_args(self.config_as_str, is_config_str=True)
args = self.read_config_as_args(self.config_as_str, is_config_str=True)
if 'bert' in args.text_model:
self.token_keys = ['input_ids', 'token_type_ids', 'attention_mask']
else:
if 'roberta' in args.text_model or 'clip' in args.text_model or 'gpt' in args.text_model:
self.token_keys = ['input_ids', 'attention_mask']
elif 'bert' in args.text_model:
self.token_keys = ['input_ids', 'token_type_ids', 'attention_mask']
clap = CLAP(
audioenc_name=args.audioenc_name,
@@ -58,15 +96,65 @@ class CLAPWrapper():
# Load pretrained weights for model
model_state_dict = torch.load(self.model_fp, map_location=torch.device('cpu'))['model']
# We unwrap the DDP model and save. If the model is not unwrapped and saved, then the model needs to unwrapped before `load_state_dict`:
# Reference link: https://discuss.pytorch.org/t/how-to-load-dataparallel-model-which-trained-using-multiple-gpus/146005
clap.load_state_dict(model_state_dict)
clap.eval() # set clap in eval mode
tokenizer = AutoTokenizer.from_pretrained(args.text_model)
if 'gpt' in args.text_model:
tokenizer.add_special_tokens({'pad_token': '!'})
if self.use_cuda and torch.cuda.is_available():
clap = clap.cuda()
return clap, tokenizer, args
def load_clapcap(self):
r"""Load CLAP model with args from config file"""
args = self.read_config_as_args(self.config_as_str, is_config_str=True)
args.prefix_dim = args.d_proj
text_model = args.text_model
args.text_model = args.text_decoder
args.cross_attention = True if 'cross' in args.clapcap_model.lower() else False
if 'roberta' in args.text_model or 'clip' in args.text_model or 'gpt' in args.text_model:
self.token_keys = ['input_ids', 'attention_mask']
elif 'bert' in args.text_model:
self.token_keys = ['input_ids', 'token_type_ids', 'attention_mask']
clap = CLAP(
audioenc_name=args.audioenc_name,
sample_rate=args.sampling_rate,
window_size=args.window_size,
hop_size=args.hop_size,
mel_bins=args.mel_bins,
fmin=args.fmin,
fmax=args.fmax,
classes_num=args.num_classes,
out_emb=args.out_emb,
text_model=text_model,
transformer_embed_dim=args.transformer_embed_dim,
d_proj=args.d_proj
)
clapcap = get_clapcap(args.clapcap_model)(clap, args.text_decoder, args.prefix_length, args.prefix_length_clip, args.prefix_dim,
args.num_layers, args.normalize_prefix, args.mapping_type, True, True)
model_state_dict = torch.load(self.model_fp, map_location=torch.device('cpu'))['model']
clapcap.load_state_dict(model_state_dict)
clapcap.eval() # set clap in eval mode
tokenizer = AutoTokenizer.from_pretrained(args.text_model)
if 'gpt' in args.text_model:
tokenizer.add_special_tokens({'pad_token': '!'})
if self.use_cuda and torch.cuda.is_available():
clapcap = clapcap.cuda()
return clapcap, tokenizer, args
def default_collate(self, batch):
r"""Puts each data field into a tensor with outer dimension batch size"""
@@ -113,15 +201,22 @@ class CLAPWrapper():
return [self.default_collate(samples) for samples in transposed]
raise TypeError(self.default_collate_err_msg_format.format(elem_type))
def load_audio_into_tensor(self, audio_path, audio_duration, resample=False):
r"""Loads audio file and returns raw audio."""
def read_audio(self, audio_path, resample=False):
r"""Loads audio file or array and returns a torch tensor"""
# Randomly sample a segment of audio_duration from the clip or pad to match duration
audio_time_series, sample_rate = torchaudio.load(audio_path)
resample_rate = self.args.sampling_rate
if resample:
resampler = T.Resample(sample_rate, resample_rate)
audio_time_series = resampler(audio_time_series)
return audio_time_series, sample_rate
def load_audio_into_tensor(self, audio_path, audio_duration, resample=False):
r"""Loads audio file and returns raw audio."""
# Randomly sample a segment of audio_duration from the clip or pad to match duration
audio_time_series, sample_rate = self.read_audio(audio_path, resample=False)
audio_time_series = audio_time_series.reshape(-1)
# audio_time_series is shorter than predefined audio duration,
@@ -157,8 +252,10 @@ class CLAPWrapper():
r"""Load list of class labels and return tokenized text"""
tokenized_texts = []
for ttext in text_queries:
if 'gpt' in self.args.text_model:
ttext = ttext + ' <|endoftext|>'
tok = self.tokenizer.encode_plus(
text=ttext, add_special_tokens=True, max_length=self.args.text_len, pad_to_max_length=True, return_tensors="pt")
text=ttext, add_special_tokens=True, max_length=self.args.text_len, padding='max_length', return_tensors="pt")
for key in self.token_keys:
tok[key] = tok[key].reshape(-1).cuda() if self.use_cuda and torch.cuda.is_available() else tok[key].reshape(-1)
tokenized_texts.append(tok)
@@ -167,23 +264,17 @@ class CLAPWrapper():
def get_text_embeddings(self, class_labels):
r"""Load list of class labels and return text embeddings"""
preprocessed_text = self.preprocess_text(class_labels)
text_embeddings = self._get_text_embeddings(preprocessed_text)
text_embeddings = text_embeddings/torch.norm(text_embeddings, dim=-1, keepdim=True)
return text_embeddings
return self._get_text_embeddings(preprocessed_text)
def get_audio_embeddings(self, audio_files, resample):
r"""Load list of audio files and return a audio embeddings"""
preprocessed_audio = self.preprocess_audio(audio_files, resample)
audio_embeddings = self._get_audio_embeddings(preprocessed_audio)
audio_embeddings = audio_embeddings/torch.norm(audio_embeddings, dim=-1, keepdim=True)
return audio_embeddings
return self._get_audio_embeddings(preprocessed_audio)
def _get_text_embeddings(self, preprocessed_text):
r"""Load preprocessed text and return text embeddings"""
with torch.no_grad():
text_embeddings = self.clap.caption_encoder(preprocessed_text)
text_embeddings = text_embeddings/torch.norm(text_embeddings, dim=-1, keepdim=True)
return text_embeddings
return self.clap.caption_encoder(preprocessed_text)
def _get_audio_embeddings(self, preprocessed_audio):
r"""Load preprocessed audio and return a audio embeddings"""
@@ -191,15 +282,7 @@ class CLAPWrapper():
preprocessed_audio = preprocessed_audio.reshape(
preprocessed_audio.shape[0], preprocessed_audio.shape[2])
#Append [0] the audio emebdding, [1] has output class probabilities
audio_embeddings = self.clap.audio_encoder(preprocessed_audio)[0]
audio_embeddings = audio_embeddings/torch.norm(audio_embeddings, dim=-1, keepdim=True)
return audio_embeddings
def compute_similarity(self, audio_embeddings, text_embeddings):
r"""Compute similarity between text and audio embeddings"""
logit_scale = self.clap.logit_scale.exp()
similarity = logit_scale*text_embeddings @ audio_embeddings.T
return similarity.T
return self.clap.audio_encoder(preprocessed_audio)[0]
def _generic_batch_inference(self, func, *args):
r"""Process audio and/or text per batch"""
@@ -232,6 +315,94 @@ class CLAPWrapper():
r"""Load preprocessed text and return text embeddings per batch"""
return self._generic_batch_inference(self.get_text_embeddings, class_labels, batch_size)
def compute_similarity(self, audio_embeddings, text_embeddings):
r"""Compute similarity between text and audio embeddings"""
audio_embeddings = audio_embeddings/torch.norm(audio_embeddings, dim=-1, keepdim=True)
text_embeddings = text_embeddings/torch.norm(text_embeddings, dim=-1, keepdim=True)
logit_scale = self.clap.logit_scale.exp()
similarity = logit_scale*text_embeddings @ audio_embeddings.T
return similarity.T
def classify_audio_files_per_batch(self, audio_files, class_labels, batch_size):
r"""Compute classification probabilities for each audio recording in a batch and each class label"""
return self._generic_batch_inference(self.classify_audio_files, audio_files, class_labels, batch_size)
def generate_caption(self, audio_files, resample=True, beam_size: int = 5, entry_length=67, temperature=1.):
r"""Generate audio captions for each audio recording in a batch"""
captions = []
audio_tensors = self.preprocess_audio(audio_files, resample)
with torch.no_grad():
prefix = self.clapcap.clap(audio_tensors.squeeze(1))[0]
if self.args.normalize_prefix:
prefix = prefix / prefix.norm(2, -1).reshape(-1,1)
prefix_embed = self.clapcap.clap_project(prefix).view(-1, self.args.prefix_length, self.clapcap.gpt.transformer.wte.weight.shape[1])
for i in range(len(audio_tensors)):
gen_caption = self._generate_beam(embed=prefix_embed[i].unsqueeze(0),\
beam_size=beam_size,\
entry_length=entry_length,\
temperature=temperature)[0]
captions.append(gen_caption.capitalize())
return captions
def _generate_beam(self, beam_size: int = 5, prompt=None, embed=None,
entry_length=67, temperature=1., stop_token: str = ' <|endoftext|>'):
r"""Generate captions by beam search decoding"""
self.clapcap.eval()
stop_token_index = self.tokenizer.encode(stop_token)[0]
tokens = None
scores = None
device = next(self.clapcap.parameters()).device
seq_lengths = torch.ones(beam_size, device=device)
is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool)
with torch.no_grad():
if embed is not None:
generated = embed
else:
if tokens is None:
tokens = torch.tensor(self.tokenizer.encode(prompt))
tokens = tokens.unsqueeze(0).to(device)
generated = self.clapcap.gpt.transformer.wte(tokens)
for i in range(entry_length):
outputs = self.clapcap.gpt(inputs_embeds=generated)
logits = outputs.logits
logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
logits = logits.softmax(-1).log()
if scores is None:
scores, next_tokens = logits.topk(beam_size, -1)
generated = generated.expand(beam_size, *generated.shape[1:])
next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
if tokens is None:
tokens = next_tokens
else:
tokens = tokens.expand(beam_size, *tokens.shape[1:])
tokens = torch.cat((tokens, next_tokens), dim=1)
else:
logits[is_stopped] = -float(np.inf)
logits[is_stopped, 0] = 0
scores_sum = scores[:, None] + logits
seq_lengths[~is_stopped] += 1
scores_sum_average = scores_sum / seq_lengths[:, None]
scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1)
next_tokens_source = next_tokens // scores_sum.shape[1]
seq_lengths = seq_lengths[next_tokens_source]
next_tokens = next_tokens % scores_sum.shape[1]
next_tokens = next_tokens.unsqueeze(1)
tokens = tokens[next_tokens_source]
tokens = torch.cat((tokens, next_tokens), dim=1)
generated = generated[next_tokens_source]
scores = scores_sum_average * seq_lengths
is_stopped = is_stopped[next_tokens_source]
next_token_embed = self.clapcap.gpt.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1)
generated = torch.cat((generated, next_token_embed), dim=1)
is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze()
if is_stopped.all():
break
scores = scores / seq_lengths
output_list = tokens.cpu().numpy()
output_texts = [self.tokenizer.decode(output[:int(length)]) for output, length in zip(output_list, seq_lengths)]
order = scores.argsort(descending=True)
output_texts = [output_texts[i] for i in order]
return output_texts
+25
View File
@@ -0,0 +1,25 @@
"""
This is an example using CLAPCAP for audio captioning.
"""
from CLAPWrapper import CLAPWrapper
# Load and initialize CLAP
weights_path = "weights_path"
clap_model = CLAPWrapper(weights_path, version = 'clapcap', use_cuda=False)
#Load audio files
audio_files = ['audio_file']
# Generate captions for the recording
captions = clap_model.generate_caption(audio_files, resample=True, beam_size=5, entry_length=67, temperature=0.01)
# Print the result
for i in range(len(audio_files)):
print(f"Audio file: {audio_files[i]} \n")
print(f"Generated caption: {captions[i]} \n")
"""
The output (the exact caption may vary):
The birds are singing in the trees.
"""
+26
View File
@@ -0,0 +1,26 @@
# TEXT ENCODER CONFIG
text_model: 'gpt2'
text_len: 77
transformer_embed_dim: 768
freeze_text_encoder_weights: True
# AUDIO ENCODER CONFIG
audioenc_name: 'HTSAT'
out_emb: 768
sampling_rate: 44100
duration: 7
fmin: 50
fmax: 8000 #14000
n_fft: 1024 # 1028
hop_size: 320
mel_bins: 64
window_size: 1024
# PROJECTION SPACE CONFIG
d_proj: 1024
temperature: 0.003
# TRAINING AND EVALUATION CONFIG
num_classes: 527
batch_size: 1024
demo: False
+34
View File
@@ -0,0 +1,34 @@
# TEXT ENCODER CONFIG
text_model: 'gpt2'
transformer_embed_dim: 768
freeze_text_encoder_weights: True
# AUDIO ENCODER CONFIG
audioenc_name: 'HTSAT'
out_emb: 768
sampling_rate: 44100
duration: 7
fmin: 50
fmax: 8000
n_fft: 1024
hop_size: 320
mel_bins: 64
window_size: 1024
# PROJECTION SPACE CONFIG
d_proj: 1024
temperature: 0.003
# TRAINING AND EVALUATION CONFIG
batch_size: 128
num_classes: 527
# CLAPCAP CONFIG
clapcap_model: 'ClapCaption'
text_decoder: 'gpt2'
prefix_length: 40
prefix_length_clip: 40
mapping_type: 'transformer'
num_layers: 8
normalize_prefix: True
freeze_gpt_weights: True
+2 -2
View File
@@ -24,7 +24,7 @@ class AudioDataset(Dataset):
class ESC50(AudioDataset):
base_folder = 'ESC-50-master'
url = "https://github.com/karolpiczak/ESC-50/archive/refs/heads/master.zip"
url = "https://github.com/karoldvl/ESC-50/archive/master.zip"
filename = "ESC-50-master.zip"
num_files_in_dir = 2000
audio_dir = 'audio'
@@ -79,4 +79,4 @@ class ESC50(AudioDataset):
# extract file
from zipfile import ZipFile
with ZipFile(os.path.join(self.root, self.filename), 'r') as zip:
zip.extractall(path=self.root)
zip.extractall(path=self.root)
+4 -1
View File
@@ -1,3 +1,6 @@
from . import clap
from . import audio
from . import utils
from . import htsat
from . import config
from . import pytorch_utils
from . import htsat
+3
View File
@@ -2,10 +2,13 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from torchlibrosa.stft import Spectrogram, LogmelFilterBank
from models.htsat import HTSATWrapper
def get_audio_encoder(name: str):
if name == "Cnn14":
return Cnn14
elif name == "HTSAT":
return HTSATWrapper
else:
raise Exception('The audio encoder name {} is incorrect or not supported'.format(name))
+21 -2
View File
@@ -42,14 +42,33 @@ class AudioEncoder(nn.Module):
class TextEncoder(nn.Module):
def __init__(self, d_out: int, text_model: str, transformer_embed_dim: int) -> None:
super().__init__()
self.text_model = text_model
self.base = AutoModel.from_pretrained(text_model)
if 'clip' in text_model:
self.clip_text_projection = self.base.text_projection
self.base = self.base.text_model
if 'base' in text_model:
transformer_embed_dim = 512
self.projection = Projection(transformer_embed_dim, d_out)
def forward(self, x):
out = self.base(**x)[0]
out = out[:, 0, :] # get CLS token output
if 'clip' in self.text_model:
pooled_output = self.base(**x)[1] # get pooled output
out = self.clip_text_projection(pooled_output) # get CLS token output
elif 'gpt' in self.text_model:
batch_size = x['input_ids'].shape[0]
hidden_states = self.base(**x)[0] # (batch_size=4, seq_len, 768)
sequence_lengths = torch.ne(x['input_ids'], 0).sum(-1) - 1 # tensor([13, 14, 18, 17])
out = hidden_states[torch.arange(batch_size, device=hidden_states.device), sequence_lengths] # [batch_size, 768] = [4, 768]
else:
out = self.base(**x)[0]
out = out[:, 0, :] # get CLS token output
projected_vec = self.projection(out)
return projected_vec
class CLAP(nn.Module):
+128
View File
@@ -0,0 +1,128 @@
# Ke Chen
# knutchen@ucsd.edu
# HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION
# The configuration for training the model
exp_name = "exp_htsat_pretrain" # the saved ckpt prefix name of the model
workspace = "/home/kechen/Research/HTSAT" # the folder of your code
dataset_path = "/home/Research/audioset" # the dataset path
desed_folder = "/home/Research/DESED" # the desed file
dataset_type = "audioset" # "audioset" "esc-50" "scv2"
index_type = "full_train" # only works for audioset
balanced_data = True # only works for audioset
loss_type = "clip_bce" #
# AudioSet & SCV2: "clip_bce" | ESC-50: "clip_ce"
# trained from a checkpoint, or evaluate a single model
resume_checkpoint = None
# "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_1.ckpt"
esc_fold = 0 # just for esc dataset, select the fold you need for evaluation and (+1) validation
debug = False
random_seed = 970131 # 19970318 970131 12412 127777 1009 34047
batch_size = 32 * 4 # batch size per GPU x GPU number , default is 32 x 4 = 128
learning_rate = 1e-3 # 1e-4 also workable
max_epoch = 100
num_workers = 3
lr_scheduler_epoch = [10,20,30]
lr_rate = [0.02, 0.05, 0.1]
# these data preparation optimizations do not bring many improvements, so deprecated
enable_token_label = False # token label
class_map_path = "class_hier_map.npy"
class_filter = None
retrieval_index = [15382, 9202, 130, 17618, 17157, 17516, 16356, 6165, 13992, 9238, 5550, 5733, 1914, 1600, 3450, 13735, 11108, 3762,
9840, 11318, 8131, 4429, 16748, 4992, 16783, 12691, 4945, 8779, 2805, 9418, 2797, 14357, 5603, 212, 3852, 12666, 1338, 10269, 2388, 8260, 4293, 14454, 7677, 11253, 5060, 14938, 8840, 4542, 2627, 16336, 8992, 15496, 11140, 446, 6126, 10691, 8624, 10127, 9068, 16710, 10155, 14358, 7567, 5695, 2354, 8057, 17635, 133, 16183, 14535, 7248, 4560, 14429, 2463, 10773, 113, 2462, 9223, 4929, 14274, 4716, 17307, 4617, 2132, 11083, 1039, 1403, 9621, 13936, 2229, 2875, 17840, 9359, 13311, 9790, 13288, 4750, 17052, 8260, 14900]
token_label_range = [0.2,0.6]
enable_time_shift = False # shift time
enable_label_enhance = False # enhance hierarchical label
enable_repeat_mode = False # repeat the spectrogram / reshape the spectrogram
# for model's design
enable_tscam = True # enbale the token-semantic layer
# for signal processing
sample_rate = 32000 # 16000 for scv2, 32000 for audioset and esc-50
clip_samples = sample_rate * 10 # audio_set 10-sec clip
window_size = 1024
hop_size = 320 # 160 for scv2, 320 for audioset and esc-50
mel_bins = 64
fmin = 50
fmax = 14000
shift_max = int(clip_samples * 0.5)
# for data collection
classes_num = 527 # esc: 50 | audioset: 527 | scv2: 35
patch_size = (25, 4) # deprecated
crop_size = None # int(clip_samples * 0.5) deprecated
# for htsat hyperparamater
htsat_window_size = 8
htsat_spec_size = 256
htsat_patch_size = 4
htsat_stride = (4, 4)
htsat_num_head = [4,8,16,32]
htsat_dim = 96
htsat_depth = [2,2,6,2]
swin_pretrain_path = None
# "/home/Research/model_backup/pretrain/swin_tiny_c24_patch4_window8_256.pth"
# Some Deprecated Optimization in the model design, check the model code for details
htsat_attn_heatmap = False
htsat_hier_output = False
htsat_use_max = False
# for ensemble test
ensemble_checkpoints = []
ensemble_strides = []
# weight average folder
wa_folder = "/home/version_0/checkpoints/"
# weight average output filename
wa_model_path = "HTSAT_AudioSet_Saved_x.ckpt"
esm_model_pathes = [
"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_1.ckpt",
"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_2.ckpt",
"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_3.ckpt",
"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_4.ckpt",
"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_5.ckpt",
"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_6.ckpt"
]
# for framewise localization
heatmap_dir = "/home/Research/heatmap_output"
test_file = "htsat-test-ensemble"
fl_local = False # indicate if we need to use this dataset for the framewise detection
fl_dataset = "/home/Research/desed/desed_eval.npy"
fl_class_num = [
"Speech", "Frying", "Dishes", "Running_water",
"Blender", "Electric_shaver_toothbrush", "Alarm_bell_ringing",
"Cat", "Dog", "Vacuum_cleaner"
]
# map 527 classes into 10 classes
fl_audioset_mapping = [
[0,1,2,3,4,5,6,7],
[366, 367, 368],
[364],
[288, 289, 290, 291, 292, 293, 294, 295, 296, 297],
[369],
[382],
[310, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402],
[81, 82, 83, 84, 85],
[74, 75, 76, 77, 78, 79],
[377]
]
+950
View File
@@ -0,0 +1,950 @@
# Ke Chen
# knutchen@ucsd.edu
# HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION
# Model Core
# below codes are based and referred from https://github.com/microsoft/Swin-Transformer
# Swin Transformer for Computer Vision: https://arxiv.org/pdf/2103.14030.pdf
import logging
import pdb
import math
import random
from numpy.core.fromnumeric import clip, reshape
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from torchlibrosa.stft import Spectrogram, LogmelFilterBank
from torchlibrosa.augmentation import SpecAugmentation
from itertools import repeat
from typing import List
try:
from models.pytorch_utils import do_mixup, interpolate
import models.config as config
except:
from CLAP_API.models.pytorch_utils import do_mixup, interpolate
from CLAP_API.models import config
import torch.nn.functional as F
import collections.abc
import warnings
from torch.nn.init import _calculate_fan_in_and_fan_out
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))
return parse
to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = _ntuple
def drop_path(x, drop_prob: float = 0., training: bool = False):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
class PatchEmbed(nn.Module):
""" 2D Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, patch_stride = 16):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patch_stride = to_2tuple(patch_stride)
self.img_size = img_size
self.patch_size = patch_size
self.patch_stride = patch_stride
self.grid_size = (img_size[0] // patch_stride[0], img_size[1] // patch_stride[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.in_chans = in_chans
self.embed_dim = embed_dim
padding = ((patch_size[0] - patch_stride[0]) // 2, (patch_size[1] - patch_stride[1]) // 2)
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_stride, padding=padding)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
x = self.norm(x)
return x
class Mlp(nn.Module):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2.
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2)
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
# type: (Tensor, float, float, float, float) -> Tensor
r"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
"""
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
if mode == 'fan_in':
denom = fan_in
elif mode == 'fan_out':
denom = fan_out
elif mode == 'fan_avg':
denom = (fan_in + fan_out) / 2
variance = scale / denom
if distribution == "truncated_normal":
# constant is stddev of standard normal truncated to (-2, 2)
trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978)
elif distribution == "normal":
tensor.normal_(std=math.sqrt(variance))
elif distribution == "uniform":
bound = math.sqrt(3 * variance)
tensor.uniform_(-bound, bound)
else:
raise ValueError(f"invalid distribution {distribution}")
def lecun_normal_(tensor):
variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal')
# below codes are based and referred from https://github.com/microsoft/Swin-Transformer
# Swin Transformer for Computer Vision: https://arxiv.org/pdf/2103.14030.pdf
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x, attn
def extra_repr(self):
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
# We use the model based on Swintransformer Block, therefore we can use the swin-transformer pretrained model
class SwinTransformerBlock(nn.Module):
r""" Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm, norm_before_mlp='ln'):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
self.norm_before_mlp = norm_before_mlp
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
if self.norm_before_mlp == 'ln':
self.norm2 = nn.LayerNorm(dim)
elif self.norm_before_mlp == 'bn':
self.norm2 = lambda x: nn.BatchNorm1d(dim)(x.transpose(1, 2)).transpose(1, 2)
else:
raise NotImplementedError
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
if self.shift_size > 0:
# calculate attention mask for SW-MSA
H, W = self.input_resolution
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
def forward(self, x):
# pdb.set_trace()
H, W = self.input_resolution
# print("H: ", H)
# print("W: ", W)
# pdb.set_trace()
B, L, C = x.shape
# assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows, attn = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
x = x.view(B, H * W, C)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x, attn
def extra_repr(self):
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
class PatchMerging(nn.Module):
r""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
def extra_repr(self):
return f"input_resolution={self.input_resolution}, dim={self.dim}"
class BasicLayer(nn.Module):
""" A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
norm_before_mlp='ln'):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.ModuleList([
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
num_heads=num_heads, window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop, attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer, norm_before_mlp=norm_before_mlp)
for i in range(depth)])
# patch merging layer
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x):
attns = []
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x)
else:
x, attn = blk(x)
if not self.training:
attns.append(attn.unsqueeze(0))
if self.downsample is not None:
x = self.downsample(x)
if not self.training:
attn = torch.cat(attns, dim = 0)
attn = torch.mean(attn, dim = 0)
return x, attn
def extra_repr(self):
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
# The Core of HTSAT
class HTSAT_Swin_Transformer(nn.Module):
r"""HTSAT based on the Swin Transformer
Args:
spec_size (int | tuple(int)): Input Spectrogram size. Default 256
patch_size (int | tuple(int)): Patch size. Default: 4
path_stride (iot | tuple(int)): Patch Stride for Frequency and Time Axis. Default: 4
in_chans (int): Number of input image channels. Default: 1 (mono)
num_classes (int): Number of classes for classification head. Default: 527
embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each HTSAT-Swin Transformer layer.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size. Default: 8
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
drop_rate (float): Dropout rate. Default: 0
attn_drop_rate (float): Attention dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
patch_norm (bool): If True, add normalization after patch embedding. Default: True
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
config (module): The configuration Module from config.py
"""
def __init__(self, spec_size=256, patch_size=4, patch_stride=(4,4),
in_chans=1, num_classes=527,
embed_dim=96, depths=[2, 2, 6, 2], num_heads=[4, 8, 16, 32],
window_size=8, mlp_ratio=4., qkv_bias=True, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm,
ape=False, patch_norm=True,
use_checkpoint=False, norm_before_mlp='ln', config = None, **kwargs):
super(HTSAT_Swin_Transformer, self).__init__()
self.config = config
self.spec_size = spec_size
self.patch_stride = patch_stride
self.patch_size = patch_size
self.window_size = window_size
self.embed_dim = embed_dim
self.depths = depths
self.ape = ape
self.in_chans = in_chans
self.num_classes = num_classes
self.num_heads = num_heads
self.num_layers = len(self.depths)
self.num_features = int(self.embed_dim * 2 ** (self.num_layers - 1))
self.drop_rate = drop_rate
self.attn_drop_rate = attn_drop_rate
self.drop_path_rate = drop_path_rate
self.qkv_bias = qkv_bias
self.qk_scale = None
self.patch_norm = patch_norm
self.norm_layer = norm_layer if self.patch_norm else None
self.norm_before_mlp = norm_before_mlp
self.mlp_ratio = mlp_ratio
self.use_checkpoint = use_checkpoint
# process mel-spec ; used only once
self.freq_ratio = self.spec_size // self.config.mel_bins
window = 'hann'
center = True
pad_mode = 'reflect'
ref = 1.0
amin = 1e-10
top_db = None
self.interpolate_ratio = 32 # Downsampled ratio
# Spectrogram extractor
self.spectrogram_extractor = Spectrogram(n_fft=config.window_size, hop_length=config.hop_size,
win_length=config.window_size, window=window, center=center, pad_mode=pad_mode,
freeze_parameters=True)
# Logmel feature extractor
self.logmel_extractor = LogmelFilterBank(sr=config.sample_rate, n_fft=config.window_size,
n_mels=config.mel_bins, fmin=config.fmin, fmax=config.fmax, ref=ref, amin=amin, top_db=top_db,
freeze_parameters=True)
# Spec augmenter
self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2,
freq_drop_width=8, freq_stripes_num=2) # 2 2
self.bn0 = nn.BatchNorm2d(self.config.mel_bins)
# split spctrogram into non-overlapping patches
self.patch_embed = PatchEmbed(
img_size=self.spec_size, patch_size=self.patch_size, in_chans=self.in_chans,
embed_dim=self.embed_dim, norm_layer=self.norm_layer, patch_stride = patch_stride)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.grid_size
self.patches_resolution = patches_resolution
# absolute position embedding
if self.ape:
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, self.embed_dim))
trunc_normal_(self.absolute_pos_embed, std=.02)
self.pos_drop = nn.Dropout(p=self.drop_rate)
# stochastic depth
dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, sum(self.depths))] # stochastic depth decay rule
# build layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = BasicLayer(dim=int(self.embed_dim * 2 ** i_layer),
input_resolution=(patches_resolution[0] // (2 ** i_layer),
patches_resolution[1] // (2 ** i_layer)),
depth=self.depths[i_layer],
num_heads=self.num_heads[i_layer],
window_size=self.window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=self.qkv_bias, qk_scale=self.qk_scale,
drop=self.drop_rate, attn_drop=self.attn_drop_rate,
drop_path=dpr[sum(self.depths[:i_layer]):sum(self.depths[:i_layer + 1])],
norm_layer=self.norm_layer,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
use_checkpoint=use_checkpoint,
norm_before_mlp=self.norm_before_mlp)
self.layers.append(layer)
self.norm = self.norm_layer(self.num_features)
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.maxpool = nn.AdaptiveMaxPool1d(1)
if self.config.enable_tscam:
SF = self.spec_size // (2 ** (len(self.depths) - 1)) // self.patch_stride[0] // self.freq_ratio
self.tscam_conv = nn.Conv2d(
in_channels = self.num_features,
out_channels = self.num_classes,
kernel_size = (SF,3),
padding = (0,1)
)
self.head = nn.Linear(num_classes, num_classes)
else:
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'absolute_pos_embed'}
@torch.jit.ignore
def no_weight_decay_keywords(self):
return {'relative_position_bias_table'}
def forward_features(self, x):
frames_num = x.shape[2]
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for i, layer in enumerate(self.layers):
x, attn = layer(x)
if self.config.enable_tscam:
# for x
x = self.norm(x)
B, N, C = x.shape
SF = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[0]
ST = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[1]
x = x.permute(0,2,1).contiguous().reshape(B, C, SF, ST)
B, C, F, T = x.shape
# group 2D CNN
c_freq_bin = F // self.freq_ratio
x = x.reshape(B, C, F // c_freq_bin, c_freq_bin, T)
x = x.permute(0,1,3,2,4).contiguous().reshape(B, C, c_freq_bin, -1)
# get latent_output
latent_output = self.avgpool(torch.flatten(x,2))
latent_output = torch.flatten(latent_output, 1)
# display the attention map, if needed
if self.config.htsat_attn_heatmap:
# for attn
attn = torch.mean(attn, dim = 1)
attn = torch.mean(attn, dim = 1)
attn = attn.reshape(B, SF, ST)
c_freq_bin = SF // self.freq_ratio
attn = attn.reshape(B, SF // c_freq_bin, c_freq_bin, ST)
attn = attn.permute(0,2,1,3).contiguous().reshape(B, c_freq_bin, -1)
attn = attn.mean(dim = 1)
attn_max = torch.max(attn, dim = 1, keepdim = True)[0]
attn_min = torch.min(attn, dim = 1, keepdim = True)[0]
attn = ((attn * 0.15) + (attn_max * 0.85 - attn_min)) / (attn_max - attn_min)
attn = attn.unsqueeze(dim = 2)
x = self.tscam_conv(x)
x = torch.flatten(x, 2) # B, C, T
if self.config.htsat_attn_heatmap:
fpx = interpolate(torch.sigmoid(x).permute(0,2,1).contiguous() * attn, 8 * self.patch_stride[1])
else:
fpx = interpolate(torch.sigmoid(x).permute(0,2,1).contiguous(), 8 * self.patch_stride[1])
x = self.avgpool(x)
x = torch.flatten(x, 1)
if self.config.loss_type == "clip_ce":
output_dict = {
'framewise_output': fpx, # already sigmoided
'clipwise_output': x,
'latent_output': latent_output
}
else:
output_dict = {
'framewise_output': fpx, # already sigmoided
'clipwise_output': torch.sigmoid(x),
'latent_output': latent_output
}
else:
x = self.norm(x) # B N C
B, N, C = x.shape
fpx = x.permute(0,2,1).contiguous().reshape(B, C, frames_num // (2 ** (len(self.depths) + 1)), frames_num // (2 ** (len(self.depths) + 1)) )
B, C, F, T = fpx.shape
c_freq_bin = F // self.freq_ratio
fpx = fpx.reshape(B, C, F // c_freq_bin, c_freq_bin, T)
fpx = fpx.permute(0,1,3,2,4).contiguous().reshape(B, C, c_freq_bin, -1)
fpx = torch.sum(fpx, dim = 2)
fpx = interpolate(fpx.permute(0,2,1).contiguous(), 8 * self.patch_stride[1])
x = self.avgpool(x.transpose(1, 2)) # B C 1
x = torch.flatten(x, 1)
if self.num_classes > 0:
x = self.head(x)
fpx = self.head(fpx)
output_dict = {'framewise_output': torch.sigmoid(fpx),
'clipwise_output': torch.sigmoid(x)}
return output_dict
def crop_wav(self, x, crop_size, spe_pos = None):
time_steps = x.shape[2]
tx = torch.zeros(x.shape[0], x.shape[1], crop_size, x.shape[3]).to(x.device)
for i in range(len(x)):
if spe_pos is None:
crop_pos = random.randint(0, time_steps - crop_size - 1)
else:
crop_pos = spe_pos
tx[i][0] = x[i, 0, crop_pos:crop_pos + crop_size,:]
return tx
# Reshape the wavform to a img size, if you want to use the pretrained swin transformer model
def reshape_wav2img(self, x):
B, C, T, F = x.shape
target_T = int(self.spec_size * self.freq_ratio)
target_F = self.spec_size // self.freq_ratio
assert T <= target_T and F <= target_F, "the wav size should less than or equal to the swin input size"
# to avoid bicubic zero error
if T < target_T:
x = nn.functional.interpolate(x, (target_T, x.shape[3]), mode="bicubic", align_corners=True)
if F < target_F:
x = nn.functional.interpolate(x, (x.shape[2], target_F), mode="bicubic", align_corners=True)
x = x.permute(0,1,3,2).contiguous()
x = x.reshape(x.shape[0], x.shape[1], x.shape[2], self.freq_ratio, x.shape[3] // self.freq_ratio)
# print(x.shape)
x = x.permute(0,1,3,2,4).contiguous()
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3], x.shape[4])
return x
# Repeat the wavform to a img size, if you want to use the pretrained swin transformer model
def repeat_wat2img(self, x, cur_pos):
B, C, T, F = x.shape
target_T = int(self.spec_size * self.freq_ratio)
target_F = self.spec_size // self.freq_ratio
assert T <= target_T and F <= target_F, "the wav size should less than or equal to the swin input size"
# to avoid bicubic zero error
if T < target_T:
x = nn.functional.interpolate(x, (target_T, x.shape[3]), mode="bicubic", align_corners=True)
if F < target_F:
x = nn.functional.interpolate(x, (x.shape[2], target_F), mode="bicubic", align_corners=True)
x = x.permute(0,1,3,2).contiguous() # B C F T
x = x[:,:,:,cur_pos:cur_pos + self.spec_size]
x = x.repeat(repeats = (1,1,4,1))
return x
def forward(self, x: torch.Tensor, mixup_lambda = None, infer_mode = False):# out_feat_keys: List[str] = None):
x = self.spectrogram_extractor(x) # (batch_size, 1, time_steps, freq_bins)
x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
x = x.transpose(1, 3)
x = self.bn0(x)
x = x.transpose(1, 3)
if self.training:
x = self.spec_augmenter(x)
if self.training and mixup_lambda is not None:
x = do_mixup(x, mixup_lambda)
if infer_mode:
# in infer mode. we need to handle different length audio input
frame_num = x.shape[2]
target_T = int(self.spec_size * self.freq_ratio)
repeat_ratio = math.floor(target_T / frame_num)
x = x.repeat(repeats=(1,1,repeat_ratio,1))
x = self.reshape_wav2img(x)
output_dict = self.forward_features(x)
elif self.config.enable_repeat_mode:
if self.training:
cur_pos = random.randint(0, (self.freq_ratio - 1) * self.spec_size - 1)
x = self.repeat_wat2img(x, cur_pos)
output_dict = self.forward_features(x)
else:
output_dicts = []
for cur_pos in range(0, (self.freq_ratio - 1) * self.spec_size + 1, self.spec_size):
tx = x.clone()
tx = self.repeat_wat2img(tx, cur_pos)
output_dicts.append(self.forward_features(tx))
clipwise_output = torch.zeros_like(output_dicts[0]["clipwise_output"]).float().to(x.device)
framewise_output = torch.zeros_like(output_dicts[0]["framewise_output"]).float().to(x.device)
for d in output_dicts:
clipwise_output += d["clipwise_output"]
framewise_output += d["framewise_output"]
clipwise_output = clipwise_output / len(output_dicts)
framewise_output = framewise_output / len(output_dicts)
output_dict = {
'framewise_output': framewise_output,
'clipwise_output': clipwise_output
}
else:
if x.shape[2] > self.freq_ratio * self.spec_size:
if self.training:
x = self.crop_wav(x, crop_size=self.freq_ratio * self.spec_size)
x = self.reshape_wav2img(x)
output_dict = self.forward_features(x)
else:
# Change: Hard code here
overlap_size = 344 #(x.shape[2] - 1) // 4
output_dicts = []
crop_size = 689 #(x.shape[2] - 1) // 2
for cur_pos in range(0, x.shape[2] - crop_size - 1, overlap_size):
tx = self.crop_wav(x, crop_size = crop_size, spe_pos = cur_pos)
tx = self.reshape_wav2img(tx)
output_dicts.append(self.forward_features(tx))
clipwise_output = torch.zeros_like(output_dicts[0]["clipwise_output"]).float().to(x.device)
framewise_output = torch.zeros_like(output_dicts[0]["framewise_output"]).float().to(x.device)
latent_output = torch.zeros_like(output_dicts[0]["latent_output"]).float().to(x.device)
for d in output_dicts:
clipwise_output += d["clipwise_output"]
framewise_output += d["framewise_output"]
latent_output += d["latent_output"]
clipwise_output = clipwise_output / len(output_dicts)
framewise_output = framewise_output / len(output_dicts)
latent_output = latent_output / len(output_dicts)
output_dict = {
'framewise_output': framewise_output,
'clipwise_output': clipwise_output,
'latent_output': latent_output,
}
else: # this part is typically used, and most easy one
x = self.reshape_wav2img(x)
output_dict = self.forward_features(x)
# x = self.head(x)
return output_dict
class HTSATWrapper(nn.Module):
def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin,
fmax, classes_num, out_emb):
super().__init__()
# print("parameters are being overidden when using HTSAT")
# print("HTSAT only support loading a pretrained model on AudioSet")
# @TODO later look at what parameters are same and can be merged
self.htsat = HTSAT_Swin_Transformer(config=config)
def forward(self, x):
out_dict = self.htsat(x)
out_dict['embedding'] = out_dict['latent_output']
return out_dict
+200
View File
@@ -0,0 +1,200 @@
import torch
import torch.nn as nn
from torch.nn import functional as nnf
from torch.utils.data import Dataset, DataLoader
from enum import Enum
from transformers import GPT2LMHeadModel
from typing import Tuple, Optional, Union
def get_clapcap(name: str):
if name == "ClapCaption":
return ClapCaptionModel
else:
raise Exception('The ClapCap model {} is incorrect or not supported'.format(name))
class MappingType(Enum):
MLP = 'mlp'
Transformer = 'transformer'
class MLP(nn.Module):
def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
super(MLP, self).__init__()
layers = []
for i in range(len(sizes) - 1):
layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
if i < len(sizes) - 2:
layers.append(act())
self.model = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.model(x)
class MlpTransformer(nn.Module):
def __init__(self, in_dim, h_dim, out_d: Optional[int] = None, act=nnf.relu, dropout=0.):
super().__init__()
out_d = out_d if out_d is not None else in_dim
self.fc1 = nn.Linear(in_dim, h_dim)
self.act = act
self.fc2 = nn.Linear(h_dim, out_d)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return x
class MultiHeadAttention(nn.Module):
def __init__(self, dim_self, dim_ref, num_heads, bias=True, dropout=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim_self // num_heads
self.scale = head_dim ** -0.5
self.to_queries = nn.Linear(dim_self, dim_self, bias=bias)
self.to_keys_values = nn.Linear(dim_ref, dim_self * 2, bias=bias)
self.project = nn.Linear(dim_self, dim_self)
self.dropout = nn.Dropout(dropout)
def forward(self, x, y=None, mask=None):
y = y if y is not None else x
b, n, c = x.shape
_, m, d = y.shape
# b n h dh
queries = self.to_queries(x).reshape(b, n, self.num_heads, c // self.num_heads)
# b m 2 h dh
keys_values = self.to_keys_values(y).reshape(b, m, 2, self.num_heads, c // self.num_heads)
keys, values = keys_values[:, :, 0], keys_values[:, :, 1]
attention = torch.einsum('bnhd,bmhd->bnmh', queries, keys) * self.scale
if mask is not None:
if mask.dim() == 2:
mask = mask.unsqueeze(1)
attention = attention.masked_fill(mask.unsqueeze(3), float("-inf"))
attention = attention.softmax(dim=2)
out = torch.einsum('bnmh,bmhd->bnhd', attention, values).reshape(b, n, c)
out = self.project(out)
return out, attention
class TransformerLayer(nn.Module):
def forward_with_attention(self, x, y=None, mask=None):
x_, attention = self.attn(self.norm1(x), y, mask)
x = x + x_
x = x + self.mlp(self.norm2(x))
return x, attention
def forward(self, x, y=None, mask=None):
x = x + self.attn(self.norm1(x), y, mask)[0]
x = x + self.mlp(self.norm2(x))
return x
def __init__(self, dim_self, dim_ref, num_heads, mlp_ratio=4., bias=False, dropout=0., act=nnf.relu,
norm_layer: nn.Module = nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim_self)
self.attn = MultiHeadAttention(dim_self, dim_ref, num_heads, bias=bias, dropout=dropout)
self.norm2 = norm_layer(dim_self)
self.mlp = MlpTransformer(dim_self, int(dim_self * mlp_ratio), act=act, dropout=dropout)
class Transformer(nn.Module):
def __init__(self, dim_self: int, num_heads: int, num_layers: int, dim_ref: Optional[int] = None,
mlp_ratio: float = 2., act=nnf.relu, norm_layer: nn.Module = nn.LayerNorm, enc_dec: bool = False):
super(Transformer, self).__init__()
dim_ref = dim_ref if dim_ref is not None else dim_self
self.enc_dec = enc_dec
if enc_dec:
num_layers = num_layers * 2
layers = []
for i in range(num_layers):
if i % 2 == 0 and enc_dec: # cross
layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
elif enc_dec: # self
layers.append(TransformerLayer(dim_self, dim_self, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
else: # self or cross
layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
self.layers = nn.ModuleList(layers)
def forward_with_attention(self, x, y=None, mask=None):
attentions = []
for layer in self.layers:
x, att = layer.forward_with_attention(x, y, mask)
attentions.append(att)
return x, attentions
def forward(self, x, y=None, mask=None):
for i, layer in enumerate(self.layers):
if i % 2 == 0 and self.enc_dec: # cross
x = layer(x, y)
elif self.enc_dec: # self
x = layer(x, x, mask)
else: # self or cross
x = layer(x, y, mask)
return x
class TransformerMapper(nn.Module):
def __init__(self, dim_clip: int, dim_embedding: int, prefix_length: int, clip_length: int, num_layers: int = 8):
super(TransformerMapper, self).__init__()
self.clip_length = clip_length
self.transformer = Transformer(dim_embedding, 8, num_layers)
self.linear = nn.Linear(dim_clip, clip_length * dim_embedding)
self.prefix_const = nn.Parameter(torch.randn(prefix_length, dim_embedding), requires_grad=True)
def forward(self, x):
x = self.linear(x).view(x.shape[0], self.clip_length, -1)
prefix = self.prefix_const.unsqueeze(0).expand(x.shape[0], *self.prefix_const.shape)
prefix = torch.cat((x, prefix), dim=1)
out = self.transformer(prefix)[:, self.clip_length:]
return out
class ClapCaptionModel(nn.Module):
def __init__(self, clap, text_decoder: str, prefix_length: int, clip_length: Optional[int] = None, prefix_size: int = 512,
num_layers: int = 8, normalize_prefix: bool = True, mapping_type: str = None,\
freeze_audio_encoder_weights: bool = True, freeze_gpt_weights: bool = True):
super(ClapCaptionModel, self).__init__()
self.clap = clap.audio_encoder
self.prefix_length = prefix_length
self.normalize_prefix = normalize_prefix
self.gpt = GPT2LMHeadModel.from_pretrained(text_decoder)
self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
if mapping_type == 'mlp':
self.clap_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2,
self.gpt_embedding_size * prefix_length))
else:
self.clap_project = TransformerMapper(prefix_size, self.gpt_embedding_size, prefix_length,
clip_length, num_layers)
# Freeze all CLAP parameters
if freeze_audio_encoder_weights:
for p in self.clap.parameters():
p.requires_grad = False
if freeze_gpt_weights:
for p in self.gpt.parameters():
p.requires_grad = False
def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor:
return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)
def forward(self, audios: torch.Tensor, tokens: torch.Tensor, mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None):
# get audio embeddings
prefix, _ = self.clap(audios)
# normalize prefix (audio embedding)
if self.normalize_prefix:
prefix = prefix / prefix.norm(2, -1).reshape(-1,1)
embedding_text = self.gpt.transformer.wte(tokens['input_ids'])
prefix_projections = self.clap_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size)
embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
if labels is not None:
dummy_token = self.get_dummy_token(tokens['input_ids'].shape[0], tokens['input_ids'].device)
labels = torch.cat((dummy_token, tokens), dim=1)
out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
return out
+184
View File
@@ -0,0 +1,184 @@
import numpy as np
import time
import torch
import torch.nn as nn
def move_data_to_device(x, device):
if 'float' in str(x.dtype):
x = torch.Tensor(x)
elif 'int' in str(x.dtype):
x = torch.LongTensor(x)
else:
return x
return x.to(device)
def do_mixup(x, mixup_lambda):
"""Mixup x of even indexes (0, 2, 4, ...) with x of odd indexes
(1, 3, 5, ...).
Args:
x: (batch_size * 2, ...)
mixup_lambda: (batch_size * 2,)
Returns:
out: (batch_size, ...)
"""
out = (x[0 :: 2].transpose(0, -1) * mixup_lambda[0 :: 2] + \
x[1 :: 2].transpose(0, -1) * mixup_lambda[1 :: 2]).transpose(0, -1)
return out
def append_to_dict(dict, key, value):
if key in dict.keys():
dict[key].append(value)
else:
dict[key] = [value]
def interpolate(x, ratio):
"""Interpolate data in time domain. This is used to compensate the
resolution reduction in downsampling of a CNN.
Args:
x: (batch_size, time_steps, classes_num)
ratio: int, ratio to interpolate
Returns:
upsampled: (batch_size, time_steps * ratio, classes_num)
"""
(batch_size, time_steps, classes_num) = x.shape
upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1)
upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num)
return upsampled
def pad_framewise_output(framewise_output, frames_num):
"""Pad framewise_output to the same length as input frames. The pad value
is the same as the value of the last frame.
Args:
framewise_output: (batch_size, frames_num, classes_num)
frames_num: int, number of frames to pad
Outputs:
output: (batch_size, frames_num, classes_num)
"""
pad = framewise_output[:, -1 :, :].repeat(1, frames_num - framewise_output.shape[1], 1)
"""tensor for padding"""
output = torch.cat((framewise_output, pad), dim=1)
"""(batch_size, frames_num, classes_num)"""
return output
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def count_flops(model, audio_length):
"""Count flops. Code modified from others' implementation.
"""
multiply_adds = True
list_conv2d=[]
def conv2d_hook(self, input, output):
batch_size, input_channels, input_height, input_width = input[0].size()
output_channels, output_height, output_width = output[0].size()
kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups) * (2 if multiply_adds else 1)
bias_ops = 1 if self.bias is not None else 0
params = output_channels * (kernel_ops + bias_ops)
flops = batch_size * params * output_height * output_width
list_conv2d.append(flops)
list_conv1d=[]
def conv1d_hook(self, input, output):
batch_size, input_channels, input_length = input[0].size()
output_channels, output_length = output[0].size()
kernel_ops = self.kernel_size[0] * (self.in_channels / self.groups) * (2 if multiply_adds else 1)
bias_ops = 1 if self.bias is not None else 0
params = output_channels * (kernel_ops + bias_ops)
flops = batch_size * params * output_length
list_conv1d.append(flops)
list_linear=[]
def linear_hook(self, input, output):
batch_size = input[0].size(0) if input[0].dim() == 2 else 1
weight_ops = self.weight.nelement() * (2 if multiply_adds else 1)
bias_ops = self.bias.nelement()
flops = batch_size * (weight_ops + bias_ops)
list_linear.append(flops)
list_bn=[]
def bn_hook(self, input, output):
list_bn.append(input[0].nelement() * 2)
list_relu=[]
def relu_hook(self, input, output):
list_relu.append(input[0].nelement() * 2)
list_pooling2d=[]
def pooling2d_hook(self, input, output):
batch_size, input_channels, input_height, input_width = input[0].size()
output_channels, output_height, output_width = output[0].size()
kernel_ops = self.kernel_size * self.kernel_size
bias_ops = 0
params = output_channels * (kernel_ops + bias_ops)
flops = batch_size * params * output_height * output_width
list_pooling2d.append(flops)
list_pooling1d=[]
def pooling1d_hook(self, input, output):
batch_size, input_channels, input_length = input[0].size()
output_channels, output_length = output[0].size()
kernel_ops = self.kernel_size[0]
bias_ops = 0
params = output_channels * (kernel_ops + bias_ops)
flops = batch_size * params * output_length
list_pooling2d.append(flops)
def foo(net):
childrens = list(net.children())
if not childrens:
if isinstance(net, nn.Conv2d):
net.register_forward_hook(conv2d_hook)
elif isinstance(net, nn.Conv1d):
net.register_forward_hook(conv1d_hook)
elif isinstance(net, nn.Linear):
net.register_forward_hook(linear_hook)
elif isinstance(net, nn.BatchNorm2d) or isinstance(net, nn.BatchNorm1d):
net.register_forward_hook(bn_hook)
elif isinstance(net, nn.ReLU):
net.register_forward_hook(relu_hook)
elif isinstance(net, nn.AvgPool2d) or isinstance(net, nn.MaxPool2d):
net.register_forward_hook(pooling2d_hook)
elif isinstance(net, nn.AvgPool1d) or isinstance(net, nn.MaxPool1d):
net.register_forward_hook(pooling1d_hook)
else:
print('Warning: flop of module {} is not counted!'.format(net))
return
for c in childrens:
foo(c)
# Register hook
foo(model)
device = device = next(model.parameters()).device
input = torch.rand(1, audio_length).to(device)
out = model(input)
total_flops = sum(list_conv2d) + sum(list_conv1d) + sum(list_linear) + \
sum(list_bn) + sum(list_relu) + sum(list_pooling2d) + sum(list_pooling1d)
return total_flops
+7 -7
View File
@@ -1,5 +1,5 @@
"""
This is an example using CLAP to perform zeroshot
This is an example using CLAP to perform zeroshot
classification on ESC50 (https://github.com/karolpiczak/ESC-50).
"""
@@ -11,15 +11,14 @@ from tqdm import tqdm
from sklearn.metrics import accuracy_score
# Load dataset
dataset = ESC50(root="data_path", download=False)
prompt = 'this is a sound of '
root_path = "root_path" # Folder with ESC-50-master/
dataset = ESC50(root=root_path, download=True) #If download=False code assumes base_folder='ESC-50-master' in esc50_dataset.py
prompt = 'this is the sound of '
y = [prompt + x for x in dataset.classes]
# Load and initialize CLAP
weights_path = "weights_path"
clap_model = CLAPWrapper(weights_path, use_cuda=False)
clap_model = CLAPWrapper(weights_path, version = '2023', use_cuda=False)
# Computing text embeddings
text_embeddings = clap_model.get_text_embeddings(y)
@@ -34,6 +33,7 @@ for i in tqdm(range(len(dataset))):
y_preds.append(y_pred)
y_labels.append(one_hot_target.detach().cpu().numpy())
y_labels, y_preds = np.concatenate(y_labels, axis=0), np.concatenate(y_preds, axis=0)
acc = accuracy_score(np.argmax(y_labels, axis=1), np.argmax(y_preds, axis=1))
print('ESC50 Accuracy {}'.format(acc))
@@ -41,6 +41,6 @@ print('ESC50 Accuracy {}'.format(acc))
"""
The output:
ESC50 Accuracy: 82.6%
ESC50 Accuracy: 93.9%
"""
+23 -24
View File
@@ -1,30 +1,29 @@
"""
This is an example using CLAP for zero-shot
inference using ESC50 (https://github.com/karolpiczak/ESC-50).
This is an example using CLAP for zero-shot inference.
"""
from CLAPWrapper import CLAPWrapper
from esc50_dataset import ESC50
import torch.nn.functional as F
# Load ESC50 dataset
dataset = ESC50(root="data_path", download=True) # set download=True when dataset is not downloaded
audio_file, target, one_hot_target = dataset[1000]
audio_file = [audio_file]
# Define classes for zero-shot
# Should be in lower case and can be more than one word
classes = ['coughing','sneezing','drinking sipping', 'breathing', 'brushing teeth']
ground_truth = ['coughing']
# Add prompt
prompt = 'this is a sound of '
y = [prompt + x for x in dataset.classes]
class_prompts = [prompt + x for x in classes]
#Load audio files
audio_files = ['audio_file']
# Load and initialize CLAP
weights_path = "weights_path"
# Setting use_cuda = True will load the model on a GPU using CUDA
clap_model = CLAPWrapper(weights_path, use_cuda=False)
clap_model = CLAPWrapper(weights_path, version = '2023', use_cuda=False)
# compute text embeddings from natural text
text_embeddings = clap_model.get_text_embeddings(y)
# compute text embeddings from natural text
text_embeddings = clap_model.get_text_embeddings(class_prompts)
# compute the audio embeddings from an audio file
audio_embeddings = clap_model.get_audio_embeddings(audio_file, resample=True)
# compute the audio embeddings from an audio file
audio_embeddings = clap_model.get_audio_embeddings(audio_files, resample=True)
# compute the similarity between audio_embeddings and text_embeddings
similarity = clap_model.compute_similarity(audio_embeddings, text_embeddings)
@@ -32,11 +31,11 @@ similarity = clap_model.compute_similarity(audio_embeddings, text_embeddings)
similarity = F.softmax(similarity, dim=1)
values, indices = similarity[0].topk(5)
# view the results
print("Ground Truth: {}".format(target))
# Print the results
print("Ground Truth: {}".format(ground_truth))
print("Top predictions:\n")
for value, index in zip(values, indices):
print(f"{dataset.classes[index]:>16s}: {100 * value.item():.2f}%")
print(f"{classes[index]:>16s}: {100 * value.item():.2f}%")
"""
The output (the exact numbers may vary):
@@ -44,9 +43,9 @@ The output (the exact numbers may vary):
Ground Truth: coughing
Top predictions:
coughing: 86.34%
sneezing: 9.30%
drinking sipping: 1.31%
laughing: 1.20%
glass breaking: 0.81%
"""
coughing: 98.55%
sneezing: 1.24%
drinking sipping: 0.15%
breathing: 0.02%
brushing teeth: 0.01%
"""