Add superbooga time weighted history retrieval (#2080)
This commit is contained in:
parent
a04266161d
commit
ee674afa50
2 changed files with 49 additions and 21 deletions
|
@ -12,6 +12,8 @@ from .download_urls import download_urls
|
|||
|
||||
params = {
|
||||
'chunk_count': 5,
|
||||
'chunk_count_initial': 10,
|
||||
'time_weight': 0,
|
||||
'chunk_length': 700,
|
||||
'chunk_separator': '',
|
||||
'strong_cleanup': False,
|
||||
|
@ -20,7 +22,6 @@ params = {
|
|||
|
||||
collector = make_collector()
|
||||
chat_collector = make_collector()
|
||||
chunk_count = 5
|
||||
|
||||
|
||||
def feed_data_into_collector(corpus, chunk_len, chunk_sep):
|
||||
|
@ -83,13 +84,12 @@ def feed_url_into_collector(urls, chunk_len, chunk_sep, strong_cleanup, threads)
|
|||
yield i
|
||||
|
||||
|
||||
def apply_settings(_chunk_count):
|
||||
global chunk_count
|
||||
chunk_count = int(_chunk_count)
|
||||
settings_to_display = {
|
||||
'chunk_count': chunk_count,
|
||||
}
|
||||
|
||||
def apply_settings(chunk_count, chunk_count_initial, time_weight):
|
||||
global params
|
||||
params['chunk_count'] = int(chunk_count)
|
||||
params['chunk_count_initial'] = int(chunk_count_initial)
|
||||
params['time_weight'] = time_weight
|
||||
settings_to_display = {k: params[k] for k in params if k in ['chunk_count', 'chunk_count_initial', 'time_weight']}
|
||||
yield f"The following settings are now active: {str(settings_to_display)}"
|
||||
|
||||
|
||||
|
@ -97,7 +97,7 @@ def custom_generate_chat_prompt(user_input, state, **kwargs):
|
|||
global chat_collector
|
||||
|
||||
if state['mode'] == 'instruct':
|
||||
results = collector.get_sorted(user_input, n_results=chunk_count)
|
||||
results = collector.get_sorted(user_input, n_results=params['chunk_count'])
|
||||
additional_context = '\nYour reply should be based on the context below:\n\n' + '\n'.join(results)
|
||||
user_input += additional_context
|
||||
else:
|
||||
|
@ -108,7 +108,7 @@ def custom_generate_chat_prompt(user_input, state, **kwargs):
|
|||
output += f"{state['name2']}: {shared.history['internal'][id_][1]}\n"
|
||||
return output
|
||||
|
||||
if len(shared.history['internal']) > chunk_count and user_input != '':
|
||||
if len(shared.history['internal']) > params['chunk_count'] and user_input != '':
|
||||
chunks = []
|
||||
hist_size = len(shared.history['internal'])
|
||||
for i in range(hist_size-1):
|
||||
|
@ -117,7 +117,7 @@ def custom_generate_chat_prompt(user_input, state, **kwargs):
|
|||
add_chunks_to_collector(chunks, chat_collector)
|
||||
query = '\n'.join(shared.history['internal'][-1] + [user_input])
|
||||
try:
|
||||
best_ids = chat_collector.get_ids_sorted(query, n_results=chunk_count)
|
||||
best_ids = chat_collector.get_ids_sorted(query, n_results=params['chunk_count'], n_initial=params['chunk_count_initial'], time_weight=params['time_weight'])
|
||||
additional_context = '\n'
|
||||
for id_ in best_ids:
|
||||
if shared.history['internal'][id_][0] != '<|BEGIN-VISIBLE-CHAT|>':
|
||||
|
@ -151,7 +151,7 @@ def input_modifier(string):
|
|||
user_input = match.group(1).strip()
|
||||
|
||||
# Get the most similar chunks
|
||||
results = collector.get_sorted(user_input, n_results=chunk_count)
|
||||
results = collector.get_sorted(user_input, n_results=params['chunk_count'])
|
||||
|
||||
# Make the injection
|
||||
string = string.replace('<|injection-point|>', '\n'.join(results))
|
||||
|
@ -240,6 +240,10 @@ def ui():
|
|||
|
||||
with gr.Tab("Generation settings"):
|
||||
chunk_count = gr.Number(value=params['chunk_count'], label='Chunk count', info='The number of closest-matching chunks to include in the prompt.')
|
||||
gr.Markdown('Time weighting (optional, used in to make recently added chunks more likely to appear)')
|
||||
time_weight = gr.Slider(0, 1, value=params['time_weight'], label='Time weight', info='Defines the strength of the time weighting. 0 = no time weighting.')
|
||||
chunk_count_initial = gr.Number(value=params['chunk_count_initial'], label='Initial chunk count', info='The number of closest-matching chunks retrieved for time weight reordering in chat mode. This should be >= chunk count. -1 = All chunks are retrieved. Only used if time_weight > 0.')
|
||||
|
||||
update_settings = gr.Button('Apply changes')
|
||||
|
||||
chunk_len = gr.Number(value=params['chunk_length'], label='Chunk length', info='In characters, not tokens. This value is used when you click on "Load data".')
|
||||
|
@ -250,4 +254,4 @@ def ui():
|
|||
update_data.click(feed_data_into_collector, [data_input, chunk_len, chunk_sep], last_updated, show_progress=False)
|
||||
update_url.click(feed_url_into_collector, [url_input, chunk_len, chunk_sep, strong_cleanup, threads], last_updated, show_progress=False)
|
||||
update_file.click(feed_file_into_collector, [file_input, chunk_len, chunk_sep], last_updated, show_progress=False)
|
||||
update_settings.click(apply_settings, [chunk_count], last_updated, show_progress=False)
|
||||
update_settings.click(apply_settings, [chunk_count, chunk_count_initial, time_weight], last_updated, show_progress=False)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue