Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
275 changes: 275 additions & 0 deletions agentplatform/_genai/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,19 @@ def _GetCorpusOperationParameters_to_vertex(
return to_object


def _GetRagConfigOperationParameters_to_vertex(
from_object: Union[dict[str, Any], object],
parent_object: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
to_object: dict[str, Any] = {}
if getv(from_object, ["operation_name"]) is not None:
setv(
to_object, ["_url", "operation_name"], getv(from_object, ["operation_name"])
)

return to_object


def _GetRagConfigRequestParameters_to_vertex(
from_object: Union[dict[str, Any], object],
parent_object: Optional[dict[str, Any]] = None,
Expand Down Expand Up @@ -2300,6 +2313,74 @@ def retrieve_contexts(
self._api_client._verify_response(return_value)
return return_value

def _get_rag_config_operation(
self,
*,
operation_name: str,
config: Optional[types.GetRagConfigOperationConfigOrDict] = None,
) -> types.RagEngineConfigOperation:
parameter_model = types._GetRagConfigOperationParameters(
operation_name=operation_name,
config=config,
)

request_url_dict: Optional[dict[str, str]]
if not self._api_client.vertexai:
raise ValueError(
"This method is only supported in Gemini Enterprise Agent Platform mode, not in Gemini Developer API mode."
)
else:
request_dict = _GetRagConfigOperationParameters_to_vertex(parameter_model)
request_url_dict = request_dict.get("_url")
if request_url_dict:
path = "{operation_name}".format_map(request_url_dict)
else:
path = "{operation_name}"

query_params = request_dict.get("_query")
if query_params:
path = f"{path}?{urlencode(query_params)}"
# TODO: remove the hack that pops config.
request_dict.pop("config", None)

http_options: Optional[types.HttpOptions] = None
if (
parameter_model.config is not None
and parameter_model.config.http_options is not None
):
http_options = parameter_model.config.http_options

request_dict = _common.convert_to_dict(request_dict)
request_dict = _common.encode_unserializable_types(request_dict)

response = self._api_client.request("get", path, request_dict, http_options)

response_dict = {} if not response.body else json.loads(response.body)

return_value = types.RagEngineConfigOperation._from_response(
response=response_dict,
kwargs=(
{
"config": {
"response_schema": getattr(
parameter_model.config, "response_schema", None
),
"response_json_schema": getattr(
parameter_model.config, "response_json_schema", None
),
"include_all_fields": getattr(
parameter_model.config, "include_all_fields", None
),
}
}
if getattr(parameter_model, "config", None)
else {}
),
)

self._api_client._verify_response(return_value)
return return_value

def create_corpus(
self,
*,
Expand Down Expand Up @@ -2372,6 +2453,67 @@ def delete_file(

return None

def update_corpus(
self,
*,
name: str,
rag_corpus: types.RagCorpusOrDict,
config: Optional[types.UpdateRagCorpusConfigOrDict] = None,
) -> types.RagCorpus:
"""
Updates a Rag Corpus and waits for completion.

Args:
name: The name of the RagCorpus to update, formatted as
`projects/{project}/locations/{location}/ragCorpora/{corpus_id}`.
rag_corpus: The RagCorpus to update.
config: The configuration to use for the RagCorpus update request.

Returns:
The updated RagCorpus.
"""
operation = self._update_corpus(name=name, rag_corpus=rag_corpus, config=config)

operation = _operations_utils.await_operation(
operation_name=operation.name,
get_operation_fn=self._get_corpus_operation,
)

if operation.error:
raise RuntimeError(f"Failed to update RagCorpus: {operation.error}")

return self.get_corpus(name=operation.response.name)

def update_config(
self,
*,
updated_config: types.RagEngineConfigOrDict,
request_config: Optional[types.UpdateRagConfigOrDict] = None,
) -> types.RagEngineConfig:
"""
Updates a RagEngineConfig and waits for completion.

Args:
updated_config: The RagEngineConfig to update.
request_config: The configuration to use for the RagEngineConfig update request.

Returns:
The updated RagEngineConfig.
"""
operation = self._update_config(
updated_config=updated_config, config=request_config
)

operation = _operations_utils.await_operation(
operation_name=operation.name,
get_operation_fn=self._get_rag_config_operation,
)

if operation.error:
raise RuntimeError(f"Failed to update RagEngineConfig: {operation.error}")

return self.get_config()


class AsyncRag(_api_module.BaseModule):

Expand Down Expand Up @@ -3333,6 +3475,76 @@ async def retrieve_contexts(
self._api_client._verify_response(return_value)
return return_value

async def _get_rag_config_operation(
self,
*,
operation_name: str,
config: Optional[types.GetRagConfigOperationConfigOrDict] = None,
) -> types.RagEngineConfigOperation:
parameter_model = types._GetRagConfigOperationParameters(
operation_name=operation_name,
config=config,
)

request_url_dict: Optional[dict[str, str]]
if not self._api_client.vertexai:
raise ValueError(
"This method is only supported in Gemini Enterprise Agent Platform mode, not in Gemini Developer API mode."
)
else:
request_dict = _GetRagConfigOperationParameters_to_vertex(parameter_model)
request_url_dict = request_dict.get("_url")
if request_url_dict:
path = "{operation_name}".format_map(request_url_dict)
else:
path = "{operation_name}"

query_params = request_dict.get("_query")
if query_params:
path = f"{path}?{urlencode(query_params)}"
# TODO: remove the hack that pops config.
request_dict.pop("config", None)

http_options: Optional[types.HttpOptions] = None
if (
parameter_model.config is not None
and parameter_model.config.http_options is not None
):
http_options = parameter_model.config.http_options

request_dict = _common.convert_to_dict(request_dict)
request_dict = _common.encode_unserializable_types(request_dict)

response = await self._api_client.async_request(
"get", path, request_dict, http_options
)

response_dict = {} if not response.body else json.loads(response.body)

return_value = types.RagEngineConfigOperation._from_response(
response=response_dict,
kwargs=(
{
"config": {
"response_schema": getattr(
parameter_model.config, "response_schema", None
),
"response_json_schema": getattr(
parameter_model.config, "response_json_schema", None
),
"include_all_fields": getattr(
parameter_model.config, "include_all_fields", None
),
}
}
if getattr(parameter_model, "config", None)
else {}
),
)

self._api_client._verify_response(return_value)
return return_value

async def create_corpus(
self,
*,
Expand Down Expand Up @@ -3404,3 +3616,66 @@ async def delete_file(
)

return None

async def update_corpus(
self,
*,
name: str,
rag_corpus: types.RagCorpusOrDict,
config: Optional[types.UpdateRagCorpusConfigOrDict] = None,
) -> types.RagCorpus:
"""
Updates a Rag Corpus and waits for completion asynchronously.

Args:
name: The name of the RagCorpus to update, formatted as
`projects/{project}/locations/{location}/ragCorpora/{corpus_id}`.
rag_corpus: The RagCorpus to update.
config: The configuration to use for the RagCorpus update request.

Returns:
The updated RagCorpus.
"""
operation = await self._update_corpus(
name=name, rag_corpus=rag_corpus, config=config
)

operation = await _operations_utils.await_operation_async(
operation_name=operation.name,
get_operation_fn=self._get_corpus_operation,
)

if operation.error:
raise RuntimeError(f"Failed to update RagCorpus: {operation.error}")

return await self.get_corpus(name=operation.response.name)

async def update_config(
self,
*,
updated_config: types.RagEngineConfigOrDict,
request_config: Optional[types.UpdateRagConfigOrDict] = None,
) -> types.RagEngineConfig:
"""
Updates a RagEngineConfig and waits for completion asynchronously.

Args:
updated_config: The RagEngineConfig to update.
request_config: The configuration to use for the RagEngineConfig update request.

Returns:
The updated RagEngineConfig.
"""
operation = await self._update_config(
updated_config=updated_config, config=request_config
)

operation = await _operations_utils.await_operation_async(
operation_name=operation.name,
get_operation_fn=self._get_rag_config_operation,
)

if operation.error:
raise RuntimeError(f"Failed to update RagEngineConfig: {operation.error}")

return await self.get_config()
14 changes: 14 additions & 0 deletions agentplatform/_genai/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
from .common import _GetEvaluationSetParameters
from .common import _GetMultimodalDatasetOperationParameters
from .common import _GetMultimodalDatasetParameters
from .common import _GetRagConfigOperationParameters
from .common import _GetRagConfigRequestParameters
from .common import _GetRagCorpusRequestParameters
from .common import _GetRagFileRequestParameters
Expand Down Expand Up @@ -692,6 +693,9 @@
from .common import GetPromptConfigOrDict
from .common import GetRagConfig
from .common import GetRagConfigDict
from .common import GetRagConfigOperationConfig
from .common import GetRagConfigOperationConfigDict
from .common import GetRagConfigOperationConfigOrDict
from .common import GetRagConfigOrDict
from .common import GetRagCorpusConfig
from .common import GetRagCorpusConfigDict
Expand Down Expand Up @@ -1108,6 +1112,9 @@
from .common import RagEmbeddingModelConfigVertexPredictionEndpointOrDict
from .common import RagEngineConfig
from .common import RagEngineConfigDict
from .common import RagEngineConfigOperation
from .common import RagEngineConfigOperationDict
from .common import RagEngineConfigOperationOrDict
from .common import RagEngineConfigOrDict
from .common import RagFile
from .common import RagFileDict
Expand Down Expand Up @@ -2731,6 +2738,12 @@
"RetrieveContextsResponse",
"RetrieveContextsResponseDict",
"RetrieveContextsResponseOrDict",
"GetRagConfigOperationConfig",
"GetRagConfigOperationConfigDict",
"GetRagConfigOperationConfigOrDict",
"RagEngineConfigOperation",
"RagEngineConfigOperationDict",
"RagEngineConfigOperationOrDict",
"GetAgentEngineRuntimeRevisionConfig",
"GetAgentEngineRuntimeRevisionConfigDict",
"GetAgentEngineRuntimeRevisionConfigOrDict",
Expand Down Expand Up @@ -3334,6 +3347,7 @@
"_DeleteRagFileRequestParameters",
"_UpdateRagConfigRequestParameters",
"_RetrieveRagContextsRequestParameters",
"_GetRagConfigOperationParameters",
"_GetAgentEngineRuntimeRevisionRequestParameters",
"_ListAgentEngineRuntimeRevisionsRequestParameters",
"_DeleteAgentEngineRuntimeRevisionRequestParameters",
Expand Down
Loading
Loading