diff --git a/load_estimation_model.py b/load_estimation_model.py index 6a18f52..51b8256 100644 --- a/load_estimation_model.py +++ b/load_estimation_model.py @@ -53,11 +53,16 @@ def load_estimation_model(inputfilename, outputfilename, begin, end, csv_export= model.load_state_dict(torch.load("em.pth")) my_prediction = model.forward(data) + my_prediction[0][0] = 1000 * float(my_prediction[0][0]) + my_prediction[0][1] = 1000 * float(my_prediction[0][1]) + my_prediction[0][2] = 1000 * float(my_prediction[0][2]) + my_prediction[0][3] = 1000 * float(my_prediction[0][3]) + if csv_export: with open(outputfilename, "w") as wf: wf.write("NAME,begin,end,F1,F2,F3,F4\n") wf.write(name + "," + str(begin) + "," + str(end) + "," + \ - str(1000 * float(my_prediction[0][0])) + "," + str(1000 * float(my_prediction[0][1])) + "," + \ - str(1000 * float(my_prediction[0][2])) + "," + str(1000 * float(my_prediction[0][3])) + "\n") + str(my_prediction[0][0]) + "," + str(my_prediction[0][1]) + "," + \ + str(my_prediction[0][2]) + "," + str(my_prediction[0][3]) + "\n") return my_prediction