extensions/api: models api for blocking_api (updated) (#2539)

This commit is contained in:
matatonic 2023-06-08 10:34:36 -04:00 committed by GitHub
parent 084b006cfe
commit 7be6fe126b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 256 additions and 2 deletions

View file

@ -6,7 +6,19 @@ from extensions.api.util import build_parameters, try_start_cloudflared
from modules import shared
from modules.chat import generate_chat_reply
from modules.text_generation import encode, generate_reply, stop_everything_event
from modules.models import load_model, unload_model
from modules.LoRA import add_lora_to_model
from modules.utils import get_available_models
from server import get_model_specific_settings, update_model_parameters
def get_model_info():
return {
'model_name': shared.model_name,
'lora_names': shared.lora_names,
# dump
'shared.settings': shared.settings,
'shared.args': vars(shared.args),
}
class Handler(BaseHTTPRequestHandler):
def do_GET(self):
@ -91,6 +103,67 @@ class Handler(BaseHTTPRequestHandler):
self.wfile.write(response.encode('utf-8'))
elif self.path == '/api/v1/model':
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
# by default return the same as the GET interface
result = shared.model_name
# Actions: info, load, list, unload
action = body.get('action', '')
if action == 'load':
model_name = body['model_name']
args = body.get('args', {})
print('args', args)
for k in args:
setattr(shared.args, k, args[k])
shared.model_name = model_name
unload_model()
model_settings = get_model_specific_settings(shared.model_name)
shared.settings.update(model_settings)
update_model_parameters(model_settings, initial=True)
if shared.settings['mode'] != 'instruct':
shared.settings['instruction_template'] = None
try:
shared.model, shared.tokenizer = load_model(shared.model_name)
if shared.args.lora:
add_lora_to_model(shared.args.lora) # list
except Exception as e:
response = json.dumps({'error': { 'message': repr(e) } })
self.wfile.write(response.encode('utf-8'))
raise e
shared.args.model = shared.model_name
result = get_model_info()
elif action == 'unload':
unload_model()
shared.model_name = None
shared.args.model = None
result = get_model_info()
elif action == 'list':
result = get_available_models()
elif action == 'info':
result = get_model_info()
response = json.dumps({
'result': result,
})
self.wfile.write(response.encode('utf-8'))
elif self.path == '/api/v1/token-count':
self.send_response(200)
self.send_header('Content-Type', 'application/json')

View file

@ -56,7 +56,12 @@ class LLaVA_v0_Pipeline(AbstractMultimodalPipeline):
@staticmethod
def embed_tokens(input_ids: torch.Tensor) -> torch.Tensor:
return shared.model.model.embed_tokens(input_ids).to(shared.model.device, dtype=shared.model.dtype)
if hasattr(shared.model.model, 'embed_tokens'):
func = shared.model.model.embed_tokens
else:
func = shared.model.model.model.embed_tokens # AutoGPTQ case
return func(input_ids).to(shared.model.device, dtype=shared.model.dtype)
@staticmethod
def placeholder_embeddings() -> torch.Tensor: