Make /v1/embeddings functional, add request/response types
This commit is contained in:
parent
7ed2143cd6
commit
c5be3f7acb
6 changed files with 40 additions and 26 deletions
|
@ -31,6 +31,8 @@ from .typing import (
|
|||
CompletionResponse,
|
||||
DecodeRequest,
|
||||
DecodeResponse,
|
||||
EmbeddingsRequest,
|
||||
EmbeddingsResponse,
|
||||
EncodeRequest,
|
||||
EncodeResponse,
|
||||
LoadModelRequest,
|
||||
|
@ -41,7 +43,7 @@ from .typing import (
|
|||
|
||||
params = {
|
||||
'embedding_device': 'cpu',
|
||||
'embedding_model': 'all-mpnet-base-v2',
|
||||
'embedding_model': 'sentence-transformers/all-mpnet-base-v2',
|
||||
'sd_webui_url': '',
|
||||
'debug': 0
|
||||
}
|
||||
|
@ -196,19 +198,16 @@ async def handle_image_generation(request: Request):
|
|||
return JSONResponse(response)
|
||||
|
||||
|
||||
@app.post("/v1/embeddings")
|
||||
async def handle_embeddings(request: Request):
|
||||
body = await request.json()
|
||||
encoding_format = body.get("encoding_format", "")
|
||||
|
||||
input = body.get('input', body.get('text', ''))
|
||||
@app.post("/v1/embeddings", response_model=EmbeddingsResponse)
|
||||
async def handle_embeddings(request: Request, request_data: EmbeddingsRequest):
|
||||
input = request_data.input
|
||||
if not input:
|
||||
raise HTTPException(status_code=400, detail="Missing required argument input")
|
||||
|
||||
if type(input) is str:
|
||||
input = [input]
|
||||
|
||||
response = OAIembeddings.embeddings(input, encoding_format)
|
||||
response = OAIembeddings.embeddings(input, request_data.encoding_format)
|
||||
return JSONResponse(response)
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue