diff --git a/src/CLAPWrapper.py b/src/CLAPWrapper.py index 06b3f67..542ab6f 100644 --- a/src/CLAPWrapper.py +++ b/src/CLAPWrapper.py @@ -7,8 +7,8 @@ import collections import re import numpy as np from transformers import AutoTokenizer, logging -from models.clap import CLAP -from models.mapper import get_clapcap +from .models.clap import CLAP +from .models.mapper import get_clapcap import math import torchaudio.transforms as T import os diff --git a/src/models/audio.py b/src/models/audio.py index 6735b55..9d3749a 100644 --- a/src/models/audio.py +++ b/src/models/audio.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn import torch.nn.functional as F from torchlibrosa.stft import Spectrogram, LogmelFilterBank -from models.htsat import HTSATWrapper +from .htsat import HTSATWrapper def get_audio_encoder(name: str): if name == "Cnn14": diff --git a/src/models/htsat.py b/src/models/htsat.py index 7a4f528..1504c00 100644 --- a/src/models/htsat.py +++ b/src/models/htsat.py @@ -6,11 +6,8 @@ # Swin Transformer for Computer Vision: https://arxiv.org/pdf/2103.14030.pdf -import logging -import pdb import math import random -from numpy.core.fromnumeric import clip, reshape import torch import torch.nn as nn import torch.utils.checkpoint as checkpoint @@ -19,15 +16,10 @@ from torchlibrosa.stft import Spectrogram, LogmelFilterBank from torchlibrosa.augmentation import SpecAugmentation from itertools import repeat -from typing import List -try: - from models.pytorch_utils import do_mixup, interpolate - import models.config as config -except: - from CLAP_API.models.pytorch_utils import do_mixup, interpolate - from CLAP_API.models import config -import torch.nn.functional as F +from .pytorch_utils import do_mixup, interpolate +from . import config + import collections.abc import warnings diff --git a/src/models/mapper.py b/src/models/mapper.py index 6ad3733..fd62b93 100644 --- a/src/models/mapper.py +++ b/src/models/mapper.py @@ -2,10 +2,9 @@ import torch import torch.nn as nn from torch.nn import functional as nnf -from torch.utils.data import Dataset, DataLoader from enum import Enum from transformers import GPT2LMHeadModel -from typing import Tuple, Optional, Union +from typing import Tuple, Optional def get_clapcap(name: str): if name == "ClapCaption": diff --git a/src/models/pytorch_utils.py b/src/models/pytorch_utils.py index 453d3ce..eadf619 100644 --- a/src/models/pytorch_utils.py +++ b/src/models/pytorch_utils.py @@ -1,5 +1,3 @@ -import numpy as np -import time import torch import torch.nn as nn