Add /v1/internal/logits endpoint (#4650)

This commit is contained in:
oobabooga 2023-11-18 23:19:31 -03:00 committed by GitHub
parent 8f4f4daf8b
commit 0fa1af296c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 71 additions and 9 deletions

View file

@ -16,6 +16,7 @@ from sse_starlette import EventSourceResponse
import extensions.openai.completions as OAIcompletions
import extensions.openai.embeddings as OAIembeddings
import extensions.openai.images as OAIimages
import extensions.openai.logits as OAIlogits
import extensions.openai.models as OAImodels
import extensions.openai.moderations as OAImoderations
from extensions.openai.errors import ServiceUnavailableError
@ -38,6 +39,8 @@ from .typing import (
EncodeRequest,
EncodeResponse,
LoadModelRequest,
LogitsRequest,
LogitsResponse,
ModelInfoResponse,
TokenCountResponse,
to_dict
@ -242,6 +245,16 @@ async def handle_token_count(request_data: EncodeRequest):
return JSONResponse(response)
@app.post("/v1/internal/logits", response_model=LogitsResponse, dependencies=check_key)
async def handle_logits(request_data: LogitsRequest):
'''
Given a prompt, returns the top 50 most likely logits as a dict.
The keys are the tokens, and the values are the probabilities.
'''
response = OAIlogits._get_next_logits(to_dict(request_data))
return JSONResponse(response)
@app.post("/v1/internal/stop-generation", dependencies=check_key)
async def handle_stop_generation(request: Request):
stop_everything_event()

View file

@ -126,15 +126,15 @@ class EncodeRequest(BaseModel):
text: str
class DecodeRequest(BaseModel):
tokens: List[int]
class EncodeResponse(BaseModel):
tokens: List[int]
length: int
class DecodeRequest(BaseModel):
tokens: List[int]
class DecodeResponse(BaseModel):
text: str
@ -143,6 +143,24 @@ class TokenCountResponse(BaseModel):
length: int
class LogitsRequestParams(BaseModel):
prompt: str
use_samplers: bool = False
frequency_penalty: float | None = 0
max_tokens: int | None = 16
presence_penalty: float | None = 0
temperature: float | None = 1
top_p: float | None = 1
class LogitsRequest(GenerationOptions, LogitsRequestParams):
pass
class LogitsResponse(BaseModel):
logits: dict
class ModelInfoResponse(BaseModel):
model_name: str
lora_names: List[str]