diff --git a/llama_cpp/_internals.py b/llama_cpp/_internals.py index 208de8c..ec47c42 100644 --- a/llama_cpp/_internals.py +++ b/llama_cpp/_internals.py @@ -204,6 +204,31 @@ class _LlamaModel: output[1:] if len(tokens) > 0 and tokens[0] == self.token_bos() else output ) + # Extra + def metadata(self) -> Dict[str, str]: + assert self.model is not None + metadata: Dict[str, str] = {} + buffer_size = 1024 + buffer = ctypes.create_string_buffer(buffer_size) + # zero the buffer + buffer.value = b'\0' * buffer_size + # iterate over model keys + for i in range(llama_cpp.llama_model_meta_count(self.model)): + nbytes = llama_cpp.llama_model_meta_key_by_index(self.model, i, buffer, buffer_size) + if nbytes > buffer_size: + buffer_size = nbytes + buffer = ctypes.create_string_buffer(buffer_size) + nbytes = llama_cpp.llama_model_meta_key_by_index(self.model, i, buffer, buffer_size) + key = buffer.value.decode("utf-8") + nbytes = llama_cpp.llama_model_meta_val_str_by_index(self.model, i, buffer, buffer_size) + if nbytes > buffer_size: + buffer_size = nbytes + buffer = ctypes.create_string_buffer(buffer_size) + nbytes = llama_cpp.llama_model_meta_val_str_by_index(self.model, i, buffer, buffer_size) + value = buffer.value.decode("utf-8") + metadata[key] = value + return metadata + @staticmethod def default_params(): """Get the default llama_model_params.""" diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 32eb3fe..5c66bcf 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -331,6 +331,16 @@ class Llama: self._mirostat_mu = ctypes.c_float(2.0 * 5.0) # TODO: Move this to sampling context + try: + self.metadata = self._model.metadata() + except Exception as e: + self.metadata = {} + if self.verbose: + print(f"Failed to load metadata: {e}", file=sys.stderr) + + if self.verbose: + print(f"Model metadata: {self.metadata}", file=sys.stderr) + @property def ctx(self) -> llama_cpp.llama_context_p: assert self._ctx.ctx is not None