Add superbooga time weighted history retrieval (#2080)

This commit is contained in:
Luis Lopez 2023-05-25 21:22:45 +08:00 committed by GitHub
parent a04266161d
commit ee674afa50
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 49 additions and 21 deletions

View file

@ -47,34 +47,58 @@ class ChromaCollector(Collecter):
self.ids = [f"id{i}" for i in range(len(texts))]
self.collection.add(documents=texts, ids=self.ids)
def get_documents_and_ids(self, search_strings: list[str], n_results: int):
def get_documents_ids_distances(self, search_strings: list[str], n_results: int):
n_results = min(len(self.ids), n_results)
if n_results == 0:
return [], []
result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])
result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents', 'distances'])
documents = result['documents'][0]
ids = list(map(lambda x: int(x[2:]), result['ids'][0]))
return documents, ids
distances = result['distances'][0]
return documents, ids, distances
# Get chunks by similarity
def get(self, search_strings: list[str], n_results: int) -> list[str]:
documents, _ = self.get_documents_and_ids(search_strings, n_results)
documents, _, _ = self.get_documents_ids_distances(search_strings, n_results)
return documents
# Get ids by similarity
def get_ids(self, search_strings: list[str], n_results: int) -> list[str]:
_, ids = self.get_documents_and_ids(search_strings, n_results)
_, ids, _ = self.get_documents_ids_distances(search_strings, n_results)
return ids
# Get chunks by similarity and then sort by insertion order
def get_sorted(self, search_strings: list[str], n_results: int) -> list[str]:
documents, ids = self.get_documents_and_ids(search_strings, n_results)
documents, ids, _ = self.get_documents_ids_distances(search_strings, n_results)
return [x for _, x in sorted(zip(ids, documents))]
# Multiply distance by factor within [0, time_weight] where more recent is lower
def apply_time_weight_to_distances(self, ids: list[int], distances: list[float], time_weight: float = 1.0) -> list[float]:
if len(self.ids) <= 1:
return distances.copy()
return [distance * (1 - _id / (len(self.ids) - 1) * time_weight) for _id, distance in zip(ids, distances)]
# Get ids by similarity and then sort by insertion order
def get_ids_sorted(self, search_strings: list[str], n_results: int) -> list[str]:
_, ids = self.get_documents_and_ids(search_strings, n_results)
def get_ids_sorted(self, search_strings: list[str], n_results: int, n_initial: int = None, time_weight: float = 1.0) -> list[str]:
do_time_weight = time_weight > 0
if not (do_time_weight and n_initial is not None):
n_initial = n_results
elif n_initial == -1:
n_initial = len(self.ids)
if n_initial < n_results:
raise ValueError(f"n_initial {n_initial} should be >= n_results {n_results}")
_, ids, distances = self.get_documents_ids_distances(search_strings, n_initial)
if do_time_weight:
distances_w = self.apply_time_weight_to_distances(ids, distances, time_weight=time_weight)
results = zip(ids, distances, distances_w)
results = sorted(results, key=lambda x: x[2])[:n_results]
results = sorted(results, key=lambda x: x[0])
ids = [x[0] for x in results]
return sorted(ids)
def clear(self):