Merge branch 'main' into v0.2-wip
This commit is contained in:
commit
0538ba1dab
7 changed files with 295 additions and 69 deletions
|
@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
|
||||
## [Unreleased]
|
||||
|
||||
## [0.1.74]
|
||||
|
||||
### Added
|
||||
|
||||
- (server) OpenAI style error responses
|
||||
|
||||
## [0.1.73]
|
||||
|
||||
### Added
|
||||
|
|
|
@ -47,10 +47,10 @@ Otherwise, while installing it will build the llama.ccp x86 version which will b
|
|||
`llama.cpp` supports multiple BLAS backends for faster processing.
|
||||
Use the `FORCE_CMAKE=1` environment variable to force the use of `cmake` and install the pip package for the desired BLAS backend.
|
||||
|
||||
To install with OpenBLAS, set the `LLAMA_OPENBLAS=1` environment variable before installing:
|
||||
To install with OpenBLAS, set the `LLAMA_BLAS and LLAMA_BLAS_VENDOR` environment variables before installing:
|
||||
|
||||
```bash
|
||||
CMAKE_ARGS="-DLLAMA_OPENBLAS=on" FORCE_CMAKE=1 pip install llama-cpp-python
|
||||
CMAKE_ARGS="-DLLAMA_BLAS=ON -DLLAMA_BLAS_VENDOR=OpenBLAS" FORCE_CMAKE=1 pip install llama-cpp-python
|
||||
```
|
||||
|
||||
To install with cuBLAS, set the `LLAMA_CUBLAS=1` environment variable before installing:
|
||||
|
|
|
@ -850,7 +850,7 @@ class Llama:
|
|||
|
||||
if len(prompt_tokens) >= llama_cpp.llama_n_ctx(self.ctx):
|
||||
raise ValueError(
|
||||
f"Requested tokens exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}"
|
||||
f"Requested tokens ({len(prompt_tokens)}) exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}"
|
||||
)
|
||||
|
||||
if max_tokens <= 0:
|
||||
|
@ -958,7 +958,7 @@ class Llama:
|
|||
token_end_position += len(self.detokenize([token]))
|
||||
# Check if stop sequence is in the token
|
||||
if token_end_position >= (
|
||||
remaining_length - first_stop_position - 1
|
||||
remaining_length - first_stop_position
|
||||
):
|
||||
break
|
||||
logprobs_or_none: Optional[CompletionLogprobs] = None
|
||||
|
|
|
@ -175,6 +175,7 @@ llama_progress_callback = ctypes.CFUNCTYPE(None, c_float, c_void_p)
|
|||
# // context pointer passed to the progress callback
|
||||
# void * progress_callback_user_data;
|
||||
|
||||
|
||||
# // Keep the booleans together to avoid misalignment during copy-by-value.
|
||||
# bool low_vram; // if true, reduce VRAM usage at the cost of performance
|
||||
# bool f16_kv; // use fp16 for KV cache
|
||||
|
@ -292,6 +293,15 @@ class llama_timings(Structure):
|
|||
]
|
||||
|
||||
|
||||
# LLAMA_API int llama_max_devices();
|
||||
def llama_max_devices() -> int:
|
||||
return _lib.llama_max_devices()
|
||||
|
||||
|
||||
_lib.llama_max_devices.argtypes = []
|
||||
_lib.llama_max_devices.restype = c_int
|
||||
|
||||
|
||||
# LLAMA_API struct llama_context_params llama_context_default_params();
|
||||
def llama_context_default_params() -> llama_context_params:
|
||||
return _lib.llama_context_default_params()
|
||||
|
@ -748,7 +758,12 @@ def llama_get_vocab(
|
|||
return _lib.llama_get_vocab(ctx, strings, scores, capacity)
|
||||
|
||||
|
||||
_lib.llama_get_vocab.argtypes = [llama_context_p, c_char_p, c_float, c_int]
|
||||
_lib.llama_get_vocab.argtypes = [
|
||||
llama_context_p,
|
||||
POINTER(c_char_p),
|
||||
POINTER(c_float),
|
||||
c_int,
|
||||
]
|
||||
_lib.llama_get_vocab.restype = c_int
|
||||
|
||||
|
||||
|
@ -766,6 +781,15 @@ def llama_get_vocab_from_model(
|
|||
return _lib.llama_get_vocab_from_model(model, strings, scores, capacity)
|
||||
|
||||
|
||||
_lib.llama_get_vocab_from_model.argtypes = [
|
||||
llama_model_p,
|
||||
POINTER(c_char_p),
|
||||
POINTER(c_float),
|
||||
c_int,
|
||||
]
|
||||
_lib.llama_get_vocab_from_model.restype = c_int
|
||||
|
||||
|
||||
# Token logits obtained from the last call to llama_eval()
|
||||
# The logits for the last token are stored in the last row
|
||||
# Can be mutated in order to change the probabilities of the next token
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
import json
|
||||
import multiprocessing
|
||||
from re import compile, Match, Pattern
|
||||
from threading import Lock
|
||||
from functools import partial
|
||||
from typing import Iterator, List, Optional, Union, Dict
|
||||
from typing import Callable, Coroutine, Iterator, List, Optional, Tuple, Union, Dict
|
||||
from typing_extensions import TypedDict, Literal
|
||||
|
||||
import llama_cpp
|
||||
|
@ -10,8 +11,10 @@ import llama_cpp
|
|||
import anyio
|
||||
from anyio.streams.memory import MemoryObjectSendStream
|
||||
from starlette.concurrency import run_in_threadpool, iterate_in_threadpool
|
||||
from fastapi import Depends, FastAPI, APIRouter, Request
|
||||
from fastapi import Depends, FastAPI, APIRouter, Request, Response
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.routing import APIRoute
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic_settings import BaseSettings
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
@ -99,7 +102,190 @@ class Settings(BaseSettings):
|
|||
)
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
class ErrorResponse(TypedDict):
|
||||
"""OpenAI style error response"""
|
||||
|
||||
message: str
|
||||
type: str
|
||||
param: Optional[str]
|
||||
code: Optional[str]
|
||||
|
||||
|
||||
class ErrorResponseFormatters:
|
||||
"""Collection of formatters for error responses.
|
||||
|
||||
Args:
|
||||
request (Union[CreateCompletionRequest, CreateChatCompletionRequest]):
|
||||
Request body
|
||||
match (Match[str]): Match object from regex pattern
|
||||
|
||||
Returns:
|
||||
Tuple[int, ErrorResponse]: Status code and error response
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def context_length_exceeded(
|
||||
request: Union[
|
||||
"CreateCompletionRequest", "CreateChatCompletionRequest"
|
||||
],
|
||||
match, # type: Match[str] # type: ignore
|
||||
) -> Tuple[int, ErrorResponse]:
|
||||
"""Formatter for context length exceeded error"""
|
||||
|
||||
context_window = int(match.group(2))
|
||||
prompt_tokens = int(match.group(1))
|
||||
completion_tokens = request.max_tokens
|
||||
if hasattr(request, "messages"):
|
||||
# Chat completion
|
||||
message = (
|
||||
"This model's maximum context length is {} tokens. "
|
||||
"However, you requested {} tokens "
|
||||
"({} in the messages, {} in the completion). "
|
||||
"Please reduce the length of the messages or completion."
|
||||
)
|
||||
else:
|
||||
# Text completion
|
||||
message = (
|
||||
"This model's maximum context length is {} tokens, "
|
||||
"however you requested {} tokens "
|
||||
"({} in your prompt; {} for the completion). "
|
||||
"Please reduce your prompt; or completion length."
|
||||
)
|
||||
return 400, ErrorResponse(
|
||||
message=message.format(
|
||||
context_window,
|
||||
completion_tokens + prompt_tokens,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
),
|
||||
type="invalid_request_error",
|
||||
param="messages",
|
||||
code="context_length_exceeded",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def model_not_found(
|
||||
request: Union[
|
||||
"CreateCompletionRequest", "CreateChatCompletionRequest"
|
||||
],
|
||||
match # type: Match[str] # type: ignore
|
||||
) -> Tuple[int, ErrorResponse]:
|
||||
"""Formatter for model_not_found error"""
|
||||
|
||||
model_path = str(match.group(1))
|
||||
message = f"The model `{model_path}` does not exist"
|
||||
return 400, ErrorResponse(
|
||||
message=message,
|
||||
type="invalid_request_error",
|
||||
param=None,
|
||||
code="model_not_found",
|
||||
)
|
||||
|
||||
|
||||
class RouteErrorHandler(APIRoute):
|
||||
"""Custom APIRoute that handles application errors and exceptions"""
|
||||
|
||||
# key: regex pattern for original error message from llama_cpp
|
||||
# value: formatter function
|
||||
pattern_and_formatters: Dict[
|
||||
"Pattern",
|
||||
Callable[
|
||||
[
|
||||
Union["CreateCompletionRequest", "CreateChatCompletionRequest"],
|
||||
"Match[str]",
|
||||
],
|
||||
Tuple[int, ErrorResponse],
|
||||
],
|
||||
] = {
|
||||
compile(
|
||||
r"Requested tokens \((\d+)\) exceed context window of (\d+)"
|
||||
): ErrorResponseFormatters.context_length_exceeded,
|
||||
compile(
|
||||
r"Model path does not exist: (.+)"
|
||||
): ErrorResponseFormatters.model_not_found,
|
||||
}
|
||||
|
||||
def error_message_wrapper(
|
||||
self,
|
||||
error: Exception,
|
||||
body: Optional[
|
||||
Union[
|
||||
"CreateChatCompletionRequest",
|
||||
"CreateCompletionRequest",
|
||||
"CreateEmbeddingRequest",
|
||||
]
|
||||
] = None,
|
||||
) -> Tuple[int, ErrorResponse]:
|
||||
"""Wraps error message in OpenAI style error response"""
|
||||
|
||||
if body is not None and isinstance(
|
||||
body,
|
||||
(
|
||||
CreateCompletionRequest,
|
||||
CreateChatCompletionRequest,
|
||||
),
|
||||
):
|
||||
# When text completion or chat completion
|
||||
for pattern, callback in self.pattern_and_formatters.items():
|
||||
match = pattern.search(str(error))
|
||||
if match is not None:
|
||||
return callback(body, match)
|
||||
|
||||
# Wrap other errors as internal server error
|
||||
return 500, ErrorResponse(
|
||||
message=str(error),
|
||||
type="internal_server_error",
|
||||
param=None,
|
||||
code=None,
|
||||
)
|
||||
|
||||
def get_route_handler(
|
||||
self,
|
||||
) -> Callable[[Request], Coroutine[None, None, Response]]:
|
||||
"""Defines custom route handler that catches exceptions and formats
|
||||
in OpenAI style error response"""
|
||||
|
||||
original_route_handler = super().get_route_handler()
|
||||
|
||||
async def custom_route_handler(request: Request) -> Response:
|
||||
try:
|
||||
return await original_route_handler(request)
|
||||
except Exception as exc:
|
||||
json_body = await request.json()
|
||||
try:
|
||||
if "messages" in json_body:
|
||||
# Chat completion
|
||||
body: Optional[
|
||||
Union[
|
||||
CreateChatCompletionRequest,
|
||||
CreateCompletionRequest,
|
||||
CreateEmbeddingRequest,
|
||||
]
|
||||
] = CreateChatCompletionRequest(**json_body)
|
||||
elif "prompt" in json_body:
|
||||
# Text completion
|
||||
body = CreateCompletionRequest(**json_body)
|
||||
else:
|
||||
# Embedding
|
||||
body = CreateEmbeddingRequest(**json_body)
|
||||
except Exception:
|
||||
# Invalid request body
|
||||
body = None
|
||||
|
||||
# Get proper error message from the exception
|
||||
(
|
||||
status_code,
|
||||
error_message,
|
||||
) = self.error_message_wrapper(error=exc, body=body)
|
||||
return JSONResponse(
|
||||
{"error": error_message},
|
||||
status_code=status_code,
|
||||
)
|
||||
|
||||
return custom_route_handler
|
||||
|
||||
|
||||
router = APIRouter(route_class=RouteErrorHandler)
|
||||
|
||||
settings: Optional[Settings] = None
|
||||
llama: Optional[llama_cpp.Llama] = None
|
||||
|
@ -188,12 +374,33 @@ def get_settings():
|
|||
yield settings
|
||||
|
||||
|
||||
model_field = Field(
|
||||
description="The model to use for generating completions.", default=None
|
||||
)
|
||||
async def get_event_publisher(
|
||||
request: Request,
|
||||
inner_send_chan: MemoryObjectSendStream,
|
||||
iterator: Iterator,
|
||||
):
|
||||
async with inner_send_chan:
|
||||
try:
|
||||
async for chunk in iterate_in_threadpool(iterator):
|
||||
await inner_send_chan.send(dict(data=json.dumps(chunk)))
|
||||
if await request.is_disconnected():
|
||||
raise anyio.get_cancelled_exc_class()()
|
||||
if settings.interrupt_requests and llama_outer_lock.locked():
|
||||
await inner_send_chan.send(dict(data="[DONE]"))
|
||||
raise anyio.get_cancelled_exc_class()()
|
||||
await inner_send_chan.send(dict(data="[DONE]"))
|
||||
except anyio.get_cancelled_exc_class() as e:
|
||||
print("disconnected")
|
||||
with anyio.move_on_after(1, shield=True):
|
||||
print(
|
||||
f"Disconnected from client (via refresh/close) {request.client}"
|
||||
)
|
||||
raise e
|
||||
|
||||
model_field = Field(description="The model to use for generating completions.", default=None)
|
||||
|
||||
max_tokens_field = Field(
|
||||
default=16, ge=1, le=2048, description="The maximum number of tokens to generate."
|
||||
default=16, ge=1, description="The maximum number of tokens to generate."
|
||||
)
|
||||
|
||||
temperature_field = Field(
|
||||
|
@ -383,35 +590,31 @@ async def create_completion(
|
|||
]
|
||||
)
|
||||
|
||||
if body.stream:
|
||||
iterator_or_completion: Union[llama_cpp.Completion, Iterator[
|
||||
llama_cpp.CompletionChunk
|
||||
]] = await run_in_threadpool(llama, **kwargs)
|
||||
|
||||
if isinstance(iterator_or_completion, Iterator):
|
||||
# EAFP: It's easier to ask for forgiveness than permission
|
||||
first_response = await run_in_threadpool(next, iterator_or_completion)
|
||||
|
||||
# If no exception was raised from first_response, we can assume that
|
||||
# the iterator is valid and we can use it to stream the response.
|
||||
def iterator() -> Iterator[llama_cpp.CompletionChunk]:
|
||||
yield first_response
|
||||
yield from iterator_or_completion
|
||||
|
||||
send_chan, recv_chan = anyio.create_memory_object_stream(10)
|
||||
|
||||
async def event_publisher(inner_send_chan: MemoryObjectSendStream):
|
||||
async with inner_send_chan:
|
||||
try:
|
||||
iterator: Iterator[llama_cpp.CompletionChunk] = await run_in_threadpool(llama, **kwargs) # type: ignore
|
||||
async for chunk in iterate_in_threadpool(iterator):
|
||||
await inner_send_chan.send(dict(data=json.dumps(chunk)))
|
||||
if await request.is_disconnected():
|
||||
raise anyio.get_cancelled_exc_class()()
|
||||
if settings.interrupt_requests and llama_outer_lock.locked():
|
||||
await inner_send_chan.send(dict(data="[DONE]"))
|
||||
raise anyio.get_cancelled_exc_class()()
|
||||
await inner_send_chan.send(dict(data="[DONE]"))
|
||||
except anyio.get_cancelled_exc_class() as e:
|
||||
print("disconnected")
|
||||
with anyio.move_on_after(1, shield=True):
|
||||
print(
|
||||
f"Disconnected from client (via refresh/close) {request.client}"
|
||||
)
|
||||
raise e
|
||||
|
||||
return EventSourceResponse(
|
||||
recv_chan, data_sender_callable=partial(event_publisher, send_chan)
|
||||
) # type: ignore
|
||||
recv_chan, data_sender_callable=partial( # type: ignore
|
||||
get_event_publisher,
|
||||
request=request,
|
||||
inner_send_chan=send_chan,
|
||||
iterator=iterator(),
|
||||
)
|
||||
)
|
||||
else:
|
||||
completion: llama_cpp.Completion = await run_in_threadpool(llama, **kwargs) # type: ignore
|
||||
return completion
|
||||
return iterator_or_completion
|
||||
|
||||
|
||||
class CreateEmbeddingRequest(BaseModel):
|
||||
|
@ -524,38 +727,31 @@ async def create_chat_completion(
|
|||
]
|
||||
)
|
||||
|
||||
if body.stream:
|
||||
iterator_or_completion: Union[llama_cpp.ChatCompletion, Iterator[
|
||||
llama_cpp.ChatCompletionChunk
|
||||
]] = await run_in_threadpool(llama.create_chat_completion, **kwargs)
|
||||
|
||||
if isinstance(iterator_or_completion, Iterator):
|
||||
# EAFP: It's easier to ask for forgiveness than permission
|
||||
first_response = await run_in_threadpool(next, iterator_or_completion)
|
||||
|
||||
# If no exception was raised from first_response, we can assume that
|
||||
# the iterator is valid and we can use it to stream the response.
|
||||
def iterator() -> Iterator[llama_cpp.ChatCompletionChunk]:
|
||||
yield first_response
|
||||
yield from iterator_or_completion
|
||||
|
||||
send_chan, recv_chan = anyio.create_memory_object_stream(10)
|
||||
|
||||
async def event_publisher(inner_send_chan: MemoryObjectSendStream):
|
||||
async with inner_send_chan:
|
||||
try:
|
||||
iterator: Iterator[llama_cpp.ChatCompletionChunk] = await run_in_threadpool(llama.create_chat_completion, **kwargs) # type: ignore
|
||||
async for chat_chunk in iterate_in_threadpool(iterator):
|
||||
await inner_send_chan.send(dict(data=json.dumps(chat_chunk)))
|
||||
if await request.is_disconnected():
|
||||
raise anyio.get_cancelled_exc_class()()
|
||||
if settings.interrupt_requests and llama_outer_lock.locked():
|
||||
await inner_send_chan.send(dict(data="[DONE]"))
|
||||
raise anyio.get_cancelled_exc_class()()
|
||||
await inner_send_chan.send(dict(data="[DONE]"))
|
||||
except anyio.get_cancelled_exc_class() as e:
|
||||
print("disconnected")
|
||||
with anyio.move_on_after(1, shield=True):
|
||||
print(
|
||||
f"Disconnected from client (via refresh/close) {request.client}"
|
||||
)
|
||||
raise e
|
||||
|
||||
return EventSourceResponse(
|
||||
recv_chan,
|
||||
data_sender_callable=partial(event_publisher, send_chan),
|
||||
) # type: ignore
|
||||
else:
|
||||
completion: llama_cpp.ChatCompletion = await run_in_threadpool(
|
||||
llama.create_chat_completion, **kwargs # type: ignore
|
||||
recv_chan, data_sender_callable=partial( # type: ignore
|
||||
get_event_publisher,
|
||||
request=request,
|
||||
inner_send_chan=send_chan,
|
||||
iterator=iterator(),
|
||||
)
|
||||
)
|
||||
return completion
|
||||
else:
|
||||
return iterator_or_completion
|
||||
|
||||
|
||||
class ModelData(TypedDict):
|
||||
|
|
|
@ -4,7 +4,7 @@ build-backend = "scikit_build_core.build"
|
|||
|
||||
[project]
|
||||
name = "llama_cpp_python"
|
||||
version = "0.1.73"
|
||||
version = "0.1.74"
|
||||
description = "Python bindings for the llama.cpp library"
|
||||
readme = "README.md"
|
||||
license = { text = "MIT" }
|
||||
|
|
2
vendor/llama.cpp
vendored
2
vendor/llama.cpp
vendored
|
@ -1 +1 @@
|
|||
Subproject commit d01bccde9f759b24449fdaa16306b406a50eb367
|
||||
Subproject commit e782c9e735f93ab4767ffc37462c523b73a17ddc
|
Loading…
Reference in a new issue