90 lines
3.0 KiB
Python
90 lines
3.0 KiB
Python
import numpy as np
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch import nn
|
|
from transformers import AutoModel
|
|
from .audio import get_audio_encoder
|
|
|
|
class Projection(nn.Module):
|
|
def __init__(self, d_in: int, d_out: int, p: float=0.5) -> None:
|
|
super().__init__()
|
|
self.linear1 = nn.Linear(d_in, d_out, bias=False)
|
|
self.linear2 = nn.Linear(d_out, d_out, bias=False)
|
|
self.layer_norm = nn.LayerNorm(d_out)
|
|
self.drop = nn.Dropout(p)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
embed1 = self.linear1(x)
|
|
embed2 = self.drop(self.linear2(F.gelu(embed1)))
|
|
embeds = self.layer_norm(embed1 + embed2)
|
|
return embeds
|
|
|
|
class AudioEncoder(nn.Module):
|
|
def __init__(self, audioenc_name:str, d_in: int, d_out: int, sample_rate: int, window_size: int,
|
|
hop_size: int, mel_bins: int, fmin: int, fmax: int, classes_num: int) -> None:
|
|
super().__init__()
|
|
|
|
audio_encoder = get_audio_encoder(audioenc_name)
|
|
|
|
self.base = audio_encoder(
|
|
sample_rate, window_size,
|
|
hop_size, mel_bins, fmin, fmax,
|
|
classes_num, d_in)
|
|
|
|
self.projection = Projection(d_in, d_out)
|
|
|
|
def forward(self, x):
|
|
out_dict = self.base(x)
|
|
audio_features, audio_classification_output = out_dict['embedding'], out_dict['clipwise_output']
|
|
projected_vec = self.projection(audio_features)
|
|
return projected_vec, audio_classification_output
|
|
|
|
class TextEncoder(nn.Module):
|
|
def __init__(self, d_out: int, text_model: str, transformer_embed_dim: int) -> None:
|
|
super().__init__()
|
|
self.base = AutoModel.from_pretrained(text_model)
|
|
|
|
self.projection = Projection(transformer_embed_dim, d_out)
|
|
|
|
def forward(self, x):
|
|
out = self.base(**x)[0]
|
|
out = out[:, 0, :] # get CLS token output
|
|
projected_vec = self.projection(out)
|
|
return projected_vec
|
|
|
|
class CLAP(nn.Module):
|
|
def __init__(self,
|
|
# audio
|
|
audioenc_name: str,
|
|
sample_rate: int,
|
|
window_size: int,
|
|
hop_size: int,
|
|
mel_bins: int,
|
|
fmin: int,
|
|
fmax: int,
|
|
classes_num: int,
|
|
out_emb: int,
|
|
# text
|
|
text_model: str,
|
|
transformer_embed_dim: int,
|
|
# common
|
|
d_proj: int,
|
|
):
|
|
super().__init__()
|
|
|
|
|
|
self.audio_encoder = AudioEncoder(
|
|
audioenc_name, out_emb, d_proj,
|
|
sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num)
|
|
|
|
self.caption_encoder = TextEncoder(
|
|
d_proj, text_model, transformer_embed_dim
|
|
)
|
|
|
|
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
|
|
|
def forward(self, audio, text):
|
|
audio_embed, _ = self.audio_encoder(audio)
|
|
caption_embed = self.caption_encoder(text)
|
|
|
|
return caption_embed, audio_embed, self.logit_scale.exp() |