Source code for chainfury.components.qdrant

from uuid import uuid4
from functools import lru_cache
from typing import List, Dict, Tuple, Optional, Union

try:
    from qdrant_client import models, QdrantClient

    QDRANT_CLIENT_INSTALLED = True
except ImportError:
    QDRANT_CLIENT_INSTALLED = False

from chainfury import Secret, memory_registry, logger
from chainfury.components.const import Env, ComponentMissingError

# https://qdrant.tech/documentation/concepts/filtering
# Must : "must" : AND
# Should : "should" : OR
# Must Not: "must_not" : NOT
# Match: =
# Match Any: IN
# Match Except: NOT IN


@lru_cache(maxsize=1)
def _get_qdrant_client(qdrant_url: Secret = Secret(), qdrant_api_key: Secret = Secret()):
    """Create a qdrant client and cache it

    Args:
        qdrant_url (Secret, optional): qdrant url or set env var `QDRANT_API_URL`.
        qdrant_api_key (Secret, optional): qdrant api key or set env var `QDRANT_API_KEY`.

    Returns:
        qdrant_client.QdrantClient: qdrant client
    """
    qdrant_url = Secret(Env.QDRANT_API_URL(qdrant_url.value)).value  # type: ignore
    qdrant_api_key = Secret(Env.QDRANT_API_KEY(qdrant_api_key.value)).value  # type: ignore
    if not qdrant_url:
        raise Exception("Qdrant URL is not set. Please pass `qdrant_url` or  env var `QDRANT_API_URL=<your_url>`")
    if not qdrant_api_key:
        raise Exception("Qdrant API KEY is not set. Please pass `qdrant_api_key` or  env var `QDRANT_API_KEY=<your_url>`")
    logger.info("Creating Qdrant client")
    return QdrantClient(url=qdrant_url, api_key=qdrant_api_key)  # type: ignore


