Merge pull request #73 from jmorganca/generate-eof

fix eof error in generate
This commit is contained in:
Michael Yang 2023-07-12 11:09:23 -07:00 committed by GitHub
commit 5571ed5248
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 16 additions and 20 deletions

View file

@ -59,7 +59,7 @@ func pull(model string) error {
&api.PullRequest{Model: model}, &api.PullRequest{Model: model},
func(progress api.PullProgress) error { func(progress api.PullProgress) error {
if bar == nil { if bar == nil {
if progress.Percent == 100 { if progress.Percent >= 100 {
// already downloaded // already downloaded
return nil return nil
} }
@ -73,10 +73,9 @@ func pull(model string) error {
} }
func RunGenerate(_ *cobra.Command, args []string) error { func RunGenerate(_ *cobra.Command, args []string) error {
// join all args into a single prompt
prompt := strings.Join(args[1:], " ")
if len(args) > 1 { if len(args) > 1 {
return generate(args[0], prompt) // join all args into a single prompt
return generate(args[0], strings.Join(args[1:], " "))
} }
if term.IsTerminal(int(os.Stdin.Fd())) { if term.IsTerminal(int(os.Stdin.Fd())) {

View file

@ -199,10 +199,10 @@ func (llm *llama) generate(tokens []C.llama_token, fn func(string)) error {
token, err := llm.sample(pastTokens, &opts) token, err := llm.sample(pastTokens, &opts)
switch { switch {
case err != nil:
return err
case errors.Is(err, io.EOF): case errors.Is(err, io.EOF):
return nil return nil
case err != nil:
return err
} }
fn(llm.detokenize(token)) fn(llm.detokenize(token))

View file

@ -119,25 +119,22 @@ func saveModel(model *Model, fn func(total, completed int64)) error {
} }
defer out.Close() defer out.Close()
totalSize, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64) remaining, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
completed := size
totalBytes := size total := remaining + completed
totalSize += size
for { for {
n, err := io.CopyN(out, resp.Body, 8192) fn(total, completed)
if completed >= total {
return os.Rename(model.TempFile(), model.FullName())
}
n , err := io.CopyN(out, resp.Body, 8192)
if err != nil && !errors.Is(err, io.EOF) { if err != nil && !errors.Is(err, io.EOF) {
return err return err
} }
if n == 0 { completed += n
break
}
totalBytes += n
fn(totalSize, totalBytes)
} }
fn(totalSize, totalSize)
return os.Rename(model.TempFile(), model.FullName())
} }

View file

@ -112,7 +112,7 @@ func pull(c *gin.Context) {
ch <- api.PullProgress{ ch <- api.PullProgress{
Total: total, Total: total,
Completed: completed, Completed: completed,
Percent: float64(total) / float64(completed) * 100, Percent: float64(completed) / float64(total) * 100,
} }
} }