Fix issue with Literal and Optional cli arguments not working. Closes #702

This commit is contained in:
Andrei Betlen 2023-09-13 18:06:12 -04:00
parent 6cfc54284b
commit 759405c84b

View file

@ -23,11 +23,27 @@ Then visit http://localhost:8000/docs to see the interactive API docs.
"""
import os
import argparse
from typing import Literal, Union
import uvicorn
from llama_cpp.server.app import create_app, Settings
def get_non_none_base_types(annotation):
if not hasattr(annotation, "__args__"):
return annotation
return [arg for arg in annotation.__args__ if arg is not type(None)][0]
def get_base_type(annotation):
if getattr(annotation, '__origin__', None) is Literal:
return type(annotation.__args__[0])
elif getattr(annotation, '__origin__', None) is Union:
non_optional_args = [arg for arg in annotation.__args__ if arg is not type(None)]
if non_optional_args:
return get_base_type(non_optional_args[0])
else:
return annotation
if __name__ == "__main__":
parser = argparse.ArgumentParser()
for name, field in Settings.model_fields.items():
@ -37,7 +53,7 @@ if __name__ == "__main__":
parser.add_argument(
f"--{name}",
dest=name,
type=field.annotation if field.annotation is not None else str,
type=get_base_type(field.annotation) if field.annotation is not None else str,
help=description,
)