diff --git a/agentplatform/_genai/rag.py b/agentplatform/_genai/rag.py index f00db97a5e..ecc0b93a88 100644 --- a/agentplatform/_genai/rag.py +++ b/agentplatform/_genai/rag.py @@ -135,6 +135,19 @@ def _GetCorpusOperationParameters_to_vertex( return to_object +def _GetRagConfigOperationParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["operation_name"]) is not None: + setv( + to_object, ["_url", "operation_name"], getv(from_object, ["operation_name"]) + ) + + return to_object + + def _GetRagConfigRequestParameters_to_vertex( from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, @@ -2300,6 +2313,74 @@ def retrieve_contexts( self._api_client._verify_response(return_value) return return_value + def _get_rag_config_operation( + self, + *, + operation_name: str, + config: Optional[types.GetRagConfigOperationConfigOrDict] = None, + ) -> types.RagEngineConfigOperation: + parameter_model = types._GetRagConfigOperationParameters( + operation_name=operation_name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in Gemini Enterprise Agent Platform mode, not in Gemini Developer API mode." + ) + else: + request_dict = _GetRagConfigOperationParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{operation_name}".format_map(request_url_dict) + else: + path = "{operation_name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("get", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.RagEngineConfigOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + def create_corpus( self, *, @@ -2372,6 +2453,67 @@ def delete_file( return None + def update_corpus( + self, + *, + name: str, + rag_corpus: types.RagCorpusOrDict, + config: Optional[types.UpdateRagCorpusConfigOrDict] = None, + ) -> types.RagCorpus: + """ + Updates a Rag Corpus and waits for completion. + + Args: + name: The name of the RagCorpus to update, formatted as + `projects/{project}/locations/{location}/ragCorpora/{corpus_id}`. + rag_corpus: The RagCorpus to update. + config: The configuration to use for the RagCorpus update request. + + Returns: + The updated RagCorpus. + """ + operation = self._update_corpus(name=name, rag_corpus=rag_corpus, config=config) + + operation = _operations_utils.await_operation( + operation_name=operation.name, + get_operation_fn=self._get_corpus_operation, + ) + + if operation.error: + raise RuntimeError(f"Failed to update RagCorpus: {operation.error}") + + return self.get_corpus(name=operation.response.name) + + def update_config( + self, + *, + updated_config: types.RagEngineConfigOrDict, + request_config: Optional[types.UpdateRagConfigOrDict] = None, + ) -> types.RagEngineConfig: + """ + Updates a RagEngineConfig and waits for completion. + + Args: + updated_config: The RagEngineConfig to update. + request_config: The configuration to use for the RagEngineConfig update request. + + Returns: + The updated RagEngineConfig. + """ + operation = self._update_config( + updated_config=updated_config, config=request_config + ) + + operation = _operations_utils.await_operation( + operation_name=operation.name, + get_operation_fn=self._get_rag_config_operation, + ) + + if operation.error: + raise RuntimeError(f"Failed to update RagEngineConfig: {operation.error}") + + return self.get_config() + class AsyncRag(_api_module.BaseModule): @@ -3333,6 +3475,76 @@ async def retrieve_contexts( self._api_client._verify_response(return_value) return return_value + async def _get_rag_config_operation( + self, + *, + operation_name: str, + config: Optional[types.GetRagConfigOperationConfigOrDict] = None, + ) -> types.RagEngineConfigOperation: + parameter_model = types._GetRagConfigOperationParameters( + operation_name=operation_name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in Gemini Enterprise Agent Platform mode, not in Gemini Developer API mode." + ) + else: + request_dict = _GetRagConfigOperationParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{operation_name}".format_map(request_url_dict) + else: + path = "{operation_name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.RagEngineConfigOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + async def create_corpus( self, *, @@ -3404,3 +3616,66 @@ async def delete_file( ) return None + + async def update_corpus( + self, + *, + name: str, + rag_corpus: types.RagCorpusOrDict, + config: Optional[types.UpdateRagCorpusConfigOrDict] = None, + ) -> types.RagCorpus: + """ + Updates a Rag Corpus and waits for completion asynchronously. + + Args: + name: The name of the RagCorpus to update, formatted as + `projects/{project}/locations/{location}/ragCorpora/{corpus_id}`. + rag_corpus: The RagCorpus to update. + config: The configuration to use for the RagCorpus update request. + + Returns: + The updated RagCorpus. + """ + operation = await self._update_corpus( + name=name, rag_corpus=rag_corpus, config=config + ) + + operation = await _operations_utils.await_operation_async( + operation_name=operation.name, + get_operation_fn=self._get_corpus_operation, + ) + + if operation.error: + raise RuntimeError(f"Failed to update RagCorpus: {operation.error}") + + return await self.get_corpus(name=operation.response.name) + + async def update_config( + self, + *, + updated_config: types.RagEngineConfigOrDict, + request_config: Optional[types.UpdateRagConfigOrDict] = None, + ) -> types.RagEngineConfig: + """ + Updates a RagEngineConfig and waits for completion asynchronously. + + Args: + updated_config: The RagEngineConfig to update. + request_config: The configuration to use for the RagEngineConfig update request. + + Returns: + The updated RagEngineConfig. + """ + operation = await self._update_config( + updated_config=updated_config, config=request_config + ) + + operation = await _operations_utils.await_operation_async( + operation_name=operation.name, + get_operation_fn=self._get_rag_config_operation, + ) + + if operation.error: + raise RuntimeError(f"Failed to update RagEngineConfig: {operation.error}") + + return await self.get_config() diff --git a/agentplatform/_genai/types/__init__.py b/agentplatform/_genai/types/__init__.py index 8f4787820f..8b1e8b52cf 100644 --- a/agentplatform/_genai/types/__init__.py +++ b/agentplatform/_genai/types/__init__.py @@ -94,6 +94,7 @@ from .common import _GetEvaluationSetParameters from .common import _GetMultimodalDatasetOperationParameters from .common import _GetMultimodalDatasetParameters +from .common import _GetRagConfigOperationParameters from .common import _GetRagConfigRequestParameters from .common import _GetRagCorpusRequestParameters from .common import _GetRagFileRequestParameters @@ -692,6 +693,9 @@ from .common import GetPromptConfigOrDict from .common import GetRagConfig from .common import GetRagConfigDict +from .common import GetRagConfigOperationConfig +from .common import GetRagConfigOperationConfigDict +from .common import GetRagConfigOperationConfigOrDict from .common import GetRagConfigOrDict from .common import GetRagCorpusConfig from .common import GetRagCorpusConfigDict @@ -1108,6 +1112,9 @@ from .common import RagEmbeddingModelConfigVertexPredictionEndpointOrDict from .common import RagEngineConfig from .common import RagEngineConfigDict +from .common import RagEngineConfigOperation +from .common import RagEngineConfigOperationDict +from .common import RagEngineConfigOperationOrDict from .common import RagEngineConfigOrDict from .common import RagFile from .common import RagFileDict @@ -2731,6 +2738,12 @@ "RetrieveContextsResponse", "RetrieveContextsResponseDict", "RetrieveContextsResponseOrDict", + "GetRagConfigOperationConfig", + "GetRagConfigOperationConfigDict", + "GetRagConfigOperationConfigOrDict", + "RagEngineConfigOperation", + "RagEngineConfigOperationDict", + "RagEngineConfigOperationOrDict", "GetAgentEngineRuntimeRevisionConfig", "GetAgentEngineRuntimeRevisionConfigDict", "GetAgentEngineRuntimeRevisionConfigOrDict", @@ -3334,6 +3347,7 @@ "_DeleteRagFileRequestParameters", "_UpdateRagConfigRequestParameters", "_RetrieveRagContextsRequestParameters", + "_GetRagConfigOperationParameters", "_GetAgentEngineRuntimeRevisionRequestParameters", "_ListAgentEngineRuntimeRevisionsRequestParameters", "_DeleteAgentEngineRuntimeRevisionRequestParameters", diff --git a/agentplatform/_genai/types/common.py b/agentplatform/_genai/types/common.py index e5638c98ef..ec37a5c3f8 100644 --- a/agentplatform/_genai/types/common.py +++ b/agentplatform/_genai/types/common.py @@ -14654,6 +14654,94 @@ class RetrieveContextsResponseDict(TypedDict, total=False): ] +class GetRagConfigOperationConfig(_common.BaseModel): + """Config for getting a RAG config operation.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class GetRagConfigOperationConfigDict(TypedDict, total=False): + """Config for getting a RAG config operation.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + +GetRagConfigOperationConfigOrDict = Union[ + GetRagConfigOperationConfig, GetRagConfigOperationConfigDict +] + + +class _GetRagConfigOperationParameters(_common.BaseModel): + """Parameters for getting a RAG config operation.""" + + operation_name: Optional[str] = Field( + default=None, description="""The server-assigned name for the operation.""" + ) + config: Optional[GetRagConfigOperationConfig] = Field( + default=None, description="""Used to override the default configuration.""" + ) + + +class _GetRagConfigOperationParametersDict(TypedDict, total=False): + """Parameters for getting a RAG config operation.""" + + operation_name: Optional[str] + """The server-assigned name for the operation.""" + + config: Optional[GetRagConfigOperationConfigDict] + """Used to override the default configuration.""" + + +_GetRagConfigOperationParametersOrDict = Union[ + _GetRagConfigOperationParameters, _GetRagConfigOperationParametersDict +] + + +class RagEngineConfigOperation(_common.BaseModel): + """Operation for getting a RAG config.""" + + name: Optional[str] = Field( + default=None, + description="""The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""", + ) + metadata: Optional[dict[str, Any]] = Field( + default=None, + description="""Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""", + ) + done: Optional[bool] = Field( + default=None, + description="""If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""", + ) + error: Optional[dict[str, Any]] = Field( + default=None, + description="""The error result of the operation in case of failure or cancellation.""", + ) + + +class RagEngineConfigOperationDict(TypedDict, total=False): + """Operation for getting a RAG config.""" + + name: Optional[str] + """The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""" + + metadata: Optional[dict[str, Any]] + """Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""" + + done: Optional[bool] + """If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""" + + error: Optional[dict[str, Any]] + """The error result of the operation in case of failure or cancellation.""" + + +RagEngineConfigOperationOrDict = Union[ + RagEngineConfigOperation, RagEngineConfigOperationDict +] + + class GetAgentEngineRuntimeRevisionConfig(_common.BaseModel): """Config for getting an Agent Engine Runtime Revision.""" diff --git a/tests/unit/agentplatform/genai/replays/test_rag_update.py b/tests/unit/agentplatform/genai/replays/test_rag_update.py index 572bc56bd1..758b634b88 100644 --- a/tests/unit/agentplatform/genai/replays/test_rag_update.py +++ b/tests/unit/agentplatform/genai/replays/test_rag_update.py @@ -42,6 +42,33 @@ def test_update_rag_corpus_private(client): assert isinstance(corpus_op, types.UpdateRagCorpusOperation) +def test_update_rag_corpus(client): + search_config = types.VertexAiSearchConfig( + serving_config="projects/vertex-sdk-dev/locations/us-central1/collections/default_collection/engines/test-engine/servingConfigs/default_serving_config" + ) + + # Create a corpus to update + corpus = client.rag.create_corpus( + rag_corpus=types.RagCorpus( + display_name="My Test Corpus", + description="My Test Corpus Description", + vertex_ai_search_config=search_config, + ), + ) + + updated_corpus = client.rag.update_corpus( + name=corpus.name, + rag_corpus=types.RagCorpus( + display_name="My Updated Vertex AI Search Test Corpus", + description="My Updated Test Corpus Description", + vertex_ai_search_config=search_config, + ), + ) + + assert updated_corpus.display_name == "My Updated Vertex AI Search Test Corpus" + assert updated_corpus.description == "My Updated Test Corpus Description" + + pytest_plugins = ("pytest_asyncio",) @@ -59,3 +86,31 @@ async def test_update_rag_corpus_private_async(client): ) assert isinstance(corpus_op, types.UpdateRagCorpusOperation) + + +@pytest.mark.asyncio +async def test_update_rag_corpus_async(client): + search_config = types.VertexAiSearchConfig( + serving_config="projects/vertex-sdk-dev/locations/us-central1/collections/default_collection/engines/test-engine/servingConfigs/default_serving_config" + ) + + # Create a corpus to update + corpus = await client.aio.rag.create_corpus( + rag_corpus=types.RagCorpus( + display_name="My Test Corpus", + description="My Test Corpus Description", + vertex_ai_search_config=search_config, + ), + ) + + updated_corpus = await client.aio.rag.update_corpus( + name=corpus.name, + rag_corpus=types.RagCorpus( + display_name="My Updated Vertex AI Search Test Corpus", + description="My Updated Test Corpus Description", + vertex_ai_search_config=search_config, + ), + ) + + assert updated_corpus.display_name == "My Updated Vertex AI Search Test Corpus" + assert updated_corpus.description == "My Updated Test Corpus Description" diff --git a/tests/unit/agentplatform/genai/replays/test_rag_update_config.py b/tests/unit/agentplatform/genai/replays/test_rag_update_config.py index 4d81241440..4afe088cc6 100644 --- a/tests/unit/agentplatform/genai/replays/test_rag_update_config.py +++ b/tests/unit/agentplatform/genai/replays/test_rag_update_config.py @@ -39,6 +39,20 @@ def test_update_rag_config_private(client): assert isinstance(config_op, types.UpdateRagConfigOperation) +def test_config_update(client): + updated_config = client.rag.update_config( + updated_config=types.RagEngineConfig( + name="projects/vertex-sdk-dev/locations/us-central1/ragEngineConfig", + ) + ) + + assert isinstance(updated_config, types.RagEngineConfig) + assert ( + updated_config.name + == "projects/vertex-sdk-dev/locations/us-central1/ragEngineConfig" + ) + + pytest_plugins = ("pytest_asyncio",) @@ -54,3 +68,18 @@ async def test_update_rag_config_private_async(client): ) assert isinstance(config_op, types.UpdateRagConfigOperation) + + +@pytest.mark.asyncio +async def test_config_update_async(client): + updated_config = await client.aio.rag.update_config( + updated_config=types.RagEngineConfig( + name="projects/vertex-sdk-dev/locations/us-central1/ragEngineConfig", + ) + ) + + assert isinstance(updated_config, types.RagEngineConfig) + assert ( + updated_config.name + == "projects/vertex-sdk-dev/locations/us-central1/ragEngineConfig" + )