Make it useable (#296)

* add api/static/* to .githignore

* add init_prompt to params so that we can set n_ctx to len(init_prompt) + params.n_ctx

* change get_prompt so it properly builds a instruction prompt from the recent history

- keeps the prompt below n_ctx
- goes backwards and adds only the most recent sentences
- always prepends the init_prompt

* ignore eventSource.onerror

I don't know why it does this, but this gets called locally for me and messes up the history, just doing nothing here works fine for me

* fix split on ! not *

* run black formatter on stream.py

* revert previous black format, just do the one change it wants manually

* third times the charm

---------

Co-authored-by: Juan Calderon-Perez <835733+gaby@users.noreply.github.com>
This commit is contained in:
Andreas Raster 2023-06-03 18:47:27 +02:00 committed by GitHub
parent 13daab6880
commit 6ae405179d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 63 additions and 16 deletions

3
.gitignore vendored
View File

@ -6,4 +6,5 @@ api/weights/**
!.env.sample
node_modules
__pycache__
__pycache__
api/static/*

View File

@ -1,3 +1,5 @@
from typing import Optional
from datetime import datetime
from uuid import uuid4
@ -25,6 +27,7 @@ class ChatParameters(BaseModel):
repeat_penalty: float
top_k: int
# stream: bool
init_prompt: Optional[str] = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
class Chat(BaseModel):

View File

@ -48,6 +48,7 @@ async def create_new_chat(
last_n_tokens_size=repeat_last_n,
repeat_penalty=repeat_penalty,
n_threads=n_threads,
init_prompt=init_prompt,
)
# create the chat
chat = Chat(params=params)
@ -157,14 +158,14 @@ def stream_ask_a_question(chat_id: str, prompt: str):
logger.debug(f"adding question {prompt}")
history.add_user_message(prompt)
prompt = get_prompt(history)
prompt = get_prompt(history, chat.params)
prompt += "### Response:\n"
logger.debug("creating Llama client")
try:
client = Llama(
model_path="/usr/src/app/weights/" + chat.params.model_path + ".bin",
n_ctx=chat.params.n_ctx,
n_ctx=len(chat.params.init_prompt) + chat.params.n_ctx,
n_threads=chat.params.n_threads,
last_n_tokens_size=chat.params.last_n_tokens_size,
)
@ -222,13 +223,13 @@ async def ask_a_question(chat_id: str, prompt: str):
history = RedisChatMessageHistory(chat.id)
history.add_user_message(prompt)
prompt = get_prompt(history)
prompt = get_prompt(history, chat.params)
prompt += "### Response:\n"
try:
client = Llama(
model_path="/usr/src/app/weights/" + chat.params.model_path + ".bin",
n_ctx=chat.params.n_ctx,
n_ctx=len(chat.params.init_prompt) + chat.params.n_ctx,
n_threads=chat.params.n_threads,
last_n_tokens_size=chat.params.last_n_tokens_size,
)

View File

@ -1,3 +1,5 @@
import re
from typing import Any, Dict, List, Union
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
@ -47,21 +49,60 @@ class ChainRedisHandler(StreamingStdOutCallbackHandler):
"""Run when LLM errors."""
def get_prompt(history: RedisChatMessageHistory):
def get_prompt(history: RedisChatMessageHistory, params):
"""
Get the prompt for the LLM from the chat history.
"""
prompt = ""
for message in history.messages:
def tokenize_content(content):
split_content = list(filter(None, re.split("([^\\n\.\?!]+[\\n\.\?! ]+)", content)))
split_content.reverse()
return split_content
def sum_prompts_lengths(prompts):
prompt_length = 0
for s in prompts:
prompt_length += len(s)
return prompt_length
dupes = {}
prompts = []
messages = history.messages.copy()
messages.reverse()
for message in messages:
if message.content in dupes:
continue
dupes[message.content] = True
instruction = ""
match message.type:
case "human":
prompt += "### Instruction:\n" + message.content + "\n"
instruction = "### Instruction: "
case "ai":
prompt += "### Response:\n" + message.content + "\n"
case "system":
prompt += "### System:\n" + message.content + "\n"
instruction = "### Response: "
# case "system":
# instruction = "### System: "
case _:
pass
continue
return prompt
stop = False
next_prompt = ""
tokens = tokenize_content(message.content)
prompt_length = sum_prompts_lengths(prompts)
for token in tokens:
if prompt_length + len(next_prompt) + len(token) < params.n_ctx:
next_prompt = token + next_prompt
else:
stop = True
if len(next_prompt) > 0:
prompts.append(instruction + next_prompt + "\n")
if stop:
break
message_prompt = ""
prompts.reverse()
for next_prompt in prompts:
message_prompt += next_prompt
final_prompt = params.init_prompt + "\n" + message_prompt[: params.n_ctx]
return final_prompt

View File

@ -65,9 +65,10 @@
});
eventSource.onerror = async (error) => {
console.log("error", error);
eventSource.close();
history[history.length - 1].data.content = "A server error occurred.";
await invalidate("/api/chat/" + $page.params.id);
//history[history.length - 1].data.content = "A server error occurred.";
//await invalidate("/api/chat/" + $page.params.id);
};
}
}