feat(server): Provide ability to dynamically allocate all threads if desired using -1 (#1364)

This commit is contained in:
Sean Bailey 2024-04-23 02:35:38 -04:00 committed by GitHub
parent 507c1da066
commit 53ebcc8bb5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -3,7 +3,7 @@ from __future__ import annotations
import multiprocessing import multiprocessing
from typing import Optional, List, Literal, Union from typing import Optional, List, Literal, Union
from pydantic import Field from pydantic import Field, root_validator
from pydantic_settings import BaseSettings from pydantic_settings import BaseSettings
import llama_cpp import llama_cpp
@ -67,12 +67,12 @@ class ModelSettings(BaseSettings):
n_threads: int = Field( n_threads: int = Field(
default=max(multiprocessing.cpu_count() // 2, 1), default=max(multiprocessing.cpu_count() // 2, 1),
ge=1, ge=1,
description="The number of threads to use.", description="The number of threads to use. Use -1 for max cpu threads",
) )
n_threads_batch: int = Field( n_threads_batch: int = Field(
default=max(multiprocessing.cpu_count(), 1), default=max(multiprocessing.cpu_count(), 1),
ge=0, ge=0,
description="The number of threads to use when batch processing.", description="The number of threads to use when batch processing. Use -1 for max cpu threads",
) )
rope_scaling_type: int = Field( rope_scaling_type: int = Field(
default=llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED default=llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED
@ -173,6 +173,16 @@ class ModelSettings(BaseSettings):
default=True, description="Whether to print debug information." default=True, description="Whether to print debug information."
) )
@root_validator(pre=True) # pre=True to ensure this runs before any other validation
def set_dynamic_defaults(cls, values):
# If n_threads or n_threads_batch is -1, set it to multiprocessing.cpu_count()
cpu_count = multiprocessing.cpu_count()
if values.get('n_threads', 0) == -1:
values['n_threads'] = cpu_count
if values.get('n_threads_batch', 0) == -1:
values['n_threads_batch'] = cpu_count
return values
class ServerSettings(BaseSettings): class ServerSettings(BaseSettings):
"""Server settings used to configure the FastAPI and Uvicorn server.""" """Server settings used to configure the FastAPI and Uvicorn server."""