From deeac961bb78e53102d33497c210953123497bc4 Mon Sep 17 00:00:00 2001 From: Patrick Devine Date: Wed, 25 Oct 2023 16:41:18 -0700 Subject: [PATCH] new readline library (#847) --- cmd/cmd.go | 89 ++-------- go.mod | 3 +- go.sum | 11 +- readline/buffer.go | 370 +++++++++++++++++++++++++++++++++++++++++ readline/errors.go | 17 ++ readline/history.go | 152 +++++++++++++++++ readline/readline.go | 254 ++++++++++++++++++++++++++++ readline/term.go | 35 ++++ readline/term_bsd.go | 24 +++ readline/term_linux.go | 26 +++ readline/types.go | 77 +++++++++ 11 files changed, 972 insertions(+), 86 deletions(-) create mode 100644 readline/buffer.go create mode 100644 readline/errors.go create mode 100644 readline/history.go create mode 100644 readline/readline.go create mode 100644 readline/term.go create mode 100644 readline/term_bsd.go create mode 100644 readline/term_linux.go create mode 100644 readline/types.go diff --git a/cmd/cmd.go b/cmd/cmd.go index d334d431..16383d60 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -22,7 +22,6 @@ import ( "github.com/dustin/go-humanize" "github.com/olekukonko/tablewriter" - "github.com/pdevine/readline" "github.com/spf13/cobra" "golang.org/x/crypto/ssh" "golang.org/x/term" @@ -30,30 +29,11 @@ import ( "github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/format" "github.com/jmorganca/ollama/progressbar" + "github.com/jmorganca/ollama/readline" "github.com/jmorganca/ollama/server" "github.com/jmorganca/ollama/version" ) -type Painter struct { - IsMultiLine bool -} - -func (p Painter) Paint(line []rune, _ int) []rune { - termType := os.Getenv("TERM") - if termType == "xterm-256color" && len(line) == 0 { - var prompt string - if p.IsMultiLine { - prompt = "Use \"\"\" to end multi-line input" - } else { - prompt = "Send a message (/? for help)" - } - return []rune(fmt.Sprintf("\033[38;5;245m%s\033[%dD\033[0m", prompt, len(prompt))) - } - // add a space and a backspace to prevent the cursor from walking up the screen - line = append(line, []rune(" \b")...) - return line -} - func CreateHandler(cmd *cobra.Command, args []string) error { filename, _ := cmd.Flags().GetString("file") filename, err := filepath.Abs(filename) @@ -508,38 +488,11 @@ func generate(cmd *cobra.Command, model, prompt string, wordWrap bool) error { } func generateInteractive(cmd *cobra.Command, model string) error { - home, err := os.UserHomeDir() - if err != nil { - return err - } - // load the model if err := generate(cmd, model, "", false); err != nil { return err } - completer := readline.NewPrefixCompleter( - readline.PcItem("/help"), - readline.PcItem("/list"), - readline.PcItem("/set", - readline.PcItem("history"), - readline.PcItem("nohistory"), - readline.PcItem("wordwrap"), - readline.PcItem("nowordwrap"), - readline.PcItem("verbose"), - readline.PcItem("quiet"), - ), - readline.PcItem("/show", - readline.PcItem("license"), - readline.PcItem("modelfile"), - readline.PcItem("parameters"), - readline.PcItem("system"), - readline.PcItem("template"), - ), - readline.PcItem("/exit"), - readline.PcItem("/bye"), - ) - usage := func() { fmt.Fprintln(os.Stderr, "Available Commands:") fmt.Fprintln(os.Stderr, " /set Set session variables") @@ -572,16 +525,14 @@ func generateInteractive(cmd *cobra.Command, model string) error { fmt.Fprintln(os.Stderr, "") } - var painter Painter - - config := readline.Config{ - Painter: &painter, - Prompt: ">>> ", - HistoryFile: filepath.Join(home, ".ollama", "history"), - AutoComplete: completer, + prompt := readline.Prompt{ + Prompt: ">>> ", + AltPrompt: "... ", + Placeholder: "Send a message (/? for help)", + AltPlaceholder: `Use """ to end multi-line input`, } - scanner, err := readline.NewEx(&config) + scanner, err := readline.New(prompt) if err != nil { return err } @@ -603,7 +554,6 @@ func generateInteractive(cmd *cobra.Command, model string) error { } var multiLineBuffer string - var isMultiLine bool for { line, err := scanner.Readline() @@ -612,7 +562,7 @@ func generateInteractive(cmd *cobra.Command, model string) error { return nil case errors.Is(err, readline.ErrInterrupt): if line == "" { - fmt.Println("Use Ctrl-D or /bye to exit.") + fmt.Println("\nUse Ctrl-D or /bye to exit.") } continue @@ -623,23 +573,19 @@ func generateInteractive(cmd *cobra.Command, model string) error { line = strings.TrimSpace(line) switch { - case isMultiLine: + case scanner.Prompt.UseAlt: if strings.HasSuffix(line, `"""`) { - isMultiLine = false - painter.IsMultiLine = isMultiLine + scanner.Prompt.UseAlt = false multiLineBuffer += strings.TrimSuffix(line, `"""`) line = multiLineBuffer multiLineBuffer = "" - scanner.SetPrompt(">>> ") } else { multiLineBuffer += line + " " continue } case strings.HasPrefix(line, `"""`): - isMultiLine = true - painter.IsMultiLine = isMultiLine + scanner.Prompt.UseAlt = true multiLineBuffer = strings.TrimPrefix(line, `"""`) + " " - scanner.SetPrompt("... ") continue case strings.HasPrefix(line, "/list"): args := strings.Fields(line) @@ -666,19 +612,6 @@ func generateInteractive(cmd *cobra.Command, model string) error { case "quiet": cmd.Flags().Set("verbose", "false") fmt.Println("Set 'quiet' mode.") - case "mode": - if len(args) > 2 { - switch args[2] { - case "vim": - scanner.SetVimMode(true) - case "emacs", "default": - scanner.SetVimMode(false) - default: - usage() - } - } else { - usage() - } default: fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1]) } diff --git a/go.mod b/go.mod index 967e2777..1f8860cc 100644 --- a/go.mod +++ b/go.mod @@ -4,13 +4,14 @@ go 1.20 require ( github.com/dustin/go-humanize v1.0.1 + github.com/emirpasic/gods v1.18.1 github.com/gin-gonic/gin v1.9.1 github.com/mattn/go-runewidth v0.0.14 github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db github.com/olekukonko/tablewriter v0.0.5 - github.com/pdevine/readline v1.5.2 github.com/spf13/cobra v1.7.0 golang.org/x/sync v0.3.0 + gonum.org/v1/gonum v0.14.0 ) require github.com/rivo/uniseg v0.2.0 // indirect diff --git a/go.sum b/go.sum index 9dae5b7f..b12628a1 100644 --- a/go.sum +++ b/go.sum @@ -4,10 +4,6 @@ github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZX github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= -github.com/chzyer/logex v1.2.1 h1:XHDu3E6q+gdHgsdTPH6ImJMIp436vR6MPtH8gP05QzM= -github.com/chzyer/logex v1.2.1/go.mod h1:JLbx6lG2kDbNRFnfkgvh4eRJRPX1QCoOIWomwysCBrQ= -github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04= -github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8= github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -15,6 +11,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc= +github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ= github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= github.com/gin-contrib/cors v1.4.0 h1:oJ6gwtUl3lqV0WEIwM/LxPF1QZ5qe2lGWdY2+bz7y0g= @@ -78,8 +76,6 @@ github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58 h1:onHthvaw9LFnH4t2DcNVpwGmV9E1BkGknEliJkfwQj0= github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58/go.mod h1:DXv8WO4yhMYhSNPKjeNKa5WY9YCIEBRbNzFFPJbWO6Y= -github.com/pdevine/readline v1.5.2 h1:oz6Y5GdTmhPG+08hhxcAvtHitSANWuA2100Sppb38xI= -github.com/pdevine/readline v1.5.2/go.mod h1:na/LbuE5PYwxI7GyopWdIs3U8HVe89lYlNTFTXH3wOw= github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo= github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= @@ -131,7 +127,6 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= @@ -145,6 +140,8 @@ golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gonum.org/v1/gonum v0.14.0 h1:2NiG67LD1tEH0D7kM+ps2V+fXmsAnpUeec7n8tcr4S0= +gonum.org/v1/gonum v0.14.0/go.mod h1:AoWeoz0becf9QMWtE8iWXNXc27fK4fNeHNf/oMejGfU= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= diff --git a/readline/buffer.go b/readline/buffer.go new file mode 100644 index 00000000..1a75f139 --- /dev/null +++ b/readline/buffer.go @@ -0,0 +1,370 @@ +package readline + +import ( + "fmt" + + "github.com/emirpasic/gods/lists/arraylist" + "golang.org/x/term" +) + +type Buffer struct { + Pos int + Buf *arraylist.List + Prompt *Prompt + LineWidth int + Width int + Height int +} + +func NewBuffer(prompt *Prompt) (*Buffer, error) { + width, height, err := term.GetSize(0) + if err != nil { + fmt.Println("Error getting size:", err) + return nil, err + } + + lwidth := width - len(prompt.Prompt) + if prompt.UseAlt { + lwidth = width - len(prompt.AltPrompt) + } + + b := &Buffer{ + Pos: 0, + Buf: arraylist.New(), + Prompt: prompt, + Width: width, + Height: height, + LineWidth: lwidth, + } + + return b, nil +} + +func (b *Buffer) MoveLeft() { + if b.Pos > 0 { + if b.Pos%b.LineWidth == 0 { + fmt.Printf(CursorUp + CursorBOL + cursorRightN(b.Width)) + } else { + fmt.Printf(CursorLeft) + } + b.Pos -= 1 + } +} + +func (b *Buffer) MoveLeftWord() { + if b.Pos > 0 { + var foundNonspace bool + for { + v, _ := b.Buf.Get(b.Pos - 1) + if v == ' ' { + if foundNonspace { + break + } + } else { + foundNonspace = true + } + b.MoveLeft() + + if b.Pos == 0 { + break + } + } + } +} + +func (b *Buffer) MoveRight() { + if b.Pos < b.Size() { + b.Pos += 1 + if b.Pos%b.LineWidth == 0 { + fmt.Printf(CursorDown + CursorBOL + cursorRightN(b.PromptSize())) + } else { + fmt.Printf(CursorRight) + } + } +} + +func (b *Buffer) MoveRightWord() { + if b.Pos < b.Size() { + for { + b.MoveRight() + v, _ := b.Buf.Get(b.Pos) + if v == ' ' { + break + } + + if b.Pos == b.Size() { + break + } + } + } +} + +func (b *Buffer) MoveToStart() { + if b.Pos > 0 { + currLine := b.Pos / b.LineWidth + if currLine > 0 { + for cnt := 0; cnt < currLine; cnt++ { + fmt.Printf(CursorUp) + } + } + fmt.Printf(CursorBOL + cursorRightN(b.PromptSize())) + b.Pos = 0 + } +} + +func (b *Buffer) MoveToEnd() { + if b.Pos < b.Size() { + currLine := b.Pos / b.LineWidth + totalLines := b.Size() / b.LineWidth + if currLine < totalLines { + for cnt := 0; cnt < totalLines-currLine; cnt++ { + fmt.Printf(CursorDown) + } + remainder := b.Size() % b.LineWidth + fmt.Printf(CursorBOL + cursorRightN(b.PromptSize()+remainder)) + } else { + fmt.Printf(cursorRightN(b.Size() - b.Pos)) + } + + b.Pos = b.Size() + } +} + +func (b *Buffer) Size() int { + return b.Buf.Size() +} + +func min(n, m int) int { + if n > m { + return m + } + return n +} + +func (b *Buffer) PromptSize() int { + if b.Prompt.UseAlt { + return len(b.Prompt.AltPrompt) + } + return len(b.Prompt.Prompt) +} + +func (b *Buffer) Add(r rune) { + if b.Pos == b.Buf.Size() { + fmt.Printf("%c", r) + b.Buf.Add(r) + b.Pos += 1 + if b.Pos > 0 && b.Pos%b.LineWidth == 0 { + fmt.Printf("\n%s", b.Prompt.AltPrompt) + } + } else { + fmt.Printf("%c", r) + b.Buf.Insert(b.Pos, r) + b.Pos += 1 + if b.Pos > 0 && b.Pos%b.LineWidth == 0 { + fmt.Printf("\n%s", b.Prompt.AltPrompt) + } + b.drawRemaining() + } +} + +func (b *Buffer) drawRemaining() { + var place int + remainingText := b.StringN(b.Pos) + if b.Pos > 0 { + place = b.Pos % b.LineWidth + } + fmt.Printf(CursorHide) + + // render the rest of the current line + currLine := remainingText[:min(b.LineWidth-place, len(remainingText))] + if len(currLine) > 0 { + fmt.Printf(ClearToEOL + currLine) + fmt.Printf(cursorLeftN(len(currLine))) + } else { + fmt.Printf(ClearToEOL) + } + + // render the other lines + if len(remainingText) > len(currLine) { + remaining := []rune(remainingText[len(currLine):]) + var totalLines int + for i, c := range remaining { + if i%b.LineWidth == 0 { + fmt.Printf("\n%s", b.Prompt.AltPrompt) + totalLines += 1 + } + fmt.Printf("%c", c) + } + fmt.Printf(ClearToEOL) + fmt.Printf(cursorUpN(totalLines)) + fmt.Printf(CursorBOL + cursorRightN(b.Width-len(currLine))) + } + + fmt.Printf(CursorShow) +} + +func (b *Buffer) Remove() { + if b.Buf.Size() > 0 && b.Pos > 0 { + if b.Pos%b.LineWidth == 0 { + // if the user backspaces over the word boundary, do this magic to clear the line + // and move to the end of the previous line + fmt.Printf(CursorBOL + ClearToEOL) + fmt.Printf(CursorUp + CursorBOL + cursorRightN(b.Width) + " " + CursorLeft) + } else { + fmt.Printf(CursorLeft + " " + CursorLeft) + } + + var eraseExtraLine bool + if (b.Size()-1)%b.LineWidth == 0 { + eraseExtraLine = true + } + + b.Pos -= 1 + b.Buf.Remove(b.Pos) + + if b.Pos < b.Size() { + b.drawRemaining() + // this erases a line which is left over when backspacing in the middle of a line and there + // are trailing characters which go over the line width boundary + if eraseExtraLine { + remainingLines := (b.Size() - b.Pos) / b.LineWidth + fmt.Printf(cursorDownN(remainingLines+1) + CursorBOL + ClearToEOL) + place := b.Pos % b.LineWidth + fmt.Printf(cursorUpN(remainingLines+1) + cursorRightN(place+len(b.Prompt.Prompt))) + } + } + } +} + +func (b *Buffer) Delete() { + if b.Size() > 0 && b.Pos < b.Size() { + b.Buf.Remove(b.Pos) + b.drawRemaining() + if b.Size()%b.LineWidth == 0 { + if b.Pos != b.Size() { + remainingLines := (b.Size() - b.Pos) / b.LineWidth + fmt.Printf(cursorDownN(remainingLines) + CursorBOL + ClearToEOL) + place := b.Pos % b.LineWidth + fmt.Printf(cursorUpN(remainingLines) + cursorRightN(place+len(b.Prompt.Prompt))) + } + } + } +} + +func (b *Buffer) DeleteBefore() { + if b.Pos > 0 { + for cnt := b.Pos - 1; cnt >= 0; cnt-- { + b.Remove() + } + } +} + +func (b *Buffer) DeleteRemaining() { + if b.Size() > 0 && b.Pos < b.Size() { + charsToDel := b.Size() - b.Pos + for cnt := 0; cnt < charsToDel; cnt++ { + b.Delete() + } + } +} + +func (b *Buffer) DeleteWord() { + if b.Buf.Size() > 0 && b.Pos > 0 { + var foundNonspace bool + for { + v, _ := b.Buf.Get(b.Pos - 1) + if v == ' ' { + if !foundNonspace { + b.Remove() + } else { + break + } + } else { + foundNonspace = true + b.Remove() + } + + if b.Pos == 0 { + break + } + } + } +} + +func (b *Buffer) ClearScreen() { + fmt.Printf(ClearScreen + CursorReset + b.Prompt.Prompt) + if b.IsEmpty() { + ph := b.Prompt.Placeholder + fmt.Printf(ColorGrey + ph + cursorLeftN(len(ph)) + ColorDefault) + } else { + currPos := b.Pos + b.Pos = 0 + b.drawRemaining() + fmt.Printf(CursorReset + cursorRightN(len(b.Prompt.Prompt))) + if currPos > 0 { + targetLine := currPos / b.LineWidth + if targetLine > 0 { + for cnt := 0; cnt < targetLine; cnt++ { + fmt.Printf(CursorDown) + } + } + remainder := currPos % b.LineWidth + if remainder > 0 { + fmt.Printf(cursorRightN(remainder)) + } + if currPos%b.LineWidth == 0 { + fmt.Printf(CursorBOL + b.Prompt.AltPrompt) + } + } + b.Pos = currPos + } +} + +func (b *Buffer) IsEmpty() bool { + return b.Buf.Empty() +} + +func (b *Buffer) Replace(r []rune) { + b.Pos = 0 + b.Buf.Clear() + fmt.Printf(ClearLine + CursorBOL + b.Prompt.Prompt) + for _, c := range r { + b.Add(c) + } +} + +func (b *Buffer) String() string { + return b.StringN(0) +} + +func (b *Buffer) StringN(n int) string { + return b.StringNM(n, 0) +} + +func (b *Buffer) StringNM(n, m int) string { + var s string + if m == 0 { + m = b.Size() + } + for cnt := n; cnt < m; cnt++ { + c, _ := b.Buf.Get(cnt) + s += string(c.(rune)) + } + return s +} + +func cursorLeftN(n int) string { + return fmt.Sprintf(CursorLeftN, n) +} + +func cursorRightN(n int) string { + return fmt.Sprintf(CursorRightN, n) +} + +func cursorUpN(n int) string { + return fmt.Sprintf(CursorUpN, n) +} + +func cursorDownN(n int) string { + return fmt.Sprintf(CursorDownN, n) +} diff --git a/readline/errors.go b/readline/errors.go new file mode 100644 index 00000000..40e40cb7 --- /dev/null +++ b/readline/errors.go @@ -0,0 +1,17 @@ +package readline + +import ( + "errors" +) + +var ( + ErrInterrupt = errors.New("Interrupt") +) + +type InterruptError struct { + Line []rune +} + +func (*InterruptError) Error() string { + return "Interrupted" +} diff --git a/readline/history.go b/readline/history.go new file mode 100644 index 00000000..b16f937d --- /dev/null +++ b/readline/history.go @@ -0,0 +1,152 @@ +package readline + +import ( + "bufio" + "errors" + "io" + "os" + "path/filepath" + "strings" + + "github.com/emirpasic/gods/lists/arraylist" +) + +type History struct { + Buf *arraylist.List + Autosave bool + Pos int + Limit int + Filename string + Enabled bool +} + +func NewHistory() (*History, error) { + h := &History{ + Buf: arraylist.New(), + Limit: 100, //resizeme + Autosave: true, + Enabled: true, + } + + err := h.Init() + if err != nil { + return nil, err + } + + return h, nil +} + +func (h *History) Init() error { + home, err := os.UserHomeDir() + if err != nil { + return err + } + + path := filepath.Join(home, ".ollama", "history") + h.Filename = path + + //todo check if the file exists + f, err := os.OpenFile(path, os.O_CREATE|os.O_RDONLY, 0600) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return nil + } + return err + } + defer f.Close() + + r := bufio.NewReader(f) + for { + line, err := r.ReadString('\n') + if err != nil { + if err == io.EOF { + break + } + return err + } + + line = strings.TrimSpace(line) + if len(line) == 0 { + continue + } + + h.Add([]rune(line)) + } + + return nil +} + +func (h *History) Add(l []rune) { + h.Buf.Add(l) + h.Pos = h.Size() + h.Compact() + if h.Autosave { + h.Save() + } +} + +func (h *History) Compact() { + s := h.Buf.Size() + if s > h.Limit { + for cnt := 0; cnt < s-h.Limit; cnt++ { + h.Buf.Remove(0) + } + } +} + +func (h *History) Clear() { + h.Buf.Clear() +} + +func (h *History) Prev() []rune { + var line []rune + if h.Pos > 0 { + h.Pos -= 1 + } + v, _ := h.Buf.Get(h.Pos) + line, _ = v.([]rune) + return line +} + +func (h *History) Next() []rune { + var line []rune + if h.Pos < h.Buf.Size() { + h.Pos += 1 + v, _ := h.Buf.Get(h.Pos) + line, _ = v.([]rune) + } + return line +} + +func (h *History) Size() int { + return h.Buf.Size() +} + +func (h *History) Save() error { + if !h.Enabled { + return nil + } + + tmpFile := h.Filename + ".tmp" + + f, err := os.OpenFile(tmpFile, os.O_CREATE|os.O_WRONLY|os.O_TRUNC|os.O_APPEND, 0666) + if err != nil { + return err + } + defer f.Close() + + buf := bufio.NewWriter(f) + for cnt := 0; cnt < h.Size(); cnt++ { + v, _ := h.Buf.Get(cnt) + line, _ := v.([]rune) + buf.WriteString(string(line) + "\n") + } + buf.Flush() + f.Close() + + if err = os.Rename(tmpFile, h.Filename); err != nil { + return err + } + + return nil +} diff --git a/readline/readline.go b/readline/readline.go new file mode 100644 index 00000000..43ab400f --- /dev/null +++ b/readline/readline.go @@ -0,0 +1,254 @@ +package readline + +import ( + "bufio" + "fmt" + "io" + "os" + "sync" + "syscall" +) + +type Prompt struct { + Prompt string + AltPrompt string + Placeholder string + AltPlaceholder string + UseAlt bool +} + +type Terminal struct { + m sync.Mutex + wg sync.WaitGroup + outchan chan rune +} + +type Instance struct { + Prompt *Prompt + Terminal *Terminal + History *History +} + +func New(prompt Prompt) (*Instance, error) { + term, err := NewTerminal() + if err != nil { + return nil, err + } + + history, err := NewHistory() + if err != nil { + return nil, err + } + + return &Instance{ + Prompt: &prompt, + Terminal: term, + History: history, + }, nil +} + +func (i *Instance) Readline() (string, error) { + prompt := i.Prompt.Prompt + if i.Prompt.UseAlt { + prompt = i.Prompt.AltPrompt + } + fmt.Printf(prompt) + + termios, err := SetRawMode(syscall.Stdin) + if err != nil { + return "", err + } + defer UnsetRawMode(syscall.Stdin, termios) + + buf, _ := NewBuffer(i.Prompt) + + var esc bool + var escex bool + var metaDel bool + var bracketedPaste bool + var ignoreEnter bool + + var currentLineBuf []rune + + for { + if buf.IsEmpty() { + ph := i.Prompt.Placeholder + if i.Prompt.UseAlt { + ph = i.Prompt.AltPlaceholder + } + fmt.Printf(ColorGrey + ph + fmt.Sprintf(CursorLeftN, len(ph)) + ColorDefault) + } + + r := i.Terminal.ReadRune() + if buf.IsEmpty() { + fmt.Printf(ClearToEOL) + } + + if r == 0 { // io.EOF + break + } + + if escex { + escex = false + + switch r { + case KeyUp: + if i.History.Pos > 0 { + if i.History.Pos == i.History.Size() { + currentLineBuf = []rune(buf.String()) + } + buf.Replace(i.History.Prev()) + } + case KeyDown: + if i.History.Pos < i.History.Size() { + buf.Replace(i.History.Next()) + if i.History.Pos == i.History.Size() { + buf.Replace(currentLineBuf) + } + } + case KeyLeft: + buf.MoveLeft() + case KeyRight: + buf.MoveRight() + case CharBracketedPaste: + bracketedPaste = true + case KeyDel: + if buf.Size() > 0 { + buf.Delete() + } + metaDel = true + case MetaStart: + buf.MoveToStart() + case MetaEnd: + buf.MoveToEnd() + default: + // skip any keys we don't know about + continue + } + continue + } else if esc { + esc = false + + switch r { + case 'b': + buf.MoveLeftWord() + case 'f': + buf.MoveRightWord() + case CharEscapeEx: + escex = true + } + continue + } + + switch r { + case CharBracketedPasteStart: + if bracketedPaste { + ignoreEnter = true + } + case CharEsc: + esc = true + case CharInterrupt: + return "", ErrInterrupt + case CharLineStart: + buf.MoveToStart() + case CharLineEnd: + buf.MoveToEnd() + case CharBackward: + buf.MoveLeft() + case CharForward: + buf.MoveRight() + case CharBackspace, CharCtrlH: + buf.Remove() + case CharTab: + // todo: convert back to real tabs + for cnt := 0; cnt < 8; cnt++ { + buf.Add(' ') + } + case CharDelete: + if buf.Size() > 0 { + buf.Delete() + } else { + return "", io.EOF + } + case CharKill: + buf.DeleteRemaining() + case CharCtrlU: + buf.DeleteBefore() + case CharCtrlL: + buf.ClearScreen() + case CharCtrlW: + buf.DeleteWord() + case CharEnter: + if !ignoreEnter { + output := buf.String() + if output != "" { + i.History.Add([]rune(output)) + } + buf.MoveToEnd() + fmt.Println() + return output, nil + } + fallthrough + default: + if metaDel { + metaDel = false + continue + } + if r >= CharSpace || r == CharEnter { + buf.Add(r) + } + } + } + return "", nil +} + +func (i *Instance) Close() error { + return i.Terminal.Close() +} + +func (i *Instance) HistoryEnable() { + i.History.Enabled = true +} + +func (i *Instance) HistoryDisable() { + i.History.Enabled = false +} + +func NewTerminal() (*Terminal, error) { + t := &Terminal{ + outchan: make(chan rune), + } + + go t.ioloop() + + return t, nil +} + +func (t *Terminal) ioloop() { + buf := bufio.NewReader(os.Stdin) + + for { + r, _, err := buf.ReadRune() + if err != nil { + break + } + t.outchan <- r + if r == 0 { // EOF + break + } + } + +} + +func (t *Terminal) ReadRune() rune { + r, ok := <-t.outchan + if !ok { + return rune(0) + } + return r +} + +func (t *Terminal) Close() error { + close(t.outchan) + return nil +} diff --git a/readline/term.go b/readline/term.go new file mode 100644 index 00000000..a2c35938 --- /dev/null +++ b/readline/term.go @@ -0,0 +1,35 @@ +// +build aix darwin dragonfly freebsd linux,!appengine netbsd openbsd os400 solaris +package readline + +import ( + "syscall" +) + +type Termios syscall.Termios + +func SetRawMode(fd int) (*Termios, error) { + termios, err := getTermios(fd) + if err != nil { + return nil, err + } + + newTermios := *termios + newTermios.Iflag &^= syscall.IGNBRK | syscall.BRKINT | syscall.PARMRK | syscall.ISTRIP | syscall.INLCR | syscall.IGNCR | syscall.ICRNL | syscall.IXON + newTermios.Lflag &^= syscall.ECHO | syscall.ECHONL | syscall.ICANON | syscall.ISIG | syscall.IEXTEN + newTermios.Cflag &^= syscall.CSIZE | syscall.PARENB + newTermios.Cflag |= syscall.CS8 + newTermios.Cc[syscall.VMIN] = 1 + newTermios.Cc[syscall.VTIME] = 0 + + return termios, setTermios(fd, &newTermios) +} + +func UnsetRawMode(fd int, termios *Termios) error { + return setTermios(fd, termios) +} + +// IsTerminal returns true if the given file descriptor is a terminal. +func IsTerminal(fd int) bool { + _, err := getTermios(fd) + return err == nil +} diff --git a/readline/term_bsd.go b/readline/term_bsd.go new file mode 100644 index 00000000..98672075 --- /dev/null +++ b/readline/term_bsd.go @@ -0,0 +1,24 @@ +// go build darwin dragonfly freebsd netbsd openbsd +package readline + +import ( + "syscall" + "unsafe" +) + +func getTermios(fd int) (*Termios, error) { + termios := new(Termios) + _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), syscall.TIOCGETA, uintptr(unsafe.Pointer(termios)), 0, 0, 0) + if err != 0 { + return nil, err + } + return termios, nil +} + +func setTermios(fd int, termios *Termios) error { + _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), syscall.TIOCSETA, uintptr(unsafe.Pointer(termios)), 0, 0, 0) + if err != 0 { + return err + } + return nil +} diff --git a/readline/term_linux.go b/readline/term_linux.go new file mode 100644 index 00000000..ed3634e2 --- /dev/null +++ b/readline/term_linux.go @@ -0,0 +1,26 @@ +package readline + +import ( + "syscall" + "unsafe" +) + +const tcgets = 0x5401 +const tcsets = 0x5402 + +func getTermios(fd int) (*Termios, error) { + termios := new(Termios) + _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), tcgets, uintptr(unsafe.Pointer(termios)), 0, 0, 0) + if err != 0 { + return nil, err + } + return termios, nil +} + +func setTermios(fd int, termios *Termios) error { + _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), tcsets, uintptr(unsafe.Pointer(termios)), 0, 0, 0) + if err != 0 { + return err + } + return nil +} diff --git a/readline/types.go b/readline/types.go new file mode 100644 index 00000000..32b77007 --- /dev/null +++ b/readline/types.go @@ -0,0 +1,77 @@ +package readline + +const ( + CharLineStart = 1 + CharBackward = 2 + CharInterrupt = 3 + CharDelete = 4 + CharLineEnd = 5 + CharForward = 6 + CharBell = 7 + CharCtrlH = 8 + CharTab = 9 + CharCtrlJ = 10 + CharKill = 11 + CharCtrlL = 12 + CharEnter = 13 + CharNext = 14 + CharPrev = 16 + CharBckSearch = 18 + CharFwdSearch = 19 + CharTranspose = 20 + CharCtrlU = 21 + CharCtrlW = 23 + CharCtrlY = 25 + CharCtrlZ = 26 + CharEsc = 27 + CharSpace = 32 + CharEscapeEx = 91 + CharBackspace = 127 +) + +const ( + KeyDel = 51 + KeyUp = 65 + KeyDown = 66 + KeyRight = 67 + KeyLeft = 68 + MetaEnd = 70 + MetaStart = 72 +) + +const ( + CursorUp = "\033[1A" + CursorDown = "\033[1B" + CursorRight = "\033[1C" + CursorLeft = "\033[1D" + + CursorSave = "\033[s" + CursorRestore = "\033[u" + + CursorUpN = "\033[%dA" + CursorDownN = "\033[%dB" + CursorRightN = "\033[%dC" + CursorLeftN = "\033[%dD" + + CursorEOL = "\033[E" + CursorBOL = "\033[1G" + CursorHide = "\033[?25l" + CursorShow = "\033[?25h" + + ClearToEOL = "\033[K" + ClearLine = "\033[2K" + ClearScreen = "\033[2J" + CursorReset = "\033[0;0f" + + ColorGrey = "\033[38;5;245m" + ColorDefault = "\033[0m" + + StartBracketedPaste = "\033[?2004h" + EndBracketedPaste = "\033[?2004l" +) + +const ( + CharBracketedPaste = 50 + CharBracketedPasteStart = 0 + CharBracketedPasteEnd = 1 +)