Make it useable (#296)
* add api/static/* to .githignore * add init_prompt to params so that we can set n_ctx to len(init_prompt) + params.n_ctx * change get_prompt so it properly builds a instruction prompt from the recent history - keeps the prompt below n_ctx - goes backwards and adds only the most recent sentences - always prepends the init_prompt * ignore eventSource.onerror I don't know why it does this, but this gets called locally for me and messes up the history, just doing nothing here works fine for me * fix split on ! not * * run black formatter on stream.py * revert previous black format, just do the one change it wants manually * third times the charm --------- Co-authored-by: Juan Calderon-Perez <835733+gaby@users.noreply.github.com>
This commit is contained in:
parent
13daab6880
commit
6ae405179d
3
.gitignore
vendored
3
.gitignore
vendored
@ -6,4 +6,5 @@ api/weights/**
|
||||
!.env.sample
|
||||
|
||||
node_modules
|
||||
__pycache__
|
||||
__pycache__
|
||||
api/static/*
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
from datetime import datetime
|
||||
from uuid import uuid4
|
||||
|
||||
@ -25,6 +27,7 @@ class ChatParameters(BaseModel):
|
||||
repeat_penalty: float
|
||||
top_k: int
|
||||
# stream: bool
|
||||
init_prompt: Optional[str] = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
|
||||
|
||||
|
||||
class Chat(BaseModel):
|
||||
|
||||
@ -48,6 +48,7 @@ async def create_new_chat(
|
||||
last_n_tokens_size=repeat_last_n,
|
||||
repeat_penalty=repeat_penalty,
|
||||
n_threads=n_threads,
|
||||
init_prompt=init_prompt,
|
||||
)
|
||||
# create the chat
|
||||
chat = Chat(params=params)
|
||||
@ -157,14 +158,14 @@ def stream_ask_a_question(chat_id: str, prompt: str):
|
||||
logger.debug(f"adding question {prompt}")
|
||||
|
||||
history.add_user_message(prompt)
|
||||
prompt = get_prompt(history)
|
||||
prompt = get_prompt(history, chat.params)
|
||||
prompt += "### Response:\n"
|
||||
|
||||
logger.debug("creating Llama client")
|
||||
try:
|
||||
client = Llama(
|
||||
model_path="/usr/src/app/weights/" + chat.params.model_path + ".bin",
|
||||
n_ctx=chat.params.n_ctx,
|
||||
n_ctx=len(chat.params.init_prompt) + chat.params.n_ctx,
|
||||
n_threads=chat.params.n_threads,
|
||||
last_n_tokens_size=chat.params.last_n_tokens_size,
|
||||
)
|
||||
@ -222,13 +223,13 @@ async def ask_a_question(chat_id: str, prompt: str):
|
||||
history = RedisChatMessageHistory(chat.id)
|
||||
history.add_user_message(prompt)
|
||||
|
||||
prompt = get_prompt(history)
|
||||
prompt = get_prompt(history, chat.params)
|
||||
prompt += "### Response:\n"
|
||||
|
||||
try:
|
||||
client = Llama(
|
||||
model_path="/usr/src/app/weights/" + chat.params.model_path + ".bin",
|
||||
n_ctx=chat.params.n_ctx,
|
||||
n_ctx=len(chat.params.init_prompt) + chat.params.n_ctx,
|
||||
n_threads=chat.params.n_threads,
|
||||
last_n_tokens_size=chat.params.last_n_tokens_size,
|
||||
)
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
import re
|
||||
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
@ -47,21 +49,60 @@ class ChainRedisHandler(StreamingStdOutCallbackHandler):
|
||||
"""Run when LLM errors."""
|
||||
|
||||
|
||||
def get_prompt(history: RedisChatMessageHistory):
|
||||
def get_prompt(history: RedisChatMessageHistory, params):
|
||||
"""
|
||||
Get the prompt for the LLM from the chat history.
|
||||
"""
|
||||
prompt = ""
|
||||
|
||||
for message in history.messages:
|
||||
def tokenize_content(content):
|
||||
split_content = list(filter(None, re.split("([^\\n\.\?!]+[\\n\.\?! ]+)", content)))
|
||||
split_content.reverse()
|
||||
return split_content
|
||||
|
||||
def sum_prompts_lengths(prompts):
|
||||
prompt_length = 0
|
||||
for s in prompts:
|
||||
prompt_length += len(s)
|
||||
return prompt_length
|
||||
|
||||
dupes = {}
|
||||
prompts = []
|
||||
messages = history.messages.copy()
|
||||
messages.reverse()
|
||||
for message in messages:
|
||||
if message.content in dupes:
|
||||
continue
|
||||
dupes[message.content] = True
|
||||
|
||||
instruction = ""
|
||||
match message.type:
|
||||
case "human":
|
||||
prompt += "### Instruction:\n" + message.content + "\n"
|
||||
instruction = "### Instruction: "
|
||||
case "ai":
|
||||
prompt += "### Response:\n" + message.content + "\n"
|
||||
case "system":
|
||||
prompt += "### System:\n" + message.content + "\n"
|
||||
instruction = "### Response: "
|
||||
# case "system":
|
||||
# instruction = "### System: "
|
||||
case _:
|
||||
pass
|
||||
continue
|
||||
|
||||
return prompt
|
||||
stop = False
|
||||
next_prompt = ""
|
||||
tokens = tokenize_content(message.content)
|
||||
prompt_length = sum_prompts_lengths(prompts)
|
||||
for token in tokens:
|
||||
if prompt_length + len(next_prompt) + len(token) < params.n_ctx:
|
||||
next_prompt = token + next_prompt
|
||||
else:
|
||||
stop = True
|
||||
if len(next_prompt) > 0:
|
||||
prompts.append(instruction + next_prompt + "\n")
|
||||
if stop:
|
||||
break
|
||||
|
||||
message_prompt = ""
|
||||
prompts.reverse()
|
||||
for next_prompt in prompts:
|
||||
message_prompt += next_prompt
|
||||
|
||||
final_prompt = params.init_prompt + "\n" + message_prompt[: params.n_ctx]
|
||||
return final_prompt
|
||||
|
||||
@ -65,9 +65,10 @@
|
||||
});
|
||||
|
||||
eventSource.onerror = async (error) => {
|
||||
console.log("error", error);
|
||||
eventSource.close();
|
||||
history[history.length - 1].data.content = "A server error occurred.";
|
||||
await invalidate("/api/chat/" + $page.params.id);
|
||||
//history[history.length - 1].data.content = "A server error occurred.";
|
||||
//await invalidate("/api/chat/" + $page.params.id);
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user