diff --git a/api-example.py b/api-example.py index 5138eb8..eff610c 100644 --- a/api-example.py +++ b/api-example.py @@ -36,10 +36,10 @@ params = { 'early_stopping': False, 'seed': -1, 'add_bos_token': True, - 'custom_stopping_strings': [], 'truncation_length': 2048, 'ban_eos_token': False, 'skip_special_tokens': True, + 'stopping_strings': [], } # Input prompt diff --git a/modules/api.py b/modules/api.py index b57cfe8..9de8e25 100644 --- a/modules/api.py +++ b/modules/api.py @@ -29,14 +29,16 @@ def generate_reply_wrapper(string): 'early_stopping': False, 'seed': -1, 'add_bos_token': True, - 'custom_stopping_strings': [], + 'custom_stopping_strings': '', 'truncation_length': 2048, 'ban_eos_token': False, 'skip_special_tokens': True, + 'stopping_strings': [], } params = json.loads(string) generate_params.update(params[1]) - for i in generate_reply(params[0], generate_params): + stopping_strings = generate_params.pop('stopping_strings') + for i in generate_reply(params[0], generate_params, stopping_strings=stopping_strings): yield i