diff --git a/backend/.gitignore b/backend/.gitignore new file mode 100644 index 0000000..29df4a8 --- /dev/null +++ b/backend/.gitignore @@ -0,0 +1,4 @@ +.DS_Store +.venv +venv +.env \ No newline at end of file diff --git a/backend/README.md b/backend/README.md new file mode 100644 index 0000000..f7da530 --- /dev/null +++ b/backend/README.md @@ -0,0 +1,35 @@ + +> For the demo, language ∈ {EN_JA, EN_ZH-CN} + +POST /ai-mark { + question: str, + user_answer: str, + expected: str, + chapter: str, + language: str +} +-> {correct: bool, reason: str} + +POST /human-chat { + target_name: str, + target_hobbies: list\[str], + history: list\[{from\_me: bool, msg: str}], + language: str +} +-> {msg: str} + +> Chat histories will be stored in localStorage + +POST /recognize (Body is an audio file) +-> {text: str} + +POST /character-chat { + character: str + history: list\[{from_me: bool, msg: str}] + language: str +} +-> {msg: str, audio: UUID} + +GET /audio/{UUID} + +Run `openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -days 365` to generate a self-signed certificate for HTTPS. \ No newline at end of file diff --git a/backend/api.py b/backend/api.py new file mode 100644 index 0000000..5cad18e --- /dev/null +++ b/backend/api.py @@ -0,0 +1,234 @@ +from fastapi import FastAPI, UploadFile, File, HTTPException +from fastapi.responses import StreamingResponse +from pydantic import BaseModel, UUID4 +from typing import List, Dict, Union, Optional +import uuid +from dotenv import load_dotenv +from openai import OpenAI +from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam, \ + ChatCompletionAssistantMessageParam +from tempfile import TemporaryDirectory +from starlette.responses import FileResponse +import json +from dataclasses import dataclass +import torch +from transformers import pipeline + +load_dotenv() + +app = FastAPI() +openai_client = OpenAI() + + +class AIMarkRequest(BaseModel): + question: str + user_answer: str + expected: str + chapter: str + language: str + + +class AIMarkResponse(BaseModel): + correct: bool + reason: str + + +@app.post("/ai-mark", response_model=AIMarkResponse) +def ai_mark(request: AIMarkRequest): + marking_system_prompt = f""" + You are a marking system for a language learning app. + You are marking a question from a chapter on {request.chapter} in {request.language}. + Please mark the user's answer as correct or incorrect and give a reason for your marking. + Output in the following JSON format: {{"correct": bool, "reason": str}} + """ + user_prompt = f""" + The question is: {request.question} + The user's answer is: {request.user_answer} + The expected answer is: {request.expected} + """ + completion = openai_client.chat.completions.create( + model="gpt-3.5-turbo-1106", + messages=[ + {"role": "system", "content": marking_system_prompt}, + {"role": "user", "content": user_prompt}, + ], + response_format={"type": "json_object"} + ) + response = json.loads(completion.choices[0].message.content) + if "correct" not in response or "reason" not in response: + raise HTTPException(status_code=500, detail="Invalid response from OpenAI") + return AIMarkResponse(correct=response["correct"], reason=response["reason"]) + + +CHAT_HISTORIES = {} + + +class ChatHistory: + def __init__(self, session_id: UUID4, system_prompt: str): + self.session_id = session_id + self.history = [] + self.add_message(ChatCompletionSystemMessageParam(content=system_prompt, role="system")) + + def add_human_message(self, message: str): + self.add_message(ChatCompletionUserMessageParam(content=message, role="user")) + + def add_message(self, message: Union[ + ChatCompletionAssistantMessageParam, ChatCompletionUserMessageParam, ChatCompletionSystemMessageParam]): + self.history.append(message) + + def generate_and_record_message(self): + completion = openai_client.chat.completions.create( + model="gpt-3.5-turbo-1106", + messages=self.history, + ) + message = completion.choices[0].message + self.add_message(ChatCompletionAssistantMessageParam(content=message.content, role="assistant")) + return message + + +class HumanChatCreationRequest(BaseModel): + user_name: str + user_hobbies: List[str] + target_name: str + target_hobbies: List[str] + language: str + + +class HumanChatCreationResponse(BaseModel): + session_id: UUID4 + + +@app.post("/human-chat", response_model=HumanChatCreationResponse) +def human_chat_creation(request: HumanChatCreationRequest): + session_id = uuid.uuid4() + system_prompt = f""" + Your name is {request.target_name} and are playing the role of a human chatting to a person named {request.user_name} who likes {request.user_hobbies}. + Your hobbies are {request.target_hobbies}. + They are a language learner trying to learn {request.language}. Please help them learn by chatting with them + in the language but do not act like a robot. Try to be as human as possible. + """ + CHAT_HISTORIES[session_id] = ChatHistory(session_id, system_prompt) + return HumanChatCreationResponse(session_id=session_id) + + +class HumanChatMessageRequest(BaseModel): + msg: str + + +class HumanChatMessageResponse(BaseModel): + msg: str + + +@app.post("/human-chat/{session_id}", response_model=HumanChatMessageResponse) +def human_chat_message(session_id: UUID4, request: HumanChatMessageRequest): + if session_id not in CHAT_HISTORIES: + raise HTTPException(status_code=404, detail="Session ID not found") + history = CHAT_HISTORIES[session_id] + history.add_human_message(request.msg) + message = history.generate_and_record_message() + return HumanChatMessageResponse(msg=message.content) + + +class CharacterChatCreationRequest(BaseModel): + character: str + user_name: str + language: str + + +class CharacterChatCreationResponse(BaseModel): + session_id: UUID4 + + +def get_character_prompt(character: str, user_name: str, language: str): + if character == "ash_pokemon": + return f""" + You are playing the role of Ash Ketchum, a Pokemon trainer chatting to a person named {user_name}. + They are a language learner trying to learn {language}. Please help them learn by chatting with them in the language + but do not act like a robot. Try to be as human as possible. + """ + + raise HTTPException(status_code=404, detail="Character not found") + + +@app.post("/character-chat", response_model=CharacterChatCreationResponse) +def character_chat(request: CharacterChatCreationRequest): + session_id = uuid.uuid4() + system_prompt = get_character_prompt(request.character, request.user_name, request.language) + CHAT_HISTORIES[session_id] = ChatHistory(session_id, system_prompt) + return HumanChatCreationResponse(session_id=session_id) + + +class CharacterChatMessageRequest(BaseModel): + msg: str + + +class CharacterChatMessageResponse(BaseModel): + msg: str + audio_id: UUID4 + + +PENDING_AUDIO = {} + + +@app.post("/character-chat/{session_id}", response_model=CharacterChatMessageResponse) +def character_chat_message(session_id: UUID4, request: CharacterChatMessageRequest): + if session_id not in CHAT_HISTORIES: + raise HTTPException(status_code=404, detail="Session ID not found") + history = CHAT_HISTORIES[session_id] + history.add_human_message(request.msg) + message = history.generate_and_record_message() + audio_id = uuid.uuid4() + PENDING_AUDIO[audio_id] = message.content + return CharacterChatMessageResponse(msg=message.content, audio_id=audio_id) + + +class RecognizeResponse(BaseModel): + text: str + + +pipe = pipeline( + "automatic-speech-recognition", + model="openai/whisper-large-v2", + torch_dtype=torch.float16, + device="mps", +) + + +@app.post("/recognize", response_model=RecognizeResponse) +def recognize(audio_file: UploadFile = File(...)): + with TemporaryDirectory() as tmpdir: + filename = tmpdir + "/audio.wav" + with open(filename, "wb") as f: + f.write(audio_file.file.read()) + outputs = pipe(filename, + chunk_length_s=30, + batch_size=1, + return_timestamps=True) + return RecognizeResponse(text=outputs["text"].strip()) + + +@app.get("/audio/{UUID}") +def get_audio(UUID: UUID4): + if UUID not in PENDING_AUDIO: + raise HTTPException(status_code=404, detail="Audio not found") + if isinstance(PENDING_AUDIO[UUID], str): + audio_text = PENDING_AUDIO[UUID] + del PENDING_AUDIO[UUID] + response = openai_client.audio.speech.create( + model="tts-1", + voice="alloy", + input=audio_text + ) + PENDING_AUDIO[UUID] = response + return StreamingResponse(PENDING_AUDIO[UUID].iter_bytes(chunk_size=1024), media_type="audio/mpeg") + + +@app.get("/") +def read_root(): + return FileResponse("test.html") + + +if __name__ == "__main__": + import uvicorn + + uvicorn.run(app, host="0.0.0.0", port=8000, ssl_keyfile="key.pem", ssl_certfile="cert.pem") diff --git a/backend/test.html b/backend/test.html new file mode 100644 index 0000000..e69de29