Add ExLlama support (#2444)
This commit is contained in:
parent
dea43685b0
commit
9f40032d32
12 changed files with 156 additions and 47 deletions
|
@ -59,18 +59,18 @@ class LlamaCppModel:
|
|||
|
||||
return self.model.tokenize(string)
|
||||
|
||||
def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=1, mirostat_mode=0, mirostat_tau=5, mirostat_eta=0.1, callback=None):
|
||||
context = context if type(context) is str else context.decode()
|
||||
def generate(self, prompt, state, callback=None):
|
||||
prompt = prompt if type(prompt) is str else prompt.decode()
|
||||
completion_chunks = self.model.create_completion(
|
||||
prompt=context,
|
||||
max_tokens=token_count,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
repeat_penalty=repetition_penalty,
|
||||
mirostat_mode=int(mirostat_mode),
|
||||
mirostat_tau=mirostat_tau,
|
||||
mirostat_eta=mirostat_eta,
|
||||
prompt=prompt,
|
||||
max_tokens=state['max_new_tokens'],
|
||||
temperature=state['temperature'],
|
||||
top_p=state['top_p'],
|
||||
top_k=state['top_k'],
|
||||
repeat_penalty=state['repetition_penalty'],
|
||||
mirostat_mode=int(state['mirostat_mode']),
|
||||
mirostat_tau=state['mirostat_tau'],
|
||||
mirostat_eta=state['mirostat_eta'],
|
||||
stream=True
|
||||
)
|
||||
|
||||
|
@ -83,8 +83,8 @@ class LlamaCppModel:
|
|||
|
||||
return output
|
||||
|
||||
def generate_with_streaming(self, **kwargs):
|
||||
with Iteratorize(self.generate, kwargs, callback=None) as generator:
|
||||
def generate_with_streaming(self, *args, **kwargs):
|
||||
with Iteratorize(self.generate, args, kwargs, callback=None) as generator:
|
||||
reply = ''
|
||||
for token in generator:
|
||||
reply += token
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue