Compare commits

...

1 Commits

Author SHA1 Message Date
Juan Calderon-Perez
a6a69a17f7
Add support for dynamic threads 2023-09-19 23:06:35 -04:00
7 changed files with 5 additions and 43 deletions

View File

@ -14,7 +14,6 @@ class ChatParameters(BaseModel):
# logits_all: bool
# vocab_only: bool
# use_mlock: bool
n_threads: int
# n_batch: int
last_n_tokens_size: int
max_tokens: int

View File

@ -1,3 +1,5 @@
import os
from typing import Optional
from fastapi import APIRouter
from langchain.memory import RedisChatMessageHistory
@ -28,7 +30,6 @@ async def create_new_chat(
repeat_last_n: int = 64,
repeat_penalty: float = 1.3,
init_prompt: str = "Below is an instruction that describes a task. Write a response that appropriately completes the request.",
n_threads: int = 4,
):
try:
client = Llama(
@ -51,7 +52,7 @@ async def create_new_chat(
n_gpu_layers=gpu_layers,
last_n_tokens_size=repeat_last_n,
repeat_penalty=repeat_penalty,
n_threads=n_threads,
n_threads=len(os.sched_getaffinity(0)),
init_prompt=init_prompt,
)
# create the chat

View File

@ -119,7 +119,6 @@ class LlamaCpp(LLM):
"stop_sequences": self.stop_sequences,
"repeat_penalty": self.repeat_penalty,
"top_k": self.top_k,
"n_threads": self.n_threads,
"n_ctx": self.n_ctx,
"n_gpu_layers": self.n_gpu_layers,
"n_parts": self.n_parts,

View File

@ -104,7 +104,7 @@
`/api/chat/?model=${dataCht.params.model_path}&temperature=${dataCht.params.temperature}&top_k=${dataCht.params.top_k}` +
`&top_p=${dataCht.params.top_p}&max_length=${dataCht.params.max_tokens}&context_window=${dataCht.params.n_ctx}` +
`&repeat_last_n=${dataCht.params.last_n_tokens_size}&repeat_penalty=${dataCht.params.repeat_penalty}` +
`&n_threads=${dataCht.params.n_threads}&init_prompt=${dataCht.history[0].data.content}` +
`&init_prompt=${dataCht.history[0].data.content}` +
`&gpu_layers=${dataCht.params.n_gpu_layers}`,
{

View File

@ -23,7 +23,6 @@
let init_prompt =
"Below is an instruction that describes a task. Write a response that appropriately completes the request.";
let n_threads = 4;
let context_window = 2048;
let gpu_layers = 0;
@ -226,20 +225,6 @@
{/each}
</select>
</div>
<div
class="tooltip flex flex-col"
data-tip="Number of threads to run LLaMA on."
>
<label for="n_threads" class="label-text pb-1">n_threads</label>
<input
class="input-bordered input w-full max-w-xs"
name="n_threads"
type="number"
bind:value={n_threads}
min="0"
max="64"
/>
</div>
<div
class="tooltip flex flex-col"
data-tip="The weight of the penalty to avoid repeating the last repeat_last_n tokens."

View File

@ -125,7 +125,7 @@
`/api/chat/?model=${data.chat.params.model_path}&temperature=${data.chat.params.temperature}&top_k=${data.chat.params.top_k}` +
`&top_p=${data.chat.params.top_p}&max_length=${data.chat.params.max_tokens}&context_window=${data.chat.params.n_ctx}` +
`&repeat_last_n=${data.chat.params.last_n_tokens_size}&repeat_penalty=${data.chat.params.repeat_penalty}` +
`&n_threads=${data.chat.params.n_threads}&init_prompt=${data.chat.history[0].data.content}` +
`&init_prompt=${data.chat.history[0].data.content}` +
`&gpu_layers=${data.chat.params.n_gpu_layers}`,
{
@ -337,27 +337,6 @@
{data.chat.params.n_ctx}/{data.chat.params.max_tokens}
</span>
</div>
{#if data.chat.params.n_threads > 0}
<div class="pl-4 hidden sm:flex flex-row items-center justify-center">
<svg
xmlns="http://www.w3.org/2000/svg"
fill="none"
viewBox="0 0 24 24"
stroke-width="1.5"
stroke="currentColor"
class="w-4 h-4"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M8.25 3v1.5M4.5 8.25H3m18 0h-1.5M4.5 12H3m18 0h-1.5m-15 3.75H3m18 0h-1.5M8.25 19.5V21M12 3v1.5m0 15V21m3.75-18v1.5m0 15V21m-9-1.5h10.5a2.25 2.25 0 002.25-2.25V6.75a2.25 2.25 0 00-2.25-2.25H6.75A2.25 2.25 0 004.5 6.75v10.5a2.25 2.25 0 002.25 2.25zm.75-12h9v9h-9v-9z"
/>
</svg>
<span class="ml-2 inline-block text-center text-sm font-semibold">
{data.chat.params.n_threads}
</span>
</div>
{/if}
{#if data.chat.params.n_gpu_layers > 0}
<div class="pl-4 hidden sm:flex flex-row items-center justify-center">
<svg

View File

@ -15,7 +15,6 @@ interface Params {
model_path: string;
n_ctx: number;
n_gpu_layers: number;
n_threads: number;
last_n_tokens_size: number;
max_tokens: number;
temperature: number;