New universal API with streaming/blocking endpoints (#990)

Previous title: Add api_streaming extension and update api-example-stream to use it

* Merge with latest main

* Add parameter capturing encoder_repetition_penalty

* Change some defaults, minor fixes

* Add --api, --public-api flags

* remove unneeded/broken comment from blocking API startup. The comment is already correctly emitted in try_start_cloudflared by calling the lambda we pass in.

* Update on_start message for blocking_api, it should say 'non-streaming' and not 'streaming'

* Update the API examples

* Change a comment

* Update README

* Remove the gradio API

* Remove unused import

* Minor change

* Remove unused import

---------

Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com>
This commit is contained in:
Andy Salerno 2023-04-23 11:52:43 -07:00 committed by GitHub
parent 459e725af9
commit 654933c634
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 346 additions and 286 deletions

View file

@ -1,52 +0,0 @@
import json
import gradio as gr
from modules import shared
from modules.text_generation import generate_reply
# set this to True to rediscover the fn_index using the browser DevTools
VISIBLE = False
def generate_reply_wrapper(string):
# Provide defaults so as to not break the API on the client side when new parameters are added
generate_params = {
'max_new_tokens': 200,
'do_sample': True,
'temperature': 0.5,
'top_p': 1,
'typical_p': 1,
'repetition_penalty': 1.1,
'encoder_repetition_penalty': 1,
'top_k': 0,
'min_length': 0,
'no_repeat_ngram_size': 0,
'num_beams': 1,
'penalty_alpha': 0,
'length_penalty': 1,
'early_stopping': False,
'seed': -1,
'add_bos_token': True,
'custom_stopping_strings': '',
'truncation_length': 2048,
'ban_eos_token': False,
'skip_special_tokens': True,
'stopping_strings': [],
}
params = json.loads(string)
generate_params.update(params[1])
stopping_strings = generate_params.pop('stopping_strings')
for i in generate_reply(params[0], generate_params, stopping_strings=stopping_strings):
yield i
def create_apis():
t1 = gr.Textbox(visible=VISIBLE)
t2 = gr.Textbox(visible=VISIBLE)
dummy = gr.Button(visible=VISIBLE)
input_params = [t1]
output_params = [t2] + [shared.gradio[k] for k in ['markdown', 'html']]
dummy.click(generate_reply_wrapper, input_params, output_params, api_name='textgen')

View file

@ -14,7 +14,8 @@ def load_extensions():
global state, setup_called
for i, name in enumerate(shared.args.extensions):
if name in available_extensions:
print(f'Loading the extension "{name}"... ', end='')
if name != 'api':
print(f'Loading the extension "{name}"... ', end='')
try:
exec(f"import extensions.{name}.script")
extension = getattr(extensions, name).script
@ -22,9 +23,11 @@ def load_extensions():
setup_called.add(extension)
extension.setup()
state[name] = [True, i]
print('Ok.')
if name != 'api':
print('Ok.')
except:
print('Fail.')
if name != 'api':
print('Fail.')
traceback.print_exc()

View file

@ -150,6 +150,11 @@ parser.add_argument('--share', action='store_true', help='Create a public URL. T
parser.add_argument('--auto-launch', action='store_true', default=False, help='Open the web UI in the default browser upon launch.')
parser.add_argument("--gradio-auth-path", type=str, help='Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3"', default=None)
# API
parser.add_argument('--api', action='store_true', help='Enable the API extension.')
parser.add_argument('--public-api', action='store_true', help='Create a public URL for the API using Cloudfare.')
args = parser.parse_args()
args_defaults = parser.parse_args([])
@ -171,6 +176,13 @@ if args.trust_remote_code:
if args.share:
print("Warning: the gradio \"share link\" feature downloads a proprietary and\nunaudited blob to create a reverse tunnel. This is potentially dangerous.\n")
# Activating the API extension
if args.api or args.public_api:
if args.extensions is None:
args.extensions = ['api']
elif 'api' not in args.extensions:
args.extensions.append('api')
def is_chat():
return args.chat