diff --git a/llama_cpp/server/settings.py b/llama_cpp/server/settings.py index 934aecd..eab5a8a 100644 --- a/llama_cpp/server/settings.py +++ b/llama_cpp/server/settings.py @@ -3,7 +3,7 @@ from __future__ import annotations import multiprocessing from typing import Optional, List, Literal, Union -from pydantic import Field +from pydantic import Field, root_validator from pydantic_settings import BaseSettings import llama_cpp @@ -67,12 +67,12 @@ class ModelSettings(BaseSettings): n_threads: int = Field( default=max(multiprocessing.cpu_count() // 2, 1), ge=1, - description="The number of threads to use.", + description="The number of threads to use. Use -1 for max cpu threads", ) n_threads_batch: int = Field( default=max(multiprocessing.cpu_count(), 1), ge=0, - description="The number of threads to use when batch processing.", + description="The number of threads to use when batch processing. Use -1 for max cpu threads", ) rope_scaling_type: int = Field( default=llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED @@ -173,6 +173,16 @@ class ModelSettings(BaseSettings): default=True, description="Whether to print debug information." ) + @root_validator(pre=True) # pre=True to ensure this runs before any other validation + def set_dynamic_defaults(cls, values): + # If n_threads or n_threads_batch is -1, set it to multiprocessing.cpu_count() + cpu_count = multiprocessing.cpu_count() + if values.get('n_threads', 0) == -1: + values['n_threads'] = cpu_count + if values.get('n_threads_batch', 0) == -1: + values['n_threads_batch'] = cpu_count + return values + class ServerSettings(BaseSettings): """Server settings used to configure the FastAPI and Uvicorn server."""