diff --git a/src/CLAPWrapper.py b/src/CLAPWrapper.py index d69d3bf..7db3622 100644 --- a/src/CLAPWrapper.py +++ b/src/CLAPWrapper.py @@ -149,7 +149,7 @@ class CLAPWrapper(): audio_tensor = self.load_audio_into_tensor( audio_file, self.args.duration, resample) audio_tensor = audio_tensor.reshape( - 1, -1).cuda if self.use_cuda and torch.cuda.is_available() else audio_tensor.reshape(1, -1) + 1, -1).cuda() if self.use_cuda and torch.cuda.is_available() else audio_tensor.reshape(1, -1) audio_tensors.append(audio_tensor) return self.default_collate(audio_tensors)