extensions/api: models api for blocking_api (updated) (#2539)
This commit is contained in:
parent
084b006cfe
commit
7be6fe126b
4 changed files with 256 additions and 2 deletions
|
@ -6,7 +6,19 @@ from extensions.api.util import build_parameters, try_start_cloudflared
|
|||
from modules import shared
|
||||
from modules.chat import generate_chat_reply
|
||||
from modules.text_generation import encode, generate_reply, stop_everything_event
|
||||
from modules.models import load_model, unload_model
|
||||
from modules.LoRA import add_lora_to_model
|
||||
from modules.utils import get_available_models
|
||||
from server import get_model_specific_settings, update_model_parameters
|
||||
|
||||
def get_model_info():
|
||||
return {
|
||||
'model_name': shared.model_name,
|
||||
'lora_names': shared.lora_names,
|
||||
# dump
|
||||
'shared.settings': shared.settings,
|
||||
'shared.args': vars(shared.args),
|
||||
}
|
||||
|
||||
class Handler(BaseHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
|
@ -91,6 +103,67 @@ class Handler(BaseHTTPRequestHandler):
|
|||
|
||||
self.wfile.write(response.encode('utf-8'))
|
||||
|
||||
elif self.path == '/api/v1/model':
|
||||
self.send_response(200)
|
||||
self.send_header('Content-Type', 'application/json')
|
||||
self.end_headers()
|
||||
|
||||
# by default return the same as the GET interface
|
||||
result = shared.model_name
|
||||
|
||||
# Actions: info, load, list, unload
|
||||
action = body.get('action', '')
|
||||
|
||||
if action == 'load':
|
||||
model_name = body['model_name']
|
||||
args = body.get('args', {})
|
||||
print('args', args)
|
||||
for k in args:
|
||||
setattr(shared.args, k, args[k])
|
||||
|
||||
shared.model_name = model_name
|
||||
unload_model()
|
||||
|
||||
model_settings = get_model_specific_settings(shared.model_name)
|
||||
shared.settings.update(model_settings)
|
||||
update_model_parameters(model_settings, initial=True)
|
||||
|
||||
if shared.settings['mode'] != 'instruct':
|
||||
shared.settings['instruction_template'] = None
|
||||
|
||||
try:
|
||||
shared.model, shared.tokenizer = load_model(shared.model_name)
|
||||
if shared.args.lora:
|
||||
add_lora_to_model(shared.args.lora) # list
|
||||
|
||||
except Exception as e:
|
||||
response = json.dumps({'error': { 'message': repr(e) } })
|
||||
|
||||
self.wfile.write(response.encode('utf-8'))
|
||||
raise e
|
||||
|
||||
shared.args.model = shared.model_name
|
||||
|
||||
result = get_model_info()
|
||||
|
||||
elif action == 'unload':
|
||||
unload_model()
|
||||
shared.model_name = None
|
||||
shared.args.model = None
|
||||
result = get_model_info()
|
||||
|
||||
elif action == 'list':
|
||||
result = get_available_models()
|
||||
|
||||
elif action == 'info':
|
||||
result = get_model_info()
|
||||
|
||||
response = json.dumps({
|
||||
'result': result,
|
||||
})
|
||||
|
||||
self.wfile.write(response.encode('utf-8'))
|
||||
|
||||
elif self.path == '/api/v1/token-count':
|
||||
self.send_response(200)
|
||||
self.send_header('Content-Type', 'application/json')
|
||||
|
|
|
@ -56,7 +56,12 @@ class LLaVA_v0_Pipeline(AbstractMultimodalPipeline):
|
|||
|
||||
@staticmethod
|
||||
def embed_tokens(input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return shared.model.model.embed_tokens(input_ids).to(shared.model.device, dtype=shared.model.dtype)
|
||||
if hasattr(shared.model.model, 'embed_tokens'):
|
||||
func = shared.model.model.embed_tokens
|
||||
else:
|
||||
func = shared.model.model.model.embed_tokens # AutoGPTQ case
|
||||
|
||||
return func(input_ids).to(shared.model.device, dtype=shared.model.dtype)
|
||||
|
||||
@staticmethod
|
||||
def placeholder_embeddings() -> torch.Tensor:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue