diff --git a/server.py b/server.py index 1286dd0..0bd97ff 100644 --- a/server.py +++ b/server.py @@ -209,8 +209,8 @@ def list_model_parameters(): # Model parameters: update the command-line arguments based on the interface values def update_model_parameters(*args): - args = list(args) # the values of the parameters - elements = list_model_parameters() # the names of the parameters + args = list(args) # the values of the parameters + elements = list_model_parameters() # the names of the parameters gpu_memories = [] for i, element in enumerate(elements): @@ -232,8 +232,8 @@ def update_model_parameters(*args): elif element == 'cpu_memory' and args[i] is not None: args[i] = f"{args[i]}MiB" - #print(element, repr(eval(f"shared.args.{element}")), repr(args[i])) - #print(f"shared.args.{element} = args[i]") + # print(element, repr(eval(f"shared.args.{element}")), repr(args[i])) + # print(f"shared.args.{element} = args[i]") exec(f"shared.args.{element} = args[i]") found_positive = False @@ -251,7 +251,7 @@ def create_model_menus(): # Finding the default values for the GPU and CPU memories total_mem = [] for i in range(torch.cuda.device_count()): - total_mem.append(math.floor(torch.cuda.get_device_properties(i).total_memory / (1024*1024))) + total_mem.append(math.floor(torch.cuda.get_device_properties(i).total_memory / (1024 * 1024))) default_gpu_mem = [] if shared.args.gpu_memory is not None and len(shared.args.gpu_memory) > 0: @@ -259,11 +259,11 @@ def create_model_menus(): if 'mib' in i.lower(): default_gpu_mem.append(int(re.sub('[a-zA-Z ]', '', i))) else: - default_gpu_mem.append(int(re.sub('[a-zA-Z ]', '', i))*1000) + default_gpu_mem.append(int(re.sub('[a-zA-Z ]', '', i)) * 1000) while len(default_gpu_mem) < len(total_mem): default_gpu_mem.append(0) - total_cpu_mem = math.floor(psutil.virtual_memory().total / (1024*1024)) + total_cpu_mem = math.floor(psutil.virtual_memory().total / (1024 * 1024)) if shared.args.cpu_memory is not None: default_cpu_mem = re.sub('[a-zA-Z ]', '', shared.args.cpu_memory) else: @@ -441,16 +441,19 @@ else: if extension not in shared.args.extensions: shared.args.extensions.append(extension) -# Default model +# Model defined through --model if shared.args.model is not None: shared.model_name = shared.args.model - shared.model, shared.tokenizer = load_model(shared.model_name) + +# Only one model is available +elif len(available_models) == 1: + shared.model_name = available_models[0] + +# Select the model from a command-line menu elif shared.args.model_menu: if len(available_models) == 0: print('No models are available! Please download at least one.') sys.exit(0) - elif len(available_models) == 1: - i = 0 else: print('The following models are available:\n') for i, model in enumerate(available_models): @@ -459,10 +462,12 @@ elif shared.args.model_menu: i = int(input()) - 1 print() shared.model_name = available_models[i] - shared.model, shared.tokenizer = load_model(shared.model_name) -if shared.args.model is not None and shared.args.lora: - add_lora_to_model(shared.args.lora) +# If any model has been selected, load it +if shared.model_name != 'None': + shared.model, shared.tokenizer = load_model(shared.model_name) + if shared.args.lora: + add_lora_to_model(shared.args.lora) # Default UI settings default_preset = shared.settings['presets'][next((k for k in shared.settings['presets'] if re.match(k.lower(), shared.model_name.lower())), 'default')] @@ -685,14 +690,14 @@ def create_interface(): gen_events.append(shared.gradio['Generate'].click( ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( - generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)#.then( - #None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}") + generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream) # .then( + # None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}") ) gen_events.append(shared.gradio['textbox'].submit( ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( - generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)#.then( - #None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}") + generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream) # .then( + # None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}") ) shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None) @@ -744,20 +749,20 @@ def create_interface(): gen_events.append(shared.gradio['Generate'].click( ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( - generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)#.then( - #None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[1]; element.scrollTop = element.scrollHeight}") + generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream) # .then( + # None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[1]; element.scrollTop = element.scrollHeight}") ) gen_events.append(shared.gradio['textbox'].submit( ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( - generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)#.then( - #None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[1]; element.scrollTop = element.scrollHeight}") + generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream) # .then( + # None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[1]; element.scrollTop = element.scrollHeight}") ) gen_events.append(shared.gradio['Continue'].click( ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( - generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream)#.then( - #None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[1]; element.scrollTop = element.scrollHeight}") + generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream) # .then( + # None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[1]; element.scrollTop = element.scrollHeight}") ) shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None)