import argparse from llama_cpp import Llama parser = argparse.ArgumentParser() parser.add_argument("-m", "--model", type=str, default="../models/7B/ggml-models.bin") parser.add_argument("-p", "--prompt", type=str, default="def add(") parser.add_argument("-s", "--suffix", type=str, default="\n return sum\n\n") parser.add_argument("-i", "--spm-infill", action='store_true') args = parser.parse_args() llm = Llama(model_path=args.model, n_gpu_layers=-1, spm_infill=args.spm_infill) output = llm.create_completion( temperature = 0.0, repeat_penalty = 1.0, prompt = args.prompt, suffix = args.suffix, ) # Models sometimes repeat suffix in response, attempt to filter that response = output["choices"][0]["text"] response_stripped = response.rstrip() unwanted_response_suffix = args.suffix.rstrip() unwanted_response_length = len(unwanted_response_suffix) filtered = False if unwanted_response_suffix and response_stripped[-unwanted_response_length:] == unwanted_response_suffix: response = response_stripped[:-unwanted_response_length] filtered = True print(f"Fill-in-Middle completion{' (filtered)' if filtered else ''}:\n\n{args.prompt}\033[32m{response}\033[{'33' if filtered else '0'}m{args.suffix}\033[0m")