Revert "Bump llama-cpp-python to 0.2.18 (#4611)"

This reverts commit 923c8e25fb.
This commit is contained in:
oobabooga 2023-11-17 05:14:25 -08:00
parent e0a7cc5e0f
commit 9d6f79db74
17 changed files with 174 additions and 92 deletions

View file

@ -1,7 +1,6 @@
import re
from functools import partial
import llama_cpp
import numpy as np
import torch
@ -10,6 +9,23 @@ from modules.callbacks import Iteratorize
from modules.logging_colors import logger
from modules.text_generation import get_max_prompt_length
try:
import llama_cpp
except:
llama_cpp = None
try:
import llama_cpp_cuda
except:
llama_cpp_cuda = None
def llama_cpp_lib():
if (shared.args.cpu and llama_cpp is not None) or llama_cpp_cuda is None:
return llama_cpp
else:
return llama_cpp_cuda
def ban_eos_logits_processor(eos_token, input_ids, logits):
logits[eos_token] = -float('inf')
@ -34,6 +50,10 @@ class LlamaCppModel:
@classmethod
def from_pretrained(self, path):
Llama = llama_cpp_lib().Llama
LlamaCache = llama_cpp_lib().LlamaCache
result = self()
cache_capacity = 0
if shared.args.cache_capacity is not None:
@ -54,6 +74,7 @@ class LlamaCppModel:
params = {
'model_path': str(path),
'n_ctx': shared.args.n_ctx,
'seed': int(shared.args.llama_cpp_seed),
'n_threads': shared.args.threads or None,
'n_threads_batch': shared.args.threads_batch or None,
'n_batch': shared.args.n_batch,
@ -67,9 +88,9 @@ class LlamaCppModel:
'rope_freq_scale': 1.0 / shared.args.compress_pos_emb,
}
result.model = llama_cpp.Llama(**params)
result.model = Llama(**params)
if cache_capacity > 0:
result.model.set_cache(llama_cpp.LlamaCache(capacity_bytes=cache_capacity))
result.model.set_cache(LlamaCache(capacity_bytes=cache_capacity))
# This is ugly, but the model and the tokenizer are the same object in this library.
return result, result
@ -93,13 +114,13 @@ class LlamaCppModel:
if string != self.grammar_string:
self.grammar_string = string
if string.strip() != '':
self.grammar = llama_cpp.LlamaGrammar.from_string(string)
self.grammar = llama_cpp_lib().LlamaGrammar.from_string(string)
else:
self.grammar = None
def generate(self, prompt, state, callback=None):
LogitsProcessorList = llama_cpp.LogitsProcessorList
LogitsProcessorList = llama_cpp_lib().LogitsProcessorList
prompt = prompt if type(prompt) is str else prompt.decode()
@ -123,16 +144,15 @@ class LlamaCppModel:
max_tokens=state['max_new_tokens'],
temperature=state['temperature'],
top_p=state['top_p'],
frequency_penalty=state['frequency_penalty'],
presence_penalty=state['presence_penalty'],
repeat_penalty=state['repetition_penalty'],
top_k=state['top_k'],
stream=True,
seed=int(state['seed']) if state['seed'] != -1 else None,
repeat_penalty=state['repetition_penalty'],
presence_penalty=state['presence_penalty'],
frequency_penalty=state['frequency_penalty'],
tfs_z=state['tfs'],
mirostat_mode=int(state['mirostat_mode']),
mirostat_tau=state['mirostat_tau'],
mirostat_eta=state['mirostat_eta'],
stream=True,
logits_processor=logit_processors,
grammar=self.grammar
)