Add /v1/internal/model/load endpoint (tentative)

This commit is contained in:
oobabooga 2023-11-07 20:58:06 -08:00
parent 43c53a7820
commit 2358706453
5 changed files with 47 additions and 4 deletions

View file

@ -1,5 +1,6 @@
import json
import os
import traceback
from threading import Thread
import extensions.openai.completions as OAIcompletions
@ -31,6 +32,7 @@ from .typing import (
DecodeResponse,
EncodeRequest,
EncodeResponse,
LoadModelRequest,
ModelInfoResponse,
TokenCountResponse,
to_dict
@ -231,12 +233,22 @@ async def handle_stop_generation(request: Request):
return JSONResponse(content="OK")
@app.get("/v1/internal/model-info", response_model=ModelInfoResponse)
@app.get("/v1/internal/model/info", response_model=ModelInfoResponse)
async def handle_model_info():
payload = OAImodels.get_current_model_info()
return JSONResponse(content=payload)
@app.post("/v1/internal/model/load")
async def handle_load_model(request_data: LoadModelRequest):
try:
OAImodels._load_model(to_dict(request_data))
return JSONResponse(content="OK")
except:
traceback.print_exc()
return HTTPException(status_code=400, detail="Failed to load the model.")
def run_server():
server_addr = '0.0.0.0' if shared.args.listen else '127.0.0.1'
port = int(os.environ.get('OPENEDAI_PORT', shared.args.api_port))