Add superbooga time weighted history retrieval (#2080)

This commit is contained in:
Luis Lopez 2023-05-25 21:22:45 +08:00 committed by GitHub
parent a04266161d
commit ee674afa50
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 49 additions and 21 deletions

View file

@ -12,6 +12,8 @@ from .download_urls import download_urls
params = {
'chunk_count': 5,
'chunk_count_initial': 10,
'time_weight': 0,
'chunk_length': 700,
'chunk_separator': '',
'strong_cleanup': False,
@ -20,7 +22,6 @@ params = {
collector = make_collector()
chat_collector = make_collector()
chunk_count = 5
def feed_data_into_collector(corpus, chunk_len, chunk_sep):
@ -83,13 +84,12 @@ def feed_url_into_collector(urls, chunk_len, chunk_sep, strong_cleanup, threads)
yield i
def apply_settings(_chunk_count):
global chunk_count
chunk_count = int(_chunk_count)
settings_to_display = {
'chunk_count': chunk_count,
}
def apply_settings(chunk_count, chunk_count_initial, time_weight):
global params
params['chunk_count'] = int(chunk_count)
params['chunk_count_initial'] = int(chunk_count_initial)
params['time_weight'] = time_weight
settings_to_display = {k: params[k] for k in params if k in ['chunk_count', 'chunk_count_initial', 'time_weight']}
yield f"The following settings are now active: {str(settings_to_display)}"
@ -97,7 +97,7 @@ def custom_generate_chat_prompt(user_input, state, **kwargs):
global chat_collector
if state['mode'] == 'instruct':
results = collector.get_sorted(user_input, n_results=chunk_count)
results = collector.get_sorted(user_input, n_results=params['chunk_count'])
additional_context = '\nYour reply should be based on the context below:\n\n' + '\n'.join(results)
user_input += additional_context
else:
@ -108,7 +108,7 @@ def custom_generate_chat_prompt(user_input, state, **kwargs):
output += f"{state['name2']}: {shared.history['internal'][id_][1]}\n"
return output
if len(shared.history['internal']) > chunk_count and user_input != '':
if len(shared.history['internal']) > params['chunk_count'] and user_input != '':
chunks = []
hist_size = len(shared.history['internal'])
for i in range(hist_size-1):
@ -117,7 +117,7 @@ def custom_generate_chat_prompt(user_input, state, **kwargs):
add_chunks_to_collector(chunks, chat_collector)
query = '\n'.join(shared.history['internal'][-1] + [user_input])
try:
best_ids = chat_collector.get_ids_sorted(query, n_results=chunk_count)
best_ids = chat_collector.get_ids_sorted(query, n_results=params['chunk_count'], n_initial=params['chunk_count_initial'], time_weight=params['time_weight'])
additional_context = '\n'
for id_ in best_ids:
if shared.history['internal'][id_][0] != '<|BEGIN-VISIBLE-CHAT|>':
@ -151,7 +151,7 @@ def input_modifier(string):
user_input = match.group(1).strip()
# Get the most similar chunks
results = collector.get_sorted(user_input, n_results=chunk_count)
results = collector.get_sorted(user_input, n_results=params['chunk_count'])
# Make the injection
string = string.replace('<|injection-point|>', '\n'.join(results))
@ -240,6 +240,10 @@ def ui():
with gr.Tab("Generation settings"):
chunk_count = gr.Number(value=params['chunk_count'], label='Chunk count', info='The number of closest-matching chunks to include in the prompt.')
gr.Markdown('Time weighting (optional, used in to make recently added chunks more likely to appear)')
time_weight = gr.Slider(0, 1, value=params['time_weight'], label='Time weight', info='Defines the strength of the time weighting. 0 = no time weighting.')
chunk_count_initial = gr.Number(value=params['chunk_count_initial'], label='Initial chunk count', info='The number of closest-matching chunks retrieved for time weight reordering in chat mode. This should be >= chunk count. -1 = All chunks are retrieved. Only used if time_weight > 0.')
update_settings = gr.Button('Apply changes')
chunk_len = gr.Number(value=params['chunk_length'], label='Chunk length', info='In characters, not tokens. This value is used when you click on "Load data".')
@ -250,4 +254,4 @@ def ui():
update_data.click(feed_data_into_collector, [data_input, chunk_len, chunk_sep], last_updated, show_progress=False)
update_url.click(feed_url_into_collector, [url_input, chunk_len, chunk_sep, strong_cleanup, threads], last_updated, show_progress=False)
update_file.click(feed_file_into_collector, [file_input, chunk_len, chunk_sep], last_updated, show_progress=False)
update_settings.click(apply_settings, [chunk_count], last_updated, show_progress=False)
update_settings.click(apply_settings, [chunk_count, chunk_count_initial, time_weight], last_updated, show_progress=False)