New universal API with streaming/blocking endpoints (#990)
Previous title: Add api_streaming extension and update api-example-stream to use it * Merge with latest main * Add parameter capturing encoder_repetition_penalty * Change some defaults, minor fixes * Add --api, --public-api flags * remove unneeded/broken comment from blocking API startup. The comment is already correctly emitted in try_start_cloudflared by calling the lambda we pass in. * Update on_start message for blocking_api, it should say 'non-streaming' and not 'streaming' * Update the API examples * Change a comment * Update README * Remove the gradio API * Remove unused import * Minor change * Remove unused import --------- Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com>
This commit is contained in:
parent
459e725af9
commit
654933c634
12 changed files with 346 additions and 286 deletions
|
@ -1,52 +0,0 @@
|
|||
import json
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from modules import shared
|
||||
from modules.text_generation import generate_reply
|
||||
|
||||
# set this to True to rediscover the fn_index using the browser DevTools
|
||||
VISIBLE = False
|
||||
|
||||
|
||||
def generate_reply_wrapper(string):
|
||||
|
||||
# Provide defaults so as to not break the API on the client side when new parameters are added
|
||||
generate_params = {
|
||||
'max_new_tokens': 200,
|
||||
'do_sample': True,
|
||||
'temperature': 0.5,
|
||||
'top_p': 1,
|
||||
'typical_p': 1,
|
||||
'repetition_penalty': 1.1,
|
||||
'encoder_repetition_penalty': 1,
|
||||
'top_k': 0,
|
||||
'min_length': 0,
|
||||
'no_repeat_ngram_size': 0,
|
||||
'num_beams': 1,
|
||||
'penalty_alpha': 0,
|
||||
'length_penalty': 1,
|
||||
'early_stopping': False,
|
||||
'seed': -1,
|
||||
'add_bos_token': True,
|
||||
'custom_stopping_strings': '',
|
||||
'truncation_length': 2048,
|
||||
'ban_eos_token': False,
|
||||
'skip_special_tokens': True,
|
||||
'stopping_strings': [],
|
||||
}
|
||||
params = json.loads(string)
|
||||
generate_params.update(params[1])
|
||||
stopping_strings = generate_params.pop('stopping_strings')
|
||||
for i in generate_reply(params[0], generate_params, stopping_strings=stopping_strings):
|
||||
yield i
|
||||
|
||||
|
||||
def create_apis():
|
||||
t1 = gr.Textbox(visible=VISIBLE)
|
||||
t2 = gr.Textbox(visible=VISIBLE)
|
||||
dummy = gr.Button(visible=VISIBLE)
|
||||
|
||||
input_params = [t1]
|
||||
output_params = [t2] + [shared.gradio[k] for k in ['markdown', 'html']]
|
||||
dummy.click(generate_reply_wrapper, input_params, output_params, api_name='textgen')
|
|
@ -14,7 +14,8 @@ def load_extensions():
|
|||
global state, setup_called
|
||||
for i, name in enumerate(shared.args.extensions):
|
||||
if name in available_extensions:
|
||||
print(f'Loading the extension "{name}"... ', end='')
|
||||
if name != 'api':
|
||||
print(f'Loading the extension "{name}"... ', end='')
|
||||
try:
|
||||
exec(f"import extensions.{name}.script")
|
||||
extension = getattr(extensions, name).script
|
||||
|
@ -22,9 +23,11 @@ def load_extensions():
|
|||
setup_called.add(extension)
|
||||
extension.setup()
|
||||
state[name] = [True, i]
|
||||
print('Ok.')
|
||||
if name != 'api':
|
||||
print('Ok.')
|
||||
except:
|
||||
print('Fail.')
|
||||
if name != 'api':
|
||||
print('Fail.')
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
|
|
|
@ -150,6 +150,11 @@ parser.add_argument('--share', action='store_true', help='Create a public URL. T
|
|||
parser.add_argument('--auto-launch', action='store_true', default=False, help='Open the web UI in the default browser upon launch.')
|
||||
parser.add_argument("--gradio-auth-path", type=str, help='Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3"', default=None)
|
||||
|
||||
# API
|
||||
parser.add_argument('--api', action='store_true', help='Enable the API extension.')
|
||||
parser.add_argument('--public-api', action='store_true', help='Create a public URL for the API using Cloudfare.')
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
args_defaults = parser.parse_args([])
|
||||
|
||||
|
@ -171,6 +176,13 @@ if args.trust_remote_code:
|
|||
if args.share:
|
||||
print("Warning: the gradio \"share link\" feature downloads a proprietary and\nunaudited blob to create a reverse tunnel. This is potentially dangerous.\n")
|
||||
|
||||
# Activating the API extension
|
||||
if args.api or args.public_api:
|
||||
if args.extensions is None:
|
||||
args.extensions = ['api']
|
||||
elif 'api' not in args.extensions:
|
||||
args.extensions.append('api')
|
||||
|
||||
|
||||
def is_chat():
|
||||
return args.chat
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue