ollama/cmd/cmd.go

127 lines
2.5 KiB
Go
Raw Normal View History

package cmd
import (
"context"
2023-07-06 16:24:49 +00:00
"fmt"
"log"
"net"
"os"
"path"
2023-07-06 18:18:40 +00:00
"sync"
2023-07-06 18:18:40 +00:00
"github.com/gosuri/uiprogress"
2023-07-03 20:32:48 +00:00
"github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/server"
2023-07-04 04:47:00 +00:00
"github.com/spf13/cobra"
)
2023-07-06 15:59:42 +00:00
func cacheDir() string {
2023-07-04 04:47:00 +00:00
home, err := os.UserHomeDir()
if err != nil {
panic(err)
}
2023-07-06 15:59:42 +00:00
return path.Join(home, ".ollama")
2023-07-04 04:47:00 +00:00
}
2023-07-06 18:18:40 +00:00
func bytesToGB(bytes int) float64 {
return float64(bytes) / float64(1<<30)
}
2023-07-06 16:24:49 +00:00
func run(model string) error {
client, err := NewAPIClient()
if err != nil {
return err
}
pr := api.PullRequest{
Model: model,
}
2023-07-06 18:18:40 +00:00
var bar *uiprogress.Bar
mutex := &sync.Mutex{}
var progressData api.PullProgress
2023-07-06 18:57:11 +00:00
pullCallback := func(progress api.PullProgress) {
2023-07-06 18:18:40 +00:00
mutex.Lock()
progressData = progress
if bar == nil {
2023-07-06 19:00:15 +00:00
uiprogress.Start()
bar = uiprogress.AddBar(int(progress.Total))
2023-07-06 18:18:40 +00:00
bar.PrependFunc(func(b *uiprogress.Bar) string {
return fmt.Sprintf("Downloading: %.2f GB / %.2f GB", bytesToGB(progressData.Completed), bytesToGB(progressData.Total))
})
bar.AppendFunc(func(b *uiprogress.Bar) string {
return fmt.Sprintf(" %d%%", int((float64(progressData.Completed)/float64(progressData.Total))*100))
})
}
bar.Set(int(progress.Completed))
mutex.Unlock()
2023-07-06 16:24:49 +00:00
}
2023-07-06 18:57:11 +00:00
if err := client.Pull(context.Background(), &pr, pullCallback); err != nil {
return err
}
fmt.Println("Up to date.")
return nil
2023-07-06 16:24:49 +00:00
}
2023-07-04 04:47:00 +00:00
func serve() error {
2023-07-06 17:56:08 +00:00
ln, err := net.Listen("tcp", "127.0.0.1:11434")
2023-07-04 04:47:00 +00:00
if err != nil {
return err
}
return server.Serve(ln)
}
func NewAPIClient() (*api.Client, error) {
return &api.Client{
2023-07-06 17:56:08 +00:00
URL: "http://localhost:11434",
}, nil
}
func NewCLI() *cobra.Command {
log.SetFlags(log.LstdFlags | log.Lshortfile)
rootCmd := &cobra.Command{
2023-07-03 20:32:48 +00:00
Use: "ollama",
Short: "Large language model runner",
CompletionOptions: cobra.CompletionOptions{
DisableDefaultCmd: true,
},
PersistentPreRun: func(cmd *cobra.Command, args []string) {
// Disable usage printing on errors
cmd.SilenceUsage = true
2023-07-06 15:59:42 +00:00
// create the models directory and it's parent
if err := os.MkdirAll(path.Join(cacheDir(), "models"), 0o700); err != nil {
panic(err)
}
},
}
cobra.EnableCommandSorting = false
runCmd := &cobra.Command{
2023-07-06 15:59:42 +00:00
Use: "run MODEL",
2023-07-04 04:47:00 +00:00
Short: "Run a model",
2023-07-06 15:59:42 +00:00
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
2023-07-06 16:24:49 +00:00
return run(args[0])
},
}
serveCmd := &cobra.Command{
Use: "serve",
Aliases: []string{"start"},
Short: "Start ollama",
RunE: func(cmd *cobra.Command, args []string) error {
2023-07-04 04:47:00 +00:00
return serve()
},
}
rootCmd.AddCommand(
serveCmd,
runCmd,
)
return rootCmd
}