Add /v1/internal/lora endpoints (#4652)

This commit is contained in:
oobabooga 2023-11-19 00:35:22 -03:00 committed by GitHub
parent ef6feedeb2
commit 771e62e476
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 72 additions and 19 deletions

View file

@ -38,10 +38,13 @@ from .typing import (
EmbeddingsResponse,
EncodeRequest,
EncodeResponse,
LoadLorasRequest,
LoadModelRequest,
LogitsRequest,
LogitsResponse,
LoraListResponse,
ModelInfoResponse,
ModelListResponse,
TokenCountResponse,
to_dict
)
@ -141,7 +144,7 @@ async def handle_models(request: Request):
is_list = request.url.path.split('?')[0].split('#')[0] == '/v1/models'
if is_list:
response = OAImodels.list_models()
response = OAImodels.list_dummy_models()
else:
model_name = path[len('/v1/models/'):]
response = OAImodels.model_info_dict(model_name)
@ -267,6 +270,12 @@ async def handle_model_info():
return JSONResponse(content=payload)
@app.get("/v1/internal/model/list", response_model=ModelListResponse, dependencies=check_admin_key)
async def handle_list_models():
payload = OAImodels.list_models()
return JSONResponse(content=payload)
@app.post("/v1/internal/model/load", dependencies=check_admin_key)
async def handle_load_model(request_data: LoadModelRequest):
'''
@ -307,6 +316,27 @@ async def handle_load_model(request_data: LoadModelRequest):
@app.post("/v1/internal/model/unload", dependencies=check_admin_key)
async def handle_unload_model():
unload_model()
@app.get("/v1/internal/lora/list", response_model=LoraListResponse, dependencies=check_admin_key)
async def handle_list_loras():
response = OAImodels.list_loras()
return JSONResponse(content=response)
@app.post("/v1/internal/lora/load", dependencies=check_admin_key)
async def handle_load_loras(request_data: LoadLorasRequest):
try:
OAImodels.load_loras(request_data.lora_names)
return JSONResponse(content="OK")
except:
traceback.print_exc()
return HTTPException(status_code=400, detail="Failed to apply the LoRA(s).")
@app.post("/v1/internal/lora/unload", dependencies=check_admin_key)
async def handle_unload_loras():
OAImodels.unload_all_loras()
return JSONResponse(content="OK")