[docs]def qdrant_write( embeddings: List[List[float]], collection_name: str, qdrant_url: Secret = Secret(""), qdrant_api_key: Secret = Secret(""), extra_payload: List[Dict[str, str]] = [], wait: bool = True, create_if_not_present: bool = True, distance: str = "cosine", ) -> Tuple[str, Optional[Exception]]: """ Write to the Qdrant DB using the Qdrant client. In order to use this, access via the `memory_registry`: Example: >>> from chainfury import memory_registry >>> mem = memory_registry.get_write("qdrant") >>> sentence = "C.P. Cavafy is widely considered the most distinguished Greek poet of the 20th century." >>> out, err = mem( { "items": [sentence], "extra_payload": [ {"data": sentence}, ], "collection_name": "my_test_collection", "embedding_model": "openai-embedding", "create_if_not_present": True, } ) >>> if err: print("TRACE:", out) else: print(out) Args: embeddings (List[List[float]]): list of embeddings collection_name (str): collection name qdrant_url (Secret, optional): qdrant url or set env var `QDRANT_API_URL`. qdrant_api_key (Secret, optional): qdrant api key or set env var `QDRANT_API_KEY`. extra_payload (List[Dict[str, str]], optional): extra payload. Defaults to []. wait (bool, optional): wait for the response. Defaults to True. create_if_not_present (bool, optional): create collection if not present. Defaults to True. distance (str, optional): distance metric. Defaults to "cosine". Returns: Tuple[str, Optional[Exception]]: status and error """ # client check if not QDRANT_CLIENT_INSTALLED: raise ComponentMissingError("Qdrant client is not installed. Please install it with `pip install qdrant-client`") # checks if not (len(embeddings) and len(embeddings[0]) and type(embeddings[0][0]) == float): raise Exception("Embeddings should be a list of lists of floats") if extra_payload and len(extra_payload) != len(embeddings): raise Exception("Length of extra_payload should be equal to embeddings") client: QdrantClient = _get_qdrant_client(qdrant_url, qdrant_api_key) # type: ignore # next we create points and upsert them into the DB points = [] for i, embedding in enumerate(embeddings): payload = {} if extra_payload: payload = extra_payload[i] points.append(models.PointStruct(id=str(uuid4()), payload=payload, vector=embedding)) # type: ignore batch = models.Batch( # type: ignore ids=[point.id for point in points], vectors=[point.vector for point in points], payloads=[point.payload for point in points], ) def _insert(): try: result = client.upsert( collection_name=collection_name, points=batch, wait=wait, ) except Exception as e: return e.content, e # type: ignore return result.status.lower(), None status, err = _insert() if err and err.status_code == 404 and create_if_not_present: # type: ignore collection = client.recreate_collection( collection_name=collection_name, vectors_config=models.VectorParams( # type: ignore size=len(embeddings[0]), distance=getattr(models.Distance, distance.upper()), # type: ignore ), ) logger.info(f"Created collection {collection}") status, err = _insert() return status, err
memory_registry.register_write( component_name="qdrant", fn=qdrant_write, outputs={"status": 0}, vector_key="embeddings", description="Write to the Qdrant DB using the Qdrant client", )
[docs]def qdrant_read( embeddings: List[List[float]], collection_name: str, cutoff_score: float = 0.0, top: int = 5, limit: int = 0, offset: int = 0, filters: Dict[str, Dict[str, str]] = {}, qdrant_url: Secret = Secret(""), qdrant_api_key: Secret = Secret(""), qdrant_search_hnsw_ef: int = 0, qdrant_search_exact: bool = False, batch_search: bool = False, ) -> Tuple[Dict[str, List[Dict[str, Union[float, int]]]], Optional[Exception]]: """ Read from the Qdrant DB using the Qdrant client. In order to use this access via the `memory_registry`: Example: >>> from chainfury import memory_registry >>> mem = memory_registry.get_read("qdrant") >>> sentence = "Who was the Cafavy?" >>> out, err = mem( { "items": [sentence], "collection_name": "my_test_collection", "embedding_model": "openai-embedding" } ) >>> if err: print("TRACE:", out) else: print(out) Note: `batch_search` is not implemented yet. There's some issues from the `qdrant_client` library. Args: embeddings (List[List[float]]): list of embeddings collection_name (str): collection name cutoff_score (float, optional): cutoff score. Defaults to 0.0. limit (int, optional): limit. Defaults to 3. offset (int, optional): offset. Defaults to 0. qdrant_url (Secret, optional): qdrant url or set env var `QDRANT_API_URL`. qdrant_api_key (Secret, optional): qdrant api key or set env var `QDRANT_API_KEY`. qdrant_search_hnsw_ef (int, optional): qdrant search beam size, the larger the beam size the more accurate the search, if not set uses default value. qdrant_search_exact (bool, optional): qdrant search exact. Defaults to False. batch_search (bool, optional): batch search. Defaults to False. Returns: Tuple[List[Dict[str, Union[float, int]]], Optional[Exception]]: list of results and error """ # client check if not QDRANT_CLIENT_INSTALLED: raise ComponentMissingError("Qdrant client is not installed. Please install it with `pip install qdrant-client`") # checks if not (len(embeddings) and len(embeddings[0]) and type(embeddings[0][0]) == float): raise Exception("Embeddings should be a list of lists of floats") if batch_search: raise NotImplementedError("Batch search is not implemented yet") if not batch_search and len(embeddings) > 1: raise Exception("Batch search is not enabled, but multiple embeddings are passed") if not top and not limit: raise Exception("Either top or limit should be set") client: QdrantClient = _get_qdrant_client(qdrant_url, qdrant_api_key) # type: ignore search_params = models.SearchParams() # type: ignore if qdrant_search_hnsw_ef: search_params.hnsw_ef = qdrant_search_hnsw_ef if qdrant_search_exact: search_params.exact = qdrant_search_exact if batch_search: # this is not implemented, this fails when we try to pass a list of vectors search_queries = [models.SearchRequest(vector=x, limit=limit, offset=offset, params=search_params) for x in embeddings] out = client.search_batch( collection_name=collection_name, requests=search_queries, ) res = [[_x.dict(skip_defaults=False) for _x in x] for x in out] # type: ignore query_filter = None if filters: query_filter = models.Filter(**filters) # type: ignore out = client.search( collection_name=collection_name, query_vector=embeddings[0], query_filter=query_filter, limit=max(limit, top), offset=offset, search_params=search_params, ) out = [x for x in out if x.score > cutoff_score] res = [_x.dict(skip_defaults=False) for _x in out] # type: ignore return {"data": res}, None
memory_registry.register_read( component_name="qdrant", fn=qdrant_read, outputs={"items": 0}, vector_key="embeddings", description="Function to read from the Qdrant DB using the Qdrant client", ) # helper functions
[docs]def recreate_collection(collection_name: str, embedding_dim: int) -> bool: """ Deletes and recreates a collection Note: This will delete all the data in the collection, use with caution Args: collection_name (str): collection name embedding_dim (int): embedding dimension Returns: bool: success """ client: QdrantClient = _get_qdrant_client() # type: ignore return client.recreate_collection( collection_name=collection_name, vectors_config=models.VectorParams( # type: ignore size=embedding_dim, distance=models.Distance.COSINE, # type: ignore ), optimizers_config=models.OptimizersConfigDiff( # type: ignore indexing_threshold=0, ), )
[docs]def enable_indexing(collection_name: str, indexing_threshold: int = 20000) -> bool: """ Enable indexing for a collection, use this in conjunction with `disable_indexing`. Read more `here <https://qdrant.tech/documentation/tutorials/bulk-upload/#upload-directly-to-disk>`_. Example: >>> from chainfury.components.qdrant import enable_indexing, disable_indexing, qdrant_write >>> disable_indexing("my_collection") >>> qdrant_write([[1, 2, 3] for _ in range(100)], "my_collection") >>> enable_indexing("my_collection") Args: collection_name (str): collection name indexing_threshold (int, optional): indexing threshold. Defaults to 20000. Returns: bool: success """ client: QdrantClient = _get_qdrant_client() # type: ignore return client.update_collection( collection_name=collection_name, optimizer_config=models.OptimizersConfigDiff( # type: ignore indexing_threshold=indexing_threshold, ), )
[docs]def disable_indexing(collection_name: str): """ Disable indexing for a collection, use this in conjunction with `enable_indexing`. Read more `here <https://qdrant.tech/documentation/tutorials/bulk-upload/#upload-directly-to-disk>`_. Args: collection_name (str): collection name Returns: bool: success """ client: QdrantClient = _get_qdrant_client() # type: ignore return client.update_collection( collection_name=collection_name, optimizer_config=models.OptimizersConfigDiff( # type: ignore indexing_threshold=0, ), )