Fix issue with Literal and Optional cli arguments not working. Closes #702
This commit is contained in:
parent
6cfc54284b
commit
759405c84b
1 changed files with 17 additions and 1 deletions
|
@ -23,11 +23,27 @@ Then visit http://localhost:8000/docs to see the interactive API docs.
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
import argparse
|
import argparse
|
||||||
|
from typing import Literal, Union
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
from llama_cpp.server.app import create_app, Settings
|
from llama_cpp.server.app import create_app, Settings
|
||||||
|
|
||||||
|
def get_non_none_base_types(annotation):
|
||||||
|
if not hasattr(annotation, "__args__"):
|
||||||
|
return annotation
|
||||||
|
return [arg for arg in annotation.__args__ if arg is not type(None)][0]
|
||||||
|
|
||||||
|
def get_base_type(annotation):
|
||||||
|
if getattr(annotation, '__origin__', None) is Literal:
|
||||||
|
return type(annotation.__args__[0])
|
||||||
|
elif getattr(annotation, '__origin__', None) is Union:
|
||||||
|
non_optional_args = [arg for arg in annotation.__args__ if arg is not type(None)]
|
||||||
|
if non_optional_args:
|
||||||
|
return get_base_type(non_optional_args[0])
|
||||||
|
else:
|
||||||
|
return annotation
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
for name, field in Settings.model_fields.items():
|
for name, field in Settings.model_fields.items():
|
||||||
|
@ -37,7 +53,7 @@ if __name__ == "__main__":
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
f"--{name}",
|
f"--{name}",
|
||||||
dest=name,
|
dest=name,
|
||||||
type=field.annotation if field.annotation is not None else str,
|
type=get_base_type(field.annotation) if field.annotation is not None else str,
|
||||||
help=description,
|
help=description,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue