estimation in pytorch
This commit is contained in:
+10
-8
@@ -1,26 +1,26 @@
|
||||
|
||||
from json import load
|
||||
import extract_features as features
|
||||
import argparse
|
||||
from helpers.textgrid import *
|
||||
from helpers.utilities import *
|
||||
import shutil
|
||||
|
||||
from load_estimation_model import load_estimation_model
|
||||
|
||||
|
||||
def predict_from_times(wav_filename, preds_filename, begin, end):
|
||||
tmp_features_filename = tempfile._get_default_tempdir() + "/" + next(tempfile._get_candidate_names()) + ".txt"
|
||||
print(tmp_features_filename)
|
||||
print("Input Array Path: " + tmp_features_filename)
|
||||
|
||||
if begin > 0.0 or end > 0.0:
|
||||
print(wav_filename + " interval " + str(begin) + "-" + str(end) + ":")
|
||||
features.create_features(wav_filename, tmp_features_filename, begin, end)
|
||||
easy_call("luajit load_estimation_model.lua " + tmp_features_filename + ' ' + preds_filename)
|
||||
load_estimation_model(tmp_features_filename, preds_filename)
|
||||
#easy_call("luajit load_estimation_model.lua " + tmp_features_filename + ' ' + preds_filename)
|
||||
else:
|
||||
features.create_features(wav_filename, tmp_features_filename)
|
||||
easy_call("luajit load_tracking_model.lua " + tmp_features_filename + ' ' + preds_filename)
|
||||
|
||||
|
||||
def predict_from_textgrid(wav_filename, preds_filename, textgrid_filename, textgrid_tier):
|
||||
|
||||
print(wav_filename)
|
||||
|
||||
if os.path.exists(preds_filename):
|
||||
@@ -42,7 +42,8 @@ def predict_from_textgrid(wav_filename, preds_filename, textgrid_filename, textg
|
||||
tmp_features_filename = generate_tmp_filename("features")
|
||||
tmp_preds = generate_tmp_filename("preds")
|
||||
features.create_features(wav_filename, tmp_features_filename, interval.xmin(), interval.xmax())
|
||||
easy_call("th load_estimation_model.lua " + tmp_features_filename + ' ' + tmp_preds)
|
||||
load_estimation_model(tmp_features_filename, tmp_preds)
|
||||
#easy_call("th load_estimation_model.lua " + tmp_features_filename + ' ' + tmp_preds)
|
||||
csv_append_row(tmp_preds, preds_filename)
|
||||
else: # process first tier
|
||||
for interval in textgrid[0]:
|
||||
@@ -50,7 +51,8 @@ def predict_from_textgrid(wav_filename, preds_filename, textgrid_filename, textg
|
||||
tmp_features_filename = generate_tmp_filename("features")
|
||||
tmp_preds = generate_tmp_filename("preds")
|
||||
features.create_features(wav_filename, tmp_features_filename, interval.xmin(), interval.xmax())
|
||||
easy_call("th load_estimation_model.lua " + tmp_features_filename + ' ' + tmp_preds)
|
||||
load_estimation_model(tmp_features_filename, tmp_preds)
|
||||
#easy_call("th load_estimation_model.lua " + tmp_features_filename + ' ' + tmp_preds)
|
||||
csv_append_row(tmp_preds, preds_filename)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -28,17 +28,17 @@ class LambdaReduce(LambdaBase):
|
||||
|
||||
def load_estimation_model(inputfilename, outputfilename):
|
||||
with open(inputfilename, "r") as rf:
|
||||
contents = rf.read()
|
||||
contents = contents.split(",")
|
||||
contents = rf.read()
|
||||
contents = contents.split(",")
|
||||
|
||||
data = torch.Tensor(1,350)
|
||||
name = ""
|
||||
for i in range(len(contents)):
|
||||
if i == 0:
|
||||
name = contents[i].strip()
|
||||
else:
|
||||
val = float(contents[i].strip())
|
||||
data[0][i-1] = val
|
||||
if i == 0:
|
||||
name = contents[i].strip()
|
||||
else:
|
||||
val = float(contents[i].strip())
|
||||
data[0][i-1] = val
|
||||
|
||||
model = nn.Sequential( # Sequential,
|
||||
nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(350,1024)), # Linear,
|
||||
|
||||
Reference in New Issue
Block a user