pass flags to serve to allow setting allowed-origins + host and port

* resolves: https://github.com/jmorganca/ollama/issues/300 and
https://github.com/jmorganca/ollama/issues/282

* example usage:
```
ollama serve --port 9999 --allowed-origins "http://foo.example.com,http://192.0.0.1"
```
This commit is contained in:
cmiller01 2023-08-07 03:34:37 +00:00
parent 06fc48ad66
commit fb593b7bfc
2 changed files with 26 additions and 10 deletions

View file

@ -513,23 +513,33 @@ func generateBatch(cmd *cobra.Command, model string) error {
return nil return nil
} }
func RunServer(_ *cobra.Command, _ []string) error { func RunServer(cmd *cobra.Command, _ []string) error {
host := os.Getenv("OLLAMA_HOST") host, err := cmd.Flags().GetString("host")
if host == "" { if err != nil {
host = "127.0.0.1" return errors.New("host unset")
}
if os.Getenv("OLLAMA_HOST") != "" {
host = os.Getenv("OLLAMA_HOST")
}
port, err := cmd.Flags().GetString("port")
if err != nil {
return errors.New("port unset")
} }
port := os.Getenv("OLLAMA_PORT") if os.Getenv("OLLAMA_PORT") != "" {
if port == "" { port = os.Getenv("OLLAMA_PORT")
port = "11434"
} }
ln, err := net.Listen("tcp", fmt.Sprintf("%s:%s", host, port)) ln, err := net.Listen("tcp", fmt.Sprintf("%s:%s", host, port))
if err != nil { if err != nil {
return err return err
} }
extraOrigins, err := cmd.Flags().GetStringSlice("allowed-origins")
if err != nil {
return err
}
return server.Serve(ln) return server.Serve(ln, extraOrigins)
} }
func startMacApp(client *api.Client) error { func startMacApp(client *api.Client) error {
@ -621,6 +631,10 @@ func NewCLI() *cobra.Command {
RunE: RunServer, RunE: RunServer,
} }
serveCmd.Flags().String("port", "11434", "Port to listen on, may also use OLLAMA_PORT environment variable")
serveCmd.Flags().String("host", "127.0.0.1", "Host listen address, may also use OLLAMA_HOST environment variable")
serveCmd.Flags().StringSlice("allowed-origins", []string{}, "Additional allowed CORS origins (outside of localhost), specify as comma-separated list")
pullCmd := &cobra.Command{ pullCmd := &cobra.Command{
Use: "pull MODEL", Use: "pull MODEL",
Short: "Pull a model from a registry", Short: "Pull a model from a registry",

View file

@ -301,11 +301,11 @@ func CopyModelHandler(c *gin.Context) {
} }
} }
func Serve(ln net.Listener) error { func Serve(ln net.Listener, extraOrigins []string) error {
config := cors.DefaultConfig() config := cors.DefaultConfig()
config.AllowWildcard = true config.AllowWildcard = true
// only allow http/https from localhost // only allow http/https from localhost
config.AllowOrigins = []string{ allowedOrigins := []string{
"http://localhost", "http://localhost",
"http://localhost:*", "http://localhost:*",
"https://localhost", "https://localhost",
@ -315,6 +315,8 @@ func Serve(ln net.Listener) error {
"https://127.0.0.1", "https://127.0.0.1",
"https://127.0.0.1:*", "https://127.0.0.1:*",
} }
allowedOrigins = append(allowedOrigins, extraOrigins...)
config.AllowOrigins = allowedOrigins
r := gin.Default() r := gin.Default()
r.Use(cors.New(config)) r.Use(cors.New(config))