optional csv export

This commit is contained in:
Jason
2022-04-03 11:43:20 -07:00
parent ca26e3f058
commit 826f5bc486
2 changed files with 11 additions and 8 deletions
+2 -2
View File
@@ -5,14 +5,14 @@ from helpers.utilities import *
from load_estimation_model import load_estimation_model
def predict_from_times(wav_filename, preds_filename, begin, end):
def predict_from_times(wav_filename, preds_filename, begin, end, csv_export=True):
tmp_features_filename = "temp/" + next(tempfile._get_candidate_names()) + ".txt"
print("Input Array Path: " + tmp_features_filename)
if begin > 0.0 or end > 0.0:
print(wav_filename + " interval " + str(begin) + "-" + str(end) + ":")
features.create_features(wav_filename, tmp_features_filename, begin, end)
load_estimation_model(tmp_features_filename, preds_filename, begin, end)
load_estimation_model(tmp_features_filename, preds_filename, begin, end, csv_export=csv_export)
#easy_call("luajit load_estimation_model.lua " + tmp_features_filename + ' ' + preds_filename)
else:
features.create_features(wav_filename, tmp_features_filename)
+9 -6
View File
@@ -26,7 +26,7 @@ class LambdaReduce(LambdaBase):
return reduce(self.lambda_func,self.forward_prepare(input))
def load_estimation_model(inputfilename, outputfilename, begin, end):
def load_estimation_model(inputfilename, outputfilename, begin, end, csv_export=True):
with open(inputfilename, "r") as rf:
contents = rf.read()
contents = contents.split(",")
@@ -53,8 +53,11 @@ def load_estimation_model(inputfilename, outputfilename, begin, end):
model.load_state_dict(torch.load("em.pth"))
my_prediction = model.forward(data)
with open(outputfilename, "w") as wf:
wf.write("NAME,begin,end,F1,F2,F3,F4\n")
wf.write(name + "," + str(begin) + "," + str(end) + "," + \
str(1000 * float(my_prediction[0][0])) + "," + str(1000 * float(my_prediction[0][1])) + "," + \
str(1000 * float(my_prediction[0][2])) + "," + str(1000 * float(my_prediction[0][3])) + "\n")
if csv_export:
with open(outputfilename, "w") as wf:
wf.write("NAME,begin,end,F1,F2,F3,F4\n")
wf.write(name + "," + str(begin) + "," + str(end) + "," + \
str(1000 * float(my_prediction[0][0])) + "," + str(1000 * float(my_prediction[0][1])) + "," + \
str(1000 * float(my_prediction[0][2])) + "," + str(1000 * float(my_prediction[0][3])) + "\n")
return my_prediction