182 lines
6.0 KiB
Python
182 lines
6.0 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
def move_data_to_device(x, device):
|
|
if 'float' in str(x.dtype):
|
|
x = torch.Tensor(x)
|
|
elif 'int' in str(x.dtype):
|
|
x = torch.LongTensor(x)
|
|
else:
|
|
return x
|
|
|
|
return x.to(device)
|
|
|
|
|
|
def do_mixup(x, mixup_lambda):
|
|
"""Mixup x of even indexes (0, 2, 4, ...) with x of odd indexes
|
|
(1, 3, 5, ...).
|
|
Args:
|
|
x: (batch_size * 2, ...)
|
|
mixup_lambda: (batch_size * 2,)
|
|
Returns:
|
|
out: (batch_size, ...)
|
|
"""
|
|
out = (x[0 :: 2].transpose(0, -1) * mixup_lambda[0 :: 2] + \
|
|
x[1 :: 2].transpose(0, -1) * mixup_lambda[1 :: 2]).transpose(0, -1)
|
|
return out
|
|
|
|
|
|
def append_to_dict(dict, key, value):
|
|
if key in dict.keys():
|
|
dict[key].append(value)
|
|
else:
|
|
dict[key] = [value]
|
|
|
|
|
|
def interpolate(x, ratio):
|
|
"""Interpolate data in time domain. This is used to compensate the
|
|
resolution reduction in downsampling of a CNN.
|
|
|
|
Args:
|
|
x: (batch_size, time_steps, classes_num)
|
|
ratio: int, ratio to interpolate
|
|
Returns:
|
|
upsampled: (batch_size, time_steps * ratio, classes_num)
|
|
"""
|
|
(batch_size, time_steps, classes_num) = x.shape
|
|
upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1)
|
|
upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num)
|
|
return upsampled
|
|
|
|
|
|
def pad_framewise_output(framewise_output, frames_num):
|
|
"""Pad framewise_output to the same length as input frames. The pad value
|
|
is the same as the value of the last frame.
|
|
Args:
|
|
framewise_output: (batch_size, frames_num, classes_num)
|
|
frames_num: int, number of frames to pad
|
|
Outputs:
|
|
output: (batch_size, frames_num, classes_num)
|
|
"""
|
|
pad = framewise_output[:, -1 :, :].repeat(1, frames_num - framewise_output.shape[1], 1)
|
|
"""tensor for padding"""
|
|
|
|
output = torch.cat((framewise_output, pad), dim=1)
|
|
"""(batch_size, frames_num, classes_num)"""
|
|
|
|
return output
|
|
|
|
|
|
def count_parameters(model):
|
|
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
|
|
|
|
def count_flops(model, audio_length):
|
|
"""Count flops. Code modified from others' implementation.
|
|
"""
|
|
multiply_adds = True
|
|
list_conv2d=[]
|
|
def conv2d_hook(self, input, output):
|
|
batch_size, input_channels, input_height, input_width = input[0].size()
|
|
output_channels, output_height, output_width = output[0].size()
|
|
|
|
kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups) * (2 if multiply_adds else 1)
|
|
bias_ops = 1 if self.bias is not None else 0
|
|
|
|
params = output_channels * (kernel_ops + bias_ops)
|
|
flops = batch_size * params * output_height * output_width
|
|
|
|
list_conv2d.append(flops)
|
|
|
|
list_conv1d=[]
|
|
def conv1d_hook(self, input, output):
|
|
batch_size, input_channels, input_length = input[0].size()
|
|
output_channels, output_length = output[0].size()
|
|
|
|
kernel_ops = self.kernel_size[0] * (self.in_channels / self.groups) * (2 if multiply_adds else 1)
|
|
bias_ops = 1 if self.bias is not None else 0
|
|
|
|
params = output_channels * (kernel_ops + bias_ops)
|
|
flops = batch_size * params * output_length
|
|
|
|
list_conv1d.append(flops)
|
|
|
|
list_linear=[]
|
|
def linear_hook(self, input, output):
|
|
batch_size = input[0].size(0) if input[0].dim() == 2 else 1
|
|
|
|
weight_ops = self.weight.nelement() * (2 if multiply_adds else 1)
|
|
bias_ops = self.bias.nelement()
|
|
|
|
flops = batch_size * (weight_ops + bias_ops)
|
|
list_linear.append(flops)
|
|
|
|
list_bn=[]
|
|
def bn_hook(self, input, output):
|
|
list_bn.append(input[0].nelement() * 2)
|
|
|
|
list_relu=[]
|
|
def relu_hook(self, input, output):
|
|
list_relu.append(input[0].nelement() * 2)
|
|
|
|
list_pooling2d=[]
|
|
def pooling2d_hook(self, input, output):
|
|
batch_size, input_channels, input_height, input_width = input[0].size()
|
|
output_channels, output_height, output_width = output[0].size()
|
|
|
|
kernel_ops = self.kernel_size * self.kernel_size
|
|
bias_ops = 0
|
|
params = output_channels * (kernel_ops + bias_ops)
|
|
flops = batch_size * params * output_height * output_width
|
|
|
|
list_pooling2d.append(flops)
|
|
|
|
list_pooling1d=[]
|
|
def pooling1d_hook(self, input, output):
|
|
batch_size, input_channels, input_length = input[0].size()
|
|
output_channels, output_length = output[0].size()
|
|
|
|
kernel_ops = self.kernel_size[0]
|
|
bias_ops = 0
|
|
|
|
params = output_channels * (kernel_ops + bias_ops)
|
|
flops = batch_size * params * output_length
|
|
|
|
list_pooling2d.append(flops)
|
|
|
|
def foo(net):
|
|
childrens = list(net.children())
|
|
if not childrens:
|
|
if isinstance(net, nn.Conv2d):
|
|
net.register_forward_hook(conv2d_hook)
|
|
elif isinstance(net, nn.Conv1d):
|
|
net.register_forward_hook(conv1d_hook)
|
|
elif isinstance(net, nn.Linear):
|
|
net.register_forward_hook(linear_hook)
|
|
elif isinstance(net, nn.BatchNorm2d) or isinstance(net, nn.BatchNorm1d):
|
|
net.register_forward_hook(bn_hook)
|
|
elif isinstance(net, nn.ReLU):
|
|
net.register_forward_hook(relu_hook)
|
|
elif isinstance(net, nn.AvgPool2d) or isinstance(net, nn.MaxPool2d):
|
|
net.register_forward_hook(pooling2d_hook)
|
|
elif isinstance(net, nn.AvgPool1d) or isinstance(net, nn.MaxPool1d):
|
|
net.register_forward_hook(pooling1d_hook)
|
|
else:
|
|
print('Warning: flop of module {} is not counted!'.format(net))
|
|
return
|
|
for c in childrens:
|
|
foo(c)
|
|
|
|
# Register hook
|
|
foo(model)
|
|
|
|
device = device = next(model.parameters()).device
|
|
input = torch.rand(1, audio_length).to(device)
|
|
|
|
out = model(input)
|
|
|
|
total_flops = sum(list_conv2d) + sum(list_conv1d) + sum(list_linear) + \
|
|
sum(list_bn) + sum(list_relu) + sum(list_pooling2d) + sum(list_pooling1d)
|
|
|
|
return total_flops |