#!/usr/bin/env python3
import argparse
from pathlib import Path
from faster_whisper import WhisperModel, BatchedInferencePipeline
# import nemo.collections.asr as nemo_asr
# asr_model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained(model_name="nvidia/parakeet-tdt-1.1b")
# import nemo.collections.asr.models.rnnt_bpe_models.EncDecRNNTBPEModel
# asr_model.transcribe


# model_name = 'deepdml/faster-whisper-large-v3-turbo-ct2'
model_name = 'large-v3'
# model_name = 'medium.en'
m = WhisperModel(model_name, device="cuda", compute_type="float16")
model = BatchedInferencePipeline(model=m)


def format_time(seconds):
    minutes, seconds = divmod(seconds, 60)
    hours, minutes = divmod(minutes, 60)
    milliseconds = (seconds - int(seconds)) * 1000
    return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d},{int(milliseconds):03d}"


def transcribe(input_file: Path, lang: str):
    ouf = input_file.with_suffix('.srt')
    if ouf.exists():
        print(f"Output file {ouf} already exists. Skipping transcription.")
        return

    # Remove task="translate" if you want the original language
    segments, info = model.transcribe(input_file, beam_size=1, batch_size=8, 
                                    #   chunk_length=10, 
                                      without_timestamps=False,
                                      task="transcribe", vad_filter=True, language=lang)

    print(f"Transcribing file {input_file}")
    print(f"Detected language '{info.language}' with probability {info.language_probability:.2f}")

    # with ouf.open('w', encoding='utf-8') as srt_file:
    out = ""
    for seg in segments:
        start_time = format_time(seg.start)
        end_time = format_time(seg.end)
        line_out = f"{seg.id + 1}\n{start_time} --> {end_time}\n{seg.text.lstrip()}\n\n"
        print(line_out)
        out += line_out
    
    ouf.write_text(out)
    print(f"Transcription saved to {ouf}")


def main():
    parser = argparse.ArgumentParser(description="Transcribe audio from a video file and generate an SRT file.")
    # parser.add_argument("input_file", help="Path to the video file for transcription")
    parser.add_argument("input_file", nargs="+", help="Path to the video file for transcription")
    parser.add_argument("-l", "--lang", default=None, help="Language code for transcription (e.g. 'en')")
    args = parser.parse_args()
    for file in args.input_file:
        transcribe(Path(file), args.lang)


if __name__ == "__main__":
    main()
