Refactor text_generation.py, add support for custom generation functions (#1817)

This commit is contained in:
oobabooga 2023-05-05 18:53:03 -03:00 committed by GitHub
parent 876fbb97c0
commit 8aafb1f796
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 289 additions and 195 deletions

View file

@ -33,6 +33,7 @@ class Handler(BaseHTTPRequestHandler):
prompt = body['prompt']
generate_params = build_parameters(body)
stopping_strings = generate_params.pop('stopping_strings')
generate_params['stream'] = False
generator = generate_reply(
prompt, generate_params, stopping_strings=stopping_strings)
@ -66,7 +67,7 @@ class Handler(BaseHTTPRequestHandler):
self.send_error(404)
def _run_server(port: int, share: bool=False):
def _run_server(port: int, share: bool = False):
address = '0.0.0.0' if shared.args.listen else '127.0.0.1'
server = ThreadingHTTPServer((address, port), Handler)

View file

@ -23,6 +23,7 @@ async def _handle_connection(websocket, path):
prompt = message['prompt']
generate_params = build_parameters(message)
stopping_strings = generate_params.pop('stopping_strings')
generate_params['stream'] = True
generator = generate_reply(
prompt, generate_params, stopping_strings=stopping_strings)