.
This commit is contained in:
+5
-4
@@ -6,6 +6,7 @@ from helpers.utilities import *
|
||||
import shutil
|
||||
|
||||
|
||||
|
||||
def predict_from_times(wav_filename, preds_filename, begin, end):
|
||||
tmp_features_filename = tempfile._get_default_tempdir() + "/" + next(tempfile._get_candidate_names()) + ".txt"
|
||||
print tmp_features_filename
|
||||
@@ -38,16 +39,16 @@ def predict_from_textgrid(wav_filename, preds_filename, textgrid_filename, textg
|
||||
# run over all intervals in the tier
|
||||
for interval in textgrid[tier_index]:
|
||||
if re.search(r'\S', interval.mark()):
|
||||
tmp_features_filename = generate_tmp_filename()
|
||||
tmp_preds = generate_tmp_filename()
|
||||
tmp_features_filename = generate_tmp_filename("features")
|
||||
tmp_preds = generate_tmp_filename("preds")
|
||||
features.create_features(wav_filename, tmp_features_filename, interval.xmin(), interval.xmax())
|
||||
easy_call("th load_estimation_model.lua " + tmp_features_filename + ' ' + tmp_preds)
|
||||
csv_append_row(tmp_preds, preds_filename)
|
||||
else: # process first tier
|
||||
for interval in textgrid[0]:
|
||||
if re.search(r'\S', interval.mark()):
|
||||
tmp_features_filename = generate_tmp_filename()
|
||||
tmp_preds = generate_tmp_filename()
|
||||
tmp_features_filename = generate_tmp_filename("features")
|
||||
tmp_preds = generate_tmp_filename("preds")
|
||||
features.create_features(wav_filename, tmp_features_filename, interval.xmin(), interval.xmax())
|
||||
easy_call("th load_estimation_model.lua " + tmp_features_filename + ' ' + tmp_preds)
|
||||
csv_append_row(tmp_preds, preds_filename)
|
||||
|
||||
@@ -62,6 +62,7 @@ def logging_defaults(logging_level="INFO"):
|
||||
logging.basicConfig(level=logging_level, format='%(asctime)s.%(msecs)d [%(filename)s] %(levelname)s: %(message)s',
|
||||
datefmt='%H:%M:%S')
|
||||
|
||||
|
||||
def num_lines(filename):
|
||||
lines = 0
|
||||
for _ in open(filename, 'rU'):
|
||||
|
||||
Reference in New Issue
Block a user