diff --git a/formants.py b/formants.py index bdccbbf..719f721 100644 --- a/formants.py +++ b/formants.py @@ -9,17 +9,18 @@ def predict_from_times(wav_filename, preds_filename, begin, end, csv_export=True tmp_features_filename = "temp/" + next(tempfile._get_candidate_names()) + ".txt" print("Input Array Path: " + tmp_features_filename) + predictions = None if begin > 0.0 or end > 0.0: print(wav_filename + " interval " + str(begin) + "-" + str(end) + ":") features.create_features(wav_filename, tmp_features_filename, begin, end) - load_estimation_model(tmp_features_filename, preds_filename, begin, end, csv_export=csv_export) + predictions = load_estimation_model(tmp_features_filename, preds_filename, begin, end, csv_export=csv_export) #easy_call("luajit load_estimation_model.lua " + tmp_features_filename + ' ' + preds_filename) else: features.create_features(wav_filename, tmp_features_filename) easy_call("luajit load_tracking_model.lua " + tmp_features_filename + ' ' + preds_filename) delete_temp_files() - + return predictions def predict_from_textgrid(wav_filename, preds_filename, textgrid_filename, textgrid_tier): print(wav_filename)