parent
ab6acddcc5
commit
9c53517d2c
1 changed files with 9 additions and 2 deletions
|
@ -42,11 +42,17 @@ class ChromaCollector(Collecter):
|
||||||
self.ids = []
|
self.ids = []
|
||||||
|
|
||||||
def add(self, texts: list[str]):
|
def add(self, texts: list[str]):
|
||||||
|
if len(texts) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
self.ids = [f"id{i}" for i in range(len(texts))]
|
self.ids = [f"id{i}" for i in range(len(texts))]
|
||||||
self.collection.add(documents=texts, ids=self.ids)
|
self.collection.add(documents=texts, ids=self.ids)
|
||||||
|
|
||||||
def get_documents_and_ids(self, search_strings: list[str], n_results: int):
|
def get_documents_and_ids(self, search_strings: list[str], n_results: int):
|
||||||
n_results = min(len(self.ids), n_results)
|
n_results = min(len(self.ids), n_results)
|
||||||
|
if n_results == 0:
|
||||||
|
return [], []
|
||||||
|
|
||||||
result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])
|
result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])
|
||||||
documents = result['documents'][0]
|
documents = result['documents'][0]
|
||||||
ids = list(map(lambda x: int(x[2:]), result['ids'][0]))
|
ids = list(map(lambda x: int(x[2:]), result['ids'][0]))
|
||||||
|
@ -74,6 +80,7 @@ class ChromaCollector(Collecter):
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
self.collection.delete(ids=self.ids)
|
self.collection.delete(ids=self.ids)
|
||||||
|
self.ids = []
|
||||||
|
|
||||||
|
|
||||||
class SentenceTransformerEmbedder(Embedder):
|
class SentenceTransformerEmbedder(Embedder):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue