Merge branch 'main' into Zerogoki00-opt4-bit
This commit is contained in:
commit
3da73e409f
11 changed files with 159 additions and 50 deletions
|
@ -126,8 +126,9 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
|
|||
else:
|
||||
prompt = custom_generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size)
|
||||
|
||||
# Yield *Is typing...*
|
||||
if not regenerate:
|
||||
yield shared.history['visible']+[[visible_text, '*Is typing...*']]
|
||||
yield shared.history['visible']+[[visible_text, shared.processing_message]]
|
||||
|
||||
# Generate
|
||||
reply = ''
|
||||
|
@ -168,7 +169,8 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ
|
|||
prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=True)
|
||||
|
||||
reply = ''
|
||||
yield '*Is typing...*'
|
||||
# Yield *Is typing...*
|
||||
yield shared.processing_message
|
||||
for i in range(chat_generation_attempts):
|
||||
for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"):
|
||||
reply, next_character_found = extract_message_from_reply(prompt, reply, name1, name2, check, impersonate=True)
|
||||
|
@ -187,8 +189,8 @@ def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typi
|
|||
else:
|
||||
last_visible = shared.history['visible'].pop()
|
||||
last_internal = shared.history['internal'].pop()
|
||||
|
||||
yield generate_chat_output(shared.history['visible']+[[last_visible[0], '*Is typing...*']], name1, name2, shared.character)
|
||||
# Yield '*Is typing...*'
|
||||
yield generate_chat_output(shared.history['visible']+[[last_visible[0], shared.processing_message]], name1, name2, shared.character)
|
||||
for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts, regenerate=True):
|
||||
if shared.args.cai_chat:
|
||||
shared.history['visible'][-1] = [last_visible[0], _history[-1][1]]
|
||||
|
|
|
@ -11,6 +11,7 @@ is_RWKV = False
|
|||
history = {'internal': [], 'visible': []}
|
||||
character = 'None'
|
||||
stop_everything = False
|
||||
processing_message = '*Is typing...*'
|
||||
|
||||
# UI elements (buttons, sliders, HTML, etc)
|
||||
gradio = {}
|
||||
|
@ -85,12 +86,12 @@ parser.add_argument('--nvme-offload-dir', type=str, help='DeepSpeed: Directory t
|
|||
parser.add_argument('--local_rank', type=int, default=0, help='DeepSpeed: Optional argument for distributed setups.')
|
||||
parser.add_argument('--rwkv-strategy', type=str, default=None, help='RWKV: The strategy to use while loading the model. Examples: "cpu fp32", "cuda fp16", "cuda fp16i8".')
|
||||
parser.add_argument('--rwkv-cuda-on', action='store_true', help='RWKV: Compile the CUDA kernel for better performance.')
|
||||
parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time. This improves the text generation performance.')
|
||||
parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time.')
|
||||
parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example. If you create a file called settings.json, this file will be loaded by default without the need to use the --settings flag.')
|
||||
parser.add_argument('--extensions', type=str, nargs="+", help='The list of extensions to load. If you want to load more than one extension, write the names separated by spaces.')
|
||||
parser.add_argument('--listen', action='store_true', help='Make the web UI reachable from your local network.')
|
||||
parser.add_argument('--listen-port', type=int, help='The listening port that the server will use.')
|
||||
parser.add_argument('--share', action='store_true', help='Create a public URL. This is useful for running the web UI on Google Colab or similar.')
|
||||
parser.add_argument('--auto-launch', action='store_true', default=False, help='Open the web UI in the default browser upon launch.')
|
||||
parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
|
||||
parser.add_argument('--auto-launch', action='store_true', default=False, help='Open the web UI in the default browser upon launch')
|
||||
args = parser.parse_args()
|
||||
|
|
|
@ -123,7 +123,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
|
|||
original_input_ids = input_ids
|
||||
output = input_ids[0]
|
||||
cuda = "" if (shared.args.cpu or shared.args.deepspeed or shared.args.flexgen) else ".cuda()"
|
||||
eos_token_ids = [shared.tokenizer.eos_token_id]
|
||||
eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
|
||||
if eos_token is not None:
|
||||
eos_token_ids.append(int(encode(eos_token)[0][-1]))
|
||||
stopping_criteria_list = transformers.StoppingCriteriaList()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue