return predictions
This commit is contained in:
+3
-2
@@ -9,17 +9,18 @@ def predict_from_times(wav_filename, preds_filename, begin, end, csv_export=True
|
||||
tmp_features_filename = "temp/" + next(tempfile._get_candidate_names()) + ".txt"
|
||||
print("Input Array Path: " + tmp_features_filename)
|
||||
|
||||
predictions = None
|
||||
if begin > 0.0 or end > 0.0:
|
||||
print(wav_filename + " interval " + str(begin) + "-" + str(end) + ":")
|
||||
features.create_features(wav_filename, tmp_features_filename, begin, end)
|
||||
load_estimation_model(tmp_features_filename, preds_filename, begin, end, csv_export=csv_export)
|
||||
predictions = load_estimation_model(tmp_features_filename, preds_filename, begin, end, csv_export=csv_export)
|
||||
#easy_call("luajit load_estimation_model.lua " + tmp_features_filename + ' ' + preds_filename)
|
||||
else:
|
||||
features.create_features(wav_filename, tmp_features_filename)
|
||||
easy_call("luajit load_tracking_model.lua " + tmp_features_filename + ' ' + preds_filename)
|
||||
|
||||
delete_temp_files()
|
||||
|
||||
return predictions
|
||||
|
||||
def predict_from_textgrid(wav_filename, preds_filename, textgrid_filename, textgrid_tier):
|
||||
print(wav_filename)
|
||||
|
||||
Reference in New Issue
Block a user