feat(server): Provide ability to dynamically allocate all threads if desired using -1
(#1364)
This commit is contained in:
parent
507c1da066
commit
53ebcc8bb5
1 changed files with 13 additions and 3 deletions
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
|
|
||||||
from typing import Optional, List, Literal, Union
|
from typing import Optional, List, Literal, Union
|
||||||
from pydantic import Field
|
from pydantic import Field, root_validator
|
||||||
from pydantic_settings import BaseSettings
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
import llama_cpp
|
import llama_cpp
|
||||||
|
@ -67,12 +67,12 @@ class ModelSettings(BaseSettings):
|
||||||
n_threads: int = Field(
|
n_threads: int = Field(
|
||||||
default=max(multiprocessing.cpu_count() // 2, 1),
|
default=max(multiprocessing.cpu_count() // 2, 1),
|
||||||
ge=1,
|
ge=1,
|
||||||
description="The number of threads to use.",
|
description="The number of threads to use. Use -1 for max cpu threads",
|
||||||
)
|
)
|
||||||
n_threads_batch: int = Field(
|
n_threads_batch: int = Field(
|
||||||
default=max(multiprocessing.cpu_count(), 1),
|
default=max(multiprocessing.cpu_count(), 1),
|
||||||
ge=0,
|
ge=0,
|
||||||
description="The number of threads to use when batch processing.",
|
description="The number of threads to use when batch processing. Use -1 for max cpu threads",
|
||||||
)
|
)
|
||||||
rope_scaling_type: int = Field(
|
rope_scaling_type: int = Field(
|
||||||
default=llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED
|
default=llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED
|
||||||
|
@ -173,6 +173,16 @@ class ModelSettings(BaseSettings):
|
||||||
default=True, description="Whether to print debug information."
|
default=True, description="Whether to print debug information."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@root_validator(pre=True) # pre=True to ensure this runs before any other validation
|
||||||
|
def set_dynamic_defaults(cls, values):
|
||||||
|
# If n_threads or n_threads_batch is -1, set it to multiprocessing.cpu_count()
|
||||||
|
cpu_count = multiprocessing.cpu_count()
|
||||||
|
if values.get('n_threads', 0) == -1:
|
||||||
|
values['n_threads'] = cpu_count
|
||||||
|
if values.get('n_threads_batch', 0) == -1:
|
||||||
|
values['n_threads_batch'] = cpu_count
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
class ServerSettings(BaseSettings):
|
class ServerSettings(BaseSettings):
|
||||||
"""Server settings used to configure the FastAPI and Uvicorn server."""
|
"""Server settings used to configure the FastAPI and Uvicorn server."""
|
||||||
|
|
Loading…
Reference in a new issue