parent
7dc87984a2
commit
d37a28730d
3 changed files with 14 additions and 1 deletions
|
@ -1,6 +1,7 @@
|
|||
import ast
|
||||
import random
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
|
||||
|
@ -17,6 +18,15 @@ from modules.logging_colors import logger
|
|||
from modules.models import clear_torch_cache, local_rank
|
||||
|
||||
|
||||
def generate_reply(*args, **kwargs):
|
||||
shared.generation_lock.acquire()
|
||||
try:
|
||||
for result in _generate_reply(*args, **kwargs):
|
||||
yield result
|
||||
finally:
|
||||
shared.generation_lock.release()
|
||||
|
||||
|
||||
def get_max_prompt_length(state):
|
||||
max_length = state['truncation_length'] - state['max_new_tokens']
|
||||
if shared.soft_prompt:
|
||||
|
@ -154,7 +164,7 @@ def generate_reply_wrapper(question, state, eos_token=None, stopping_strings=Non
|
|||
yield formatted_outputs(reply, shared.model_name)
|
||||
|
||||
|
||||
def generate_reply(question, state, eos_token=None, stopping_strings=None, is_chat=False):
|
||||
def _generate_reply(question, state, eos_token=None, stopping_strings=None, is_chat=False):
|
||||
state = apply_extensions('state', state)
|
||||
generate_func = apply_extensions('custom_generate_reply')
|
||||
if generate_func is None:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue