Use argparse defaults
This commit is contained in:
parent
43e01282b3
commit
3a337cfded
2 changed files with 5 additions and 11 deletions
15
server.py
15
server.py
|
@ -188,14 +188,7 @@ def download_model_wrapper(repo_id):
|
|||
def update_model_parameters(state, initial=False):
|
||||
elements = ui.list_model_elements() # the names of the parameters
|
||||
gpu_memories = []
|
||||
defaults = {
|
||||
'wbits': 0,
|
||||
'groupsize': -1,
|
||||
'cpu_memory': None,
|
||||
'gpu_memory': None,
|
||||
'model_type': None,
|
||||
'pre_layer': 0
|
||||
}
|
||||
|
||||
for i, element in enumerate(elements):
|
||||
if element not in state:
|
||||
continue
|
||||
|
@ -205,14 +198,14 @@ def update_model_parameters(state, initial=False):
|
|||
gpu_memories.append(value)
|
||||
continue
|
||||
|
||||
if initial and eval(f"shared.args.{element}") != defaults[element]:
|
||||
if initial and vars(shared.args)[element] != vars(shared.args_defaults)[element]:
|
||||
continue
|
||||
|
||||
# Setting null defaults
|
||||
if element in ['wbits', 'groupsize', 'model_type'] and value == 'None':
|
||||
value = defaults[element]
|
||||
value = vars(shared.args_defaults)[element]
|
||||
elif element in ['cpu_memory'] and value == 0:
|
||||
value = defaults[element]
|
||||
value = vars(shared.args_defaults)[element]
|
||||
|
||||
# Making some simple conversions
|
||||
if element in ['wbits', 'groupsize', 'pre_layer']:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue