diff --git a/docs/record_manager.ipynb b/docs/record_manager.ipynb new file mode 100644 index 0000000..11b0f0b --- /dev/null +++ b/docs/record_manager.ipynb @@ -0,0 +1,189 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Google Firestore (Native Mode)\n", + "\n", + "> [Firestore](https://cloud.google.com/firestore) is a serverless document-oriented database that scales to meet any demand. Extend your database application to build AI-powered experiences leveraging Firestore's LangChain integrations.\n", + "\n", + "This notebook goes over how to use [Firestore](https://cloud.google.com/firestore) as a record manager for [langchain indexing](https://python.langchain.com/v0.1/docs/modules/data_connection/indexing/) your Vectorstore.\n", + "\n", + "[](https://colab.research.google.com/github/googleapis/langchain-google-firestore-python/blob/main/docs/record_manager.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Before You Begin\n", + "\n", + "To run this notebook, you will need to do the following:\n", + "\n", + "* [Create a Google Cloud Project](https://developers.google.com/workspace/guides/create-project)\n", + "* [Enable the Firestore API](https://console.cloud.google.com/flows/enableapi?apiid=firestore.googleapis.com)\n", + "* [Create a Firestore database](https://cloud.google.com/firestore/docs/manage-databases)\n", + "\n", + "After confirmed access to database in the runtime environment of this notebook, filling the following values and run the cell before running example scripts." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# @markdown Please specify a source for demo purpose.\n", + "COLLECTION_NAME = \"test\" # @param {type:\"CollectionReference\"|\"string\"}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 🦜🔗 Library Installation\n", + "\n", + "The integration lives in its own `langchain-google-firestore` package, so we need to install it. For this notebook, we will also install `langchain-google-genai` to use Google Generative AI embeddings." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install --upgrade --quiet langchain langchain-google-firestore langchain-google-vertexai" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Colab only**: Uncomment the following cell to restart the kernel or use the button to restart the kernel. For Vertex AI Workbench you can restart the terminal using the button on top." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# # Automatically restart kernel after installs so that your environment can access the new packages\n", + "# import IPython\n", + "\n", + "# app = IPython.Application.instance()\n", + "# app.kernel.do_shutdown(True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### ☁ Set Your Google Cloud Project\n", + "Set your Google Cloud project so that you can leverage Google Cloud resources within this notebook.\n", + "\n", + "If you don't know your project ID, try the following:\n", + "\n", + "* Run `gcloud config list`.\n", + "* Run `gcloud projects list`.\n", + "* See the support page: [Locate the project ID](https://support.google.com/googleapi/answer/7014113)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# @markdown Please fill in the value below with your Google Cloud project ID and then run the cell.\n", + "\n", + "PROJECT_ID = \"extensions-testing\" # @param {type:\"string\"}\n", + "\n", + "# Set the project id\n", + "!gcloud config set project {PROJECT_ID}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 🔐 Authentication\n", + "\n", + "Authenticate to Google Cloud as the IAM user logged into this notebook in order to access your Google Cloud Project.\n", + "\n", + "- If you are using Colab to run this notebook, use the cell below and continue.\n", + "- If you are using Vertex AI Workbench, check out the setup instructions [here](https://github.com/GoogleCloudPlatform/generative-ai/tree/main/setup-env)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from google.colab import auth\n", + "\n", + "auth.authenticate_user()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Basic Usage" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Initialize FirestoreRecordManager\n", + "\n", + "`FirestoreRecordManager` allows you to index your vectorstore in a Firestore database. You can use it to store references to embeddings from any model, including those from Google Generative AI." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.indexes import index\n", + "from langchain_core.documents import Document\n", + "from langchain_google_firestore import FirestoreVectorStore, FirestoreRecordManager\n", + "from langchain_google_vertexai import VertexAIEmbeddings\n", + "\n", + "namespace = f\"firstore/{COLLECTION_NAME}\"\n", + "record_manager = FirestoreRecordManager(namespace)\n", + "\n", + "embedding = VertexAIEmbeddings(model_name=\"textembedding-gecko@latest\")\n", + "vectorstore = FirestoreVectorStore(\n", + " collection=COLLECTION_NAME,\n", + " embedding_service=embedding\n", + ")\n", + "\n", + "doc1 = Document(page_content=\"test-doc-1-content\", metadata={\"source\": \"test-doc-1.txt\"})\n", + "doc2 = Document(page_content=\"test-doc-2-content\", metadata={\"source\": \"test-doc-2.txt\"})\n", + "doc3 = Document(page_content=\"test-doc-3-content\", metadata={\"source\": \"test-doc-3.txt\"})\n", + "\n", + "results = index(\n", + " [doc1, doc2, doc3],\n", + " record_manager,\n", + " vectorstore,\n", + " cleanup=\"incremental\",\n", + " source_id_key=\"source\",\n", + ")\n", + "\n", + "print(results)" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/langchain_google_firestore/__init__.py b/src/langchain_google_firestore/__init__.py index 3fd863d..0ebf46b 100644 --- a/src/langchain_google_firestore/__init__.py +++ b/src/langchain_google_firestore/__init__.py @@ -15,6 +15,7 @@ from .chat_message_history import FirestoreChatMessageHistory from .document_loader import FirestoreLoader, FirestoreSaver from .vectorstores import FirestoreVectorStore +from .record_manager import FirestoreRecordManager from .version import __version__ __all__ = [ @@ -22,5 +23,6 @@ "FirestoreLoader", "FirestoreSaver", "FirestoreVectorStore", + "FirestoreRecordManager", "__version__", ] diff --git a/src/langchain_google_firestore/record_manager.py b/src/langchain_google_firestore/record_manager.py new file mode 100644 index 0000000..13c42c2 --- /dev/null +++ b/src/langchain_google_firestore/record_manager.py @@ -0,0 +1,232 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import logging +from google.cloud import firestore +from typing import List, Optional, Sequence, Dict +from langchain_core.indexing import RecordManager + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + + +class FirestoreRecordManager(RecordManager): + def __init__( + self, + namespace: str, + collection_name: str = "record_manager", + ) -> None: + super().__init__(namespace=namespace) + self.collection_name = collection_name + self.db = firestore.Client() + self.collection = self.db.collection(self.collection_name) + logger.info( + f"Initialised FirestoreRecordManager with namespace: {namespace}, collection: {collection_name}" + ) + + def create_schema(self) -> None: + logger.info("Skipping schema creation (Firestore is schemaless)") + pass + + async def acreate_schema(self) -> None: + logger.info("Skipping schema creation (Firestore is schemaless)") + pass + + def get_time(self) -> datetime.datetime: + return datetime.datetime.now(datetime.timezone.utc) + + async def aget_time(self) -> datetime.datetime: + return datetime.datetime.now(datetime.timezone.utc) + + def update( + self, + keys: Sequence[str], + *, + group_ids: Optional[Sequence[Optional[str]]] = None, + time_at_least: Optional[float] = None, + ) -> Dict[str, int]: + if group_ids: + logger.info(f"Updating all {len(keys)} records") + else: + logger.info(f"Updating {len(keys)} records") + group_ids = [None] * len(keys) + + batch = self.db.batch() + current_time = self.get_time() + num_updated = 0 + num_added = 0 + + for key, group_id in zip(keys, group_ids): + doc_ref = self.collection.document(key) + doc = doc_ref.get() + + if doc.exists: + num_updated += 1 + if group_id: + logger.info(f"Refreshing timestamp for record: {key}") + else: + logger.info(f"Updating existing record: {key}") + else: + num_added += 1 + logger.info(f"Adding new record: {key}") + + batch.set( + doc_ref, + { + "key": key, + "namespace": self.namespace, + "updated_at": current_time, + "group_id": group_id, + }, + merge=True, + ) + + batch.commit() + logger.info(f"Update complete. Updated: {num_updated}, Added: {num_added}") + + return {"num_updated": num_updated, "num_added": num_added} + + async def aupdate( + self, + keys: Sequence[str], + *, + group_ids: Optional[Sequence[Optional[str]]] = None, + time_at_least: Optional[float] = None, + ) -> Dict[str, int]: + logger.info("Calling synchronous update method") + return self.update(keys, group_ids=group_ids, time_at_least=time_at_least) + + def exists(self, keys: Sequence[str]) -> List[bool]: + logger.info(f"Checking existence of {len(keys)} keys") + result = [False] * len(keys) + key_to_index = {key: i for i, key in enumerate(keys)} + + # Process keys in batches of 30 for Firestore limit + for i in range(0, len(keys), 30): + batch = keys[i : i + 30] + query = self.collection.where( + filter=firestore.FieldFilter("namespace", "==", self.namespace) + ) + query = query.where(filter=firestore.FieldFilter("key", "in", batch)) + docs = query.get() + + for doc in docs: + key = doc.get("key") + if key in key_to_index: + result[key_to_index[key]] = True + + logger.info(f"Existence check complete. Found {sum(result)} records") + return result + + async def aexists(self, keys: Sequence[str]) -> List[bool]: + logger.info("Calling synchronous exists method") + return self.exists(keys) + + def list_keys( + self, + *, + before: Optional[datetime.datetime] = None, + after: Optional[datetime.datetime] = None, + group_ids: Optional[Sequence[str]] = None, + limit: Optional[int] = None, + ) -> List[str]: + logger.info("Listing records with filters") + + all_keys = [] + + # If there are group_ids, process them in batches of 30 for Firestore limit + if group_ids: + for i in range(0, len(group_ids), 30): + batch_group_ids = group_ids[i : i + 30] + keys = self._list_keys_batch(before, after, batch_group_ids, limit) + all_keys.extend(keys) + if limit and len(all_keys) >= limit: + all_keys = all_keys[:limit] + break + else: + all_keys = self._list_keys_batch(before, after, None, limit) + + logger.info(f"Listed {len(all_keys)} records") + return all_keys + + def _list_keys_batch( + self, + before: Optional[datetime.datetime], + after: Optional[datetime.datetime], + group_ids: Optional[Sequence[str]], + limit: Optional[int], + ) -> List[str]: + query = self.collection.where( + filter=firestore.FieldFilter("namespace", "==", self.namespace) + ) + + if after: + query = query.where(filter=firestore.FieldFilter("updated_at", ">", after)) + logger.debug(f"Filtering records after: {after}") + if before: + query = query.where(filter=firestore.FieldFilter("updated_at", "<", before)) + logger.debug(f"Filtering records before: {before}") + if group_ids: + query = query.where( + filter=firestore.FieldFilter("group_id", "in", group_ids) + ) + logger.debug(f"Filtering by group_ids: {group_ids}") + + if limit: + query = query.limit(limit) + logger.debug(f"Limiting results to: {limit}") + + docs = query.get() + keys = [doc.get("key") for doc in docs] + logger.info(f"Listed {len(keys)} records") + return keys + + async def alist_keys( + self, + *, + before: Optional[datetime.datetime] = None, + after: Optional[datetime.datetime] = None, + group_ids: Optional[Sequence[str]] = None, + limit: Optional[int] = None, + ) -> List[str]: + logger.info("Calling synchronous list_keys method") + return self.list_keys( + before=before, after=after, group_ids=group_ids, limit=limit + ) + + def delete_keys(self, keys: Sequence[str]) -> Dict[str, int]: + logger.info(f"Deleting {len(keys)} records") + batch = self.db.batch() + num_deleted = 0 + + for key in keys: + doc_ref = self.collection.document(key) + doc = doc_ref.get() + + if doc.exists: + batch.delete(doc_ref) + num_deleted += 1 + logger.info(f"Deleting record: {key}") + + batch.commit() + logger.info(f"Deletion complete. Deleted {num_deleted} keys") + + return {"num_deleted": num_deleted} + + async def adelete_keys(self, keys: Sequence[str]) -> Dict[str, int]: + logger.info("Calling synchronous delete_keys method") + return self.delete_keys(keys) diff --git a/tests/test_record_manager.py b/tests/test_record_manager.py new file mode 100644 index 0000000..e0f1a67 --- /dev/null +++ b/tests/test_record_manager.py @@ -0,0 +1,147 @@ +import sys +import pytest +from unittest.mock import Mock, patch, MagicMock +from google.cloud import firestore +from langchain_google_firestore import FirestoreRecordManager + +@pytest.fixture(scope="module") +def test_collection(): + python_version = f"{sys.version_info.major}{sys.version_info.minor}" + return f"test_record_manager_{python_version}" + +@pytest.fixture(scope="module") +def mock_firestore_client(): + with patch('google.cloud.firestore.Client', autospec=True) as mock_client: + yield mock_client.return_value + +@pytest.fixture(autouse=True) +def cleanup_firestore(mock_firestore_client): + mock_firestore_client.reset_mock() + +def test_firestore_record_manager_init(test_collection, mock_firestore_client): + namespace = "test_namespace" + record_manager = FirestoreRecordManager(namespace, test_collection) + + assert record_manager.namespace == namespace + assert record_manager.collection_name == test_collection + assert record_manager.db == mock_firestore_client + +def test_firestore_record_manager_update(test_collection, mock_firestore_client): + namespace = "test_namespace" + record_manager = FirestoreRecordManager(namespace, test_collection) + + mock_doc = MagicMock() + mock_doc.exists = False + mock_firestore_client.collection.return_value.document.return_value.get.return_value = mock_doc + + keys = ["key1", "key2"] + group_ids = ["group1", "group2"] + + result = record_manager.update(keys, group_ids=group_ids) + + assert result["num_added"] == 2 + assert result["num_updated"] == 0 + + mock_doc.exists = True + result = record_manager.update(keys, group_ids=group_ids) + + assert result["num_added"] == 0 + assert result["num_updated"] == 2 + +def test_firestore_record_manager_exists(test_collection, mock_firestore_client): + namespace = "test_namespace" + record_manager = FirestoreRecordManager(namespace, test_collection) + + mock_docs = [ + MagicMock(get=lambda key: "key1" if key == "key" else None), + MagicMock(get=lambda key: "key2" if key == "key" else None) + ] + mock_firestore_client.collection.return_value.where.return_value.where.return_value.get.return_value = mock_docs + + keys = ["key1", "key2", "key3"] + + result = record_manager.exists(keys) + + assert result == [True, True, False] + +def test_firestore_record_manager_list_keys(test_collection, mock_firestore_client): + namespace = "test_namespace" + record_manager = FirestoreRecordManager(namespace, test_collection) + + mock_docs = [ + MagicMock(get=lambda key: "key1" if key == "key" else None), + MagicMock(get=lambda key: "key2" if key == "key" else None), + MagicMock(get=lambda key: "key3" if key == "key" else None), + ] + + mock_firestore_client.collection.return_value.where.return_value.get.return_value = mock_docs + + result = record_manager.list_keys() + assert set(result) == {"key1", "key2", "key3"} + + mock_firestore_client.collection.return_value.where.return_value.where.return_value.get.return_value = mock_docs[:2] + + result = record_manager.list_keys(group_ids=["group1"]) + assert set(result) == {"key1", "key2"} + + mock_firestore_client.collection.return_value.where.return_value.limit.return_value.get.return_value = mock_docs[:2] + + result = record_manager.list_keys(limit=2) + assert len(result) == 2 + +def test_firestore_record_manager_delete_keys(test_collection, mock_firestore_client): + namespace = "test_namespace" + record_manager = FirestoreRecordManager(namespace, test_collection) + + mock_doc1 = Mock(exists=True) + mock_doc2 = Mock(exists=True) + mock_doc3 = Mock(exists=False) + + mock_collection = mock_firestore_client.collection.return_value + mock_document_refs = [Mock(), Mock(), Mock()] + mock_collection.document.side_effect = mock_document_refs + + mock_document_refs[0].get.return_value = mock_doc1 + mock_document_refs[1].get.return_value = mock_doc2 + mock_document_refs[2].get.return_value = mock_doc3 + + mock_batch = Mock() + mock_firestore_client.batch.return_value = mock_batch + + keys = ["key1", "key2", "key3"] + + result = record_manager.delete_keys(keys) + + assert mock_batch.delete.call_count == 2 + mock_batch.delete.assert_any_call(mock_document_refs[0]) + mock_batch.delete.assert_any_call(mock_document_refs[1]) + + assert mock_document_refs[2] not in [call[0][0] for call in mock_batch.delete.call_args_list] + + mock_batch.commit.assert_called_once() + + assert result["num_deleted"] == 2 + +@pytest.mark.asyncio +async def test_firestore_record_manager_async_methods(test_collection, mock_firestore_client): + namespace = "test_namespace" + record_manager = FirestoreRecordManager(namespace, test_collection) + + record_manager.aupdate = MagicMock(return_value={"num_added": 2, "num_updated": 0}) + record_manager.aexists = MagicMock(return_value=[True, True, False]) + record_manager.alist_keys = MagicMock(return_value=["key1", "key2"]) + record_manager.adelete_keys = MagicMock(return_value={"num_deleted": 2}) + + keys = ["key1", "key2"] + + result = await record_manager.aupdate(keys) + assert result["num_added"] == 2 + + exists_result = await record_manager.aexists(keys + ["key3"]) + assert exists_result == [True, True, False] + + list_result = await record_manager.alist_keys() + assert set(list_result) == set(keys) + + delete_result = await record_manager.adelete_keys(keys) + assert delete_result["num_deleted"] == 2 \ No newline at end of file
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: