diff --git a/server.py b/server.py index 3e31377..db83b4f 100644 --- a/server.py +++ b/server.py @@ -50,26 +50,20 @@ def get_available_softprompts(): def get_available_loras(): return ['None'] + sorted([item.name for item in list(Path('loras/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower) +def unload_model(): + shared.model = shared.tokenizer = None + clear_torch_cache() + def load_model_wrapper(selected_model): if selected_model != shared.model_name: shared.model_name = selected_model - shared.model = shared.tokenizer = None - clear_torch_cache() - shared.model, shared.tokenizer = load_model(shared.model_name) + + unload_model() + if selected_model != '': + shared.model, shared.tokenizer = load_model(shared.model_name) return selected_model -def reload_model(): - unload_model() - shared.model, shared.tokenizer = load_model(shared.model_name) - -def unload_model(): - shared.model = shared.tokenizer = None - if not shared.args.cpu: - gc.collect() - torch.cuda.empty_cache() - print("Model weights unloaded.") - def load_lora_wrapper(selected_lora): add_lora_to_model(selected_lora) default_text = shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')] @@ -128,9 +122,6 @@ def create_model_and_preset_menus(): with gr.Row(): shared.gradio['preset_menu'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset') ui.create_refresh_button(shared.gradio['preset_menu'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button') - with gr.Row(): - shared.gradio['unload_model'] = gr.Button(value='Unload model to free VRAM', elem_id="unload_model") - shared.gradio['reload_model'] = gr.Button(value='Reload the model into VRAM', elem_id="reload_model") def create_settings_menus(default_preset): generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', return_dict=True) @@ -185,8 +176,6 @@ def create_settings_menus(default_preset): shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip']) shared.gradio['model_menu'].change(load_model_wrapper, [shared.gradio['model_menu']], [shared.gradio['model_menu']], show_progress=True) - shared.gradio['unload_model'].click(fn=unload_model,inputs=[],outputs=[]) - shared.gradio['reload_model'].click(fn=reload_model,inputs=[],outputs=[]) shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio['preset_menu']], [shared.gradio[k] for k in ['preset_menu_mirror', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]) shared.gradio['preset_menu_mirror'].change(load_preset_values, [shared.gradio['preset_menu_mirror']], [shared.gradio[k] for k in ['preset_menu', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]) shared.gradio['lora_menu'].change(load_lora_wrapper, [shared.gradio['lora_menu']], [shared.gradio['lora_menu'], shared.gradio['textbox']], show_progress=True)