Merge branch 'main' of github.com:abetlen/llama_cpp_python into main
This commit is contained in:
commit
690c563b60
6 changed files with 54 additions and 17 deletions
|
@ -207,7 +207,8 @@ The gguf-converted files for this model can be found here: [functionary-7b-v1](h
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
"content": "A chat between a curious user and an artificial intelligence assitant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant callse functions with appropriate input when necessary"
|
"content": "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary"
|
||||||
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
|
@ -265,7 +266,8 @@ Then you'll need to use a custom chat handler to load the clip model and process
|
||||||
>>> llm = Llama(
|
>>> llm = Llama(
|
||||||
model_path="./path/to/llava/llama-model.gguf",
|
model_path="./path/to/llava/llama-model.gguf",
|
||||||
chat_handler=chat_handler,
|
chat_handler=chat_handler,
|
||||||
n_ctx=2048 # n_ctx should be increased to accomodate the image embedding
|
n_ctx=2048, # n_ctx should be increased to accomodate the image embedding
|
||||||
|
logits_all=True,# needed to make llava work
|
||||||
)
|
)
|
||||||
>>> llm.create_chat_completion(
|
>>> llm.create_chat_completion(
|
||||||
messages = [
|
messages = [
|
||||||
|
|
|
@ -17,14 +17,18 @@ class suppress_stdout_stderr(object):
|
||||||
if self.disable:
|
if self.disable:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
# Check if sys.stdout and sys.stderr have fileno method
|
||||||
|
if not hasattr(self.sys.stdout, 'fileno') or not hasattr(self.sys.stderr, 'fileno'):
|
||||||
|
return self # Return the instance without making changes
|
||||||
|
|
||||||
self.outnull_file = self.open(self.os.devnull, "w")
|
self.outnull_file = self.open(self.os.devnull, "w")
|
||||||
self.errnull_file = self.open(self.os.devnull, "w")
|
self.errnull_file = self.open(self.os.devnull, "w")
|
||||||
|
|
||||||
self.old_stdout_fileno_undup = self.sys.stdout.fileno()
|
self.old_stdout_fileno_undup = self.sys.stdout.fileno()
|
||||||
self.old_stderr_fileno_undup = self.sys.stderr.fileno()
|
self.old_stderr_fileno_undup = self.sys.stderr.fileno()
|
||||||
|
|
||||||
self.old_stdout_fileno = self.os.dup(self.sys.stdout.fileno())
|
self.old_stdout_fileno = self.os.dup(self.old_stdout_fileno_undup)
|
||||||
self.old_stderr_fileno = self.os.dup(self.sys.stderr.fileno())
|
self.old_stderr_fileno = self.os.dup(self.old_stderr_fileno_undup)
|
||||||
|
|
||||||
self.old_stdout = self.sys.stdout
|
self.old_stdout = self.sys.stdout
|
||||||
self.old_stderr = self.sys.stderr
|
self.old_stderr = self.sys.stderr
|
||||||
|
@ -40,6 +44,8 @@ class suppress_stdout_stderr(object):
|
||||||
if self.disable:
|
if self.disable:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Check if sys.stdout and sys.stderr have fileno method
|
||||||
|
if hasattr(self.sys.stdout, 'fileno') and hasattr(self.sys.stderr, 'fileno'):
|
||||||
self.sys.stdout = self.old_stdout
|
self.sys.stdout = self.old_stdout
|
||||||
self.sys.stderr = self.old_stderr
|
self.sys.stderr = self.old_stderr
|
||||||
|
|
||||||
|
|
|
@ -2280,10 +2280,14 @@ class Llama:
|
||||||
return self._model.token_nl()
|
return self._model.token_nl()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def logits_to_logprobs(logits: List[float]) -> List[float]:
|
def logits_to_logprobs(logits: npt.NDArray[np.single]) -> npt.NDArray[np.single]:
|
||||||
exps = [math.exp(float(x)) for x in logits]
|
maximum = np.max(logits)
|
||||||
sum_exps = sum(exps)
|
tmp = np.subtract(logits, maximum, dtype=np.single)
|
||||||
return [math.log(x / sum_exps) for x in exps]
|
np.exp(tmp, out=tmp)
|
||||||
|
normalizer = 1.0 / np.sum(tmp)
|
||||||
|
np.multiply(normalizer, tmp, out=tmp)
|
||||||
|
np.log(tmp, out=tmp)
|
||||||
|
return tmp
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def longest_token_prefix(a: Sequence[int], b: Sequence[int]):
|
def longest_token_prefix(a: Sequence[int], b: Sequence[int]):
|
||||||
|
|
|
@ -637,6 +637,23 @@ def format_zephyr(
|
||||||
_prompt = _format_chatml(system_message, _messages, _sep)
|
_prompt = _format_chatml(system_message, _messages, _sep)
|
||||||
return ChatFormatterResponse(prompt=_prompt, stop=_sep)
|
return ChatFormatterResponse(prompt=_prompt, stop=_sep)
|
||||||
|
|
||||||
|
|
||||||
|
@register_chat_format("pygmalion")
|
||||||
|
def format_pygmalion(
|
||||||
|
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatFormatterResponse:
|
||||||
|
system_template = """<|system|>{system_message}"""
|
||||||
|
system_message = _get_system_message(messages)
|
||||||
|
system_message = system_template.format(system_message=system_message)
|
||||||
|
_roles = dict(user="<|user|>", assistant="<|model|>")
|
||||||
|
_sep = "\n"
|
||||||
|
_messages = _map_roles(messages, _roles)
|
||||||
|
_messages.append((_roles["assistant"], None))
|
||||||
|
_prompt = _format_chatml(system_message, _messages, _sep)
|
||||||
|
return ChatFormatterResponse(prompt=_prompt, stop=_sep)
|
||||||
|
|
||||||
|
|
||||||
@register_chat_format("chatml")
|
@register_chat_format("chatml")
|
||||||
def format_chatml(
|
def format_chatml(
|
||||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||||
|
|
|
@ -96,5 +96,6 @@ if __name__ == "__main__":
|
||||||
app = create_app(settings=settings)
|
app = create_app(settings=settings)
|
||||||
|
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
app, host=os.getenv("HOST", settings.host), port=int(os.getenv("PORT", settings.port))
|
app, host=os.getenv("HOST", settings.host), port=int(os.getenv("PORT", settings.port)),
|
||||||
|
ssl_keyfile=settings.ssl_keyfile, ssl_certfile=settings.ssl_certfile
|
||||||
)
|
)
|
||||||
|
|
|
@ -150,6 +150,13 @@ class Settings(BaseSettings):
|
||||||
# Server Params
|
# Server Params
|
||||||
host: str = Field(default="localhost", description="Listen address")
|
host: str = Field(default="localhost", description="Listen address")
|
||||||
port: int = Field(default=8000, description="Listen port")
|
port: int = Field(default=8000, description="Listen port")
|
||||||
|
# SSL Params
|
||||||
|
ssl_keyfile: Optional[str] = Field(
|
||||||
|
default=None, description="SSL key file for HTTPS"
|
||||||
|
)
|
||||||
|
ssl_certfile: Optional[str] = Field(
|
||||||
|
default=None, description="SSL certificate file for HTTPS"
|
||||||
|
)
|
||||||
interrupt_requests: bool = Field(
|
interrupt_requests: bool = Field(
|
||||||
default=True,
|
default=True,
|
||||||
description="Whether to interrupt requests when a new request is received.",
|
description="Whether to interrupt requests when a new request is received.",
|
||||||
|
|
Loading…
Reference in a new issue