diff --git a/formants.py b/formants.py index 020ec2c..bdccbbf 100644 --- a/formants.py +++ b/formants.py @@ -5,14 +5,14 @@ from helpers.utilities import * from load_estimation_model import load_estimation_model -def predict_from_times(wav_filename, preds_filename, begin, end): +def predict_from_times(wav_filename, preds_filename, begin, end, csv_export=True): tmp_features_filename = "temp/" + next(tempfile._get_candidate_names()) + ".txt" print("Input Array Path: " + tmp_features_filename) if begin > 0.0 or end > 0.0: print(wav_filename + " interval " + str(begin) + "-" + str(end) + ":") features.create_features(wav_filename, tmp_features_filename, begin, end) - load_estimation_model(tmp_features_filename, preds_filename, begin, end) + load_estimation_model(tmp_features_filename, preds_filename, begin, end, csv_export=csv_export) #easy_call("luajit load_estimation_model.lua " + tmp_features_filename + ' ' + preds_filename) else: features.create_features(wav_filename, tmp_features_filename) diff --git a/load_estimation_model.py b/load_estimation_model.py index 5c04657..6a18f52 100644 --- a/load_estimation_model.py +++ b/load_estimation_model.py @@ -26,7 +26,7 @@ class LambdaReduce(LambdaBase): return reduce(self.lambda_func,self.forward_prepare(input)) -def load_estimation_model(inputfilename, outputfilename, begin, end): +def load_estimation_model(inputfilename, outputfilename, begin, end, csv_export=True): with open(inputfilename, "r") as rf: contents = rf.read() contents = contents.split(",") @@ -53,8 +53,11 @@ def load_estimation_model(inputfilename, outputfilename, begin, end): model.load_state_dict(torch.load("em.pth")) my_prediction = model.forward(data) - with open(outputfilename, "w") as wf: - wf.write("NAME,begin,end,F1,F2,F3,F4\n") - wf.write(name + "," + str(begin) + "," + str(end) + "," + \ - str(1000 * float(my_prediction[0][0])) + "," + str(1000 * float(my_prediction[0][1])) + "," + \ - str(1000 * float(my_prediction[0][2])) + "," + str(1000 * float(my_prediction[0][3])) + "\n") + if csv_export: + with open(outputfilename, "w") as wf: + wf.write("NAME,begin,end,F1,F2,F3,F4\n") + wf.write(name + "," + str(begin) + "," + str(end) + "," + \ + str(1000 * float(my_prediction[0][0])) + "," + str(1000 * float(my_prediction[0][1])) + "," + \ + str(1000 * float(my_prediction[0][2])) + "," + str(1000 * float(my_prediction[0][3])) + "\n") + + return my_prediction