Refactor text_generation.py, add support for custom generation functions (#1817)

This commit is contained in:
oobabooga 2023-05-05 18:53:03 -03:00 committed by GitHub
parent 876fbb97c0
commit 8aafb1f796
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 289 additions and 195 deletions

View file

@ -18,15 +18,8 @@ wav_idx = 0
user = ElevenLabsUser(params['api_key'])
user_info = None
if not shared.args.no_stream:
print("Please add --no-stream. This extension is not meant to be used with streaming.")
raise ValueError
# Check if the API is valid and refresh the UI accordingly.
def check_valid_api():
global user, user_info, params
user = ElevenLabsUser(params['api_key'])
@ -41,9 +34,8 @@ def check_valid_api():
print('Got an API Key!')
return gr.update(value='Connected')
# Once the API is verified, get the available voices and update the dropdown list
def refresh_voices():
global user, user_info
@ -63,6 +55,11 @@ def remove_surrounded_chars(string):
return re.sub('\*[^\*]*?(\*|$)', '', string)
def state_modifier(state):
state['stream'] = False
return state
def input_modifier(string):
"""
This function is applied to your text inputs before
@ -109,6 +106,7 @@ def ui():
with gr.Row():
activate = gr.Checkbox(value=params['activate'], label='Activate TTS')
connection_status = gr.Textbox(value='Disconnected', label='Connection Status')
voice = gr.Dropdown(value=params['selected_voice'], choices=initial_voice, label='TTS Voice')
with gr.Row():
api_key = gr.Textbox(placeholder="Enter your API key.", label='API Key')