API Refactor & Model Manager (#101)
* API refactoring * delete partially downloaded files on startup * remove unused deps
This commit is contained in:
parent
bad45112c2
commit
b5c423fe59
17
Dockerfile
17
Dockerfile
@ -12,7 +12,7 @@ ENV TZ=Europe/Amsterdam
|
||||
|
||||
WORKDIR /usr/src/app
|
||||
|
||||
COPY --chmod=0755 compile.sh .
|
||||
COPY --chmod=0755 scripts/compile.sh .
|
||||
|
||||
# Install MongoDB and necessary tools
|
||||
RUN apt update && \
|
||||
@ -23,11 +23,7 @@ RUN apt update && \
|
||||
apt-get install -y mongodb-org && \
|
||||
git clone https://github.com/ggerganov/llama.cpp.git --branch master-d5850c5
|
||||
|
||||
|
||||
# copy & install python reqs
|
||||
COPY ./api/requirements.txt api/requirements.txt
|
||||
RUN pip install --upgrade pip && \
|
||||
pip install --no-cache-dir -r ./api/requirements.txt
|
||||
RUN pip install --upgrade pip
|
||||
|
||||
# Dev environment
|
||||
FROM base as dev
|
||||
@ -38,11 +34,7 @@ COPY --from=node_base /usr/local /usr/local
|
||||
COPY ./web/package*.json ./
|
||||
RUN npm ci
|
||||
|
||||
# Copy the rest of the project files
|
||||
COPY web /usr/src/app/web
|
||||
COPY ./api /usr/src/app/api
|
||||
|
||||
COPY --chmod=0755 dev.sh /usr/src/app/dev.sh
|
||||
COPY --chmod=0755 scripts/dev.sh /usr/src/app/dev.sh
|
||||
CMD ./dev.sh
|
||||
|
||||
# Build frontend
|
||||
@ -64,6 +56,7 @@ WORKDIR /usr/src/app
|
||||
COPY --from=frontend_builder /usr/src/app/web/build /usr/src/app/api/static/
|
||||
COPY ./api /usr/src/app/api
|
||||
|
||||
COPY --chmod=0755 deploy.sh /usr/src/app/deploy.sh
|
||||
RUN pip install ./api
|
||||
|
||||
COPY --chmod=0755 scripts/deploy.sh /usr/src/app/deploy.sh
|
||||
CMD ./deploy.sh
|
||||
|
||||
@ -22,7 +22,6 @@ cd serge
|
||||
docker compose up --build -d
|
||||
docker compose exec serge python3 /usr/src/app/api/utils/download.py tokenizer 7B
|
||||
```
|
||||
Please note that the models occupy the following storage space: 7B requires 4.21G, 13B requires 8.14G, and 30B requires 20.3G
|
||||
|
||||
#### Windows
|
||||
|
||||
|
||||
160
api/.gitignore
vendored
Normal file
160
api/.gitignore
vendored
Normal file
@ -0,0 +1,160 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
2
api/poetry.toml
Normal file
2
api/poetry.toml
Normal file
@ -0,0 +1,2 @@
|
||||
[virtualenvs]
|
||||
create = false
|
||||
66
api/pyproject.toml
Normal file
66
api/pyproject.toml
Normal file
@ -0,0 +1,66 @@
|
||||
[tool.poetry]
|
||||
name = "serge"
|
||||
description = "Serge API package"
|
||||
version = "0.1.0"
|
||||
license = "MIT"
|
||||
authors = [
|
||||
"Nathan Sarrazin <contact@nsarrazin.com>"
|
||||
]
|
||||
|
||||
packages = [
|
||||
{ include = "serge", from = "src" }
|
||||
]
|
||||
|
||||
homepage = "https://serge.chat/"
|
||||
repository = "https://github.com/nsarrazin/serge"
|
||||
|
||||
include = [{path="src"}]
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python=">=3.10,<4.0"
|
||||
asyncio = "^3.4.3"
|
||||
packaging = "^23.0"
|
||||
pydantic = "^1.10.7"
|
||||
pymongo = "^4.3.3"
|
||||
python-dotenv = "^1.0.0"
|
||||
python-multipart = "^0.0.6"
|
||||
pyyaml = "^6.0"
|
||||
rfc3986 = "^2.0.0"
|
||||
sentencepiece = "^0.1.97"
|
||||
sniffio = "^1.3.0"
|
||||
sse-starlette = "^1.3.3"
|
||||
starlette = "^0.26.1"
|
||||
toml = "^0.10.2"
|
||||
tqdm = "^4.65.0"
|
||||
typing-extensions = "^4.5.0"
|
||||
ujson = "^5.7.0"
|
||||
urllib3 = "^1.26.15"
|
||||
uvicorn = "^0.21.1"
|
||||
uvloop = "^0.17.0"
|
||||
watchfiles = "^0.19.0"
|
||||
websockets = "^10.4"
|
||||
anyio = "^3.6.2"
|
||||
certifi = "^2022.12.7"
|
||||
charset-normalizer = "^3.1.0"
|
||||
click = "^8.1.3"
|
||||
email-validator = "^1.3.1"
|
||||
fastapi = "^0.95.0"
|
||||
filelock = "^3.10.7"
|
||||
h11 = "^0.14.0"
|
||||
httpcore = "^0.17.0"
|
||||
httptools = "^0.5.0"
|
||||
huggingface-hub = "^0.13.3"
|
||||
idna = "^3.4"
|
||||
itsdangerous = "^2.1.2"
|
||||
jinja2 = "^3.1.2"
|
||||
markupsafe = "^2.1.2"
|
||||
motor = "^3.1.1"
|
||||
orjson = "^3.8.8"
|
||||
beanie = "^1.17.0"
|
||||
dnspython = "^2.3.0"
|
||||
lazy-model = "^0.0.5"
|
||||
requests = "^2.28.2"
|
||||
@ -1,44 +0,0 @@
|
||||
anyio==3.6.2
|
||||
asyncio==3.4.3
|
||||
beanie==1.17.0
|
||||
certifi==2022.12.7
|
||||
charset-normalizer==3.1.0
|
||||
click==8.1.3
|
||||
dnspython==2.3.0
|
||||
email-validator==1.3.1
|
||||
fastapi==0.95.0
|
||||
filelock==3.10.2
|
||||
h11==0.14.0
|
||||
httpcore==0.16.3
|
||||
httptools==0.5.0
|
||||
httpx==0.23.3
|
||||
huggingface-hub==0.13.3
|
||||
idna==3.4
|
||||
itsdangerous==2.1.2
|
||||
Jinja2==3.1.2
|
||||
lazy-model==0.0.5
|
||||
MarkupSafe==2.1.2
|
||||
motor==3.1.1
|
||||
orjson==3.8.8
|
||||
packaging==23.0
|
||||
psutil==5.9.4
|
||||
pydantic==1.10.7
|
||||
pymongo==4.3.3
|
||||
python-dotenv==1.0.0
|
||||
python-multipart==0.0.6
|
||||
PyYAML==6.0
|
||||
requests==2.28.2
|
||||
rfc3986==1.5.0
|
||||
sentencepiece==0.1.97
|
||||
sniffio==1.3.0
|
||||
sse-starlette==1.3.3
|
||||
starlette==0.26.1
|
||||
toml==0.10.2
|
||||
tqdm==4.65.0
|
||||
typing_extensions==4.5.0
|
||||
ujson==5.7.0
|
||||
urllib3==1.26.15
|
||||
uvicorn==0.21.1
|
||||
uvloop==0.17.0
|
||||
watchfiles==0.18.1
|
||||
websockets==10.4
|
||||
8
api/src/serge/dependencies.py
Normal file
8
api/src/serge/dependencies.py
Normal file
@ -0,0 +1,8 @@
|
||||
from .utils.convert import convert_all
|
||||
import anyio
|
||||
|
||||
|
||||
async def convert_model_files():
|
||||
await anyio.to_thread.run_sync(
|
||||
convert_all, "/usr/src/app/weights/", "/usr/src/app/weights/tokenizer.model"
|
||||
)
|
||||
105
api/src/serge/main.py
Normal file
105
api/src/serge/main.py
Normal file
@ -0,0 +1,105 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from starlette.responses import FileResponse
|
||||
|
||||
from serge.routers.chat import chat_router
|
||||
from serge.routers.model import model_router
|
||||
from serge.utils.initiate_database import initiate_database, Settings
|
||||
from serge.dependencies import convert_model_files
|
||||
|
||||
# Configure logging settings
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(levelname)s:\t%(name)s\t%(message)s",
|
||||
handlers=[logging.StreamHandler()],
|
||||
)
|
||||
|
||||
# Define a logger for the current module
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
settings = Settings()
|
||||
|
||||
tags_metadata = [
|
||||
{
|
||||
"name": "misc.",
|
||||
"description": "Miscellaneous endpoints that don't fit anywhere else",
|
||||
},
|
||||
{
|
||||
"name": "chats",
|
||||
"description": "Used to manage chats",
|
||||
},
|
||||
]
|
||||
|
||||
description = """
|
||||
Serge answers your questions poorly using LLaMa/alpaca. 🚀
|
||||
"""
|
||||
|
||||
origins = [
|
||||
"http://localhost",
|
||||
"http://api:9124",
|
||||
"http://localhost:9123",
|
||||
"http://localhost:9124",
|
||||
]
|
||||
|
||||
app = FastAPI(
|
||||
title="Serge", version="0.0.1", description=description, tags_metadata=tags_metadata
|
||||
)
|
||||
|
||||
api_app = FastAPI(title="Serge API")
|
||||
api_app.include_router(chat_router)
|
||||
api_app.include_router(model_router)
|
||||
app.mount("/api", api_app)
|
||||
|
||||
# handle serving the frontend as static files in production
|
||||
if settings.NODE_ENV == "production":
|
||||
|
||||
@app.middleware("http")
|
||||
async def add_custom_header(request, call_next):
|
||||
response = await call_next(request)
|
||||
if response.status_code == 404:
|
||||
return FileResponse("static/200.html")
|
||||
return response
|
||||
|
||||
@app.exception_handler(404)
|
||||
def not_found(request, exc):
|
||||
return FileResponse("static/200.html")
|
||||
|
||||
async def homepage(request):
|
||||
return FileResponse("static/200.html")
|
||||
|
||||
app.route("/", homepage)
|
||||
app.mount("/", StaticFiles(directory="static"))
|
||||
|
||||
start_app = app
|
||||
else:
|
||||
start_app = api_app
|
||||
|
||||
|
||||
@start_app.on_event("startup")
|
||||
async def start_database():
|
||||
WEIGHTS = "/usr/src/app/weights/"
|
||||
files = os.listdir(WEIGHTS)
|
||||
files = list(filter(lambda x: x.endswith(".tmp"), files))
|
||||
|
||||
for file in files:
|
||||
os.remove(WEIGHTS + file)
|
||||
|
||||
logger.info("initializing database connection")
|
||||
await initiate_database()
|
||||
|
||||
logger.info("initializing models")
|
||||
asyncio.create_task(convert_model_files())
|
||||
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
1
api/src/serge/models/__init__.py
Normal file
1
api/src/serge/models/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .chat import Chat, Question, ChatParameters
|
||||
@ -4,7 +4,6 @@ from uuid import UUID, uuid4
|
||||
from pydantic import Field
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
class ChatParameters(Document):
|
||||
model: str = Field(default="ggml-alpaca-7B-q4_0.bin")
|
||||
2
api/src/serge/routers/__init__.py
Normal file
2
api/src/serge/routers/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
from .chat import chat_router
|
||||
from .model import model_router
|
||||
@ -1,189 +1,11 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import psutil
|
||||
from typing import Annotated
|
||||
|
||||
import anyio
|
||||
from fastapi import FastAPI, HTTPException, status, Depends
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
from beanie.odm.enums import SortDirection
|
||||
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
from utils.initiate_database import initiate_database, Settings
|
||||
from utils.generate import generate, get_full_prompt_from_chat
|
||||
from utils.convert import convert_all
|
||||
from models import Question, Chat, ChatParameters
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from starlette.responses import FileResponse
|
||||
|
||||
|
||||
# Configure logging settings
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(levelname)s:\t%(name)s\t%(message)s",
|
||||
handlers=[
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
|
||||
# Define a logger for the current module
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
settings = Settings()
|
||||
|
||||
tags_metadata = [
|
||||
{
|
||||
"name": "misc.",
|
||||
"description": "Miscellaneous endpoints that don't fit anywhere else",
|
||||
},
|
||||
{
|
||||
"name": "chats",
|
||||
"description": "Used to manage chats",
|
||||
},
|
||||
]
|
||||
|
||||
description = """
|
||||
Serge answers your questions poorly using LLaMa/alpaca. 🚀
|
||||
"""
|
||||
|
||||
app = FastAPI(
|
||||
title="Serge", version="0.0.1", description=description, tags_metadata=tags_metadata
|
||||
)
|
||||
|
||||
api_app = FastAPI(title="Serge API")
|
||||
app.mount('/api', api_app)
|
||||
|
||||
if settings.NODE_ENV == "production":
|
||||
@app.middleware("http")
|
||||
async def add_custom_header(request, call_next):
|
||||
response = await call_next(request)
|
||||
if response.status_code == 404:
|
||||
return FileResponse('static/200.html')
|
||||
return response
|
||||
|
||||
@app.exception_handler(404)
|
||||
def not_found(request, exc):
|
||||
return FileResponse('static/200.html')
|
||||
|
||||
async def homepage(request):
|
||||
return FileResponse('static/200.html')
|
||||
|
||||
app.route('/', homepage)
|
||||
app.mount('/', StaticFiles(directory='static'))
|
||||
|
||||
if settings.NODE_ENV == "development":
|
||||
start_app = api_app
|
||||
else:
|
||||
start_app = app
|
||||
|
||||
origins = [
|
||||
"http://localhost",
|
||||
"http://api:9124",
|
||||
"http://localhost:9123",
|
||||
"http://localhost:9124",
|
||||
]
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
MODEL_IS_READY: bool = False
|
||||
|
||||
|
||||
def dep_models_ready() -> list[str]:
|
||||
"""
|
||||
FastAPI dependency that checks if models are ready.
|
||||
|
||||
Returns a list of available models
|
||||
"""
|
||||
if MODEL_IS_READY is False:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail={
|
||||
"message": "models are not ready"
|
||||
}
|
||||
)
|
||||
|
||||
files = os.listdir("/usr/src/app/weights")
|
||||
files = list(filter(lambda x: x.endswith(".bin"), files))
|
||||
return files
|
||||
|
||||
|
||||
async def convert_model_files():
|
||||
global MODEL_IS_READY
|
||||
await anyio.to_thread.run_sync(convert_all, "/usr/src/app/weights/", "/usr/src/app/weights/tokenizer.model")
|
||||
MODEL_IS_READY = True
|
||||
logger.info("models are ready")
|
||||
|
||||
|
||||
@start_app.on_event("startup")
|
||||
async def start_database():
|
||||
logger.info("initializing database connection")
|
||||
await initiate_database()
|
||||
|
||||
logger.info("initializing models")
|
||||
asyncio.create_task(convert_model_files())
|
||||
|
||||
|
||||
@api_app.get("/models", tags=["misc."])
|
||||
def list_of_installed_models(
|
||||
models: Annotated[list[str], Depends(dep_models_ready)]
|
||||
):
|
||||
return models
|
||||
|
||||
THREADS = len(psutil.Process().cpu_affinity())
|
||||
|
||||
@api_app.post("/chat", tags=["chats"])
|
||||
async def create_new_chat(
|
||||
model: str = "ggml-alpaca-7B-q4_0.bin",
|
||||
temperature: float = 0.1,
|
||||
top_k: int = 50,
|
||||
top_p: float = 0.95,
|
||||
max_length: int = 256,
|
||||
context_window: int = 512,
|
||||
repeat_last_n: int = 64,
|
||||
repeat_penalty: float = 1.3,
|
||||
init_prompt: str = "Below is an instruction that describes a task. Write a response that appropriately completes the request. The response must be accurate, concise and evidence-based whenever possible. A complete answer is always ended by [end of text].",
|
||||
n_threads: int = THREADS / 2,
|
||||
):
|
||||
parameters = await ChatParameters(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
max_length=max_length,
|
||||
context_window=context_window,
|
||||
repeat_last_n=repeat_last_n,
|
||||
repeat_penalty=repeat_penalty,
|
||||
init_prompt=init_prompt,
|
||||
n_threads=n_threads,
|
||||
).create()
|
||||
|
||||
chat = await Chat(parameters=parameters).create()
|
||||
return chat.id
|
||||
|
||||
|
||||
@api_app.get("/chat/{chat_id}", tags=["chats"])
|
||||
async def get_specific_chat(chat_id: str):
|
||||
chat = await Chat.get(chat_id)
|
||||
await chat.fetch_all_links()
|
||||
|
||||
return chat
|
||||
|
||||
@api_app.delete("/chat/{chat_id}", tags=["chats"])
|
||||
async def delete_chat(chat_id: str):
|
||||
chat = await Chat.get(chat_id)
|
||||
deleted_chat = await chat.delete()
|
||||
|
||||
if deleted_chat:
|
||||
return {"message": f"Deleted chat with id: {chat_id}"}
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="No chat found with the given id.")
|
||||
from serge.models.chat import Question, Chat,ChatParameters
|
||||
from serge.utils.generate import generate, get_full_prompt_from_chat
|
||||
|
||||
async def on_close(chat, prompt, answer=None, error=None):
|
||||
question = await Question(question=prompt.rstrip(),
|
||||
@ -207,7 +29,84 @@ def remove_matching_end(a, b):
|
||||
|
||||
return b
|
||||
|
||||
@api_app.get("/chat/{chat_id}/question", dependencies=[Depends(dep_models_ready)])
|
||||
chat_router = APIRouter(
|
||||
prefix="/chat",
|
||||
tags=["chat"],
|
||||
)
|
||||
|
||||
@chat_router.post("/")
|
||||
async def create_new_chat(
|
||||
model: str = "7B",
|
||||
temperature: float = 0.1,
|
||||
top_k: int = 50,
|
||||
top_p: float = 0.95,
|
||||
max_length: int = 256,
|
||||
context_window: int = 512,
|
||||
repeat_last_n: int = 64,
|
||||
repeat_penalty: float = 1.3,
|
||||
init_prompt: str = "Below is an instruction that describes a task. Write a response that appropriately completes the request. The response must be accurate, concise and evidence-based whenever possible. A complete answer is always ended by [end of text].",
|
||||
n_threads: int = 4,
|
||||
):
|
||||
parameters = await ChatParameters(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
max_length=max_length,
|
||||
context_window=context_window,
|
||||
repeat_last_n=repeat_last_n,
|
||||
repeat_penalty=repeat_penalty,
|
||||
init_prompt=init_prompt,
|
||||
n_threads=n_threads,
|
||||
).create()
|
||||
|
||||
chat = await Chat(parameters=parameters).create()
|
||||
return chat.id
|
||||
|
||||
|
||||
@chat_router.get("/")
|
||||
async def get_all_chats():
|
||||
res = []
|
||||
|
||||
for i in (
|
||||
await Chat.find_all().sort((Chat.created, SortDirection.DESCENDING)).to_list()
|
||||
):
|
||||
await i.fetch_link(Chat.parameters)
|
||||
await i.fetch_link(Chat.questions)
|
||||
|
||||
first_q = i.questions[0].question if i.questions else ""
|
||||
res.append(
|
||||
{
|
||||
"id": i.id,
|
||||
"created": i.created,
|
||||
"model": i.parameters.model,
|
||||
"subtitle": first_q,
|
||||
}
|
||||
)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
@chat_router.get("/{chat_id}")
|
||||
async def get_specific_chat(chat_id: str):
|
||||
chat = await Chat.get(chat_id)
|
||||
await chat.fetch_all_links()
|
||||
|
||||
return chat
|
||||
|
||||
@chat_router.delete("/{chat_id}" )
|
||||
async def delete_chat(chat_id: str):
|
||||
chat = await Chat.get(chat_id)
|
||||
deleted_chat = await chat.delete()
|
||||
|
||||
if deleted_chat:
|
||||
return {"message": f"Deleted chat with id: {chat_id}"}
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="No chat found with the given id.")
|
||||
|
||||
|
||||
|
||||
@chat_router.get("/{chat_id}/question")
|
||||
async def stream_ask_a_question(chat_id: str, prompt: str):
|
||||
|
||||
chat = await Chat.get(chat_id)
|
||||
@ -238,7 +137,6 @@ async def stream_ask_a_question(chat_id: str, prompt: str):
|
||||
|
||||
except Exception as e:
|
||||
error = e.__str__()
|
||||
logger.error(error)
|
||||
yield({"event" : "error"})
|
||||
finally:
|
||||
answer = "".join(chunks)[len(full_prompt)+1:]
|
||||
@ -248,7 +146,7 @@ async def stream_ask_a_question(chat_id: str, prompt: str):
|
||||
|
||||
return EventSourceResponse(event_generator())
|
||||
|
||||
@api_app.post("/chat/{chat_id}/question", dependencies=[Depends(dep_models_ready)])
|
||||
@chat_router.post("/{chat_id}/question")
|
||||
async def ask_a_question(chat_id: str, prompt: str):
|
||||
chat = await Chat.get(chat_id)
|
||||
await chat.fetch_link(Chat.parameters)
|
||||
@ -270,26 +168,4 @@ async def ask_a_question(chat_id: str, prompt: str):
|
||||
finally:
|
||||
await on_close(chat, prompt, answer=answer[len(full_prompt)+1:], error=error)
|
||||
|
||||
return {"question" : prompt, "answer" : answer[len(full_prompt)+1:]}
|
||||
|
||||
@api_app.get("/chats", tags=["chats"])
|
||||
async def get_all_chats():
|
||||
res = []
|
||||
|
||||
for i in (
|
||||
await Chat.find_all().sort((Chat.created, SortDirection.DESCENDING)).to_list()
|
||||
):
|
||||
await i.fetch_link(Chat.parameters)
|
||||
await i.fetch_link(Chat.questions)
|
||||
|
||||
first_q = i.questions[0].question if i.questions else ""
|
||||
res.append(
|
||||
{
|
||||
"id": i.id,
|
||||
"created": i.created,
|
||||
"model": i.parameters.model,
|
||||
"subtitle": first_q,
|
||||
}
|
||||
)
|
||||
|
||||
return res
|
||||
return {"question" : prompt, "answer" : answer[len(full_prompt)+1:]}
|
||||
109
api/src/serge/routers/model.py
Normal file
109
api/src/serge/routers/model.py
Normal file
@ -0,0 +1,109 @@
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from serge.utils.convert import convert_one_file
|
||||
import huggingface_hub
|
||||
import os
|
||||
import urllib.request
|
||||
|
||||
model_router = APIRouter(
|
||||
prefix="/model",
|
||||
tags=["model"],
|
||||
)
|
||||
|
||||
models_info = {
|
||||
"7B": [
|
||||
"nsarrazin/alpaca",
|
||||
"alpaca-7B-ggml/ggml-model-q4_0.bin",
|
||||
4.20E9,
|
||||
],
|
||||
"7B-native": [
|
||||
"nsarrazin/alpaca",
|
||||
"alpaca-native-7B-ggml/ggml-model-q4_0.bin",
|
||||
4.20E9,
|
||||
],
|
||||
"13B": [
|
||||
"nsarrazin/alpaca",
|
||||
"alpaca-13B-ggml/ggml-model-q4_0.bin",
|
||||
8.13E9,
|
||||
],
|
||||
"30B": [
|
||||
"nsarrazin/alpaca",
|
||||
"alpaca-30B-ggml/ggml-model-q4_0.bin",
|
||||
20.2E9,
|
||||
],
|
||||
}
|
||||
|
||||
WEIGHTS = "/usr/src/app/weights/"
|
||||
|
||||
@model_router.get("/all")
|
||||
async def list_of_all_models():
|
||||
res = []
|
||||
for model in models_info.keys():
|
||||
|
||||
progress = await download_status(model)
|
||||
|
||||
res.append({
|
||||
"name": model,
|
||||
"size": models_info[model][2],
|
||||
"available": model+".bin" in await list_of_installed_models(),
|
||||
"progress" : progress,
|
||||
})
|
||||
|
||||
return res
|
||||
|
||||
@model_router.get("/downloadable")
|
||||
async def list_of_downloadable_models():
|
||||
files = os.listdir(WEIGHTS)
|
||||
files = list(filter(lambda x: x.endswith(".bin"), files))
|
||||
|
||||
installed_models = [i.rstrip(".bin") for i in files]
|
||||
|
||||
return list(filter(lambda x: x not in installed_models, models_info.keys()))
|
||||
|
||||
@model_router.get("/installed")
|
||||
async def list_of_installed_models():
|
||||
files = os.listdir(WEIGHTS)
|
||||
files = list(filter(lambda x: x.endswith(".bin"), files))
|
||||
|
||||
return files
|
||||
|
||||
|
||||
@model_router.post("/{model_name}/download")
|
||||
def download_model(model_name: str):
|
||||
models = list(models_info.keys())
|
||||
if model_name not in models:
|
||||
raise HTTPException(status_code=404, detail="Model not found")
|
||||
|
||||
if not os.path.exists(WEIGHTS+ "tokenizer.model"):
|
||||
print("Downloading tokenizer...")
|
||||
url = huggingface_hub.hf_hub_url("nsarrazin/alpaca", "alpaca-7B-ggml/tokenizer.model", repo_type="model", revision="main")
|
||||
urllib.request.urlretrieve(url, WEIGHTS+"tokenizer.model")
|
||||
|
||||
|
||||
repo_id, filename,_ = models_info[model_name]
|
||||
|
||||
print(f"Downloading {model_name} model from {repo_id}...")
|
||||
url = huggingface_hub.hf_hub_url(repo_id, filename, repo_type="model", revision="main")
|
||||
urllib.request.urlretrieve(url, WEIGHTS+f"{model_name}.bin.tmp")
|
||||
|
||||
os.rename(WEIGHTS+f"{model_name}.bin.tmp", WEIGHTS+f"{model_name}.bin")
|
||||
convert_one_file(WEIGHTS+ "f{model_name}.bin", WEIGHTS + f"tokenizer.model")
|
||||
|
||||
return {"message": f"Model {model_name} downloaded"}
|
||||
|
||||
|
||||
@model_router.get("/{model_name}/download/status")
|
||||
async def download_status(model_name: str):
|
||||
models = list(models_info.keys())
|
||||
|
||||
if model_name not in models:
|
||||
raise HTTPException(status_code=404, detail="Model not found")
|
||||
|
||||
filesize = models_info[model_name][2]
|
||||
|
||||
bin_path = WEIGHTS+f"{model_name}.bin.tmp"
|
||||
|
||||
if os.path.exists(bin_path):
|
||||
currentsize = os.path.getsize(bin_path)
|
||||
return min(round(currentsize / filesize*100, 1), 100)
|
||||
return None
|
||||
0
api/src/serge/utils/__init__.py
Normal file
0
api/src/serge/utils/__init__.py
Normal file
@ -113,12 +113,11 @@ def convert_all(dir_model: str, tokenizer_model: str):
|
||||
|
||||
try:
|
||||
tokenizer = SentencePieceProcessor(tokenizer_model)
|
||||
for file in files:
|
||||
convert_one_file(file, tokenizer)
|
||||
except OSError:
|
||||
print("Missing tokenizer, don't forget to download it!")
|
||||
|
||||
for file in files:
|
||||
convert_one_file(file, tokenizer)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
@ -1,5 +1,5 @@
|
||||
import subprocess, os
|
||||
from models import Chat, ChatParameters
|
||||
from serge.models.chat import Chat, ChatParameters
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
@ -16,7 +16,7 @@ async def generate(
|
||||
args = (
|
||||
"llama",
|
||||
"--model",
|
||||
"/usr/src/app/weights/" + params.model,
|
||||
"/usr/src/app/weights/" + params.model + ".bin",
|
||||
"--prompt",
|
||||
prompt,
|
||||
"--n_predict",
|
||||
@ -4,7 +4,7 @@ from beanie import init_beanie, Document
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from pydantic import BaseSettings
|
||||
|
||||
from models import Question, Chat, ChatParameters
|
||||
from serge.models.chat import Question, Chat, ChatParameters
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
@ -1,56 +0,0 @@
|
||||
import argparse
|
||||
import huggingface_hub
|
||||
import os
|
||||
|
||||
from typing import List
|
||||
from convert import convert_all
|
||||
|
||||
models_info = {
|
||||
"7B": ["Pi3141/alpaca-7B-ggml", "ggml-model-q4_0.bin"],
|
||||
"13B": ["Pi3141/alpaca-13B-ggml", "ggml-model-q4_0.bin"],
|
||||
"30B": ["Pi3141/alpaca-30B-ggml", "ggml-model-q4_0.bin"],
|
||||
"tokenizer": ["decapoda-research/llama-7b-hf", "tokenizer.model"],
|
||||
}
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Download and convert LLaMA models to the current format"
|
||||
)
|
||||
parser.add_argument(
|
||||
"model",
|
||||
help="Model name",
|
||||
nargs="+",
|
||||
choices=["7B", "13B", "30B", "tokenizer"],
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def download_models(models: List[str]):
|
||||
for model in models:
|
||||
repo_id, filename = models_info[model]
|
||||
print(f"Downloading {model} model from {repo_id}...")
|
||||
|
||||
huggingface_hub.hf_hub_download(
|
||||
repo_id=repo_id,
|
||||
filename=filename,
|
||||
local_dir="/usr/src/app/weights",
|
||||
local_dir_use_symlinks=False,
|
||||
cache_dir="/usr/src/app/weights/.cache",
|
||||
)
|
||||
|
||||
if filename == "ggml-model-q4_0.bin":
|
||||
os.rename(
|
||||
"/usr/src/app/weights/ggml-model-q4_0.bin", f"/usr/src/app/weights/ggml-alpaca-{model}-q4_0.bin"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
print("Downloading models from HuggingFace")
|
||||
download_models(args.model)
|
||||
|
||||
print("Converting models to the current format")
|
||||
convert_all("/usr/src/app/weights", "/usr/src/app/weights/tokenizer.model")
|
||||
@ -4,7 +4,7 @@
|
||||
mongod &
|
||||
|
||||
# Start the API
|
||||
cd api && uvicorn main:app --host 0.0.0.0 --port 8008 &
|
||||
cd api && uvicorn src.serge.main:app --host 0.0.0.0 --port 8008 &
|
||||
|
||||
# Wait for any process to exit
|
||||
wait -n
|
||||
@ -1,13 +1,15 @@
|
||||
#!/bin/bash
|
||||
./compile.sh
|
||||
|
||||
pip install -e ./api
|
||||
|
||||
mongod &
|
||||
|
||||
# Start the web server
|
||||
cd web && npm run dev -- --host 0.0.0.0 --port 8008 &
|
||||
|
||||
# Start the API
|
||||
cd api && uvicorn main:api_app --reload --host 0.0.0.0 --port 9124 --root-path /api/ &
|
||||
cd api && uvicorn src.serge.main:api_app --reload --host 0.0.0.0 --port 9124 --root-path /api/ &
|
||||
|
||||
# Wait for any process to exit
|
||||
wait -n
|
||||
@ -12,7 +12,7 @@
|
||||
if (response.status == 200) {
|
||||
toggleDeleteConfirm();
|
||||
await goto("/");
|
||||
await invalidate("/api/chats");
|
||||
await invalidate("/api/chat/");
|
||||
} else {
|
||||
console.error("Error " + response.status + ": " + response.statusText);
|
||||
}
|
||||
@ -69,7 +69,7 @@
|
||||
<li>
|
||||
<a
|
||||
href={"/chat/" + chat.id}
|
||||
class="flex items-center p-2 text-base font-normal rounded-lg hover:bg-gray-700"
|
||||
class="flex items-center p-2 text-base font-normal rounded-lg hover:bg-gray-700 active:bg-gray-800"
|
||||
>
|
||||
<div class="flex flex-col">
|
||||
<div>
|
||||
|
||||
@ -8,7 +8,7 @@ type t = {
|
||||
};
|
||||
|
||||
export const load: LayoutLoad = async ({ fetch }) => {
|
||||
const r = await fetch("/api/chats");
|
||||
const r = await fetch("/api/chat/");
|
||||
const chats = (await r.json()) as t[];
|
||||
return {
|
||||
chats,
|
||||
|
||||
@ -3,7 +3,12 @@
|
||||
import { goto, invalidate } from "$app/navigation";
|
||||
export let data: PageData;
|
||||
|
||||
const modelAvailable = data.models.length > 0;
|
||||
const models = data.models.filter((el) => el.available);
|
||||
|
||||
console.log(models);
|
||||
|
||||
const modelAvailable = models.length > 0;
|
||||
const modelsLabels = models.map((el) => el.name);
|
||||
|
||||
let temp = 0.1;
|
||||
let top_k = 50;
|
||||
@ -30,7 +35,7 @@
|
||||
]);
|
||||
const searchParams = new URLSearchParams(convertedFormEntries);
|
||||
|
||||
const r = await fetch("/api/chat?" + searchParams.toString(), {
|
||||
const r = await fetch("/api/chat/?" + searchParams.toString(), {
|
||||
method: "POST",
|
||||
});
|
||||
|
||||
@ -38,14 +43,14 @@
|
||||
if (r.ok) {
|
||||
const data = await r.json();
|
||||
await goto("/chat/" + data);
|
||||
await invalidate("/api/chats");
|
||||
await invalidate("/api/chat/");
|
||||
} else {
|
||||
console.log(r.statusText);
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<h1 class="text-3xl font-bold text-center pt-5">Say Hi to Serge!</h1>
|
||||
<h1 class="text-3xl font-bold text-center pt-5">Say Hi to Serge 🦙</h1>
|
||||
<h1 class="text-xl font-light text-center pt-2 pb-5">
|
||||
An easy way to chat with Alpaca & other LLaMa based models.
|
||||
</h1>
|
||||
@ -53,8 +58,15 @@
|
||||
<form on:submit|preventDefault={onCreateChat} id="form-create-chat" class="p-5">
|
||||
<div class="w-full pb-20">
|
||||
<div class="mx-auto w-fit pt-5">
|
||||
<button class=" mx-auto btn btn-primary ml-5" disabled={!modelAvailable}
|
||||
>Start a new chat</button
|
||||
<button
|
||||
type="submit"
|
||||
class="btn btn-primary mx-5"
|
||||
disabled={!modelAvailable}>Start a new chat</button
|
||||
>
|
||||
<button
|
||||
on:click={() => goto("/models")}
|
||||
type="button"
|
||||
class="btn btn-outline mx-5">Download Models</button
|
||||
>
|
||||
</div>
|
||||
</div>
|
||||
@ -162,7 +174,7 @@
|
||||
<div class="flex flex-col">
|
||||
<label for="model" class="label-text pb-1"> Model choice </label>
|
||||
<select name="model" class="select select-bordered w-full max-w-xs">
|
||||
{#each data.models as model}
|
||||
{#each modelsLabels as model}
|
||||
<option value={model}>{model}</option>
|
||||
{/each}
|
||||
</select>
|
||||
|
||||
@ -1,8 +1,15 @@
|
||||
import type { PageLoad } from "./$types";
|
||||
|
||||
interface ModelStatus {
|
||||
name: string;
|
||||
size: number;
|
||||
available: boolean;
|
||||
progress?: number;
|
||||
}
|
||||
|
||||
export const load: PageLoad = async ({ fetch }) => {
|
||||
const r = await fetch("api/models");
|
||||
const models = (await r.json()) as string[];
|
||||
const r = await fetch("/api/model/all");
|
||||
const models = (await r.json()) as Array<ModelStatus>;
|
||||
return {
|
||||
models,
|
||||
};
|
||||
|
||||
72
web/src/routes/models/+page.svelte
Normal file
72
web/src/routes/models/+page.svelte
Normal file
@ -0,0 +1,72 @@
|
||||
<script lang="ts">
|
||||
import { invalidate } from "$app/navigation";
|
||||
import { each } from "svelte/internal";
|
||||
import type { PageData } from "./$types";
|
||||
|
||||
export let data: PageData;
|
||||
|
||||
let downloading = false;
|
||||
|
||||
setInterval(async () => {
|
||||
if (downloading) {
|
||||
await invalidate("/api/model/all");
|
||||
}
|
||||
}, 2500);
|
||||
|
||||
async function onClick(model: string) {
|
||||
if (downloading) {
|
||||
return;
|
||||
}
|
||||
|
||||
downloading = true;
|
||||
const r = await fetch(`/api/model/${model}/download`, {
|
||||
method: "POST",
|
||||
});
|
||||
|
||||
if (r.ok) {
|
||||
await invalidate("/api/model/all");
|
||||
}
|
||||
downloading = false;
|
||||
}
|
||||
</script>
|
||||
|
||||
<h1 class="text-3xl font-bold text-center pt-5">⚡ Download a model ⚡</h1>
|
||||
<h1 class="text-xl font-light text-center pt-2 pb-5">
|
||||
Make sure you have enough disk space and available RAM to run them.
|
||||
</h1>
|
||||
|
||||
<div class="flex flex-col mx-auto mt-30">
|
||||
<div class="max-w-4xl mx-auto w-full">
|
||||
<div class="divider" />
|
||||
{#each data.models as model}
|
||||
<div class="flex flex-col content-around my-5">
|
||||
<h2 class="text-3xl font-semibold mx-auto">
|
||||
{model.name + " " + (model.available ? "☑️" : "")}
|
||||
</h2>
|
||||
<p class="text-xl font-light mx-auto pb-2">
|
||||
({model.size / 1e9}GB)
|
||||
</p>
|
||||
{#if model.progress}
|
||||
<div class="w-56 mx-auto my-5 justify-center">
|
||||
<p class="w-full text-center font-light">{model.progress}%</p>
|
||||
<progress
|
||||
class="progress progress-primary h-5 w-56 mx-auto"
|
||||
value={model.progress}
|
||||
max="100"
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
<button
|
||||
on:click={() => onClick(model.name)}
|
||||
class="btn btn-primary mx-auto"
|
||||
class:model.available={() => "btn-outline"}
|
||||
disabled={model.available ||
|
||||
(model.progress && model.progress > 0 ? true : false)}
|
||||
>
|
||||
Download
|
||||
</button>
|
||||
</div>
|
||||
<div class="divider" />
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
16
web/src/routes/models/+page.ts
Normal file
16
web/src/routes/models/+page.ts
Normal file
@ -0,0 +1,16 @@
|
||||
import type { PageLoad } from "./$types";
|
||||
|
||||
interface ModelStatus {
|
||||
name: string;
|
||||
size: number;
|
||||
available: boolean;
|
||||
progress?: number;
|
||||
}
|
||||
|
||||
export const load: PageLoad = async ({ fetch }) => {
|
||||
const r = await fetch("/api/model/all");
|
||||
const models = (await r.json()) as Array<ModelStatus>;
|
||||
return {
|
||||
models,
|
||||
};
|
||||
};
|
||||
Loading…
x
Reference in New Issue
Block a user