Merge pull request #73 from jmorganca/generate-eof
fix eof error in generate
This commit is contained in:
commit
5571ed5248
4 changed files with 16 additions and 20 deletions
|
@ -59,7 +59,7 @@ func pull(model string) error {
|
||||||
&api.PullRequest{Model: model},
|
&api.PullRequest{Model: model},
|
||||||
func(progress api.PullProgress) error {
|
func(progress api.PullProgress) error {
|
||||||
if bar == nil {
|
if bar == nil {
|
||||||
if progress.Percent == 100 {
|
if progress.Percent >= 100 {
|
||||||
// already downloaded
|
// already downloaded
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -73,10 +73,9 @@ func pull(model string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func RunGenerate(_ *cobra.Command, args []string) error {
|
func RunGenerate(_ *cobra.Command, args []string) error {
|
||||||
// join all args into a single prompt
|
|
||||||
prompt := strings.Join(args[1:], " ")
|
|
||||||
if len(args) > 1 {
|
if len(args) > 1 {
|
||||||
return generate(args[0], prompt)
|
// join all args into a single prompt
|
||||||
|
return generate(args[0], strings.Join(args[1:], " "))
|
||||||
}
|
}
|
||||||
|
|
||||||
if term.IsTerminal(int(os.Stdin.Fd())) {
|
if term.IsTerminal(int(os.Stdin.Fd())) {
|
||||||
|
|
|
@ -199,10 +199,10 @@ func (llm *llama) generate(tokens []C.llama_token, fn func(string)) error {
|
||||||
|
|
||||||
token, err := llm.sample(pastTokens, &opts)
|
token, err := llm.sample(pastTokens, &opts)
|
||||||
switch {
|
switch {
|
||||||
case err != nil:
|
|
||||||
return err
|
|
||||||
case errors.Is(err, io.EOF):
|
case errors.Is(err, io.EOF):
|
||||||
return nil
|
return nil
|
||||||
|
case err != nil:
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
fn(llm.detokenize(token))
|
fn(llm.detokenize(token))
|
||||||
|
|
|
@ -119,25 +119,22 @@ func saveModel(model *Model, fn func(total, completed int64)) error {
|
||||||
}
|
}
|
||||||
defer out.Close()
|
defer out.Close()
|
||||||
|
|
||||||
totalSize, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
|
remaining, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
|
||||||
|
completed := size
|
||||||
|
|
||||||
totalBytes := size
|
total := remaining + completed
|
||||||
totalSize += size
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
|
fn(total, completed)
|
||||||
|
if completed >= total {
|
||||||
|
return os.Rename(model.TempFile(), model.FullName())
|
||||||
|
}
|
||||||
|
|
||||||
n , err := io.CopyN(out, resp.Body, 8192)
|
n , err := io.CopyN(out, resp.Body, 8192)
|
||||||
if err != nil && !errors.Is(err, io.EOF) {
|
if err != nil && !errors.Is(err, io.EOF) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if n == 0 {
|
completed += n
|
||||||
break
|
|
||||||
}
|
}
|
||||||
|
|
||||||
totalBytes += n
|
|
||||||
fn(totalSize, totalBytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn(totalSize, totalSize)
|
|
||||||
return os.Rename(model.TempFile(), model.FullName())
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -112,7 +112,7 @@ func pull(c *gin.Context) {
|
||||||
ch <- api.PullProgress{
|
ch <- api.PullProgress{
|
||||||
Total: total,
|
Total: total,
|
||||||
Completed: completed,
|
Completed: completed,
|
||||||
Percent: float64(total) / float64(completed) * 100,
|
Percent: float64(completed) / float64(total) * 100,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue