API Refactor & Model Manager (#101)

* API refactoring

* delete partially downloaded files on startup

* remove unused deps
This commit is contained in:
Nathan Sarrazin 2023-03-28 23:56:41 +02:00 committed by GitHub
parent bad45112c2
commit b5c423fe59
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 670 additions and 342 deletions

View File

@ -12,7 +12,7 @@ ENV TZ=Europe/Amsterdam
WORKDIR /usr/src/app
COPY --chmod=0755 compile.sh .
COPY --chmod=0755 scripts/compile.sh .
# Install MongoDB and necessary tools
RUN apt update && \
@ -23,11 +23,7 @@ RUN apt update && \
apt-get install -y mongodb-org && \
git clone https://github.com/ggerganov/llama.cpp.git --branch master-d5850c5
# copy & install python reqs
COPY ./api/requirements.txt api/requirements.txt
RUN pip install --upgrade pip && \
pip install --no-cache-dir -r ./api/requirements.txt
RUN pip install --upgrade pip
# Dev environment
FROM base as dev
@ -38,11 +34,7 @@ COPY --from=node_base /usr/local /usr/local
COPY ./web/package*.json ./
RUN npm ci
# Copy the rest of the project files
COPY web /usr/src/app/web
COPY ./api /usr/src/app/api
COPY --chmod=0755 dev.sh /usr/src/app/dev.sh
COPY --chmod=0755 scripts/dev.sh /usr/src/app/dev.sh
CMD ./dev.sh
# Build frontend
@ -64,6 +56,7 @@ WORKDIR /usr/src/app
COPY --from=frontend_builder /usr/src/app/web/build /usr/src/app/api/static/
COPY ./api /usr/src/app/api
COPY --chmod=0755 deploy.sh /usr/src/app/deploy.sh
RUN pip install ./api
COPY --chmod=0755 scripts/deploy.sh /usr/src/app/deploy.sh
CMD ./deploy.sh

View File

@ -22,7 +22,6 @@ cd serge
docker compose up --build -d
docker compose exec serge python3 /usr/src/app/api/utils/download.py tokenizer 7B
```
Please note that the models occupy the following storage space: 7B requires 4.21G, 13B requires 8.14G, and 30B requires 20.3G
#### Windows

160
api/.gitignore vendored Normal file
View File

@ -0,0 +1,160 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

2
api/poetry.toml Normal file
View File

@ -0,0 +1,2 @@
[virtualenvs]
create = false

66
api/pyproject.toml Normal file
View File

@ -0,0 +1,66 @@
[tool.poetry]
name = "serge"
description = "Serge API package"
version = "0.1.0"
license = "MIT"
authors = [
"Nathan Sarrazin <contact@nsarrazin.com>"
]
packages = [
{ include = "serge", from = "src" }
]
homepage = "https://serge.chat/"
repository = "https://github.com/nsarrazin/serge"
include = [{path="src"}]
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
[tool.poetry.dependencies]
python=">=3.10,<4.0"
asyncio = "^3.4.3"
packaging = "^23.0"
pydantic = "^1.10.7"
pymongo = "^4.3.3"
python-dotenv = "^1.0.0"
python-multipart = "^0.0.6"
pyyaml = "^6.0"
rfc3986 = "^2.0.0"
sentencepiece = "^0.1.97"
sniffio = "^1.3.0"
sse-starlette = "^1.3.3"
starlette = "^0.26.1"
toml = "^0.10.2"
tqdm = "^4.65.0"
typing-extensions = "^4.5.0"
ujson = "^5.7.0"
urllib3 = "^1.26.15"
uvicorn = "^0.21.1"
uvloop = "^0.17.0"
watchfiles = "^0.19.0"
websockets = "^10.4"
anyio = "^3.6.2"
certifi = "^2022.12.7"
charset-normalizer = "^3.1.0"
click = "^8.1.3"
email-validator = "^1.3.1"
fastapi = "^0.95.0"
filelock = "^3.10.7"
h11 = "^0.14.0"
httpcore = "^0.17.0"
httptools = "^0.5.0"
huggingface-hub = "^0.13.3"
idna = "^3.4"
itsdangerous = "^2.1.2"
jinja2 = "^3.1.2"
markupsafe = "^2.1.2"
motor = "^3.1.1"
orjson = "^3.8.8"
beanie = "^1.17.0"
dnspython = "^2.3.0"
lazy-model = "^0.0.5"
requests = "^2.28.2"

View File

@ -1,44 +0,0 @@
anyio==3.6.2
asyncio==3.4.3
beanie==1.17.0
certifi==2022.12.7
charset-normalizer==3.1.0
click==8.1.3
dnspython==2.3.0
email-validator==1.3.1
fastapi==0.95.0
filelock==3.10.2
h11==0.14.0
httpcore==0.16.3
httptools==0.5.0
httpx==0.23.3
huggingface-hub==0.13.3
idna==3.4
itsdangerous==2.1.2
Jinja2==3.1.2
lazy-model==0.0.5
MarkupSafe==2.1.2
motor==3.1.1
orjson==3.8.8
packaging==23.0
psutil==5.9.4
pydantic==1.10.7
pymongo==4.3.3
python-dotenv==1.0.0
python-multipart==0.0.6
PyYAML==6.0
requests==2.28.2
rfc3986==1.5.0
sentencepiece==0.1.97
sniffio==1.3.0
sse-starlette==1.3.3
starlette==0.26.1
toml==0.10.2
tqdm==4.65.0
typing_extensions==4.5.0
ujson==5.7.0
urllib3==1.26.15
uvicorn==0.21.1
uvloop==0.17.0
watchfiles==0.18.1
websockets==10.4

View File

@ -0,0 +1,8 @@
from .utils.convert import convert_all
import anyio
async def convert_model_files():
await anyio.to_thread.run_sync(
convert_all, "/usr/src/app/weights/", "/usr/src/app/weights/tokenizer.model"
)

105
api/src/serge/main.py Normal file
View File

@ -0,0 +1,105 @@
import asyncio
import logging
import os
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from starlette.responses import FileResponse
from serge.routers.chat import chat_router
from serge.routers.model import model_router
from serge.utils.initiate_database import initiate_database, Settings
from serge.dependencies import convert_model_files
# Configure logging settings
logging.basicConfig(
level=logging.INFO,
format="%(levelname)s:\t%(name)s\t%(message)s",
handlers=[logging.StreamHandler()],
)
# Define a logger for the current module
logger = logging.getLogger(__name__)
settings = Settings()
tags_metadata = [
{
"name": "misc.",
"description": "Miscellaneous endpoints that don't fit anywhere else",
},
{
"name": "chats",
"description": "Used to manage chats",
},
]
description = """
Serge answers your questions poorly using LLaMa/alpaca. 🚀
"""
origins = [
"http://localhost",
"http://api:9124",
"http://localhost:9123",
"http://localhost:9124",
]
app = FastAPI(
title="Serge", version="0.0.1", description=description, tags_metadata=tags_metadata
)
api_app = FastAPI(title="Serge API")
api_app.include_router(chat_router)
api_app.include_router(model_router)
app.mount("/api", api_app)
# handle serving the frontend as static files in production
if settings.NODE_ENV == "production":
@app.middleware("http")
async def add_custom_header(request, call_next):
response = await call_next(request)
if response.status_code == 404:
return FileResponse("static/200.html")
return response
@app.exception_handler(404)
def not_found(request, exc):
return FileResponse("static/200.html")
async def homepage(request):
return FileResponse("static/200.html")
app.route("/", homepage)
app.mount("/", StaticFiles(directory="static"))
start_app = app
else:
start_app = api_app
@start_app.on_event("startup")
async def start_database():
WEIGHTS = "/usr/src/app/weights/"
files = os.listdir(WEIGHTS)
files = list(filter(lambda x: x.endswith(".tmp"), files))
for file in files:
os.remove(WEIGHTS + file)
logger.info("initializing database connection")
await initiate_database()
logger.info("initializing models")
asyncio.create_task(convert_model_files())
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)

View File

@ -0,0 +1 @@
from .chat import Chat, Question, ChatParameters

View File

@ -4,7 +4,6 @@ from uuid import UUID, uuid4
from pydantic import Field
from datetime import datetime
from enum import Enum
class ChatParameters(Document):
model: str = Field(default="ggml-alpaca-7B-q4_0.bin")

View File

@ -0,0 +1,2 @@
from .chat import chat_router
from .model import model_router

View File

@ -1,189 +1,11 @@
import asyncio
import logging
import os
import psutil
from typing import Annotated
import anyio
from fastapi import FastAPI, HTTPException, status, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi import APIRouter, HTTPException, Depends
from sse_starlette.sse import EventSourceResponse
from beanie.odm.enums import SortDirection
from sse_starlette.sse import EventSourceResponse
from utils.initiate_database import initiate_database, Settings
from utils.generate import generate, get_full_prompt_from_chat
from utils.convert import convert_all
from models import Question, Chat, ChatParameters
from fastapi.staticfiles import StaticFiles
from starlette.responses import FileResponse
# Configure logging settings
logging.basicConfig(
level=logging.INFO,
format="%(levelname)s:\t%(name)s\t%(message)s",
handlers=[
logging.StreamHandler()
]
)
# Define a logger for the current module
logger = logging.getLogger(__name__)
settings = Settings()
tags_metadata = [
{
"name": "misc.",
"description": "Miscellaneous endpoints that don't fit anywhere else",
},
{
"name": "chats",
"description": "Used to manage chats",
},
]
description = """
Serge answers your questions poorly using LLaMa/alpaca. 🚀
"""
app = FastAPI(
title="Serge", version="0.0.1", description=description, tags_metadata=tags_metadata
)
api_app = FastAPI(title="Serge API")
app.mount('/api', api_app)
if settings.NODE_ENV == "production":
@app.middleware("http")
async def add_custom_header(request, call_next):
response = await call_next(request)
if response.status_code == 404:
return FileResponse('static/200.html')
return response
@app.exception_handler(404)
def not_found(request, exc):
return FileResponse('static/200.html')
async def homepage(request):
return FileResponse('static/200.html')
app.route('/', homepage)
app.mount('/', StaticFiles(directory='static'))
if settings.NODE_ENV == "development":
start_app = api_app
else:
start_app = app
origins = [
"http://localhost",
"http://api:9124",
"http://localhost:9123",
"http://localhost:9124",
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
MODEL_IS_READY: bool = False
def dep_models_ready() -> list[str]:
"""
FastAPI dependency that checks if models are ready.
Returns a list of available models
"""
if MODEL_IS_READY is False:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail={
"message": "models are not ready"
}
)
files = os.listdir("/usr/src/app/weights")
files = list(filter(lambda x: x.endswith(".bin"), files))
return files
async def convert_model_files():
global MODEL_IS_READY
await anyio.to_thread.run_sync(convert_all, "/usr/src/app/weights/", "/usr/src/app/weights/tokenizer.model")
MODEL_IS_READY = True
logger.info("models are ready")
@start_app.on_event("startup")
async def start_database():
logger.info("initializing database connection")
await initiate_database()
logger.info("initializing models")
asyncio.create_task(convert_model_files())
@api_app.get("/models", tags=["misc."])
def list_of_installed_models(
models: Annotated[list[str], Depends(dep_models_ready)]
):
return models
THREADS = len(psutil.Process().cpu_affinity())
@api_app.post("/chat", tags=["chats"])
async def create_new_chat(
model: str = "ggml-alpaca-7B-q4_0.bin",
temperature: float = 0.1,
top_k: int = 50,
top_p: float = 0.95,
max_length: int = 256,
context_window: int = 512,
repeat_last_n: int = 64,
repeat_penalty: float = 1.3,
init_prompt: str = "Below is an instruction that describes a task. Write a response that appropriately completes the request. The response must be accurate, concise and evidence-based whenever possible. A complete answer is always ended by [end of text].",
n_threads: int = THREADS / 2,
):
parameters = await ChatParameters(
model=model,
temperature=temperature,
top_k=top_k,
top_p=top_p,
max_length=max_length,
context_window=context_window,
repeat_last_n=repeat_last_n,
repeat_penalty=repeat_penalty,
init_prompt=init_prompt,
n_threads=n_threads,
).create()
chat = await Chat(parameters=parameters).create()
return chat.id
@api_app.get("/chat/{chat_id}", tags=["chats"])
async def get_specific_chat(chat_id: str):
chat = await Chat.get(chat_id)
await chat.fetch_all_links()
return chat
@api_app.delete("/chat/{chat_id}", tags=["chats"])
async def delete_chat(chat_id: str):
chat = await Chat.get(chat_id)
deleted_chat = await chat.delete()
if deleted_chat:
return {"message": f"Deleted chat with id: {chat_id}"}
else:
raise HTTPException(status_code=404, detail="No chat found with the given id.")
from serge.models.chat import Question, Chat,ChatParameters
from serge.utils.generate import generate, get_full_prompt_from_chat
async def on_close(chat, prompt, answer=None, error=None):
question = await Question(question=prompt.rstrip(),
@ -207,7 +29,84 @@ def remove_matching_end(a, b):
return b
@api_app.get("/chat/{chat_id}/question", dependencies=[Depends(dep_models_ready)])
chat_router = APIRouter(
prefix="/chat",
tags=["chat"],
)
@chat_router.post("/")
async def create_new_chat(
model: str = "7B",
temperature: float = 0.1,
top_k: int = 50,
top_p: float = 0.95,
max_length: int = 256,
context_window: int = 512,
repeat_last_n: int = 64,
repeat_penalty: float = 1.3,
init_prompt: str = "Below is an instruction that describes a task. Write a response that appropriately completes the request. The response must be accurate, concise and evidence-based whenever possible. A complete answer is always ended by [end of text].",
n_threads: int = 4,
):
parameters = await ChatParameters(
model=model,
temperature=temperature,
top_k=top_k,
top_p=top_p,
max_length=max_length,
context_window=context_window,
repeat_last_n=repeat_last_n,
repeat_penalty=repeat_penalty,
init_prompt=init_prompt,
n_threads=n_threads,
).create()
chat = await Chat(parameters=parameters).create()
return chat.id
@chat_router.get("/")
async def get_all_chats():
res = []
for i in (
await Chat.find_all().sort((Chat.created, SortDirection.DESCENDING)).to_list()
):
await i.fetch_link(Chat.parameters)
await i.fetch_link(Chat.questions)
first_q = i.questions[0].question if i.questions else ""
res.append(
{
"id": i.id,
"created": i.created,
"model": i.parameters.model,
"subtitle": first_q,
}
)
return res
@chat_router.get("/{chat_id}")
async def get_specific_chat(chat_id: str):
chat = await Chat.get(chat_id)
await chat.fetch_all_links()
return chat
@chat_router.delete("/{chat_id}" )
async def delete_chat(chat_id: str):
chat = await Chat.get(chat_id)
deleted_chat = await chat.delete()
if deleted_chat:
return {"message": f"Deleted chat with id: {chat_id}"}
else:
raise HTTPException(status_code=404, detail="No chat found with the given id.")
@chat_router.get("/{chat_id}/question")
async def stream_ask_a_question(chat_id: str, prompt: str):
chat = await Chat.get(chat_id)
@ -238,7 +137,6 @@ async def stream_ask_a_question(chat_id: str, prompt: str):
except Exception as e:
error = e.__str__()
logger.error(error)
yield({"event" : "error"})
finally:
answer = "".join(chunks)[len(full_prompt)+1:]
@ -248,7 +146,7 @@ async def stream_ask_a_question(chat_id: str, prompt: str):
return EventSourceResponse(event_generator())
@api_app.post("/chat/{chat_id}/question", dependencies=[Depends(dep_models_ready)])
@chat_router.post("/{chat_id}/question")
async def ask_a_question(chat_id: str, prompt: str):
chat = await Chat.get(chat_id)
await chat.fetch_link(Chat.parameters)
@ -270,26 +168,4 @@ async def ask_a_question(chat_id: str, prompt: str):
finally:
await on_close(chat, prompt, answer=answer[len(full_prompt)+1:], error=error)
return {"question" : prompt, "answer" : answer[len(full_prompt)+1:]}
@api_app.get("/chats", tags=["chats"])
async def get_all_chats():
res = []
for i in (
await Chat.find_all().sort((Chat.created, SortDirection.DESCENDING)).to_list()
):
await i.fetch_link(Chat.parameters)
await i.fetch_link(Chat.questions)
first_q = i.questions[0].question if i.questions else ""
res.append(
{
"id": i.id,
"created": i.created,
"model": i.parameters.model,
"subtitle": first_q,
}
)
return res
return {"question" : prompt, "answer" : answer[len(full_prompt)+1:]}

View File

@ -0,0 +1,109 @@
from fastapi import APIRouter, HTTPException
from serge.utils.convert import convert_one_file
import huggingface_hub
import os
import urllib.request
model_router = APIRouter(
prefix="/model",
tags=["model"],
)
models_info = {
"7B": [
"nsarrazin/alpaca",
"alpaca-7B-ggml/ggml-model-q4_0.bin",
4.20E9,
],
"7B-native": [
"nsarrazin/alpaca",
"alpaca-native-7B-ggml/ggml-model-q4_0.bin",
4.20E9,
],
"13B": [
"nsarrazin/alpaca",
"alpaca-13B-ggml/ggml-model-q4_0.bin",
8.13E9,
],
"30B": [
"nsarrazin/alpaca",
"alpaca-30B-ggml/ggml-model-q4_0.bin",
20.2E9,
],
}
WEIGHTS = "/usr/src/app/weights/"
@model_router.get("/all")
async def list_of_all_models():
res = []
for model in models_info.keys():
progress = await download_status(model)
res.append({
"name": model,
"size": models_info[model][2],
"available": model+".bin" in await list_of_installed_models(),
"progress" : progress,
})
return res
@model_router.get("/downloadable")
async def list_of_downloadable_models():
files = os.listdir(WEIGHTS)
files = list(filter(lambda x: x.endswith(".bin"), files))
installed_models = [i.rstrip(".bin") for i in files]
return list(filter(lambda x: x not in installed_models, models_info.keys()))
@model_router.get("/installed")
async def list_of_installed_models():
files = os.listdir(WEIGHTS)
files = list(filter(lambda x: x.endswith(".bin"), files))
return files
@model_router.post("/{model_name}/download")
def download_model(model_name: str):
models = list(models_info.keys())
if model_name not in models:
raise HTTPException(status_code=404, detail="Model not found")
if not os.path.exists(WEIGHTS+ "tokenizer.model"):
print("Downloading tokenizer...")
url = huggingface_hub.hf_hub_url("nsarrazin/alpaca", "alpaca-7B-ggml/tokenizer.model", repo_type="model", revision="main")
urllib.request.urlretrieve(url, WEIGHTS+"tokenizer.model")
repo_id, filename,_ = models_info[model_name]
print(f"Downloading {model_name} model from {repo_id}...")
url = huggingface_hub.hf_hub_url(repo_id, filename, repo_type="model", revision="main")
urllib.request.urlretrieve(url, WEIGHTS+f"{model_name}.bin.tmp")
os.rename(WEIGHTS+f"{model_name}.bin.tmp", WEIGHTS+f"{model_name}.bin")
convert_one_file(WEIGHTS+ "f{model_name}.bin", WEIGHTS + f"tokenizer.model")
return {"message": f"Model {model_name} downloaded"}
@model_router.get("/{model_name}/download/status")
async def download_status(model_name: str):
models = list(models_info.keys())
if model_name not in models:
raise HTTPException(status_code=404, detail="Model not found")
filesize = models_info[model_name][2]
bin_path = WEIGHTS+f"{model_name}.bin.tmp"
if os.path.exists(bin_path):
currentsize = os.path.getsize(bin_path)
return min(round(currentsize / filesize*100, 1), 100)
return None

View File

View File

@ -113,12 +113,11 @@ def convert_all(dir_model: str, tokenizer_model: str):
try:
tokenizer = SentencePieceProcessor(tokenizer_model)
for file in files:
convert_one_file(file, tokenizer)
except OSError:
print("Missing tokenizer, don't forget to download it!")
for file in files:
convert_one_file(file, tokenizer)
if __name__ == "__main__":
args = parse_args()

View File

@ -1,5 +1,5 @@
import subprocess, os
from models import Chat, ChatParameters
from serge.models.chat import Chat, ChatParameters
import asyncio
import logging
@ -16,7 +16,7 @@ async def generate(
args = (
"llama",
"--model",
"/usr/src/app/weights/" + params.model,
"/usr/src/app/weights/" + params.model + ".bin",
"--prompt",
prompt,
"--n_predict",

View File

@ -4,7 +4,7 @@ from beanie import init_beanie, Document
from motor.motor_asyncio import AsyncIOMotorClient
from pydantic import BaseSettings
from models import Question, Chat, ChatParameters
from serge.models.chat import Question, Chat, ChatParameters
class Settings(BaseSettings):

View File

@ -1,56 +0,0 @@
import argparse
import huggingface_hub
import os
from typing import List
from convert import convert_all
models_info = {
"7B": ["Pi3141/alpaca-7B-ggml", "ggml-model-q4_0.bin"],
"13B": ["Pi3141/alpaca-13B-ggml", "ggml-model-q4_0.bin"],
"30B": ["Pi3141/alpaca-30B-ggml", "ggml-model-q4_0.bin"],
"tokenizer": ["decapoda-research/llama-7b-hf", "tokenizer.model"],
}
def parse_args():
parser = argparse.ArgumentParser(
description="Download and convert LLaMA models to the current format"
)
parser.add_argument(
"model",
help="Model name",
nargs="+",
choices=["7B", "13B", "30B", "tokenizer"],
)
return parser.parse_args()
def download_models(models: List[str]):
for model in models:
repo_id, filename = models_info[model]
print(f"Downloading {model} model from {repo_id}...")
huggingface_hub.hf_hub_download(
repo_id=repo_id,
filename=filename,
local_dir="/usr/src/app/weights",
local_dir_use_symlinks=False,
cache_dir="/usr/src/app/weights/.cache",
)
if filename == "ggml-model-q4_0.bin":
os.rename(
"/usr/src/app/weights/ggml-model-q4_0.bin", f"/usr/src/app/weights/ggml-alpaca-{model}-q4_0.bin"
)
if __name__ == "__main__":
args = parse_args()
print("Downloading models from HuggingFace")
download_models(args.model)
print("Converting models to the current format")
convert_all("/usr/src/app/weights", "/usr/src/app/weights/tokenizer.model")

View File

@ -4,7 +4,7 @@
mongod &
# Start the API
cd api && uvicorn main:app --host 0.0.0.0 --port 8008 &
cd api && uvicorn src.serge.main:app --host 0.0.0.0 --port 8008 &
# Wait for any process to exit
wait -n

View File

@ -1,13 +1,15 @@
#!/bin/bash
./compile.sh
pip install -e ./api
mongod &
# Start the web server
cd web && npm run dev -- --host 0.0.0.0 --port 8008 &
# Start the API
cd api && uvicorn main:api_app --reload --host 0.0.0.0 --port 9124 --root-path /api/ &
cd api && uvicorn src.serge.main:api_app --reload --host 0.0.0.0 --port 9124 --root-path /api/ &
# Wait for any process to exit
wait -n

View File

@ -12,7 +12,7 @@
if (response.status == 200) {
toggleDeleteConfirm();
await goto("/");
await invalidate("/api/chats");
await invalidate("/api/chat/");
} else {
console.error("Error " + response.status + ": " + response.statusText);
}
@ -69,7 +69,7 @@
<li>
<a
href={"/chat/" + chat.id}
class="flex items-center p-2 text-base font-normal rounded-lg hover:bg-gray-700"
class="flex items-center p-2 text-base font-normal rounded-lg hover:bg-gray-700 active:bg-gray-800"
>
<div class="flex flex-col">
<div>

View File

@ -8,7 +8,7 @@ type t = {
};
export const load: LayoutLoad = async ({ fetch }) => {
const r = await fetch("/api/chats");
const r = await fetch("/api/chat/");
const chats = (await r.json()) as t[];
return {
chats,

View File

@ -3,7 +3,12 @@
import { goto, invalidate } from "$app/navigation";
export let data: PageData;
const modelAvailable = data.models.length > 0;
const models = data.models.filter((el) => el.available);
console.log(models);
const modelAvailable = models.length > 0;
const modelsLabels = models.map((el) => el.name);
let temp = 0.1;
let top_k = 50;
@ -30,7 +35,7 @@
]);
const searchParams = new URLSearchParams(convertedFormEntries);
const r = await fetch("/api/chat?" + searchParams.toString(), {
const r = await fetch("/api/chat/?" + searchParams.toString(), {
method: "POST",
});
@ -38,14 +43,14 @@
if (r.ok) {
const data = await r.json();
await goto("/chat/" + data);
await invalidate("/api/chats");
await invalidate("/api/chat/");
} else {
console.log(r.statusText);
}
}
</script>
<h1 class="text-3xl font-bold text-center pt-5">Say Hi to Serge!</h1>
<h1 class="text-3xl font-bold text-center pt-5">Say Hi to Serge 🦙</h1>
<h1 class="text-xl font-light text-center pt-2 pb-5">
An easy way to chat with Alpaca & other LLaMa based models.
</h1>
@ -53,8 +58,15 @@
<form on:submit|preventDefault={onCreateChat} id="form-create-chat" class="p-5">
<div class="w-full pb-20">
<div class="mx-auto w-fit pt-5">
<button class=" mx-auto btn btn-primary ml-5" disabled={!modelAvailable}
>Start a new chat</button
<button
type="submit"
class="btn btn-primary mx-5"
disabled={!modelAvailable}>Start a new chat</button
>
<button
on:click={() => goto("/models")}
type="button"
class="btn btn-outline mx-5">Download Models</button
>
</div>
</div>
@ -162,7 +174,7 @@
<div class="flex flex-col">
<label for="model" class="label-text pb-1"> Model choice </label>
<select name="model" class="select select-bordered w-full max-w-xs">
{#each data.models as model}
{#each modelsLabels as model}
<option value={model}>{model}</option>
{/each}
</select>

View File

@ -1,8 +1,15 @@
import type { PageLoad } from "./$types";
interface ModelStatus {
name: string;
size: number;
available: boolean;
progress?: number;
}
export const load: PageLoad = async ({ fetch }) => {
const r = await fetch("api/models");
const models = (await r.json()) as string[];
const r = await fetch("/api/model/all");
const models = (await r.json()) as Array<ModelStatus>;
return {
models,
};

View File

@ -0,0 +1,72 @@
<script lang="ts">
import { invalidate } from "$app/navigation";
import { each } from "svelte/internal";
import type { PageData } from "./$types";
export let data: PageData;
let downloading = false;
setInterval(async () => {
if (downloading) {
await invalidate("/api/model/all");
}
}, 2500);
async function onClick(model: string) {
if (downloading) {
return;
}
downloading = true;
const r = await fetch(`/api/model/${model}/download`, {
method: "POST",
});
if (r.ok) {
await invalidate("/api/model/all");
}
downloading = false;
}
</script>
<h1 class="text-3xl font-bold text-center pt-5">⚡ Download a model ⚡</h1>
<h1 class="text-xl font-light text-center pt-2 pb-5">
Make sure you have enough disk space and available RAM to run them.
</h1>
<div class="flex flex-col mx-auto mt-30">
<div class="max-w-4xl mx-auto w-full">
<div class="divider" />
{#each data.models as model}
<div class="flex flex-col content-around my-5">
<h2 class="text-3xl font-semibold mx-auto">
{model.name + " " + (model.available ? "☑️" : "")}
</h2>
<p class="text-xl font-light mx-auto pb-2">
({model.size / 1e9}GB)
</p>
{#if model.progress}
<div class="w-56 mx-auto my-5 justify-center">
<p class="w-full text-center font-light">{model.progress}%</p>
<progress
class="progress progress-primary h-5 w-56 mx-auto"
value={model.progress}
max="100"
/>
</div>
{/if}
<button
on:click={() => onClick(model.name)}
class="btn btn-primary mx-auto"
class:model.available={() => "btn-outline"}
disabled={model.available ||
(model.progress && model.progress > 0 ? true : false)}
>
Download
</button>
</div>
<div class="divider" />
{/each}
</div>
</div>

View File

@ -0,0 +1,16 @@
import type { PageLoad } from "./$types";
interface ModelStatus {
name: string;
size: number;
available: boolean;
progress?: number;
}
export const load: PageLoad = async ({ fetch }) => {
const r = await fetch("/api/model/all");
const models = (await r.json()) as Array<ModelStatus>;
return {
models,
};
};