diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 466dc22..dfac9bb 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -262,9 +262,7 @@ class Llama: self.n_batch = min(n_ctx, n_batch) # ??? self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1) - self.n_threads_batch = n_threads_batch or max( - multiprocessing.cpu_count() // 2, 1 - ) + self.n_threads_batch = n_threads_batch or multiprocessing.cpu_count() # Context Params self.context_params = llama_cpp.llama_context_default_params() diff --git a/llama_cpp/server/settings.py b/llama_cpp/server/settings.py index 9ebdd0d..811c6ca 100644 --- a/llama_cpp/server/settings.py +++ b/llama_cpp/server/settings.py @@ -70,7 +70,7 @@ class ModelSettings(BaseSettings): description="The number of threads to use.", ) n_threads_batch: int = Field( - default=max(multiprocessing.cpu_count() // 2, 1), + default=max(multiprocessing.cpu_count(), 1), ge=0, description="The number of threads to use when batch processing.", )