diff --git a/.gitignore b/.gitignore index 894f981..98688ce 100644 --- a/.gitignore +++ b/.gitignore @@ -6,4 +6,5 @@ api/weights/** !.env.sample node_modules -__pycache__ \ No newline at end of file +__pycache__ +api/static/* diff --git a/api/src/serge/models/chat.py b/api/src/serge/models/chat.py index a1978c6..09dfd7e 100644 --- a/api/src/serge/models/chat.py +++ b/api/src/serge/models/chat.py @@ -1,3 +1,5 @@ +from typing import Optional + from datetime import datetime from uuid import uuid4 @@ -25,6 +27,7 @@ class ChatParameters(BaseModel): repeat_penalty: float top_k: int # stream: bool + init_prompt: Optional[str] = "Below is an instruction that describes a task. Write a response that appropriately completes the request." class Chat(BaseModel): diff --git a/api/src/serge/routers/chat.py b/api/src/serge/routers/chat.py index df8c18e..c91d903 100644 --- a/api/src/serge/routers/chat.py +++ b/api/src/serge/routers/chat.py @@ -48,6 +48,7 @@ async def create_new_chat( last_n_tokens_size=repeat_last_n, repeat_penalty=repeat_penalty, n_threads=n_threads, + init_prompt=init_prompt, ) # create the chat chat = Chat(params=params) @@ -157,14 +158,14 @@ def stream_ask_a_question(chat_id: str, prompt: str): logger.debug(f"adding question {prompt}") history.add_user_message(prompt) - prompt = get_prompt(history) + prompt = get_prompt(history, chat.params) prompt += "### Response:\n" logger.debug("creating Llama client") try: client = Llama( model_path="/usr/src/app/weights/" + chat.params.model_path + ".bin", - n_ctx=chat.params.n_ctx, + n_ctx=len(chat.params.init_prompt) + chat.params.n_ctx, n_threads=chat.params.n_threads, last_n_tokens_size=chat.params.last_n_tokens_size, ) @@ -222,13 +223,13 @@ async def ask_a_question(chat_id: str, prompt: str): history = RedisChatMessageHistory(chat.id) history.add_user_message(prompt) - prompt = get_prompt(history) + prompt = get_prompt(history, chat.params) prompt += "### Response:\n" try: client = Llama( model_path="/usr/src/app/weights/" + chat.params.model_path + ".bin", - n_ctx=chat.params.n_ctx, + n_ctx=len(chat.params.init_prompt) + chat.params.n_ctx, n_threads=chat.params.n_threads, last_n_tokens_size=chat.params.last_n_tokens_size, ) diff --git a/api/src/serge/utils/stream.py b/api/src/serge/utils/stream.py index 85a0465..ee84204 100644 --- a/api/src/serge/utils/stream.py +++ b/api/src/serge/utils/stream.py @@ -1,3 +1,5 @@ +import re + from typing import Any, Dict, List, Union from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler @@ -47,21 +49,60 @@ class ChainRedisHandler(StreamingStdOutCallbackHandler): """Run when LLM errors.""" -def get_prompt(history: RedisChatMessageHistory): +def get_prompt(history: RedisChatMessageHistory, params): """ Get the prompt for the LLM from the chat history. """ - prompt = "" - for message in history.messages: + def tokenize_content(content): + split_content = list(filter(None, re.split("([^\\n\.\?!]+[\\n\.\?! ]+)", content))) + split_content.reverse() + return split_content + + def sum_prompts_lengths(prompts): + prompt_length = 0 + for s in prompts: + prompt_length += len(s) + return prompt_length + + dupes = {} + prompts = [] + messages = history.messages.copy() + messages.reverse() + for message in messages: + if message.content in dupes: + continue + dupes[message.content] = True + + instruction = "" match message.type: case "human": - prompt += "### Instruction:\n" + message.content + "\n" + instruction = "### Instruction: " case "ai": - prompt += "### Response:\n" + message.content + "\n" - case "system": - prompt += "### System:\n" + message.content + "\n" + instruction = "### Response: " + # case "system": + # instruction = "### System: " case _: - pass + continue - return prompt + stop = False + next_prompt = "" + tokens = tokenize_content(message.content) + prompt_length = sum_prompts_lengths(prompts) + for token in tokens: + if prompt_length + len(next_prompt) + len(token) < params.n_ctx: + next_prompt = token + next_prompt + else: + stop = True + if len(next_prompt) > 0: + prompts.append(instruction + next_prompt + "\n") + if stop: + break + + message_prompt = "" + prompts.reverse() + for next_prompt in prompts: + message_prompt += next_prompt + + final_prompt = params.init_prompt + "\n" + message_prompt[: params.n_ctx] + return final_prompt diff --git a/web/src/routes/chat/[id]/+page.svelte b/web/src/routes/chat/[id]/+page.svelte index 2308c48..489c2dc 100644 --- a/web/src/routes/chat/[id]/+page.svelte +++ b/web/src/routes/chat/[id]/+page.svelte @@ -65,9 +65,10 @@ }); eventSource.onerror = async (error) => { + console.log("error", error); eventSource.close(); - history[history.length - 1].data.content = "A server error occurred."; - await invalidate("/api/chat/" + $page.params.id); + //history[history.length - 1].data.content = "A server error occurred."; + //await invalidate("/api/chat/" + $page.params.id); }; } }