Make /v1/embeddings functional, add request/response types

This commit is contained in:
oobabooga 2023-11-10 07:34:27 -08:00
parent 7ed2143cd6
commit c5be3f7acb
6 changed files with 40 additions and 26 deletions

View file

@ -31,6 +31,8 @@ from .typing import (
CompletionResponse,
DecodeRequest,
DecodeResponse,
EmbeddingsRequest,
EmbeddingsResponse,
EncodeRequest,
EncodeResponse,
LoadModelRequest,
@ -41,7 +43,7 @@ from .typing import (
params = {
'embedding_device': 'cpu',
'embedding_model': 'all-mpnet-base-v2',
'embedding_model': 'sentence-transformers/all-mpnet-base-v2',
'sd_webui_url': '',
'debug': 0
}
@ -196,19 +198,16 @@ async def handle_image_generation(request: Request):
return JSONResponse(response)
@app.post("/v1/embeddings")
async def handle_embeddings(request: Request):
body = await request.json()
encoding_format = body.get("encoding_format", "")
input = body.get('input', body.get('text', ''))
@app.post("/v1/embeddings", response_model=EmbeddingsResponse)
async def handle_embeddings(request: Request, request_data: EmbeddingsRequest):
input = request_data.input
if not input:
raise HTTPException(status_code=400, detail="Missing required argument input")
if type(input) is str:
input = [input]
response = OAIembeddings.embeddings(input, encoding_format)
response = OAIembeddings.embeddings(input, request_data.encoding_format)
return JSONResponse(response)