Add /v1/internal/lora endpoints (#4652)

This commit is contained in:
oobabooga 2023-11-19 00:35:22 -03:00 committed by GitHub
parent ef6feedeb2
commit 771e62e476
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 72 additions and 19 deletions

View file

@ -1,8 +1,9 @@
from modules import shared
from modules.logging_colors import logger
from modules.LoRA import add_lora_to_model
from modules.models import load_model, unload_model
from modules.models_settings import get_model_metadata, update_model_parameters
from modules.utils import get_available_models
from modules.utils import get_available_loras, get_available_models
def get_current_model_info():
@ -13,12 +14,17 @@ def get_current_model_info():
def list_models():
return {'model_names': get_available_models()[1:]}
def list_dummy_models():
result = {
"object": "list",
"data": []
}
for model in get_dummy_models() + get_available_models()[1:]:
# these are expected by so much, so include some here as a dummy
for model in ['gpt-3.5-turbo', 'text-embedding-ada-002']:
result["data"].append(model_info_dict(model))
return result
@ -33,13 +39,6 @@ def model_info_dict(model_name: str) -> dict:
}
def get_dummy_models() -> list:
return [ # these are expected by so much, so include some here as a dummy
'gpt-3.5-turbo',
'text-embedding-ada-002',
]
def _load_model(data):
model_name = data["model_name"]
args = data["args"]
@ -67,3 +66,15 @@ def _load_model(data):
logger.info(f"TRUNCATION LENGTH (UPDATED): {shared.settings['truncation_length']}")
elif k == 'instruction_template':
logger.info(f"INSTRUCTION TEMPLATE (UPDATED): {shared.settings['instruction_template']}")
def list_loras():
return {'lora_names': get_available_loras()[1:]}
def load_loras(lora_names):
add_lora_to_model(lora_names)
def unload_all_loras():
add_lora_to_model([])