Files
VITS-fast-fine-tuning/api.py
T
2024-07-14 06:19:04 +08:00

133 lines
4.0 KiB
Python

import argparse
import io
import json
from json import JSONDecodeError
from pathlib import Path
from urllib.parse import parse_qs
import soundfile as sf
import torch
import uvicorn
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse
from hypy_utils.logging_utils import setup_logger
from starlette.middleware.cors import CORSMiddleware
from torch import no_grad, LongTensor
import commons
import utils
from models import SynthesizerTrn
from text import text_to_sequence
log = setup_logger()
app = FastAPI()
device = "cuda:0" if torch.cuda.is_available() else "cpu"
language_marks = {
"日本語": "[JA]",
"简体中文": "[ZH]",
"English": "[EN]",
"Mix": "",
}
# Allow all CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def get_text(text: str, is_symbol: bool):
text_norm = text_to_sequence(text, hps.symbols, [] if is_symbol else hps.data.text_cleaners)
if hps.data.add_blank:
text_norm = commons.intersperse(text_norm, 0)
text_norm = LongTensor(text_norm)
return text_norm
def tts_fn(text: str, speaker: str, language: str, speed: float):
if language is not None:
text = language_marks[language] + text + language_marks[language]
speaker_id = speaker_ids[speaker]
stn_tst = get_text(text, False)
with no_grad():
x_tst = stn_tst.unsqueeze(0).to(device)
x_tst_lengths = LongTensor([stn_tst.size(0)]).to(device)
sid = LongTensor([speaker_id]).to(device)
audio = model.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=.667, noise_scale_w=0.8,
length_scale=1.0 / speed)[0][0, 0].data.cpu().float().numpy()
del stn_tst, x_tst, x_tst_lengths, sid
return audio
@app.get("/tts/options")
async def get_options():
return {"speakers": list(speaker_ids.keys()), "languages": list(language_marks.keys())}
@app.post("/tts")
async def generate(request: Request):
body = (await request.body()).decode()
# Try parse json
if body.startswith('{'):
try:
data = json.loads(body)
except JSONDecodeError as e:
raise HTTPException(status_code=400, detail="Invalid JSON format")
# Try parse x-www-form-urlencoded
else:
data = parse_qs(body)
data = {k: v[0] for k, v in data.items()}
log.info(data)
text = data.get('text').strip().replace("\n", " ")
speaker = data.get('speaker')
language = data.get('language', '日本語')
speed = data.get('speed', 1.0)
if len(text) > 200:
raise HTTPException(status_code=400, detail="TL;DR")
if not text or not speaker or language not in language_marks:
raise HTTPException(status_code=400, detail="Invalid speaker or language (please check /tts/options)")
audio = tts_fn(text, speaker, language, speed)
audio_io = io.BytesIO()
# sf.write(audio_io, audio, hps.data.sampling_rate, format='OGG')
# Since safari don't support ogg, use mp3 instead
sf.write(audio_io, audio, hps.data.sampling_rate, format='MP3')
audio_io.seek(0)
return StreamingResponse(audio_io, media_type='audio/mpeg',
headers={'Content-Disposition': 'attachment; filename="output.mp3"'})
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-d", default="./OUTPUT_MODEL",
help="directory to your fine-tuned model (contains G_latest.pth and config.json)")
args = parser.parse_args()
d_config = Path(args.d) / "config.json"
d_model = Path(args.d) / "G_latest.pth"
hps = utils.get_hparams_from_file(d_config)
model = SynthesizerTrn(
len(hps.symbols),
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model).to(device)
_ = model.eval()
utils.load_checkpoint(d_model, model, None)
speaker_ids = hps.speakers
uvicorn.run(app, host='0.0.0.0', port=27519)