diff --git a/.gitignore b/.gitignore
index bec1719..7315bd4 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,6 +1,13 @@
# specific
output/
corpus/
+.history/
+dataset/
+broken_db/
+examples/Oracle_SQLPGQ_Instance/
+examples/generated_corpus/oracle_sqlpgq_*.json
+examples/generated_corpus/cypher_to_oracle_sqlpgq*.json
+test_oracle_sqlpgq_query.json
# Byte-compiled / optimized / DLL files
__pycache__/
@@ -168,4 +175,4 @@ cython_debug/
#.idea/
# poetry
-poetry.lock
\ No newline at end of file
+poetry.lock
diff --git a/README.md b/README.md
index 01b6d47..d00c05c 100644
--- a/README.md
+++ b/README.md
@@ -37,7 +37,7 @@ Awesome-Text2GQL is an AI-assisted framework for Text2GQL dataset construction.
### Generated Benchmark Dataset
-The [Text2GQL-Bench](https://arxiv.org/abs/2602.11745)'s dataset is generated by Awesome-Text2GQL framework. It contains 178,184 (Question, Query) pairs spanning 13 domains. The dataset is available at [Text2GQL-Bench_dataset](https://tugraph-web.oss-cn-beijing.aliyuncs.com/tugraph/datasets/text2gql/Text2GraphQueryBenchmark/Text2GQL-Bench_dataset.zip). To run Text2GQL test, please refer to our [Text2GraphQuery-Driver](https://github.com/TuGraph-family/text2graphquery-driver/tree/main).
+The [Text2GQL-Bench](https://arxiv.org/abs/2602.11745)'s dataset is generated by Awesome-Text2GQL framework. It contains 178,184 (Question, Query) pairs spanning 13 domains. The dataset is available at [Text2GQL-Bench_dataset](https://tugraph-web.oss-cn-beijing.aliyuncs.com/tugraph/datasets/text2gql/Text2GraphQueryBenchmark/Text2GQL-Bench_dataset.zip). The dataset including the Oracle SQL/PGQ translated queries is available at [Dataset-with-SQL/PGQ](https://objectstorage.us-ashburn-1.oraclecloud.com/p/8dIkuVGsfnRQlP3ifxVDjQP0pmidpadEY18ltEbkPC4PrZyLTxjJdqDjbtWIEYUW/n/ogcs/b/Text2GQL-Bench_dataset/o/Text2GQL-Bench_dataset.zip), it includes 19633 out of 22407 existing queries. To run Text2GQL test, please refer to our [Text2GraphQuery-Driver](https://github.com/TuGraph-family/text2graphquery-driver/tree/main).
## Demo: TuGraph-DB ChatBot
@@ -195,6 +195,15 @@ After all, run:
When the script finishes, the generated corpus will be saved to examples/generated_corpus/{graph_name}_template_corpus.json.
+#### Oracle SQL Property Graphs (SQL/PGQ)
+
+Awesome-Text2GQL includes Oracle SQL/PGQ support for schema conversion, graph setup, query translation, corpus generation, validation, and benchmark dataset preparation.
+
+For detailed workflows, see:
+
+- [Oracle SQL/PGQ data generation workflow](./doc/en-us/development/oracle_sqlpgq_data_generation_workflow.md): convert framework/TuGraph-style schemas into Oracle SQL/PGQ artifacts, create local Oracle property graphs, generate deterministic and LLM-based corpora, validate generated queries, and combine corpus outputs.
+- [Dataset preparation utilities](./dataset_prep/README.md): translate benchmark Cypher/GQL-like records to Oracle SQL/PGQ, optionally validate them against Oracle, analyze failures, compare Oracle SQL/PGQ results with Neo4j, and export validated datasets.
+
#### Cypher2GQL
`python ./examples/cypher2gql.py`
diff --git a/app/core/clauses/match_clause.py b/app/core/clauses/match_clause.py
index 6ea30fb..88a4197 100644
--- a/app/core/clauses/match_clause.py
+++ b/app/core/clauses/match_clause.py
@@ -24,14 +24,16 @@ class EdgePattern:
class PathPattern:
node_pattern_list: List[NodePattern]
edge_pattern_list: List[EdgePattern]
+ path_variable: str = ""
class MatchClause(Clause):
- def __init__(self, path_pattern: PathPattern):
+ def __init__(self, path_pattern: PathPattern, optional: bool = False):
self.path_pattern = path_pattern
+ self.optional = optional
def to_string(self) -> str:
- match_string = "MATCH "
+ match_string = "OPTIONAL MATCH " if self.optional else "MATCH "
path_degree = len(self.path_pattern.edge_pattern_list)
# add first node
node_pattern = self.path_pattern.node_pattern_list[0]
@@ -51,7 +53,7 @@ def to_string(self) -> str:
return match_string
def to_string_cypher(self) -> str:
- match_string = "MATCH "
+ match_string = "OPTIONAL MATCH " if self.optional else "MATCH "
path_degree = len(self.path_pattern.edge_pattern_list)
# add first node
node_pattern = self.path_pattern.node_pattern_list[0]
@@ -82,7 +84,7 @@ def to_string_cypher(self) -> str:
return match_string
def to_string_gql(self) -> str:
- match_string = "MATCH "
+ match_string = "OPTIONAL MATCH " if self.optional else "MATCH "
path_degree = len(self.path_pattern.edge_pattern_list)
# add first node
node_pattern = self.path_pattern.node_pattern_list[0]
diff --git a/app/core/clauses/return_clause.py b/app/core/clauses/return_clause.py
index 5a77211..634231c 100644
--- a/app/core/clauses/return_clause.py
+++ b/app/core/clauses/return_clause.py
@@ -10,6 +10,7 @@ class ReturnItem:
property: str
alias: str
function_name: str = ""
+ expression: str = ""
@dataclass
@@ -18,6 +19,7 @@ class SortItem:
property: str
order: str
function_name: str = ""
+ expression: str = ""
@dataclass
diff --git a/app/core/clauses/where_clause.py b/app/core/clauses/where_clause.py
index db01918..84ee1cf 100644
--- a/app/core/clauses/where_clause.py
+++ b/app/core/clauses/where_clause.py
@@ -10,6 +10,7 @@ class CompareExpression:
property: tuple[str, Dict]
comparison_type: str
comparison_value: str
+ raw_expression: str = ""
class WhereClause(Clause):
diff --git a/app/core/generalizer/query_generalizer.py b/app/core/generalizer/query_generalizer.py
index 23581cd..77b372a 100644
--- a/app/core/generalizer/query_generalizer.py
+++ b/app/core/generalizer/query_generalizer.py
@@ -6,14 +6,29 @@
from app.core.clauses.where_clause import CompareExpression, WhereClause
from app.core.schema.schema_graph import SchemaGraph
from app.core.schema.schema_parser import SchemaParser
+from app.impl.oracle_sqlpgq.schema.schema_parser import OracleSqlPgqSchemaParser
from app.impl.tugraph_cypher.schema.schema_parser import TuGraphSchemaParser
class QueryGeneralizer:
- def __init__(self, db_id, instance_path):
+ SCHEMA_PARSERS = {
+ "tugraph_cypher": TuGraphSchemaParser,
+ "tugraph": TuGraphSchemaParser,
+ "oracle_sqlpgq": OracleSqlPgqSchemaParser,
+ "oracle": OracleSqlPgqSchemaParser,
+ }
+
+ def __init__(self, db_id, instance_path, backend: str = "tugraph_cypher"):
self.db_id = db_id
self.instance_path = instance_path
- self.schema_parser: SchemaParser = TuGraphSchemaParser(db_id, instance_path)
+ self.backend = backend
+ parser_class = self.SCHEMA_PARSERS.get(backend)
+ if parser_class is None:
+ supported = ", ".join(sorted(self.SCHEMA_PARSERS))
+ raise ValueError(
+ f"Unsupported schema backend '{backend}'. Supported backends: {supported}"
+ )
+ self.schema_parser: SchemaParser = parser_class(db_id, instance_path)
self.schema_graph: SchemaGraph = self.schema_parser.get_schema_graph()
def generalize(self, query_pattern: List[Clause]) -> List[str]:
@@ -54,6 +69,11 @@ def generalize_from_llm(self, query_template: str) -> List[str]:
def generalize_from_cypher(self, query_template: str) -> List[str]:
# TODO: use original awesome-text2gql to generalize new query.
+ if self.backend not in {"tugraph_cypher", "tugraph"}:
+ raise NotImplementedError(
+ "generalize_from_cypher is backed by the TuGraph Cypher generalizer. "
+ "Use get_query_pattern + generalize + an Oracle translator for oracle_sqlpgq."
+ )
from app.impl.tugraph_cypher.generalizer.graph_query_generalizer import (
GraphQueryGeneralizer as CypherGeneralizer,
)
diff --git a/app/core/generator/corpus_generator.py b/app/core/generator/corpus_generator.py
index 6ca45be..f69ea60 100644
--- a/app/core/generator/corpus_generator.py
+++ b/app/core/generator/corpus_generator.py
@@ -8,11 +8,46 @@
class CorpusGenerator:
- def __init__(self, llm_client: LlmClient):
+ def __init__(
+ self,
+ llm_client: LlmClient,
+ query_language: str = "cypher",
+ graph_name: str | None = None,
+ ):
self.llm_client = llm_client
+ self.query_language = query_language.lower()
+ self.graph_name = graph_name
+
+ def _system_prompt(self) -> str:
+ if self.query_language in {"oracle_sqlpgq", "sqlpgq", "sql/pgq"}:
+ return corpus.SQLPGQ_SYSTEM_PROMPT
+ return corpus.SYSTEM_PROMPT
+
+ def _instruction_template(self) -> str:
+ if self.query_language in {"oracle_sqlpgq", "sqlpgq", "sql/pgq"}:
+ return corpus.SQLPGQ_INSTRUCTION_TEMPLATE
+ return corpus.INSTRUCTION_TEMPLATE
+
+ def _translation_prompt_template(self) -> str:
+ if self.query_language in {"oracle_sqlpgq", "sqlpgq", "sql/pgq"}:
+ return corpus.SQLPGQ_TRANSLATION_PROMPT_TEMPLATE
+ return corpus.TRANSLATION_PROMPT_TEMPLATE
+
+ def _query_template_instruction(self) -> str:
+ if self.query_language in {"oracle_sqlpgq", "sqlpgq", "sql/pgq"}:
+ return corpus.SQLPGQ_QUERY_TEMPLATE_INSTRUCTION
+ return corpus.QUERY_TEMPLATE_INSTRUCTION
+
+ def _query_archetypes(self) -> List[str]:
+ if self.query_language in {"oracle_sqlpgq", "sqlpgq", "sql/pgq"}:
+ return corpus.SQLPGQ_QUERY_ARCHETYPES
+ return corpus.QUERY_ARCHETYPES
def _extract_json_from_response(self, response: str, expect_list: bool = True):
"""Extract JSON from LLM response."""
+ if not response:
+ print(" [Warning] Empty LLM response.")
+ return [] if expect_list else {}
try:
start_char, end_char = ("[", "]") if expect_list else ("{", "}")
json_start = response.find(start_char)
@@ -40,7 +75,7 @@ def generate_questions_batch(
all_questions = set()
# Randomly select a query intent archetype to guide generation
- archetype = random.choice(corpus.QUERY_ARCHETYPES)
+ archetype = random.choice(self._query_archetypes())
print(f"Brainstorming questions with intent: '{archetype.split(':')[0]}'")
instruction = corpus.EXPLORATION_PROMPT_TEMPLATE.format(
@@ -50,7 +85,7 @@ def generate_questions_batch(
num_to_generate=questions_per_call,
)
message = [
- {"role": "system", "content": corpus.SYSTEM_PROMPT},
+ {"role": "system", "content": self._system_prompt()},
{"role": "user", "content": instruction},
]
@@ -69,16 +104,17 @@ def generate_translation_batch(
self, schema_json: str, questions: List[str], error_context: Dict[str, str] = None
) -> List[Dict[str, Any]]:
"""
- Translate a list of questions into Cypher queries.
+ Translate a list of questions into the configured graph query language.
Supports retries by providing an error_context.
"""
- instruction = corpus.TRANSLATION_PROMPT_TEMPLATE.format(
+ instruction = self._translation_prompt_template().format(
schema_json=schema_json,
question=questions[0], # Assuming one question per call for clarity
+ graph_name=self.graph_name or "GRAPH_NAME",
error_context=error_context if error_context else "",
)
message = [
- {"role": "system", "content": corpus.SYSTEM_PROMPT},
+ {"role": "system", "content": self._system_prompt()},
{"role": "user", "content": instruction},
]
@@ -193,13 +229,14 @@ def run_generation_loop(
selected_contexts = random_examples
# 1. Build Prompt
- instruction = corpus.INSTRUCTION_TEMPLATE.format(
+ instruction = self._instruction_template().format(
schema_json=schema_json,
examples_json=json.dumps(selected_contexts, indent=2, ensure_ascii=False),
num_per_iteration=num_per_iteration,
+ graph_name=self.graph_name or "GRAPH_NAME",
)
message = [
- {"role": "system", "content": corpus.SYSTEM_PROMPT},
+ {"role": "system", "content": self._system_prompt()},
{"role": "user", "content": instruction},
]
@@ -274,7 +311,7 @@ def generate_template_based_corpus(
# 3. Construct the Prompt
# We directly provide the "raw" data and ask the LLM to do three things:
# extract information, fill the template, and generate questions.
- instraction = corpus.QUERY_TEMPLATE_INSTRUCTION.format(
+ instraction = self._query_template_instruction().format(
raw_data_str=raw_data_str,
current_batch_size=current_batch_size,
selected_templates=selected_templates,
@@ -283,7 +320,7 @@ def generate_template_based_corpus(
message = [
{
"role": "system",
- "content": "You are a helpful assistant that generates Cypher datasets.",
+ "content": self._system_prompt(),
},
{"role": "user", "content": instraction},
]
diff --git a/app/core/llm/llm_client.py b/app/core/llm/llm_client.py
index f23528e..da1e679 100644
--- a/app/core/llm/llm_client.py
+++ b/app/core/llm/llm_client.py
@@ -1,4 +1,5 @@
from http import HTTPStatus
+import json
import os
import random
import time
@@ -77,20 +78,117 @@ def call_with_messages_local(self, messages):
def call_with_messages_online_for_openai(self, messages):
try:
- openai_client = OpenAI(
- api_key=os.getenv("OPENAI_API_KEY"), base_url=os.getenv("OPENAI_BASE_URL")
- )
+ client_kwargs = {
+ "api_key": os.getenv("OPENAI_API_KEY"),
+ "base_url": os.getenv("OPENAI_BASE_URL"),
+ }
+ default_headers = self._openai_default_headers()
+ if default_headers:
+ client_kwargs["default_headers"] = default_headers
+
+ openai_client = OpenAI(**client_kwargs)
response = openai_client.chat.completions.create(
model=self.model, messages=messages, temperature=0
)
- return response.choices[0].message.content
+ return self._extract_openai_chat_content(response)
except openai.RateLimitError:
print("there are too many request,ready to retry in 1 second")
time.sleep(1)
print("begin to retry")
return self.call_with_messages_online_for_openai(messages)
- except OpenAIError:
- print("Failed!", messages[1]["content"])
+ except OpenAIError as exc:
+ print(
+ "OpenAI-compatible LLM call failed: "
+ f"{type(exc).__name__}: {exc}"
+ )
+ return ""
+
+ def _extract_openai_chat_content(self, response) -> str:
+ if isinstance(response, str):
+ stream_content = self._extract_openai_sse_content(response)
+ if stream_content:
+ return stream_content
+ try:
+ response = json.loads(response)
+ except json.JSONDecodeError:
+ preview = response[:500].replace("\n", "\\n")
+ print(
+ "OpenAI-compatible LLM returned a raw string instead of a "
+ f"chat completion object: {preview}"
+ )
+ return response
+
+ if isinstance(response, dict):
+ choices = response.get("choices") or []
+ if choices:
+ message = choices[0].get("message") or {}
+ content = message.get("content")
+ if content is not None:
+ return content
+ preview = json.dumps(response, default=str)[:500]
+ print(f"OpenAI-compatible LLM returned unexpected JSON: {preview}")
+ return ""
+
+ try:
+ return response.choices[0].message.content or ""
+ except (AttributeError, IndexError, KeyError, TypeError) as exc:
+ preview = repr(response)[:500]
+ print(
+ "OpenAI-compatible LLM returned an unexpected response type "
+ f"({type(response).__name__}): {preview}; parse error: {exc}"
+ )
+ return ""
+
+ def _extract_openai_sse_content(self, response_text: str) -> str:
+ if not response_text.lstrip().startswith("data:"):
+ return ""
+
+ content_parts: list[str] = []
+ for line in response_text.splitlines():
+ line = line.strip()
+ if not line.startswith("data:"):
+ continue
+
+ payload = line.removeprefix("data:").strip()
+ if not payload or payload == "[DONE]":
+ continue
+
+ try:
+ event = json.loads(payload)
+ except json.JSONDecodeError:
+ continue
+
+ for choice in event.get("choices", []):
+ delta = choice.get("delta") or {}
+ if delta.get("content"):
+ content_parts.append(delta["content"])
+
+ message = choice.get("message") or {}
+ if message.get("content"):
+ content_parts.append(message["content"])
+
+ if choice.get("text"):
+ content_parts.append(choice["text"])
+
+ if not content_parts:
+ preview = response_text[:500].replace("\n", "\\n")
+ print(f"OpenAI-compatible LLM returned an empty SSE stream: {preview}")
+ return ""
+ return "".join(content_parts)
+
+ def _openai_default_headers(self) -> dict[str, str]:
+ raw_headers = os.getenv("OPENAI_EXTRA_HEADERS") or os.getenv("OPENAI_HTTP_HEADERS")
+ if not raw_headers:
+ return {}
+ try:
+ parsed = json.loads(raw_headers)
+ except json.JSONDecodeError as exc:
+ print(f"Invalid OPENAI_EXTRA_HEADERS JSON: {exc}")
+ return {}
+ if not isinstance(parsed, dict):
+ print("OPENAI_EXTRA_HEADERS must be a JSON object.")
+ return {}
+ return {str(key): str(value) for key, value in parsed.items()}
def call_with_messages_online_for_dashscope(self, messages):
response = Generation.call(
diff --git a/app/core/prompt/corpus.py b/app/core/prompt/corpus.py
index dbc0919..a13b069 100644
--- a/app/core/prompt/corpus.py
+++ b/app/core/prompt/corpus.py
@@ -8,6 +8,55 @@
Your output must be in strict JSON format, use English, as a list containing multiple objects.
""" # noqa: E501
+SQLPGQ_SYSTEM_PROMPT = """
+You are an Oracle SQL/PGQ expert generating executable training data.
+SQL/PGQ extends SQL with property-graph pattern matching through GRAPH_TABLE.
+
+Core rules:
+- Use only the supplied graph name, labels, edge labels, and properties. Do not invent schema.
+- Use property graph comments or validated examples as semantic guidance when present.
+- Return exactly the JSON shape requested by the user prompt. Do not add explanations.
+- Every graph access must use GRAPH_TABLE. The final SQL may be a direct GRAPH_TABLE query,
+ a CTE/subquery around GRAPH_TABLE, a join between GRAPH_TABLE and relational tables, or a
+ UNION ALL that combines graph results with normal SQL results.
+- MATCH vertices with (v IS "LABEL") or (v); match edges with -[e IS "EDGE"]->, <-[e IS "EDGE"]-, or -[e IS "EDGE"]-.
+- Declare variables for every vertex and edge that is referenced in WHERE, COLUMNS, SELECT, ORDER BY, GROUP BY, or HAVING.
+- Never use Cypher or PGQL syntax: no (v:LABEL), no {{prop: value}}, no [:EDGE], no RETURN.
+- Use double quotes for graph, label, and property identifiers; use single quotes only for string literals.
+- Access properties inside GRAPH_TABLE as v."PROPERTY" or e."PROPERTY".
+- Add graph WHERE predicates only when the question asks for a filter or the filter is required. Do not invent arbitrary literals.
+- COLUMNS must not be empty. Alias every projected expression. Outside GRAPH_TABLE, refer to COLUMNS aliases, not graph variables.
+- Project graph elements as VERTEX_ID(v) or EDGE_ID(e), not as raw vertex or edge variables.
+- Put outer SQL operations such as ORDER BY, GROUP BY, HAVING, OFFSET, and FETCH outside GRAPH_TABLE.
+- For aggregation, project needed values in COLUMNS and aggregate in outer SQL unless using path aggregates inside a quantified path.
+- Use VERTEX_EQUAL(v1, v2) to compare vertices; do not compare vertices with v1 = v2.
+- Avoid reserved words as variable names; use suffixes such as start_vertex or end_vertex.
+- Keep parentheses balanced and avoid unsupported quantifiers such as *? or +?.
+
+Advanced SQL/PGQ patterns to generate when they fit the schema and examples:
+- Mixed SQL plus SQL/PGQ: WITH clauses, inline views, joins to base tables, EXISTS/NOT EXISTS,
+ UNION ALL branches with compatible column counts and types, CASE expressions, analytic
+ functions such as ROW_NUMBER(), and outer GROUP BY/HAVING/ORDER BY/FETCH.
+- Bounded recursive paths: use supported quantifiers such as *, +, ?, {n}, {n,}, and {n,m}
+ only where Oracle SQL/PGQ accepts them. Prefer bounded forms such as {1,3} for validation.
+- ONE ROW PER MATCH, ONE ROW PER VERTEX (v), and ONE ROW PER STEP (src, edge, dst) when
+ path unnesting is requested. Iterator variables must be unique, must not appear in MATCH or
+ graph WHERE, and may be referenced only in COLUMNS.
+- MATCHNUM(), PATH_NAME(), and ELEMENT_NUMBER(iterator) may be used in COLUMNS with
+ ONE ROW PER queries.
+- LISTAGG, JSON_ARRAYAGG, MIN, MAX, AVG, SUM, and COUNT may be used in COLUMNS for path
+ aggregates when the variable is grouped by a quantified path; otherwise aggregate outside
+ GRAPH_TABLE.
+- VERTEX_ID, EDGE_ID, VERTEX_EQUAL, EDGE_EQUAL, and IS SOURCE/DESTINATION OF predicates are
+ valid SQL/PGQ tools when the question asks for element identity or edge direction checks.
+
+Shortest-path caution:
+- For Oracle database SQL property graph queries, do not generate PGQL-only path search goals
+ such as KEEP SHORTEST, ANY SHORTEST, ALL SHORTEST, ANY CHEAPEST, COST, TOTAL_COST, or
+ MATCH path_variable = (...) unless a validated example in the prompt uses that exact syntax.
+ Prefer a bounded path query with ORDER BY/FETCH or ONE ROW PER STEP for live DB validation.
+""" # noqa: E501
+
INSTRUCTION_TEMPLATE = """
# Command
@@ -46,6 +95,78 @@
]
""" # noqa: E501
+SQLPGQ_INSTRUCTION_TEMPLATE = """
+# Command
+Generate {num_per_iteration} new "question-query" data pairs based on the following information.
+
+# 1. Graph Schema
+This is the Schema definition of the Oracle SQL property graph:
+```json
+{schema_json}
+```
+
+Graph name to use in every GRAPH_TABLE call:
+```text
+{graph_name}
+```
+
+2. Verified Query Examples (Context)
+Here are some verified "question-query-result" examples that execute successfully.
+```json
+{examples_json}
+```
+
+3. Your Task
+Generate {num_per_iteration} new, meaningful Oracle SQL/PGQ "question-query" data pairs.
+
+Basic direct GRAPH_TABLE shape:
+```sql
+SELECT *
+FROM GRAPH_TABLE (
+ "{graph_name}"
+ MATCH (n IS "LABEL")-[e IS "EDGE"]->(m IS "LABEL")
+ WHERE n."property" = 'value'
+ COLUMNS (m."property" AS property_alias)
+) gt
+```
+
+Guidelines:
+- Use "{graph_name}" as the graph name in every query.
+- Use only schema labels, edge labels, and property names.
+- Use concrete literal values from verified examples/results when filtering.
+- If no known literal value is available, generate a broader query without that literal.
+- Do not add a WHERE clause unless the question requires a filter.
+- Always give every projected expression in COLUMNS an AS alias.
+- Use COLUMNS aliases in outer SELECT, JOIN, ORDER BY, GROUP BY, HAVING, OFFSET, and FETCH clauses.
+- For counts, averages, and grouping, project required values in COLUMNS and aggregate outside GRAPH_TABLE unless the query is a path aggregate.
+- Project identifiers with VERTEX_ID(v) or EDGE_ID(e) when the question asks for vertices, edges, or IDs.
+- Generate a diverse mix. Include at least one advanced SQL/PGQ pattern when the schema supports it:
+ CTEs around GRAPH_TABLE, joins to base tables, UNION ALL with normal SQL, analytic functions,
+ outer GROUP BY/HAVING, bounded path traversal, one-row-per-step path expansion, or element IDs.
+- For UNION ALL, every branch must return the same number of columns with compatible data types.
+- For joins to base tables, join using primary-key values projected from GRAPH_TABLE COLUMNS.
+- For ONE ROW PER STEP, use unique iterator variables and project them only in COLUMNS; if unsure,
+ prefer a bounded multi-hop query that can be validated by Oracle.
+- Do not generate KEEP SHORTEST, ANY SHORTEST, ALL SHORTEST, COST, or PGQL path-variable syntax
+ unless the verified examples include a working query using that exact form.
+- Do not output Cypher, PGQL, placeholders, comments, explanations, or result fields.
+- Return a strict JSON list of objects with "question" and "query" keys only. Do not include result fields.
+
+Valid example:
+```json
+[
+ {{
+ "question": "Which movies belong to the Science Fiction genre?",
+ "query": "SELECT * FROM GRAPH_TABLE (\"{graph_name}\" MATCH (m IS \"MOVIE\")-[b IS \"BELONGS_TO\"]->(g IS \"GENRE\") WHERE g.\"name\" = 'Science Fiction' COLUMNS (m.\"title\" AS movie_title, g.\"name\" AS genre_name)) gt"
+ }},
+ {{
+ "question": "Which graph-derived movies can also be checked against the base movie table?",
+ "query": "WITH graph_movies AS (SELECT gt.movie_id, gt.movie_title FROM GRAPH_TABLE (\"{graph_name}\" MATCH (m IS \"MOVIE\")-[b IS \"BELONGS_TO\"]->(g IS \"GENRE\") COLUMNS (m.\"MOVIE_id\" AS movie_id, m.\"title\" AS movie_title)) gt) SELECT gm.movie_title FROM graph_movies gm JOIN \"MOVIE\" m ON m.\"MOVIE_id\" = gm.movie_id"
+ }}
+]
+```
+""" # noqa: E501
+
ENHANCEMENT_PROMPT_TEMPLATE = """
# Command
Your task as a senior Cypher expert is to create more complex and insightful new "question-query" pairs based on existing queries.
@@ -102,6 +223,17 @@
"Path Analysis and Traversal: Focus on analysis of paths themselves, such as finding the shortest path or all possible paths. Example: 'Find the shortest path between the type A node named [instance A] and the type B node named [instance B].'", # noqa: E501
]
+SQLPGQ_QUERY_ARCHETYPES = [
+ "Direct Graph Pattern Query: Answer a question with one GRAPH_TABLE MATCH pattern, projected aliases, and optional graph WHERE filters.", # noqa: E501
+ "Mixed SQL and SQL/PGQ Join: Use GRAPH_TABLE in a CTE or inline view, then join projected graph IDs or properties to normal relational tables.", # noqa: E501
+ "UNION ALL Hybrid Query: Combine a normal SQL branch with a GRAPH_TABLE branch using compatible output columns and data types.", # noqa: E501
+ "Outer SQL Aggregation: Project graph values in COLUMNS, then use COUNT, AVG, SUM, MIN, MAX, GROUP BY, HAVING, ORDER BY, or FETCH outside GRAPH_TABLE.", # noqa: E501
+ "Analytic SQL Over Graph Results: Use ROW_NUMBER, RANK, DENSE_RANK, or partitioned aggregates over a GRAPH_TABLE result set.", # noqa: E501
+ "Bounded Path Traversal: Use multi-hop or bounded quantified path patterns to ask reachability or chain questions while avoiding unsupported shortest-path goals.", # noqa: E501
+ "One Row Per Path Expansion: Ask for path steps or path elements and generate ONE ROW PER STEP or ONE ROW PER VERTEX queries when supported by the validated examples.", # noqa: E501
+ "Element Identity Query: Project VERTEX_ID or EDGE_ID, or compare graph elements with VERTEX_EQUAL/EDGE_EQUAL when identity matters.", # noqa: E501
+]
+
QUERY_TEMPLATE_INSTRUCTION = """
@@ -135,9 +267,38 @@
""" # noqa: E501
+SQLPGQ_QUERY_TEMPLATE_INSTRUCTION = """
+ You are an Oracle SQL/PGQ query generator.
+
+ I have run exploration queries on an Oracle SQL property graph and got the following RAW RESULT DATA.
+
+ --- RAW DATA START ---
+ {raw_data_str}
+ --- RAW DATA END ---
+
+ I also have a list of Oracle SQL/PGQ templates.
+ Generate {current_batch_size} new (Question, Query) pairs by filling these templates using REAL DATA extracted from the RAW DATA above.
+
+ --- TEMPLATES ---
+ {selected_templates}
+
+ --- CRITICAL RULES ---
+ 1. Use Oracle SQL/PGQ only: every graph access must use GRAPH_TABLE (... MATCH ... COLUMNS (...)).
+ 2. Use IS label syntax: (v IS "LABEL") and -[e IS "EDGE"]->.
+ 3. Never output Cypher syntax such as (v:LABEL), {{prop: value}}, [:EDGE], or RETURN.
+ 4. Use only labels, edge labels, properties, and literal values supported by the raw data/templates.
+ 5. Put graph filters in GRAPH_TABLE WHERE only when needed.
+ 6. Alias every COLUMNS expression and use those aliases in outer SQL.
+ 7. Mixed SQL is allowed: CTEs, joins with base tables, UNION ALL, analytic functions,
+ GROUP BY/HAVING, ORDER BY/FETCH, and bounded path patterns are valid when templates show them.
+ 8. Do not generate KEEP SHORTEST or PGQL path-variable syntax unless templates show a validated example.
+ 9. Output MUST be a strict JSON list of objects: [{{"question": "...", "query": "..."}}]
+ """ # noqa: E501
+
+
EXPLORATION_PROMPT_TEMPLATE = """
# Command
-Your task is to brainstorm and generate diverse natural language questions. Focus on the breadth and depth of questions, without considering how to write Cypher queries for now.
+Your task is to brainstorm and generate diverse natural language questions. Focus on the breadth and depth of questions, without writing the graph query yet.
# 1. Graph Schema
```json
@@ -197,3 +358,64 @@
"query": "MATCH (m:Movie) WHERE m.title = 'some movie' RETURN m"
}}
""" # noqa: E501
+
+SQLPGQ_TRANSLATION_PROMPT_TEMPLATE = """
+Command
+Your task as an Oracle SQL/PGQ expert is to accurately translate the given natural language question into an executable Oracle SQL property graph query.
+
+1. Graph Schema
+```JSON
+{schema_json}
+```
+
+Graph name to use in GRAPH_TABLE:
+```text
+{graph_name}
+```
+
+2. Question to be Translated
+```json
+{question}
+```
+
+3. Important Rules
+- Use Oracle SQL/PGQ only, not Cypher and not PGQL.
+- Every graph access must use GRAPH_TABLE ("{graph_name}" MATCH ... COLUMNS (...)).
+- The final query may be direct GRAPH_TABLE SQL, or SQL wrapped with CTEs/subqueries, joins to
+ relational tables, UNION ALL, GROUP BY/HAVING, analytic functions, ORDER BY, OFFSET, or FETCH.
+- Use labels, edge labels, and property names exactly as defined by the schema.
+- Use "{graph_name}" as the graph name.
+- Put graph-pattern predicates inside GRAPH_TABLE WHERE only when required by the question.
+- Do not invent filter literals or placeholder predicates.
+- Put projected values inside COLUMNS, and always give every projected value an AS alias.
+- Use COLUMNS aliases in outer SELECT, JOIN, ORDER BY, GROUP BY, HAVING, OFFSET, and FETCH clauses.
+- For aggregation, project the required value in COLUMNS and aggregate outside GRAPH_TABLE unless using a path aggregate.
+- If the question asks for vertices or edges, project VERTEX_ID(v) or EDGE_ID(e).
+- Never use Cypher forms such as (p:PERSON), {{NAME: 'Tom Hanks'}}, [:ACTED_IN], RETURN, or m.TITLE.
+- Use Oracle SQL/PGQ forms such as (p IS "PERSON"), [a IS "ACTED_IN"], p."NAME", m."TITLE", and a."ROLE".
+- COLUMNS must look like COLUMNS (m."TITLE" AS movie_title, a."ROLE" AS role), not COLUMNS (m."TITLE", a."ROLE").
+- For UNION ALL, each branch must return the same number of columns with compatible data types.
+- For mixed SQL and SQL/PGQ joins, project graph IDs/properties in COLUMNS and join outside GRAPH_TABLE.
+- For ONE ROW PER STEP, use unique iterator variables and reference them only in COLUMNS.
+- Do not generate KEEP SHORTEST, ANY SHORTEST, ALL SHORTEST, COST, or PGQL path-variable syntax unless the error context or verified examples prove the target database accepts that exact form.
+- Ensure every variable referenced in WHERE/COLUMNS is declared in MATCH.
+- Ensure all parentheses are balanced.
+
+Correct Oracle SQL/PGQ example:
+{{
+"query": "SELECT * FROM GRAPH_TABLE (\"{graph_name}\" MATCH (p IS \"PERSON\")-[a IS \"ACTED_IN\"]->(m IS \"MOVIE\") WHERE p.\"NAME\" = 'Tom Hanks' COLUMNS (m.\"TITLE\" AS movie_title, a.\"ROLE\" AS role)) gt"
+}}
+
+Invalid Cypher-style example:
+{{
+"query": "SELECT m.TITLE, e.ROLE FROM GRAPH_TABLE (MATCH (p:PERSON {{NAME: 'Tom Hanks'}})-[:ACTED_IN]->(m:MOVIE)) gt"
+}}
+
+{error_context}
+
+4. Output Format
+Return a JSON object containing only the "query" key.
+{{
+"query": "SELECT ... FROM GRAPH_TABLE (...) gt"
+}}
+""" # noqa: E501
diff --git a/app/core/validator/validator.py b/app/core/validator/validator.py
index db4e144..1108834 100644
--- a/app/core/validator/validator.py
+++ b/app/core/validator/validator.py
@@ -1,39 +1,92 @@
+from importlib import import_module
import logging
from typing import Any, Dict, List
from app.core.validator.db_client import DB_Client, QueryResult, QueryStatus
-from app.impl.tugraph_cypher.db_client.tugraph_db_client import TuGraphDBClient
logger = logging.getLogger("CorpusValidator")
class CorpusValidator:
- def __init__(self, tu_client_params: dict):
- # Store parameters instead of the client object itself
+ CLIENTS = {
+ "tugraph_cypher": (
+ "app.impl.tugraph_cypher.db_client.tugraph_db_client",
+ "TuGraphDBClient",
+ ),
+ "tugraph": (
+ "app.impl.tugraph_cypher.db_client.tugraph_db_client",
+ "TuGraphDBClient",
+ ),
+ "oracle_sqlpgq": (
+ "app.impl.oracle_sqlpgq.db_client.oracle_db_client",
+ "OracleDBClient",
+ ),
+ "oracle": (
+ "app.impl.oracle_sqlpgq.db_client.oracle_db_client",
+ "OracleDBClient",
+ ),
+ }
+
+ def __init__(
+ self,
+ tu_client_params: dict | None = None,
+ backend: str = "tugraph_cypher",
+ db_client_params: dict | None = None,
+ db_client: DB_Client | None = None,
+ ):
"""
- Initialize validator and for instantiating TuGraph database client implementation.
+ Initialize validator and instantiate the selected database client implementation.
+
+ Args:
+ tu_client_params: Backward-compatible TuGraph parameters.
+ backend: One of tugraph_cypher or oracle_sqlpgq.
+ db_client_params: Backend-specific connection parameters.
+ db_client: Optional prebuilt client, useful for tests.
"""
- self._tu_client_params = tu_client_params
- self._client: DB_Client | None = None
+ self.backend = backend
+ self._client_params = db_client_params if db_client_params is not None else tu_client_params
+ self._client_params = self._client_params or {}
+ self._client: DB_Client | None = db_client
- # Create connection during initialization and store the instance
- self._client = TuGraphDBClient(self._tu_client_params)
+ if self._client is None:
+ client_class = self._resolve_client_class(backend)
+ self._client = client_class(self._client_params)
- # Immediately check if connection was successful
- if not self._client or not self._client.client:
- logger.error("Database client failed to initialize or connect.")
+ if not self._is_client_ready(self._client):
+ logger.error(
+ f"Database client for backend '{backend}' failed to initialize or connect."
+ )
def _get_client(self) -> DB_Client | None:
"""Return the created client instance."""
return self._client
+ def _resolve_client_class(self, backend: str):
+ target = self.CLIENTS.get(backend)
+ if target is None:
+ supported = ", ".join(sorted(self.CLIENTS))
+ raise ValueError(
+ f"Unsupported backend '{backend}'. Supported backends: {supported}"
+ )
+ module_path, class_name = target
+ module = import_module(module_path)
+ return getattr(module, class_name)
+
+ def _is_client_ready(self, client: DB_Client | None) -> bool:
+ if client is None:
+ return False
+ for attr in ("client", "connection", "driver"):
+ if hasattr(client, attr):
+ return getattr(client, attr) is not None
+ return True
+
def execute_with_results(self, pairs: List[Dict[str, str]]) -> List[Dict[str, Any]]:
"""
Validate all pairs and get query result,
filter out pairs that fail execution or have empty results.
"""
client = self._get_client()
- if not client or not client.client:
+ if not self._is_client_ready(client):
logger.error("Database connection is not ready. Skipping validation.")
raise Exception("Database connection is not ready.")
diff --git a/app/impl/oracle_sqlpgq/__init__.py b/app/impl/oracle_sqlpgq/__init__.py
new file mode 100644
index 0000000..e94768d
--- /dev/null
+++ b/app/impl/oracle_sqlpgq/__init__.py
@@ -0,0 +1,2 @@
+"""Oracle SQL/PGQ backend support."""
+
diff --git a/app/impl/oracle_sqlpgq/ast_visitor/__init__.py b/app/impl/oracle_sqlpgq/ast_visitor/__init__.py
new file mode 100644
index 0000000..5770767
--- /dev/null
+++ b/app/impl/oracle_sqlpgq/ast_visitor/__init__.py
@@ -0,0 +1,2 @@
+"""Oracle SQL/PGQ AST visitor utilities."""
+
diff --git a/app/impl/oracle_sqlpgq/ast_visitor/oracle_sqlpgq_ast_visitor.py b/app/impl/oracle_sqlpgq/ast_visitor/oracle_sqlpgq_ast_visitor.py
new file mode 100644
index 0000000..ace2d04
--- /dev/null
+++ b/app/impl/oracle_sqlpgq/ast_visitor/oracle_sqlpgq_ast_visitor.py
@@ -0,0 +1,293 @@
+import re
+from typing import List, Tuple
+
+from app.core.ast_visitor.ast_visitor import AstVisitor
+from app.core.clauses.clause import Clause
+from app.core.clauses.match_clause import EdgePattern, MatchClause, NodePattern, PathPattern
+from app.core.clauses.return_clause import ReturnBody, ReturnClause, ReturnItem, SortItem
+from app.core.clauses.where_clause import CompareExpression, WhereClause
+from app.impl.oracle_sqlpgq.utils.sqlpgq import validate_graph_table_query
+
+
+class OracleSqlPgqAstVisitor(AstVisitor):
+ """Parse the Oracle SQL/PGQ subset emitted by OracleSqlPgqQueryTranslator."""
+
+ NODE_RE = re.compile(r"\((?P
[^()]*)\)", re.DOTALL)
+ EDGE_RE = re.compile(
+ r"(?P<)?-\[(?P[^\]]*)\]-(?P>)?(?P\{[^}]+\})?",
+ re.DOTALL,
+ )
+ COMPARE_RE = re.compile(
+ r"(?P[A-Za-z_][A-Za-z0-9_]*)"
+ r"(?:\.(?:\"(?P[^\"]+)\"|(?P[A-Za-z_][A-Za-z0-9_]*)))?"
+ r"\s*(?P=|<>|<=|>=|<|>)\s*(?P.+)",
+ re.DOTALL,
+ )
+
+ def get_query_pattern(self, query: str) -> Tuple[bool, List[Clause]]:
+ if not validate_graph_table_query(query):
+ return False, []
+ try:
+ body, graph_table_end = self._extract_graph_table_span(query)
+ match_text, where_text, columns_text = self._extract_sections(body)
+ clauses: List[Clause] = []
+ for path_text in self._split_top_level(match_text):
+ clauses.append(MatchClause(self._parse_path(path_text)))
+ if where_text:
+ clauses.append(WhereClause(self._parse_where(where_text)))
+ return_items = self._parse_columns(columns_text)
+ if return_items:
+ sort_items, skip, limit = self._parse_outer_modifiers(query[graph_table_end:])
+ clauses.append(ReturnClause(ReturnBody(return_items, sort_items, skip, limit)))
+ return True, clauses
+ except Exception:
+ return False, []
+
+ def _extract_graph_table_body(self, query: str) -> str:
+ return self._extract_graph_table_span(query)[0]
+
+ def _extract_graph_table_span(self, query: str) -> Tuple[str, int]:
+ marker = re.search(r"GRAPH_TABLE\s*\(", query, re.IGNORECASE)
+ if marker is None:
+ raise ValueError("GRAPH_TABLE not found.")
+ start = marker.end() - 1
+ depth = 0
+ in_single = False
+ in_double = False
+ for index in range(start, len(query)):
+ char = query[index]
+ if char == "'" and not in_double:
+ in_single = not in_single
+ elif char == '"' and not in_single:
+ in_double = not in_double
+ elif not in_single and not in_double:
+ if char == "(":
+ depth += 1
+ elif char == ")":
+ depth -= 1
+ if depth == 0:
+ return query[start + 1 : index].strip(), index + 1
+ raise ValueError("GRAPH_TABLE body is not balanced.")
+
+ def _extract_sections(self, body: str) -> Tuple[str, str, str]:
+ upper = body.upper()
+ match_idx = upper.index("MATCH ")
+ columns_idx = upper.rindex("COLUMNS")
+ before_columns = body[match_idx + len("MATCH ") : columns_idx].strip()
+ where_idx = self._find_keyword(before_columns, "WHERE")
+ if where_idx == -1:
+ match_text = before_columns
+ where_text = ""
+ else:
+ match_text = before_columns[:where_idx].strip()
+ where_text = before_columns[where_idx + len("WHERE") :].strip()
+ columns_start = body.index("(", columns_idx)
+ columns_text = body[columns_start + 1 : body.rindex(")")].strip()
+ return match_text, where_text, columns_text
+
+ def _find_keyword(self, text: str, keyword: str) -> int:
+ pattern = re.compile(rf"\b{keyword}\b", re.IGNORECASE)
+ for match in pattern.finditer(text):
+ prefix = text[: match.start()]
+ if prefix.count("(") == prefix.count(")") and prefix.count("[") == prefix.count("]"):
+ return match.start()
+ return -1
+
+ def _parse_path(self, text: str) -> PathPattern:
+ text = text.strip()
+ position = 0
+ node_patterns: List[NodePattern] = []
+ edge_patterns: List[EdgePattern] = []
+
+ first_node = self.NODE_RE.match(text, position)
+ if first_node is None:
+ raise ValueError("Path must start with a node pattern.")
+ node_patterns.append(self._parse_node(first_node.group("body")))
+ position = first_node.end()
+
+ while position < len(text):
+ edge_match = self.EDGE_RE.match(text, position)
+ if edge_match is None:
+ break
+ edge_patterns.append(self._parse_edge(edge_match))
+ position = edge_match.end()
+ node_match = self.NODE_RE.match(text, position)
+ if node_match is None:
+ raise ValueError("Edge pattern must be followed by a node pattern.")
+ node_patterns.append(self._parse_node(node_match.group("body")))
+ position = node_match.end()
+
+ return PathPattern(node_patterns, edge_patterns)
+
+ def _parse_node(self, body: str) -> NodePattern:
+ variable, label = self._parse_variable_and_label(body)
+ return NodePattern(variable, label, [])
+
+ def _parse_edge(self, edge_match: re.Match) -> EdgePattern:
+ variable, label = self._parse_variable_and_label(edge_match.group("body"))
+ if edge_match.group("left"):
+ direction = "left"
+ elif edge_match.group("right"):
+ direction = "right"
+ else:
+ direction = "bidirection"
+ return EdgePattern(variable, label, [], direction, self._parse_hop(edge_match.group("hop")))
+
+ def _parse_variable_and_label(self, body: str) -> Tuple[str, str]:
+ body = body.strip()
+ if " WHERE " in body.upper():
+ body = re.split(r"\bWHERE\b", body, flags=re.IGNORECASE)[0].strip()
+ label_pattern = (
+ r"(?P[A-Za-z_][A-Za-z0-9_]*)"
+ r"(?:\s+IS\s+(?:\"(?P[^\"]+)\"|"
+ r"(?P[A-Za-z_][A-Za-z0-9_]*)))?"
+ )
+ label_match = re.match(
+ label_pattern,
+ body,
+ re.IGNORECASE,
+ )
+ if label_match is None:
+ return "", ""
+ label = label_match.group("qlabel") or label_match.group("label") or ""
+ return label_match.group("var"), label
+
+ def _parse_hop(self, text: str | None) -> Tuple[int, int]:
+ if not text:
+ return (-1, -1)
+ values = text.strip("{}").split(",")
+ if len(values) == 1:
+ value = int(values[0])
+ return (value, value)
+ lower = int(values[0]) if values[0] else -1
+ upper = int(values[1]) if values[1] else -1
+ return (lower, upper)
+
+ def _parse_where(self, text: str) -> List[CompareExpression]:
+ expressions = []
+ for part in re.split(r"\s+AND\s+", text, flags=re.IGNORECASE):
+ match = self.COMPARE_RE.match(part.strip())
+ if match is None:
+ continue
+ expressions.append(
+ CompareExpression(
+ symbolic_name=match.group("var"),
+ property=match.group("quoted_prop") or match.group("prop") or "",
+ comparison_type={
+ "=": "equal",
+ "<>": "neq",
+ "<": "less",
+ ">": "greater",
+ "<=": "leq",
+ ">=": "geq",
+ }[match.group("op")],
+ comparison_value=match.group("value").strip(),
+ )
+ )
+ return expressions
+
+ def _parse_columns(self, text: str) -> List[ReturnItem]:
+ items = []
+ for item in self._split_top_level(text):
+ match = re.match(r"(?P.+?)\s+AS\s+(?P[A-Za-z_][A-Za-z0-9_]*)$", item)
+ if match is None:
+ continue
+ expr = match.group("expr").strip()
+ alias = match.group("alias")
+ function_name = ""
+ function_match = re.match(r"(?P[A-Za-z_][A-Za-z0-9_]*)\((?P.+)\)", expr)
+ if function_match:
+ function_name = function_match.group("func")
+ expr = function_match.group("inner").strip()
+ prop_pattern = (
+ r"(?P[A-Za-z_][A-Za-z0-9_]*)"
+ r"(?:\.(?:\"(?P[^\"]+)\"|"
+ r"(?P[A-Za-z_][A-Za-z0-9_]*)))?"
+ )
+ prop_match = re.match(
+ prop_pattern,
+ expr,
+ )
+ if prop_match:
+ items.append(
+ ReturnItem(
+ symbolic_name=prop_match.group("var"),
+ property=prop_match.group("qprop") or prop_match.group("prop") or "",
+ alias=alias,
+ function_name=function_name,
+ )
+ )
+ return items
+
+ def _parse_outer_modifiers(self, text: str) -> Tuple[List[SortItem], int, int]:
+ sort_items: List[SortItem] = []
+ skip = -1
+ limit = -1
+
+ order_match = re.search(
+ r"\bORDER\s+BY\s+(?P.*?)(?=\bOFFSET\b|\bFETCH\b|$)",
+ text,
+ re.IGNORECASE | re.DOTALL,
+ )
+ if order_match:
+ for item in self._split_top_level(order_match.group("body").strip()):
+ sort_match = re.match(
+ r"(?P[A-Za-z_][A-Za-z0-9_]*)(?:\s+(?PASC|DESC))?$",
+ item.strip(),
+ re.IGNORECASE,
+ )
+ if sort_match:
+ sort_items.append(
+ SortItem(
+ symbolic_name=sort_match.group("alias"),
+ property="",
+ order=(sort_match.group("order") or "").upper(),
+ )
+ )
+
+ offset_match = re.search(r"\bOFFSET\s+(?P\d+)\s+ROWS\b", text, re.IGNORECASE)
+ if offset_match:
+ skip = int(offset_match.group("value"))
+
+ fetch_match = re.search(
+ r"\bFETCH\s+FIRST\s+(?P\d+)\s+ROWS\s+ONLY\b",
+ text,
+ re.IGNORECASE,
+ )
+ if fetch_match:
+ limit = int(fetch_match.group("value"))
+
+ return sort_items, skip, limit
+
+ def _split_top_level(self, text: str) -> List[str]:
+ parts = []
+ current = []
+ paren_depth = 0
+ bracket_depth = 0
+ in_single = False
+ in_double = False
+ for char in text:
+ if char == "'" and not in_double:
+ in_single = not in_single
+ elif char == '"' and not in_single:
+ in_double = not in_double
+ elif not in_single and not in_double:
+ if char == "(":
+ paren_depth += 1
+ elif char == ")":
+ paren_depth -= 1
+ elif char == "[":
+ bracket_depth += 1
+ elif char == "]":
+ bracket_depth -= 1
+ elif char == "," and paren_depth == 0 and bracket_depth == 0:
+ part = "".join(current).strip()
+ if part:
+ parts.append(part)
+ current = []
+ continue
+ current.append(char)
+ tail = "".join(current).strip()
+ if tail:
+ parts.append(tail)
+ return parts
diff --git a/app/impl/oracle_sqlpgq/db_client/__init__.py b/app/impl/oracle_sqlpgq/db_client/__init__.py
new file mode 100644
index 0000000..e6c3cf5
--- /dev/null
+++ b/app/impl/oracle_sqlpgq/db_client/__init__.py
@@ -0,0 +1,2 @@
+"""Oracle SQL/PGQ database clients."""
+
diff --git a/app/impl/oracle_sqlpgq/db_client/oracle_db_client.py b/app/impl/oracle_sqlpgq/db_client/oracle_db_client.py
new file mode 100644
index 0000000..711e444
--- /dev/null
+++ b/app/impl/oracle_sqlpgq/db_client/oracle_db_client.py
@@ -0,0 +1,125 @@
+from typing import Any, Dict, Iterable, List
+
+try:
+ import oracledb
+except ImportError: # pragma: no cover - exercised only when dependency is absent.
+ oracledb = None
+
+from app.core.validator.db_client import DB_Client, QueryResult, QueryStatus
+from app.impl.oracle_sqlpgq.utils.sqlpgq import split_sql_statements
+
+
+class OracleDBClient(DB_Client):
+ """Oracle Database client for SQL/PGQ validation and corpus execution."""
+
+ def __init__(self, db_client_params: Dict[str, Any]):
+ self.db_client_params = db_client_params
+ self.connection = self.create_client(db_client_params)
+ self.client = self.connection
+
+ def create_client(self, db_client_params: Dict[str, Any]):
+ if oracledb is None:
+ print("Failed to create OracleDBClient: install the 'oracledb' package first.")
+ return None
+
+ try:
+ connection = oracledb.connect(
+ user=db_client_params.get("user") or db_client_params.get("username"),
+ password=db_client_params.get("password"),
+ dsn=db_client_params.get("dsn"),
+ config_dir=db_client_params.get("config_dir"),
+ wallet_location=db_client_params.get("wallet_location"),
+ wallet_password=db_client_params.get("wallet_password"),
+ )
+ connection.autocommit = db_client_params.get("autocommit", True)
+ with connection.cursor() as cursor:
+ cursor.execute("SELECT 1 FROM dual")
+ cursor.fetchone()
+ print("Successfully created OracleDBClient.")
+ return connection
+ except Exception as exc:
+ print(f"Failed to create OracleDBClient: {exc}")
+ return None
+
+ def execute_query(
+ self,
+ query: str,
+ fetch_limit: int = 0,
+ call_timeout_ms: int = 0,
+ ) -> QueryResult:
+ if not self.connection:
+ return QueryResult(
+ status_code=QueryStatus.SERVER_ERROR,
+ error="Oracle connection is not initialized or connection failed.",
+ )
+
+ previous_call_timeout = getattr(self.connection, "call_timeout", 0)
+ try:
+ if call_timeout_ms > 0:
+ self.connection.call_timeout = call_timeout_ms
+ with self.connection.cursor() as cursor:
+ cursor.execute(query)
+ if cursor.description:
+ columns = [col[0] for col in cursor.description]
+ rows = cursor.fetchmany(fetch_limit) if fetch_limit > 0 else cursor.fetchall()
+ data = [dict(zip(columns, row, strict=False)) for row in rows]
+ if not data:
+ return QueryResult(status_code=QueryStatus.NO_RECORD, data=[])
+ return QueryResult(status_code=QueryStatus.SUCCESS, data=data)
+ return QueryResult(
+ status_code=QueryStatus.SUCCESS,
+ data={"rowcount": cursor.rowcount},
+ )
+ except Exception as exc:
+ error = str(exc)
+ status_code = (
+ QueryStatus.CLIENT_ERROR
+ if self._looks_like_client_error(error)
+ else QueryStatus.SERVER_ERROR
+ )
+ return QueryResult(status_code=status_code, error=error)
+ finally:
+ if call_timeout_ms > 0 and self.connection:
+ self.connection.call_timeout = previous_call_timeout
+
+ def execute_script(self, script: str) -> List[QueryResult]:
+ return [self.execute_query(statement) for statement in split_sql_statements(script)]
+
+ def executemany(self, statement: str, rows: Iterable[Iterable[Any]]) -> QueryResult:
+ if not self.connection:
+ return QueryResult(
+ status_code=QueryStatus.SERVER_ERROR,
+ error="Oracle connection is not initialized or connection failed.",
+ )
+ try:
+ with self.connection.cursor() as cursor:
+ cursor.executemany(statement, rows)
+ return QueryResult(
+ status_code=QueryStatus.SUCCESS,
+ data={"rowcount": cursor.rowcount},
+ )
+ except Exception as exc:
+ error = str(exc)
+ status_code = (
+ QueryStatus.CLIENT_ERROR
+ if self._looks_like_client_error(error)
+ else QueryStatus.SERVER_ERROR
+ )
+ return QueryResult(status_code=status_code, error=error)
+
+ def close(self) -> None:
+ if self.connection:
+ self.connection.close()
+ self.connection = None
+ self.client = None
+ print("Oracle connection closed.")
+
+ def _looks_like_client_error(self, error: str) -> bool:
+ lower = error.lower()
+ return any(token in lower for token in ["syntax", "ora-009", "ora-040", "invalid"])
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.close()
diff --git a/app/impl/oracle_sqlpgq/generator/__init__.py b/app/impl/oracle_sqlpgq/generator/__init__.py
new file mode 100644
index 0000000..f93605c
--- /dev/null
+++ b/app/impl/oracle_sqlpgq/generator/__init__.py
@@ -0,0 +1,2 @@
+"""Oracle SQL/PGQ corpus generation helpers."""
+
diff --git a/app/impl/oracle_sqlpgq/generator/corpus_combiner.py b/app/impl/oracle_sqlpgq/generator/corpus_combiner.py
new file mode 100644
index 0000000..6d0ac32
--- /dev/null
+++ b/app/impl/oracle_sqlpgq/generator/corpus_combiner.py
@@ -0,0 +1,174 @@
+import json
+from pathlib import Path
+from typing import Any
+
+from app.core.validator.db_client import QueryStatus
+from app.core.validator.validator import CorpusValidator
+
+
+class OracleSqlPgqCorpusCombiner:
+ """Combine Oracle SQL/PGQ corpus files into a normalized dataset."""
+
+ def __init__(self, backend: str = "oracle_sqlpgq"):
+ self.backend = backend
+
+ def combine_files(
+ self,
+ input_paths: list[str | Path],
+ *,
+ split: bool = False,
+ validator: CorpusValidator | None = None,
+ result_preview_chars: int = 500,
+ ) -> list[dict[str, Any]]:
+ records: list[dict[str, Any]] = []
+ for path in input_paths:
+ source_path = Path(path)
+ if not source_path.exists():
+ continue
+ with open(source_path, encoding="utf-8") as file:
+ payload = json.load(file)
+ records.extend(self._normalize_records(payload, source_path))
+ records = self._deduplicate(records)
+ if validator is not None:
+ self.validate_records(
+ records,
+ validator=validator,
+ result_preview_chars=result_preview_chars,
+ )
+ if split:
+ self._assign_splits(records)
+ return records
+
+ def write_combined(
+ self,
+ input_paths: list[str | Path],
+ output_path: str | Path,
+ *,
+ split: bool = False,
+ validator: CorpusValidator | None = None,
+ result_preview_chars: int = 500,
+ ) -> list[dict[str, Any]]:
+ records = self.combine_files(
+ input_paths,
+ split=split,
+ validator=validator,
+ result_preview_chars=result_preview_chars,
+ )
+ output = Path(output_path)
+ output.parent.mkdir(parents=True, exist_ok=True)
+ with open(output, "w", encoding="utf-8") as file:
+ json.dump(records, file, indent=2, ensure_ascii=False)
+ return records
+
+ def validate_records(
+ self,
+ records: list[dict[str, Any]],
+ *,
+ validator: CorpusValidator,
+ result_preview_chars: int = 500,
+ ) -> dict[str, int]:
+ client = validator._get_client()
+ if client is None:
+ raise RuntimeError("CorpusValidator has no database client.")
+
+ summary = {"passed": 0, "failed": 0}
+ for record in records:
+ result = client.execute_query(record["query"])
+ record["validation_status_code"] = result.status_code
+ if result.status_code == QueryStatus.SUCCESS:
+ record["validation"] = "passed"
+ record["validation_error"] = ""
+ record["result"] = self._preview(result.data, result_preview_chars)
+ summary["passed"] += 1
+ else:
+ record["validation"] = "failed"
+ record["validation_error"] = self._preview(result.error, result_preview_chars)
+ record["result"] = ""
+ summary["failed"] += 1
+ return summary
+
+ def _normalize_records(
+ self, payload: Any, source_path: Path
+ ) -> list[dict[str, Any]]:
+ if isinstance(payload, dict):
+ payload = [payload]
+ if not isinstance(payload, list):
+ return []
+
+ normalized = []
+ for index, item in enumerate(payload):
+ if not isinstance(item, dict):
+ continue
+ question = str(item.get("question") or "").strip()
+ query = str(item.get("query") or item.get("oracle_sqlpgq") or "").strip()
+ if not question and item.get("labels"):
+ question = self._question_from_labels(item["labels"])
+ if not question or not query:
+ continue
+ normalized.append(
+ {
+ "id": self._record_id(source_path, index),
+ "backend": self.backend,
+ "question": question,
+ "query": query,
+ "source_file": str(source_path),
+ "source_index": index,
+ "category": item.get("category") or self._infer_category(query),
+ "template_id": item.get("template_id", ""),
+ "labels": item.get("labels", []),
+ "validation": item.get("validation", "not_run"),
+ "result": item.get("result", ""),
+ }
+ )
+ return normalized
+
+ def _deduplicate(self, records: list[dict[str, Any]]) -> list[dict[str, Any]]:
+ seen = set()
+ deduped = []
+ for record in records:
+ key = (
+ " ".join(record["question"].lower().split()),
+ " ".join(record["query"].lower().split()),
+ )
+ if key in seen:
+ continue
+ seen.add(key)
+ record["id"] = f"oracle_sqlpgq_{len(deduped) + 1:06d}"
+ deduped.append(record)
+ return deduped
+
+ def _assign_splits(self, records: list[dict[str, Any]]) -> None:
+ total = len(records)
+ train_cutoff = int(total * 0.8)
+ dev_cutoff = int(total * 0.9)
+ for index, record in enumerate(records):
+ if index < train_cutoff:
+ record["split"] = "train"
+ elif index < dev_cutoff:
+ record["split"] = "dev"
+ else:
+ record["split"] = "test"
+
+ def _record_id(self, source_path: Path, index: int) -> str:
+ return f"{source_path.stem}_{index + 1}"
+
+ def _infer_category(self, query: str) -> str:
+ normalized = " ".join(query.upper().split())
+ if "GROUP BY" in normalized or "COUNT(" in normalized or "AVG(" in normalized:
+ return "aggregation"
+ if "UNION" in normalized or " JOIN " in normalized:
+ return "relational_sql_composition"
+ if "->{1," in normalized or "ONE ROW PER STEP" in normalized:
+ return "path_query"
+ if "EDGE_ID(" in normalized or "VERTEX_ID(" in normalized:
+ return "element_identity"
+ return "graph_traversal"
+
+ def _question_from_labels(self, labels: list[str]) -> str:
+ return "Show Oracle SQL/PGQ results for " + " to ".join(str(label) for label in labels) + "."
+
+ def _preview(self, value: Any, max_chars: int) -> str:
+ text = str(value)
+ if len(text) <= max_chars:
+ return text
+ return text[:max_chars] + "..."
diff --git a/app/impl/oracle_sqlpgq/generator/query_generalizer.py b/app/impl/oracle_sqlpgq/generator/query_generalizer.py
new file mode 100644
index 0000000..cfa4526
--- /dev/null
+++ b/app/impl/oracle_sqlpgq/generator/query_generalizer.py
@@ -0,0 +1,155 @@
+import json
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Any
+
+from app.core.clauses.clause import Clause
+from app.core.clauses.match_clause import EdgePattern, MatchClause, NodePattern, PathPattern
+from app.core.clauses.return_clause import ReturnBody, ReturnClause, ReturnItem
+from app.impl.oracle_sqlpgq.translator.oracle_sqlpgq_query_translator import (
+ OracleSqlPgqQueryTranslator,
+)
+
+
+@dataclass(frozen=True)
+class OracleGeneralizedQuery:
+ query: str
+ source_pattern_length: int
+ labels: list[str]
+
+ def to_dict(self) -> dict[str, Any]:
+ return {
+ "query": self.query,
+ "source_pattern_length": self.source_pattern_length,
+ "labels": self.labels,
+ }
+
+
+class OracleSqlPgqQueryGeneralizer:
+ """Generalize Graph-IL path patterns over an Oracle SQL/PGQ manifest."""
+
+ def __init__(self, manifest: dict[str, Any], graph_name: str | None = None):
+ self.manifest = manifest
+ self.graph_name = graph_name or manifest.get("graph_name") or "GRAPH"
+ self.vertices = list(manifest.get("vertices") or [])
+ self.edges = list(manifest.get("edges") or [])
+ self.vertex_by_label = {vertex["label"]: vertex for vertex in self.vertices}
+ self.translator = OracleSqlPgqQueryTranslator(graph_name=self.graph_name)
+
+ @classmethod
+ def from_file(
+ cls, manifest_path: str | Path, graph_name: str | None = None
+ ) -> "OracleSqlPgqQueryGeneralizer":
+ with open(manifest_path, encoding="utf-8") as file:
+ manifest = json.load(file)
+ return cls(manifest, graph_name=graph_name)
+
+ def generalize(
+ self, query_pattern: list[Clause], target_size: int | None = None
+ ) -> list[OracleGeneralizedQuery]:
+ path_length = self._path_length(query_pattern)
+ if path_length <= 0:
+ return []
+
+ generalized = []
+ for label_path in self._schema_paths(path_length):
+ graph_il = self._build_query_pattern(label_path, query_pattern)
+ generalized.append(
+ OracleGeneralizedQuery(
+ query=self.translator.translate(graph_il),
+ source_pattern_length=path_length,
+ labels=label_path,
+ )
+ )
+ if target_size is not None and len(generalized) >= target_size:
+ break
+ return generalized
+
+ def generalize_dicts(
+ self, query_pattern: list[Clause], target_size: int | None = None
+ ) -> list[dict[str, Any]]:
+ return [item.to_dict() for item in self.generalize(query_pattern, target_size)]
+
+ def _path_length(self, query_pattern: list[Clause]) -> int:
+ for clause in query_pattern:
+ if isinstance(clause, MatchClause):
+ path_pattern = clause.path_pattern
+ if isinstance(path_pattern, list):
+ path_pattern = path_pattern[0]
+ return len(path_pattern.edge_pattern_list)
+ return 0
+
+ def _schema_paths(self, path_length: int) -> list[list[str]]:
+ paths = []
+
+ def walk(current_label: str, remaining: int, labels: list[str]) -> None:
+ if remaining == 0:
+ paths.append(labels)
+ return
+ for edge in self.edges:
+ if edge["src"] != current_label:
+ continue
+ walk(edge["dst"], remaining - 1, labels + [edge["label"], edge["dst"]])
+
+ for vertex in self.vertices:
+ walk(vertex["label"], path_length, [vertex["label"]])
+ return paths
+
+ def _build_query_pattern(
+ self, label_path: list[str], source_pattern: list[Clause]
+ ) -> list[Clause]:
+ source_hops = self._source_hop_ranges(source_pattern)
+ nodes = []
+ edges = []
+ for index in range(0, len(label_path), 2):
+ nodes.append(NodePattern(f"n{len(nodes) + 1}", label_path[index], []))
+ for index in range(1, len(label_path), 2):
+ edge_number = len(edges) + 1
+ hop_range = source_hops[edge_number - 1] if edge_number <= len(source_hops) else (-1, -1)
+ edges.append(
+ EdgePattern(
+ symbolic_name=f"e{edge_number}",
+ label=label_path[index],
+ property_maps=[],
+ direction="right",
+ hop_range=hop_range,
+ )
+ )
+
+ final_node = nodes[-1]
+ final_property = self._display_property(final_node.label)
+ return [
+ MatchClause(PathPattern(nodes, edges)),
+ ReturnClause(
+ ReturnBody(
+ return_item_list=[
+ ReturnItem(
+ symbolic_name=final_node.symbolic_name,
+ property=final_property,
+ alias=f"{final_node.label}_{final_property}",
+ )
+ ],
+ sort_item_list=[],
+ limit=20,
+ )
+ ),
+ ]
+
+ def _source_hop_ranges(self, query_pattern: list[Clause]) -> list[tuple[int, int]]:
+ for clause in query_pattern:
+ if isinstance(clause, MatchClause):
+ path_pattern = clause.path_pattern
+ if isinstance(path_pattern, list):
+ path_pattern = path_pattern[0]
+ return [edge.hop_range for edge in path_pattern.edge_pattern_list]
+ return []
+
+ def _display_property(self, label: str) -> str:
+ vertex = self.vertex_by_label[label]
+ columns = list(vertex.get("columns") or [])
+ for column in columns:
+ type_name = str(column.get("type", "")).upper()
+ if "CHAR" in type_name or "CLOB" in type_name:
+ return column["name"]
+ return columns[0]["name"]
+
diff --git a/app/impl/oracle_sqlpgq/generator/template_instantiator.py b/app/impl/oracle_sqlpgq/generator/template_instantiator.py
new file mode 100644
index 0000000..2752eb1
--- /dev/null
+++ b/app/impl/oracle_sqlpgq/generator/template_instantiator.py
@@ -0,0 +1,343 @@
+import json
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Any, Iterable
+
+from app.impl.oracle_sqlpgq.utils.sqlpgq import OracleNameSanitizer, validate_graph_table_query
+
+
+INTERNAL_EDGE_COLUMNS = {"EDGE_ID", "SRC_ID", "DST_ID"}
+TEXT_TYPES = {"STRING", "TEXT", "VARCHAR2"}
+NUMERIC_TYPES = {
+ "BOOL",
+ "BOOLEAN",
+ "INT8",
+ "INT16",
+ "INT32",
+ "INT64",
+ "INTEGER",
+ "LONG",
+ "FLOAT",
+ "DOUBLE",
+ "NUMBER",
+}
+
+
+@dataclass(frozen=True)
+class OracleTemplateCorpusPair:
+ question: str
+ query: str
+ template_id: str
+ category: str
+ labels: list[str]
+
+ def to_dict(self, include_metadata: bool = True) -> dict[str, Any]:
+ item: dict[str, Any] = {"question": self.question, "query": self.query}
+ if include_metadata:
+ item.update(
+ {
+ "template_id": self.template_id,
+ "category": self.category,
+ "labels": self.labels,
+ }
+ )
+ return item
+
+
+class OracleSqlPgqTemplateInstantiator:
+ """Instantiate deterministic Oracle SQL/PGQ corpus templates from a manifest.
+
+ The manifest is the JSON emitted by OracleSqlPgqSchemaParser. The generated
+ queries intentionally avoid literal predicates, so they are schema-driven and
+ can be validated against any graph instance created from the manifest.
+ """
+
+ def __init__(self, manifest: dict[str, Any], graph_name: str | None = None):
+ self.manifest = manifest
+ self.graph_name = graph_name or manifest.get("graph_name") or "GRAPH"
+ self.vertices = list(manifest.get("vertices") or [])
+ self.edges = list(manifest.get("edges") or [])
+ self.vertex_by_label = {vertex["label"]: vertex for vertex in self.vertices}
+
+ @classmethod
+ def from_file(
+ cls, manifest_path: str | Path, graph_name: str | None = None
+ ) -> "OracleSqlPgqTemplateInstantiator":
+ with open(manifest_path, encoding="utf-8") as file:
+ manifest = json.load(file)
+ return cls(manifest, graph_name=graph_name)
+
+ def generate(self, target_size: int | None = None) -> list[OracleTemplateCorpusPair]:
+ pairs: list[OracleTemplateCorpusPair] = []
+ pairs.extend(self._one_hop_templates())
+ pairs.extend(self._edge_identity_templates())
+ pairs.extend(self._aggregate_templates())
+ pairs.extend(self._two_hop_templates())
+ pairs.extend(self._bounded_path_templates())
+
+ valid_pairs = [pair for pair in pairs if validate_graph_table_query(pair.query)]
+ if target_size is not None:
+ return valid_pairs[:target_size]
+ return valid_pairs
+
+ def generate_dicts(
+ self, target_size: int | None = None, include_metadata: bool = True
+ ) -> list[dict[str, Any]]:
+ return [
+ pair.to_dict(include_metadata=include_metadata)
+ for pair in self.generate(target_size=target_size)
+ ]
+
+ def _one_hop_templates(self) -> list[OracleTemplateCorpusPair]:
+ pairs = []
+ for index, edge in enumerate(self.edges, start=1):
+ src = self.vertex_by_label.get(edge["src"])
+ dst = self.vertex_by_label.get(edge["dst"])
+ if src is None or dst is None:
+ continue
+
+ src_prop = self._display_property(src)
+ dst_prop = self._display_property(dst)
+ edge_prop = self._display_property(edge, exclude=INTERNAL_EDGE_COLUMNS)
+ columns = [
+ self._property_column("src", src_prop, f"source_{edge['src']}_{src_prop['name']}"),
+ self._property_column("dst", dst_prop, f"target_{edge['dst']}_{dst_prop['name']}"),
+ ]
+ if edge_prop is not None:
+ columns.append(
+ self._property_column("rel", edge_prop, f"{edge['label']}_{edge_prop['name']}")
+ )
+ query = self._graph_table_query(
+ match=(
+ f'(src IS {self._q(edge["src"])})'
+ f'-[rel IS {self._q(edge["label"])}]->'
+ f'(dst IS {self._q(edge["dst"])})'
+ ),
+ columns=columns,
+ suffix="FETCH FIRST 20 ROWS ONLY",
+ )
+ pairs.append(
+ OracleTemplateCorpusPair(
+ question=(
+ f"Show sample {edge['src']} to {edge['dst']} relationships "
+ f"through {edge['label']}."
+ ),
+ query=query,
+ template_id=f"one_hop_{index}",
+ category="one_hop_traversal",
+ labels=[edge["src"], edge["label"], edge["dst"]],
+ )
+ )
+ return pairs
+
+ def _edge_identity_templates(self) -> list[OracleTemplateCorpusPair]:
+ pairs = []
+ for index, edge in enumerate(self.edges, start=1):
+ query = self._graph_table_query(
+ match=(
+ f'(src IS {self._q(edge["src"])})'
+ f'-[rel IS {self._q(edge["label"])}]->'
+ f'(dst IS {self._q(edge["dst"])})'
+ ),
+ columns=[
+ "VERTEX_ID(src) AS source_vertex_id",
+ "EDGE_ID(rel) AS edge_id",
+ "VERTEX_ID(dst) AS target_vertex_id",
+ ],
+ suffix="FETCH FIRST 20 ROWS ONLY",
+ )
+ pairs.append(
+ OracleTemplateCorpusPair(
+ question=(
+ f"List graph element identifiers for {edge['label']} edges "
+ f"from {edge['src']} to {edge['dst']}."
+ ),
+ query=query,
+ template_id=f"edge_identity_{index}",
+ category="element_identity",
+ labels=[edge["src"], edge["label"], edge["dst"]],
+ )
+ )
+ return pairs
+
+ def _aggregate_templates(self) -> list[OracleTemplateCorpusPair]:
+ pairs = []
+ for index, edge in enumerate(self.edges, start=1):
+ dst = self.vertex_by_label.get(edge["dst"])
+ if dst is None:
+ continue
+ dst_prop = self._display_property(dst)
+ alias = self._alias(f"{edge['dst']}_{dst_prop['name']}")
+ query = self._graph_table_query(
+ match=(
+ f'(src IS {self._q(edge["src"])})'
+ f'-[rel IS {self._q(edge["label"])}]->'
+ f'(dst IS {self._q(edge["dst"])})'
+ ),
+ columns=[
+ self._property_column("dst", dst_prop, alias),
+ "EDGE_ID(rel) AS edge_id",
+ ],
+ select_prefix=f"SELECT gt.{alias}, COUNT(gt.edge_id) AS relationship_count",
+ group_order_suffix=(
+ f"GROUP BY gt.{alias} "
+ f"ORDER BY relationship_count DESC, gt.{alias} "
+ "FETCH FIRST 20 ROWS ONLY"
+ ),
+ )
+ pairs.append(
+ OracleTemplateCorpusPair(
+ question=(
+ f"Count {edge['label']} relationships grouped by each "
+ f"{edge['dst']} {dst_prop['name']}."
+ ),
+ query=query,
+ template_id=f"aggregate_by_target_{index}",
+ category="aggregation",
+ labels=[edge["src"], edge["label"], edge["dst"]],
+ )
+ )
+ return pairs
+
+ def _two_hop_templates(self) -> list[OracleTemplateCorpusPair]:
+ pairs = []
+ edge_pairs = []
+ for first in self.edges:
+ for second in self.edges:
+ if first["dst"] == second["src"]:
+ edge_pairs.append((first, second))
+
+ for index, (first, second) in enumerate(edge_pairs, start=1):
+ start = self.vertex_by_label.get(first["src"])
+ end = self.vertex_by_label.get(second["dst"])
+ if start is None or end is None:
+ continue
+ start_prop = self._display_property(start)
+ end_prop = self._display_property(end)
+ query = self._graph_table_query(
+ match=(
+ f'(start_node IS {self._q(first["src"])})'
+ f'-[first_edge IS {self._q(first["label"])}]->'
+ f'(middle_node IS {self._q(first["dst"])})'
+ f'-[second_edge IS {self._q(second["label"])}]->'
+ f'(end_node IS {self._q(second["dst"])})'
+ ),
+ columns=[
+ self._property_column(
+ "start_node", start_prop, f"start_{first['src']}_{start_prop['name']}"
+ ),
+ self._property_column(
+ "end_node", end_prop, f"end_{second['dst']}_{end_prop['name']}"
+ ),
+ ],
+ suffix="FETCH FIRST 20 ROWS ONLY",
+ )
+ pairs.append(
+ OracleTemplateCorpusPair(
+ question=(
+ f"Show two-hop paths from {first['src']} through "
+ f"{first['label']} and {second['label']} to {second['dst']}."
+ ),
+ query=query,
+ template_id=f"two_hop_{index}",
+ category="two_hop_traversal",
+ labels=[
+ first["src"],
+ first["label"],
+ first["dst"],
+ second["label"],
+ second["dst"],
+ ],
+ )
+ )
+ return pairs
+
+ def _bounded_path_templates(self) -> list[OracleTemplateCorpusPair]:
+ pairs = []
+ recursive_edges = [edge for edge in self.edges if edge["src"] == edge["dst"]]
+ for index, edge in enumerate(recursive_edges, start=1):
+ vertex = self.vertex_by_label.get(edge["src"])
+ if vertex is None:
+ continue
+ vertex_prop = self._display_property(vertex)
+ query = self._graph_table_query(
+ match=(
+ f'(start_node IS {self._q(edge["src"])})'
+ f'-[path_edge IS {self._q(edge["label"])}]->{{1,3}}'
+ f'(end_node IS {self._q(edge["dst"])})'
+ ),
+ columns=[
+ "MATCHNUM() AS match_number",
+ "VERTEX_ID(start_node) AS source_vertex_id",
+ "VERTEX_ID(end_node) AS target_vertex_id",
+ self._property_column(
+ "end_node", vertex_prop, f"{edge['dst']}_{vertex_prop['name']}"
+ ),
+ ],
+ suffix="FETCH FIRST 20 ROWS ONLY",
+ )
+ pairs.append(
+ OracleTemplateCorpusPair(
+ question=(
+ f"Find {edge['dst']} vertices reachable through one to three "
+ f"{edge['label']} hops."
+ ),
+ query=query,
+ template_id=f"bounded_path_{index}",
+ category="bounded_path",
+ labels=[edge["src"], edge["label"], edge["dst"]],
+ )
+ )
+ return pairs
+
+ def _graph_table_query(
+ self,
+ *,
+ match: str,
+ columns: Iterable[str],
+ suffix: str = "",
+ select_prefix: str = "SELECT *",
+ group_order_suffix: str = "",
+ ) -> str:
+ query = (
+ f"{select_prefix} FROM GRAPH_TABLE ("
+ f"{self._q(self.graph_name)} MATCH {match} "
+ f"COLUMNS ({', '.join(columns)})) gt"
+ )
+ if group_order_suffix:
+ return f"{query} {group_order_suffix}"
+ if suffix:
+ return f"{query} {suffix}"
+ return query
+
+ def _display_property(
+ self, item: dict[str, Any], exclude: set[str] | None = None
+ ) -> dict[str, Any]:
+ exclude = exclude or set()
+ properties = [
+ column
+ for column in item.get("columns") or []
+ if column.get("name") not in exclude
+ ]
+ if not properties:
+ raise ValueError(f"No display properties found for {item.get('label')}")
+
+ for prop in properties:
+ if self._type_family(prop) in TEXT_TYPES:
+ return prop
+ for prop in properties:
+ if self._type_family(prop) in NUMERIC_TYPES:
+ return prop
+ return properties[0]
+
+ def _property_column(self, variable: str, prop: dict[str, Any], alias: str) -> str:
+ return f"{variable}.{self._q(prop['name'])} AS {self._alias(alias)}"
+
+ def _type_family(self, prop: dict[str, Any]) -> str:
+ return str(prop.get("type", "")).split("(", 1)[0].upper()
+
+ def _q(self, name: str) -> str:
+ return OracleNameSanitizer.quote(name)
+
+ def _alias(self, name: str) -> str:
+ return OracleNameSanitizer.alias(name)
diff --git a/app/impl/oracle_sqlpgq/schema/__init__.py b/app/impl/oracle_sqlpgq/schema/__init__.py
new file mode 100644
index 0000000..ccc5b79
--- /dev/null
+++ b/app/impl/oracle_sqlpgq/schema/__init__.py
@@ -0,0 +1,2 @@
+"""Oracle SQL/PGQ schema support."""
+
diff --git a/app/impl/oracle_sqlpgq/schema/schema_parser.py b/app/impl/oracle_sqlpgq/schema/schema_parser.py
new file mode 100644
index 0000000..20533b9
--- /dev/null
+++ b/app/impl/oracle_sqlpgq/schema/schema_parser.py
@@ -0,0 +1,447 @@
+import json
+from pathlib import Path
+from typing import Any, Dict, List
+
+from app.core.schema.edge import Edge
+from app.core.schema.node import Node
+from app.core.schema.schema_graph import SchemaGraph
+from app.core.schema.schema_parser import SchemaParser
+from app.impl.oracle_sqlpgq.utils.sqlpgq import (
+ OracleNameSanitizer,
+ OracleTypeMapper,
+ property_list,
+)
+
+
+class OracleSqlPgqSchemaParser(SchemaParser):
+ """Parse and emit Oracle SQL Property Graph artifacts.
+
+ Oracle SQL property graphs are metadata over relational vertex and edge
+ tables. This adapter therefore emits table DDL first, then a
+ CREATE PROPERTY GRAPH statement over those tables.
+ """
+
+ def __init__(
+ self,
+ db_id: str = "",
+ instance_path: str = "",
+ enforced: bool = True,
+ include_foreign_keys: bool = True,
+ promote_mixed_property_types: bool = False,
+ ):
+ self.db_id = db_id or "graph"
+ self.instance_path = Path(instance_path) if instance_path else None
+ self.enforced = enforced
+ self.include_foreign_keys = include_foreign_keys
+ self.promote_mixed_property_types = promote_mixed_property_types
+ self.schema_graph = SchemaGraph(self.db_id)
+ self._raw_config: Dict[str, Any] | List[Dict[str, Any]] | None = None
+ if self.instance_path:
+ self._load()
+
+ def _load(self) -> None:
+ path = self.instance_path
+ if path is None:
+ return
+ candidates = []
+ if path.is_file():
+ candidates.append(path)
+ else:
+ candidates.extend(
+ [
+ path / "oracle_schema.json",
+ path / "schema.json",
+ path / "import_config.json",
+ path / "example_schema.json",
+ ]
+ )
+ for candidate in candidates:
+ if candidate.exists():
+ with open(candidate, encoding="utf-8") as file:
+ self._raw_config = json.load(file)
+ self.schema_graph = self._schema_graph_from_json(self._raw_config)
+ return
+ print(f"[ERROR] Oracle SQL/PGQ schema file not found under {path}")
+
+ def _schema_graph_from_json(self, data: Dict[str, Any] | List[Dict[str, Any]]) -> SchemaGraph:
+ graph = SchemaGraph(self.db_id)
+ schema_items = data.get("schema", []) if isinstance(data, dict) else data
+ for item in schema_items:
+ if item.get("type") == "VERTEX":
+ graph.add_node(
+ Node(
+ label=item["label"],
+ properties=item.get("properties", []),
+ primary=item.get("primary", "id"),
+ )
+ )
+ for item in schema_items:
+ if item.get("type") == "EDGE":
+ constraints = item.get("constraints", [])
+ graph.add_edge(
+ Edge(
+ label=item["label"],
+ properties=item.get("properties", []),
+ src_dst_list=[list(pair) for pair in constraints],
+ )
+ )
+ return graph
+
+ def get_schema_graph(self) -> SchemaGraph:
+ return self.schema_graph
+
+ def save_schema_to_file(
+ self, output_dir, schema_graph: SchemaGraph, domain: str, subdomain: str
+ ):
+ output_dir = Path(output_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ base_name = f"{domain.replace(' ', '_')}_{subdomain.replace(' ', '_')}".strip("_")
+ if not base_name:
+ base_name = schema_graph.db_id or "oracle_graph"
+ graph_name = OracleNameSanitizer.clean(base_name, fallback="GRAPH")
+
+ manifest = self.build_manifest(schema_graph, graph_name)
+
+ manifest_path = output_dir / f"{base_name}_oracle_schema.json"
+ table_ddl_path = output_dir / f"{base_name}_oracle_tables.sql"
+ graph_ddl_path = output_dir / f"{base_name}_oracle_property_graph.sql"
+ loader_path = output_dir / f"{base_name}_oracle_loader.py"
+
+ with open(manifest_path, "w", encoding="utf-8") as file:
+ json.dump(manifest, file, indent=2, ensure_ascii=False)
+ table_ddl_path.write_text(manifest["table_ddl"], encoding="utf-8")
+ graph_ddl_path.write_text(manifest["property_graph_ddl"], encoding="utf-8")
+ loader_path.write_text(self._loader_template(manifest), encoding="utf-8")
+
+ return str(manifest_path)
+
+ def build_manifest(self, schema_graph: SchemaGraph, graph_name: str | None = None) -> Dict:
+ graph_name = graph_name or OracleNameSanitizer.clean(schema_graph.db_id, fallback="GRAPH")
+ vertex_meta = self._vertex_metadata(schema_graph)
+ edge_meta = self._edge_metadata(schema_graph, vertex_meta)
+ self._assign_unique_graph_labels(vertex_meta, edge_meta)
+ if self.promote_mixed_property_types:
+ self._promote_mixed_property_types(vertex_meta, edge_meta)
+ table_ddl = self._build_table_ddl(vertex_meta, edge_meta)
+ property_graph_ddl = self._build_property_graph_ddl(graph_name, vertex_meta, edge_meta)
+ return {
+ "backend": "oracle_sqlpgq",
+ "graph_name": graph_name,
+ "schema": self._schema_json(schema_graph),
+ "vertices": vertex_meta,
+ "edges": edge_meta,
+ "table_ddl": table_ddl,
+ "property_graph_ddl": property_graph_ddl,
+ "load_order": [item["table"] for item in vertex_meta]
+ + [item["table"] for item in edge_meta],
+ }
+
+ def _schema_json(self, schema_graph: SchemaGraph) -> List[Dict[str, Any]]:
+ schema_data: List[Dict[str, Any]] = []
+ for label, node in schema_graph.node_dict.items():
+ schema_data.append(
+ {
+ "type": "VERTEX",
+ "label": label,
+ "properties": node.properties,
+ "primary": node.primary,
+ }
+ )
+ for label, edge in schema_graph.edge_dict.items():
+ schema_data.append(
+ {
+ "type": "EDGE",
+ "label": label,
+ "properties": edge.properties,
+ "constraints": edge.src_dst_list,
+ }
+ )
+ return schema_data
+
+ def _vertex_metadata(self, schema_graph: SchemaGraph) -> List[Dict[str, Any]]:
+ vertices = []
+ for label, node in schema_graph.node_dict.items():
+ columns = []
+ pk = node.primary or "id"
+ has_pk = False
+ for prop_name, prop_type in property_list(node.properties):
+ if prop_name == pk:
+ has_pk = True
+ columns.append(
+ {
+ "name": prop_name,
+ "quoted": OracleNameSanitizer.quote(prop_name, fallback="COL"),
+ "type": OracleTypeMapper.to_oracle(prop_type),
+ "nullable": prop_name != pk,
+ }
+ )
+ if not has_pk:
+ columns.insert(
+ 0,
+ {
+ "name": pk,
+ "quoted": OracleNameSanitizer.quote(pk, fallback="ID"),
+ "type": "VARCHAR2(4000)",
+ "nullable": False,
+ },
+ )
+ vertices.append(
+ {
+ "label": label,
+ "graph_label": label,
+ "table": OracleNameSanitizer.clean(label, fallback="VERTEX"),
+ "quoted_table": OracleNameSanitizer.quote(label, fallback="VERTEX"),
+ "quoted_label": OracleNameSanitizer.quote(label, fallback="VERTEX"),
+ "primary": pk,
+ "quoted_primary": OracleNameSanitizer.quote(pk, fallback="ID"),
+ "columns": columns,
+ }
+ )
+ return vertices
+
+ def _edge_metadata(
+ self, schema_graph: SchemaGraph, vertex_meta: List[Dict[str, Any]]
+ ) -> List[Dict[str, Any]]:
+ vertex_by_label = {item["label"]: item for item in vertex_meta}
+ edges = []
+ for label, edge in schema_graph.edge_dict.items():
+ for src, dst in edge.src_dst_list:
+ src_meta = vertex_by_label.get(src)
+ dst_meta = vertex_by_label.get(dst)
+ if src_meta is None or dst_meta is None:
+ continue
+ table_name = OracleNameSanitizer.clean(f"{src}_{label}_{dst}", fallback="EDGE")
+ columns = [
+ {
+ "name": "EDGE_ID",
+ "quoted": OracleNameSanitizer.quote("EDGE_ID"),
+ "type": "NUMBER GENERATED BY DEFAULT AS IDENTITY",
+ "nullable": False,
+ },
+ {
+ "name": "SRC_ID",
+ "quoted": OracleNameSanitizer.quote("SRC_ID"),
+ "type": self._primary_type(src_meta),
+ "nullable": False,
+ },
+ {
+ "name": "DST_ID",
+ "quoted": OracleNameSanitizer.quote("DST_ID"),
+ "type": self._primary_type(dst_meta),
+ "nullable": False,
+ },
+ ]
+ for prop_name, prop_type in property_list(edge.properties):
+ columns.append(
+ {
+ "name": prop_name,
+ "quoted": OracleNameSanitizer.quote(prop_name, fallback="COL"),
+ "type": OracleTypeMapper.to_oracle(prop_type),
+ "nullable": True,
+ }
+ )
+ edges.append(
+ {
+ "label": label,
+ "graph_label": label,
+ "table": table_name,
+ "quoted_table": OracleNameSanitizer.quote(table_name, fallback="EDGE"),
+ "quoted_label": OracleNameSanitizer.quote(label, fallback="EDGE"),
+ "primary": "EDGE_ID",
+ "quoted_primary": OracleNameSanitizer.quote("EDGE_ID"),
+ "src": src,
+ "dst": dst,
+ "src_table": src_meta["quoted_table"],
+ "dst_table": dst_meta["quoted_table"],
+ "src_pk": src_meta["quoted_primary"],
+ "dst_pk": dst_meta["quoted_primary"],
+ "columns": columns,
+ }
+ )
+ return edges
+
+ def _assign_unique_graph_labels(
+ self,
+ vertex_meta: List[Dict[str, Any]],
+ edge_meta: List[Dict[str, Any]],
+ ) -> None:
+ used_labels = set()
+ for vertex in vertex_meta:
+ graph_label = self._unique_label(vertex["label"], used_labels)
+ vertex["graph_label"] = graph_label
+ vertex["quoted_label"] = OracleNameSanitizer.quote(graph_label, fallback="VERTEX")
+ for edge in edge_meta:
+ graph_label = self._unique_label(edge["table"], used_labels)
+ edge["graph_label"] = graph_label
+ edge["quoted_label"] = OracleNameSanitizer.quote(graph_label, fallback="EDGE")
+
+ def _unique_label(self, preferred: str, used_labels: set[str]) -> str:
+ base = OracleNameSanitizer.clean(preferred, fallback="LABEL")
+ candidate = base
+ suffix = 1
+ while candidate.upper() in used_labels:
+ suffix += 1
+ max_base_len = 128 - len(f"_{suffix}")
+ candidate = f"{base[:max_base_len]}_{suffix}"
+ used_labels.add(candidate.upper())
+ return candidate
+
+ def _primary_type(self, vertex_meta: Dict[str, Any]) -> str:
+ for column in vertex_meta["columns"]:
+ if column["name"] == vertex_meta["primary"]:
+ return column["type"]
+ return "VARCHAR2(4000)"
+
+ def _build_table_ddl(
+ self, vertex_meta: List[Dict[str, Any]], edge_meta: List[Dict[str, Any]]
+ ) -> str:
+ statements = []
+ for vertex in vertex_meta:
+ statements.append(self._create_table(vertex))
+ for edge in edge_meta:
+ statements.append(self._create_table(edge, is_edge=True))
+ return "\n\n".join(statements) + "\n"
+
+ def _create_table(self, item: Dict[str, Any], is_edge: bool = False) -> str:
+ lines = []
+ for column in item["columns"]:
+ nullable = " NOT NULL" if not column["nullable"] else ""
+ lines.append(f" {column['quoted']} {column['type']}{nullable}")
+ constraint_name = OracleNameSanitizer.quote(f"{item['table']}_PK", fallback="PK")
+ lines.append(f" CONSTRAINT {constraint_name} PRIMARY KEY ({item['quoted_primary']})")
+ if is_edge and self.include_foreign_keys:
+ src_fk = OracleNameSanitizer.quote(f"{item['table']}_SRC_FK", fallback="SRC_FK")
+ dst_fk = OracleNameSanitizer.quote(f"{item['table']}_DST_FK", fallback="DST_FK")
+ lines.append(
+ f" CONSTRAINT {src_fk} FOREIGN KEY ({OracleNameSanitizer.quote('SRC_ID')}) "
+ f"REFERENCES {item['src_table']} ({item['src_pk']})"
+ )
+ lines.append(
+ f" CONSTRAINT {dst_fk} FOREIGN KEY ({OracleNameSanitizer.quote('DST_ID')}) "
+ f"REFERENCES {item['dst_table']} ({item['dst_pk']})"
+ )
+ return f"CREATE TABLE {item['quoted_table']} (\n" + ",\n".join(lines) + "\n);"
+
+ def _build_property_graph_ddl(
+ self, graph_name: str, vertex_meta: List[Dict[str, Any]], edge_meta: List[Dict[str, Any]]
+ ) -> str:
+ vertex_defs = []
+ for vertex in vertex_meta:
+ props = self._properties_clause(vertex, exclude=set())
+ vertex_defs.append(
+ f" {vertex['quoted_table']}\n"
+ f" KEY ({vertex['quoted_primary']})\n"
+ f" LABEL {vertex['quoted_label']}{props}"
+ )
+ edge_defs = []
+ for edge in edge_meta:
+ props = self._properties_clause(edge, exclude={"EDGE_ID", "SRC_ID", "DST_ID"})
+ edge_defs.append(
+ f" {edge['quoted_table']}\n"
+ f" KEY ({edge['quoted_primary']})\n"
+ f" SOURCE KEY ({OracleNameSanitizer.quote('SRC_ID')}) "
+ f"REFERENCES {edge['src_table']} ({edge['src_pk']})\n"
+ f" DESTINATION KEY ({OracleNameSanitizer.quote('DST_ID')}) "
+ f"REFERENCES {edge['dst_table']} ({edge['dst_pk']})\n"
+ f" LABEL {edge['quoted_label']}{props}"
+ )
+ enforced = "\n OPTIONS (ENFORCED MODE)" if self.enforced else ""
+ quoted_graph = OracleNameSanitizer.quote(graph_name, fallback="GRAPH")
+ return (
+ f"CREATE OR REPLACE PROPERTY GRAPH {quoted_graph}\n"
+ " VERTEX TABLES (\n"
+ + ",\n".join(vertex_defs)
+ + "\n )\n"
+ " EDGE TABLES (\n"
+ + ",\n".join(edge_defs)
+ + "\n )"
+ + enforced
+ + ";\n"
+ )
+
+ def _properties_clause(self, item: Dict[str, Any], exclude: set[str]) -> str:
+ columns = [column["quoted"] for column in item["columns"] if column["name"] not in exclude]
+ if not columns:
+ return ""
+ return "\n PROPERTIES (" + ", ".join(columns) + ")"
+
+ def _promote_mixed_property_types(
+ self,
+ vertex_meta: List[Dict[str, Any]],
+ edge_meta: List[Dict[str, Any]],
+ ) -> None:
+ structural = {"EDGE_ID", "SRC_ID", "DST_ID"}
+ by_name: Dict[str, set[str]] = {}
+ for item in vertex_meta + edge_meta:
+ for column in item["columns"]:
+ if column["name"] in structural:
+ continue
+ by_name.setdefault(column["name"], set()).add(column["type"].upper())
+
+ mixed_names = {name for name, types in by_name.items() if len(types) > 1}
+ if not mixed_names:
+ return
+ for item in vertex_meta + edge_meta:
+ for column in item["columns"]:
+ if column["name"] in mixed_names:
+ column["type"] = "VARCHAR2(4000)"
+
+ def _loader_template(self, manifest: Dict[str, Any]) -> str:
+ compact = json.dumps(
+ {
+ "load_order": manifest["load_order"],
+ "vertices": manifest["vertices"],
+ "edges": manifest["edges"],
+ },
+ indent=2,
+ ensure_ascii=False,
+ )
+ return f'''"""CSV loader template for Oracle SQL/PGQ artifacts.
+
+Set ORACLE_DSN, ORACLE_USER, ORACLE_PASSWORD and place CSV files next to this script.
+CSV filenames are expected to match table names with .csv suffix.
+"""
+
+import csv
+import os
+from pathlib import Path
+
+import oracledb
+
+
+MANIFEST = {compact}
+
+
+def main():
+ connection = oracledb.connect(
+ user=os.environ["ORACLE_USER"],
+ password=os.environ["ORACLE_PASSWORD"],
+ dsn=os.environ["ORACLE_DSN"],
+ )
+ connection.autocommit = False
+ root = Path(__file__).resolve().parent
+ with connection.cursor() as cursor:
+ for table_name in MANIFEST["load_order"]:
+ csv_path = root / f"{{table_name}}.csv"
+ if not csv_path.exists():
+ print(f"Skipping missing {{csv_path}}")
+ continue
+ with open(csv_path, newline="", encoding="utf-8") as file:
+ reader = csv.DictReader(file)
+ rows = list(reader)
+ if not rows:
+ continue
+ columns = list(rows[0].keys())
+ column_sql = ", ".join(f'"{{column}}"' for column in columns)
+ bind_sql = ", ".join(f":{{index + 1}}" for index in range(len(columns)))
+ sql = f'INSERT INTO "{{table_name}}" ({{column_sql}}) VALUES ({{bind_sql}})'
+ cursor.executemany(sql, [[row.get(column) for column in columns] for row in rows])
+ print(f"Loaded {{len(rows)}} rows into {{table_name}}")
+ connection.commit()
+ connection.close()
+
+
+if __name__ == "__main__":
+ main()
+'''
diff --git a/app/impl/oracle_sqlpgq/translator/__init__.py b/app/impl/oracle_sqlpgq/translator/__init__.py
new file mode 100644
index 0000000..b687f72
--- /dev/null
+++ b/app/impl/oracle_sqlpgq/translator/__init__.py
@@ -0,0 +1,2 @@
+"""Oracle SQL/PGQ query translators."""
+
diff --git a/app/impl/oracle_sqlpgq/translator/oracle_sqlpgq_query_translator.py b/app/impl/oracle_sqlpgq/translator/oracle_sqlpgq_query_translator.py
new file mode 100644
index 0000000..badac69
--- /dev/null
+++ b/app/impl/oracle_sqlpgq/translator/oracle_sqlpgq_query_translator.py
@@ -0,0 +1,5485 @@
+from collections import Counter
+import re
+from typing import Dict, Iterable, List, Tuple
+
+from app.core.clauses.clause import Clause
+from app.core.clauses.match_clause import EdgePattern, MatchClause, NodePattern, PathPattern
+from app.core.clauses.return_clause import ReturnBody, ReturnClause, ReturnItem, SortItem
+from app.core.clauses.where_clause import CompareExpression, WhereClause
+from app.core.clauses.with_clause import WithClause
+from app.core.translator.query_translator import QueryTranslator
+from app.impl.oracle_sqlpgq.utils.sqlpgq import (
+ OracleNameSanitizer,
+ validate_graph_table_query,
+ validate_property_graph_ddl,
+)
+
+
+class OracleSqlPgqQueryTranslator(QueryTranslator):
+ """Translate the framework's graph-query IR into Oracle SQL/PGQ."""
+
+ AGGREGATE_FUNCTIONS = {"COUNT", "AVG", "SUM", "MIN", "MAX"}
+
+ def __init__(
+ self,
+ graph_name: str = "GRAPH",
+ node_label_map: Dict[str, List[str]] | None = None,
+ edge_label_map: Dict[str, List[str]] | None = None,
+ property_type_map: Dict[str, Dict[str, str]] | None = None,
+ node_primary_key_map: Dict[str, str] | None = None,
+ edge_primary_key_map: Dict[str, str] | None = None,
+ strict_property_validation: bool = False,
+ ):
+ self.graph_name = graph_name
+ self.node_label_map = self._normalize_label_map(node_label_map or {})
+ self.edge_label_map = self._normalize_label_map(edge_label_map or {})
+ self.property_type_map = property_type_map or {}
+ self.node_primary_key_map = node_primary_key_map or {}
+ self.edge_primary_key_map = edge_primary_key_map or {}
+ self.strict_property_validation = strict_property_validation
+ self._var_kinds: Dict[str, str] = {}
+ self._var_sql_names: Dict[str, str] = {}
+ self._var_labels: Dict[str, str] = {}
+ self._path_variables: Dict[str, List[Tuple[str, str]]] = {}
+ self._path_variable_has_quantifier: Dict[str, bool] = {}
+ self._path_variable_quantified_edges: Dict[str, set[str]] = {}
+ self._var_property_redirects: Dict[Tuple[str, str], str] = {}
+ self._reserved_variables: set[str] = set()
+ self._pattern_where_expressions: List[str] = []
+ self._auto_node_index = 0
+ self._auto_edge_index = 0
+
+ def _normalize_label_map(self, label_map: Dict[str, List[str]]) -> Dict[str, List[str]]:
+ normalized: Dict[str, List[str]] = {}
+ for source, targets in label_map.items():
+ values = list(targets or [])
+ normalized[source] = values
+ normalized.setdefault(str(source).lower(), values)
+ clean_source = OracleNameSanitizer.clean(source)
+ normalized.setdefault(clean_source, values)
+ normalized.setdefault(clean_source.lower(), values)
+ return normalized
+
+ def grammar_check(self, query: str) -> bool:
+ normalized = " ".join(str(query or "").strip().split()).upper()
+ if "CREATE" in normalized and "PROPERTY GRAPH" in normalized:
+ return validate_property_graph_ddl(query)
+ return validate_graph_table_query(query)
+
+ def translate(self, query_pattern: List[Clause]) -> str:
+ self._reset()
+ if any(isinstance(clause, MatchClause) and clause.optional for clause in query_pattern):
+ return self._translate_optional_match(query_pattern)
+ if any(isinstance(clause, WithClause) for clause in query_pattern):
+ return self._translate_supported_with(query_pattern)
+ match_clauses: List[MatchClause] = []
+ where_expressions: List[CompareExpression] = []
+ return_body: ReturnBody | None = None
+ distinct = False
+
+ for clause in query_pattern:
+ if isinstance(clause, MatchClause):
+ match_clauses.append(clause)
+ elif isinstance(clause, WhereClause):
+ where_expressions.extend(self._as_compare_list(clause.compare_expression_list))
+ elif isinstance(clause, ReturnClause):
+ return_body = clause.return_body
+ distinct = clause.distinct
+ elif isinstance(clause, WithClause):
+ if return_body is None:
+ return_body = clause.return_body
+ where_expressions.extend(self._as_compare_list(clause.compare_expression_list))
+
+ if not match_clauses:
+ raise ValueError("Oracle SQL/PGQ translation requires at least one MatchClause.")
+
+ match_parts = [self._translate_match_clause(clause) for clause in match_clauses]
+ pattern_predicates, where_expressions = self._extract_pattern_predicates(where_expressions)
+ return_pattern_predicates = self._return_pattern_predicates(return_body)
+ if return_pattern_predicates:
+ return self._translate_return_pattern_predicate_cte(
+ match_parts,
+ where_expressions,
+ return_body,
+ distinct,
+ return_pattern_predicates,
+ )
+ if pattern_predicates:
+ return self._translate_pattern_predicate_cte(
+ match_parts,
+ where_expressions,
+ return_body,
+ distinct,
+ pattern_predicates,
+ )
+ graph_table_parts = [
+ OracleNameSanitizer.quote(self.graph_name, fallback="GRAPH"),
+ "MATCH " + ", ".join(match_parts),
+ ]
+
+ where_parts = self._where_parts(where_expressions)
+ if where_parts:
+ graph_table_parts.append("WHERE " + " AND ".join(where_parts))
+
+ aggregate_query = self._has_aggregate(return_body)
+ hidden_sort_aliases = self._hidden_sort_aliases(return_body, aggregate_query)
+ graph_table_parts.append(
+ f"COLUMNS ({self._translate_columns(return_body, aggregate_query)})"
+ )
+
+ if return_body is not None and distinct and hidden_sort_aliases and not aggregate_query:
+ visible_aliases = [
+ OracleNameSanitizer.alias(alias)
+ for alias in self._resolved_return_aliases(return_body)
+ ]
+ hidden_aliases = [OracleNameSanitizer.alias(alias) for alias in hidden_sort_aliases]
+ inner_select = (
+ "SELECT DISTINCT " + ", ".join([*visible_aliases, *hidden_aliases]) + "\n"
+ f"FROM GRAPH_TABLE (\n {' '.join(graph_table_parts)}\n) gt"
+ )
+ query = (
+ "SELECT "
+ + ", ".join(visible_aliases)
+ + "\nFROM (\n"
+ + self._indent_sql(inner_select)
+ + "\n) distinct_rows"
+ )
+ query += self._outer_group_order_and_paging(return_body, aggregate_query)
+ return query
+
+ select_keyword = self._outer_select(
+ return_body,
+ distinct,
+ aggregate_query,
+ hidden_sort_aliases,
+ )
+ query = f"{select_keyword}\nFROM GRAPH_TABLE (\n {' '.join(graph_table_parts)}\n) gt"
+
+ if return_body is not None:
+ query += self._outer_group_order_and_paging(return_body, aggregate_query)
+
+ return query
+
+ def _extract_pattern_predicates(
+ self,
+ where_expressions: List[CompareExpression],
+ ) -> Tuple[List[Tuple[bool, str]], List[CompareExpression]]:
+ predicates: List[Tuple[bool, str]] = []
+ residual: List[CompareExpression] = []
+ for expression in where_expressions:
+ raw_expression = getattr(expression, "raw_expression", "")
+ raw_predicates, raw_residual = self._extract_raw_path_predicate_parts(
+ raw_expression or "",
+ )
+ if raw_predicates:
+ predicates.extend(raw_predicates)
+ if raw_residual:
+ residual.append(CompareExpression("", "", "raw", "", raw_residual))
+ continue
+ match = re.fullmatch(
+ r"\s*(?PNOT\s+)?EXISTS\s*\(\s*(?P\(.+\))\s*\)\s*",
+ raw_expression or "",
+ flags=re.IGNORECASE | re.DOTALL,
+ )
+ if match:
+ predicates.append((bool(match.group("not")), match.group("pattern")))
+ continue
+ residual.append(expression)
+ return predicates, residual
+
+ def _extract_raw_path_predicate_parts(
+ self,
+ raw_expression: str,
+ ) -> Tuple[List[Tuple[bool, str]], str]:
+ if not raw_expression:
+ return [], ""
+ predicates: List[Tuple[bool, str]] = []
+ residual_parts: List[str] = []
+ for part in self._split_top_level_and(raw_expression):
+ exists_match = re.fullmatch(
+ r"\s*(?PNOT\s+)?EXISTS\s*\(\s*(?P\(.+\))\s*\)\s*",
+ part,
+ flags=re.IGNORECASE | re.DOTALL,
+ )
+ if exists_match:
+ predicates.append((bool(exists_match.group("not")), exists_match.group("pattern")))
+ continue
+ match = re.fullmatch(
+ r"\s*(?PNOT\s+)?(?P"
+ r"\([^()]*\)\s*"
+ r"(?:(?:<-\[[^\]]*\]-)|(?:-\[[^\]]*\]->)|(?:-\[[^\]]*\]-))\s*"
+ r"\([^()]*\)"
+ r"(?:\s*(?:(?:<-\[[^\]]*\]-)|(?:-\[[^\]]*\]->)|(?:-\[[^\]]*\]-))\s*"
+ r"\([^()]*\))*"
+ r")\s*",
+ part,
+ flags=re.IGNORECASE | re.DOTALL,
+ )
+ if match:
+ predicates.append((bool(match.group("not")), match.group("pattern")))
+ else:
+ residual_parts.append(part)
+ return predicates, " AND ".join(residual_parts)
+
+ def _split_top_level_and(self, expression: str) -> List[str]:
+ parts: List[str] = []
+ start = 0
+ depth = 0
+ in_single = False
+ in_double = False
+ index = 0
+ while index < len(expression):
+ char = expression[index]
+ if char == "'" and not in_double:
+ in_single = not in_single
+ index += 1
+ continue
+ if char == '"' and not in_single:
+ in_double = not in_double
+ index += 1
+ continue
+ if in_single or in_double:
+ index += 1
+ continue
+ if char == "(":
+ depth += 1
+ elif char == ")":
+ depth -= 1
+ elif (
+ depth == 0
+ and expression[index : index + 3].upper() == "AND"
+ and (index == 0 or not expression[index - 1].isalnum())
+ and (index + 3 == len(expression) or not expression[index + 3].isalnum())
+ ):
+ parts.append(expression[start:index].strip())
+ start = index + 3
+ index += 3
+ continue
+ index += 1
+ tail = expression[start:].strip()
+ if tail:
+ parts.append(tail)
+ return parts
+
+ def _return_pattern_predicates(
+ self,
+ return_body: ReturnBody | None,
+ ) -> Dict[int, Tuple[bool, str]]:
+ if return_body is None:
+ return {}
+ predicates: Dict[int, Tuple[bool, str]] = {}
+ for index, item in enumerate(return_body.return_item_list):
+ match = re.fullmatch(
+ r"\s*(?PNOT\s+)?EXISTS\s*\(\s*(?P\(.+\))\s*\)\s*",
+ item.expression or "",
+ flags=re.IGNORECASE | re.DOTALL,
+ )
+ if match:
+ predicates[index] = (bool(match.group("not")), match.group("pattern"))
+ return predicates
+
+ def _translate_optional_match(self, query_pattern: List[Clause]) -> str:
+ optional_indexes = [
+ index
+ for index, clause in enumerate(query_pattern)
+ if isinstance(clause, MatchClause) and clause.optional
+ ]
+ if len(optional_indexes) != 1:
+ raise ValueError("Only one OPTIONAL MATCH clause is supported.")
+
+ optional_index = optional_indexes[0]
+ if optional_index == 0:
+ raise ValueError("Standalone OPTIONAL MATCH is not supported by Graph IR translation.")
+
+ if any(isinstance(clause, MatchClause) for clause in query_pattern[optional_index + 1 :]):
+ raise ValueError("OPTIONAL MATCH must be the final MATCH stage.")
+
+ before_optional = query_pattern[:optional_index]
+ if not any(isinstance(clause, MatchClause) for clause in before_optional):
+ raise ValueError("OPTIONAL MATCH requires a preceding MATCH.")
+
+ if any(isinstance(clause, WithClause) for clause in before_optional):
+ return self._translate_supported_with(query_pattern)
+
+ carried_variables = sorted(
+ self._declared_variables_in_match_clauses(
+ [clause for clause in before_optional if isinstance(clause, MatchClause)]
+ )
+ )
+ if not carried_variables:
+ raise ValueError("OPTIONAL MATCH requires named variables to carry forward.")
+
+ synthetic_with = WithClause(
+ ReturnBody(
+ [
+ ReturnItem(
+ symbolic_name=variable,
+ property="",
+ alias="",
+ function_name="",
+ expression=variable,
+ )
+ for variable in carried_variables
+ ],
+ [],
+ ),
+ [],
+ False,
+ )
+ return self._translate_supported_with(
+ [*before_optional, synthetic_with, *query_pattern[optional_index:]]
+ )
+
+ def _translate_return_pattern_predicate_cte(
+ self,
+ match_parts: List[str],
+ where_expressions: List[CompareExpression],
+ return_body: ReturnBody,
+ distinct: bool,
+ return_pattern_predicates: Dict[int, Tuple[bool, str]],
+ ) -> str:
+ if self._has_aggregate(return_body):
+ raise ValueError("Aggregate RETURN pattern predicates are not supported.")
+ outer_variables = set(self._var_kinds)
+ parsed_predicates = {}
+ correlated_variables = set()
+ for index, (negated, pattern) in return_pattern_predicates.items():
+ path = self._parse_pattern_predicate_path(pattern)
+ variables = self._path_pattern_variables(path) & outer_variables
+ if not variables:
+ raise ValueError("RETURN pattern predicate requires outer correlation.")
+ correlated_variables.update(variables)
+ parsed_predicates[index] = (negated, path, variables)
+
+ visible_items = [
+ item
+ for index, item in enumerate(return_body.return_item_list)
+ if index not in return_pattern_predicates
+ ]
+ base_return_body = self._return_body_with_join_variables(
+ ReturnBody(
+ visible_items, return_body.sort_item_list, return_body.skip, return_body.limit
+ ),
+ sorted(correlated_variables),
+ )
+ graph_table_parts = [
+ OracleNameSanitizer.quote(self.graph_name, fallback="GRAPH"),
+ "MATCH " + ", ".join(match_parts),
+ ]
+ where_parts = self._where_parts(where_expressions)
+ if where_parts:
+ graph_table_parts.append("WHERE " + " AND ".join(where_parts))
+ graph_table_parts.append(f"COLUMNS ({self._translate_columns(base_return_body)})")
+ base_query = f"SELECT *\nFROM GRAPH_TABLE (\n {' '.join(graph_table_parts)}\n) gt"
+
+ select_items = []
+ for index, item in enumerate(return_body.return_item_list):
+ alias = OracleNameSanitizer.alias(
+ item.alias
+ or item.property
+ or self._default_expression_alias(item.symbolic_name, item.expression)
+ )
+ if index in parsed_predicates:
+ negated, path, variables = parsed_predicates[index]
+ predicate = self._pattern_predicate_sql(negated, path, variables)
+ select_items.append(f"CASE WHEN {predicate} THEN 1 ELSE 0 END AS {alias}")
+ else:
+ select_items.append(OracleNameSanitizer.alias(alias))
+ keyword = "SELECT DISTINCT" if distinct else "SELECT"
+ query = (
+ "WITH base AS (\n"
+ f"{self._indent_sql(base_query)}\n"
+ ")\n"
+ f"{keyword} " + ", ".join(select_items) + "\n"
+ "FROM base"
+ )
+ query += self._outer_order_and_paging_for_with(return_body, include_group_by=False)
+ return query
+
+ def _translate_pattern_predicate_cte(
+ self,
+ match_parts: List[str],
+ where_expressions: List[CompareExpression],
+ return_body: ReturnBody | None,
+ distinct: bool,
+ pattern_predicates: List[Tuple[bool, str]],
+ ) -> str:
+ outer_variables = set(self._var_kinds)
+ correlated_variables = set()
+ parsed_predicates = []
+ for negated, pattern in pattern_predicates:
+ path = self._parse_pattern_predicate_path(pattern)
+ variables = self._path_pattern_variables(path) & outer_variables
+ if not variables:
+ raise ValueError("Pattern predicate requires correlation to the outer MATCH.")
+ correlated_variables.update(variables)
+ parsed_predicates.append((negated, path, variables))
+
+ aggregate_query = self._has_aggregate(return_body)
+ base_return_body = self._return_body_with_join_variables(
+ return_body or ReturnBody([], []),
+ sorted(correlated_variables),
+ )
+ graph_table_parts = [
+ OracleNameSanitizer.quote(self.graph_name, fallback="GRAPH"),
+ "MATCH " + ", ".join(match_parts),
+ ]
+ where_parts = self._where_parts(where_expressions)
+ if where_parts:
+ graph_table_parts.append("WHERE " + " AND ".join(where_parts))
+ graph_table_parts.append(
+ f"COLUMNS ({self._translate_columns(base_return_body, aggregate_query)})"
+ )
+ base_query = f"SELECT *\nFROM GRAPH_TABLE (\n {' '.join(graph_table_parts)}\n) gt"
+
+ final_select = (
+ self._outer_select(return_body, distinct, aggregate_query)
+ if aggregate_query
+ else self._visible_outer_select(return_body, distinct)
+ )
+ query = f"WITH base AS (\n{self._indent_sql(base_query)}\n)\n{final_select}\nFROM base"
+ filters = [
+ self._pattern_predicate_sql(negated, path, variables)
+ for negated, path, variables in parsed_predicates
+ ]
+ if filters:
+ query += "\nWHERE " + " AND ".join(filters)
+ if return_body is not None:
+ query += self._outer_group_order_and_paging(return_body, aggregate_query)
+ return query
+
+ def _visible_outer_select(
+ self,
+ return_body: ReturnBody | None,
+ distinct: bool,
+ ) -> str:
+ if return_body is None or not return_body.return_item_list:
+ return "SELECT DISTINCT *" if distinct else "SELECT *"
+ keyword = "SELECT DISTINCT" if distinct else "SELECT"
+ return f"{keyword} " + ", ".join(
+ OracleNameSanitizer.alias(alias) for alias in self._resolved_return_aliases(return_body)
+ )
+
+ def _pattern_predicate_sql(
+ self,
+ negated: bool,
+ path: PathPattern,
+ variables: set[str],
+ ) -> str:
+ translator = OracleSqlPgqQueryTranslator(
+ graph_name=self.graph_name,
+ node_label_map=self.node_label_map,
+ edge_label_map=self.edge_label_map,
+ property_type_map=self.property_type_map,
+ node_primary_key_map=self.node_primary_key_map,
+ edge_primary_key_map=self.edge_primary_key_map,
+ strict_property_validation=self.strict_property_validation,
+ )
+ match_part = translator._translate_path_pattern(path)
+ inline_where = ""
+ if translator._pattern_where_expressions:
+ inline_where = "WHERE " + " AND ".join(translator._pattern_where_expressions) + " "
+ return_body = ReturnBody(
+ [
+ ReturnItem(
+ symbolic_name=variable,
+ property="",
+ alias=translator._element_projection_alias(variable),
+ function_name="",
+ expression=variable,
+ )
+ for variable in sorted(variables)
+ ],
+ [],
+ )
+ subquery = (
+ "SELECT 1\n"
+ "FROM GRAPH_TABLE (\n"
+ f" {OracleNameSanitizer.quote(self.graph_name, fallback='GRAPH')} "
+ f"MATCH {match_part} "
+ f"{inline_where}"
+ f"COLUMNS ({translator._translate_columns(return_body)})\n"
+ ") pp\n"
+ "WHERE "
+ + " AND ".join(
+ f"pp.{translator._element_projection_alias(variable)} = "
+ f"base.{self._element_projection_alias(variable)}"
+ for variable in sorted(variables)
+ )
+ )
+ operator = "NOT EXISTS" if negated else "EXISTS"
+ return f"{operator} (\n{self._indent_sql(subquery)}\n)"
+
+ def _pattern_predicate_relation_sql(
+ self,
+ path: PathPattern,
+ variables: set[str],
+ ) -> str:
+ translator = OracleSqlPgqQueryTranslator(
+ graph_name=self.graph_name,
+ node_label_map=self.node_label_map,
+ edge_label_map=self.edge_label_map,
+ property_type_map=self.property_type_map,
+ node_primary_key_map=self.node_primary_key_map,
+ edge_primary_key_map=self.edge_primary_key_map,
+ strict_property_validation=self.strict_property_validation,
+ )
+ match_part = translator._translate_path_pattern(path)
+ inline_where = ""
+ if translator._pattern_where_expressions:
+ inline_where = "WHERE " + " AND ".join(translator._pattern_where_expressions) + " "
+ return_body = ReturnBody(
+ [
+ ReturnItem(
+ symbolic_name=variable,
+ property="",
+ alias=translator._element_projection_alias(variable),
+ function_name="",
+ expression=variable,
+ )
+ for variable in sorted(variables)
+ ],
+ [],
+ )
+ return (
+ "SELECT DISTINCT *\n"
+ "FROM GRAPH_TABLE (\n"
+ f" {OracleNameSanitizer.quote(self.graph_name, fallback='GRAPH')} "
+ f"MATCH {match_part} "
+ f"{inline_where}"
+ f"COLUMNS ({translator._translate_columns(return_body)})\n"
+ ") pp"
+ )
+
+ def _parse_pattern_predicate_path(self, pattern: str) -> PathPattern:
+ text = self._unwrap_pattern_predicate_path(pattern)
+ index = 0
+ nodes: List[NodePattern] = []
+ edges: List[EdgePattern] = []
+ node, index = self._parse_pattern_predicate_node(text, index)
+ nodes.append(node)
+ while index < len(text):
+ edge, index = self._parse_pattern_predicate_edge(text, index)
+ edges.append(edge)
+ node, index = self._parse_pattern_predicate_node(text, index)
+ nodes.append(node)
+ return PathPattern(nodes, edges, "")
+
+ def _unwrap_pattern_predicate_path(self, pattern: str) -> str:
+ text = str(pattern or "").strip()
+ if not (text.startswith("((") and text.endswith("))")):
+ return text
+ inner = text[1:-1].strip()
+ if re.match(
+ r"^\([^()]*\)\s*(?:<-\[[^\]]*\]-|-\[[^\]]*\]->|-\[[^\]]*\]-)",
+ inner,
+ flags=re.DOTALL,
+ ):
+ return inner
+ return text
+
+ def _parse_pattern_predicate_node(
+ self,
+ text: str,
+ index: int,
+ ) -> Tuple[NodePattern, int]:
+ index = self._skip_spaces(text, index)
+ if index >= len(text) or text[index] != "(":
+ raise ValueError("Pattern predicate expected a node pattern.")
+ end = text.find(")", index)
+ if end == -1:
+ raise ValueError("Pattern predicate node is missing a closing parenthesis.")
+ body = text[index + 1 : end].strip()
+ variable, label, property_maps = self._parse_pattern_predicate_body(body)
+ return (
+ NodePattern(
+ variable,
+ label,
+ property_maps,
+ ),
+ self._skip_spaces(text, end + 1),
+ )
+
+ def _parse_pattern_predicate_edge(
+ self,
+ text: str,
+ index: int,
+ ) -> Tuple[EdgePattern, int]:
+ index = self._skip_spaces(text, index)
+ direction = "both"
+ if text.startswith("<-[", index):
+ direction = "left"
+ body_start = index + 3
+ close = text.find("]-", body_start)
+ next_index = close + 2
+ elif text.startswith("-[", index):
+ body_start = index + 2
+ close = text.find("]", body_start)
+ if close == -1:
+ raise ValueError("Pattern predicate edge is missing a closing bracket.")
+ if text.startswith("->", close + 1):
+ direction = "right"
+ next_index = close + 3
+ elif text.startswith("-", close + 1):
+ next_index = close + 2
+ else:
+ raise ValueError("Pattern predicate edge is missing a direction suffix.")
+ else:
+ raise ValueError("Pattern predicate expected an edge pattern.")
+ if close == -1:
+ raise ValueError("Pattern predicate edge is missing a closing bracket.")
+ body = text[body_start:close].strip()
+ variable, label, property_maps = self._parse_pattern_predicate_body(body)
+ return (
+ EdgePattern(
+ variable,
+ label,
+ property_maps,
+ direction,
+ (-1, -1),
+ ),
+ self._skip_spaces(text, next_index),
+ )
+
+ def _parse_pattern_predicate_body(
+ self,
+ body: str,
+ ) -> Tuple[str, str, List[Tuple[str, str]]]:
+ body, property_maps = self._extract_pattern_predicate_property_map(body)
+ variable = ""
+ label = ""
+ if ":" in body:
+ variable, label = body.split(":", 1)
+ else:
+ variable = body
+ return (
+ self._clean_pattern_identifier(variable),
+ self._clean_pattern_identifier(label),
+ property_maps,
+ )
+
+ def _extract_pattern_predicate_property_map(
+ self,
+ body: str,
+ ) -> Tuple[str, List[Tuple[str, str]]]:
+ start = str(body or "").find("{")
+ if start == -1:
+ return body, []
+ end = self._matching_brace_index(body, start)
+ if end == -1:
+ return body, []
+ map_body = body[start + 1 : end]
+ property_maps = []
+ for item in self._split_top_level_commas(map_body):
+ if ":" not in item:
+ continue
+ property_name, property_value = item.split(":", 1)
+ property_maps.append(
+ (
+ self._clean_pattern_identifier(property_name),
+ property_value.strip(),
+ )
+ )
+ return (body[:start] + body[end + 1 :]).strip(), property_maps
+
+ def _matching_brace_index(self, text: str, start: int) -> int:
+ depth = 0
+ in_single = False
+ in_double = False
+ index = start
+ while index < len(text):
+ char = text[index]
+ if char == "'" and not in_double:
+ in_single = not in_single
+ index += 1
+ continue
+ if char == '"' and not in_single:
+ in_double = not in_double
+ index += 1
+ continue
+ if in_single or in_double:
+ index += 1
+ continue
+ if char == "{":
+ depth += 1
+ elif char == "}":
+ depth -= 1
+ if depth == 0:
+ return index
+ index += 1
+ return -1
+
+ def _split_top_level_commas(self, text: str) -> List[str]:
+ parts: List[str] = []
+ start = 0
+ depth = 0
+ in_single = False
+ in_double = False
+ index = 0
+ while index < len(text):
+ char = text[index]
+ if char == "'" and not in_double:
+ in_single = not in_single
+ index += 1
+ continue
+ if char == '"' and not in_single:
+ in_double = not in_double
+ index += 1
+ continue
+ if in_single or in_double:
+ index += 1
+ continue
+ if char in "({[":
+ depth += 1
+ elif char in ")}]":
+ depth = max(depth - 1, 0)
+ elif char == "," and depth == 0:
+ parts.append(text[start:index].strip())
+ start = index + 1
+ index += 1
+ tail = text[start:].strip()
+ if tail:
+ parts.append(tail)
+ return parts
+
+ def _path_pattern_variables(self, path: PathPattern) -> set[str]:
+ variables = set()
+ for node in path.node_pattern_list:
+ if node.symbolic_name:
+ variables.add(node.symbolic_name)
+ for edge in path.edge_pattern_list:
+ if edge.symbolic_name:
+ variables.add(edge.symbolic_name)
+ return variables
+
+ def _clean_pattern_identifier(self, value: str) -> str:
+ return str(value or "").strip().strip("`").strip('"')
+
+ def _skip_spaces(self, text: str, index: int) -> int:
+ while index < len(text) and text[index].isspace():
+ index += 1
+ return index
+
+ def _translate_supported_with(self, query_pattern: List[Clause]) -> str:
+ with_indexes = [
+ index for index, clause in enumerate(query_pattern) if isinstance(clause, WithClause)
+ ]
+ if len(with_indexes) == 2:
+ if any(
+ isinstance(clause, MatchClause)
+ for clause in query_pattern[with_indexes[0] + 1 : with_indexes[1]]
+ ):
+ return self._translate_with_match_then_with_cte(
+ query_pattern,
+ with_indexes,
+ )
+ return self._translate_two_stage_with_cte(query_pattern, with_indexes)
+ if len(with_indexes) != 1:
+ raise ValueError("Only one-stage WITH pipelines are supported.")
+
+ with_index = with_indexes[0]
+ before_with = query_pattern[:with_index]
+ after_with = query_pattern[with_index + 1 :]
+ with_clause = query_pattern[with_index]
+ assert isinstance(with_clause, WithClause)
+
+ if any(isinstance(clause, (MatchClause, WhereClause)) for clause in after_with):
+ return self._translate_with_match_cte(before_with, with_clause, after_with)
+ return_clauses = [clause for clause in after_with if isinstance(clause, ReturnClause)]
+ if len(return_clauses) != 1:
+ raise ValueError("WITH pipeline requires one final RETURN clause.")
+
+ match_clauses = [clause for clause in before_with if isinstance(clause, MatchClause)]
+ where_expressions: List[CompareExpression] = []
+ for clause in before_with:
+ if isinstance(clause, WhereClause):
+ where_expressions.extend(self._as_compare_list(clause.compare_expression_list))
+ if not match_clauses:
+ raise ValueError("WITH translation requires a preceding MATCH.")
+
+ if (
+ self._has_aggregate(with_clause.return_body)
+ or any(item.function_name for item in with_clause.return_body.return_item_list)
+ or with_clause.return_body.sort_item_list
+ or with_clause.return_body.skip != -1
+ or with_clause.return_body.limit != -1
+ or with_clause.compare_expression_list
+ or self._with_passthrough_variable_names(with_clause.return_body)
+ ):
+ return self._translate_with_cte(
+ match_clauses,
+ where_expressions,
+ with_clause,
+ return_clauses[0],
+ )
+
+ self._reset()
+ match_parts = [self._translate_match_clause(clause) for clause in match_clauses]
+ graph_table_parts = [
+ OracleNameSanitizer.quote(self.graph_name, fallback="GRAPH"),
+ "MATCH " + ", ".join(match_parts),
+ ]
+ where_parts = self._where_parts(where_expressions)
+ if where_parts:
+ graph_table_parts.append("WHERE " + " AND ".join(where_parts))
+ graph_table_parts.append(
+ f"COLUMNS ({self._translate_with_projection_columns(with_clause)})"
+ )
+
+ return_clause = return_clauses[0]
+ select_keyword = self._outer_select_for_with(
+ return_clause.return_body,
+ return_clause.distinct,
+ )
+ query = f"{select_keyword}\nFROM GRAPH_TABLE (\n {' '.join(graph_table_parts)}\n) gt"
+ query += self._outer_order_and_paging_for_with(return_clause.return_body)
+ return query
+
+ def _translate_with_cte(
+ self,
+ match_clauses: List[MatchClause],
+ where_expressions: List[CompareExpression],
+ with_clause: WithClause,
+ return_clause: ReturnClause,
+ ) -> str:
+ self._reset()
+ match_parts = [self._translate_match_clause(clause) for clause in match_clauses]
+ pattern_predicates, where_expressions = self._extract_pattern_predicates(where_expressions)
+ if pattern_predicates:
+ return self._translate_with_cte_pattern_predicates(
+ match_parts,
+ where_expressions,
+ with_clause,
+ return_clause,
+ pattern_predicates,
+ )
+ graph_table_parts = [
+ OracleNameSanitizer.quote(self.graph_name, fallback="GRAPH"),
+ "MATCH " + ", ".join(match_parts),
+ ]
+ where_parts = self._where_parts(where_expressions)
+ if where_parts:
+ graph_table_parts.append("WHERE " + " AND ".join(where_parts))
+
+ stage_return_body = self._with_stage_return_body(
+ with_clause.return_body,
+ return_clause.return_body,
+ with_clause,
+ )
+ aggregate_query = self._has_aggregate(stage_return_body)
+ graph_table_parts.append(
+ f"COLUMNS ({self._translate_columns(stage_return_body, aggregate_query)})"
+ )
+ stage_select = self._outer_select(
+ stage_return_body,
+ with_clause.distinct,
+ aggregate_query,
+ )
+ stage_query = f"{stage_select}\nFROM GRAPH_TABLE (\n {' '.join(graph_table_parts)}\n) gt"
+ stage_query += self._outer_group_order_and_paging(stage_return_body, aggregate_query)
+
+ carried_variables = self._with_carried_variables(with_clause.return_body)
+ final_select = self._outer_select_for_with_stage(
+ return_clause.return_body,
+ return_clause.distinct,
+ carried_variables,
+ )
+ query = (
+ f"WITH stage_1 AS (\n{self._indent_sql(stage_query)}\n)\n{final_select}\nFROM stage_1"
+ )
+ filters = self._with_filters(with_clause, carried_variables)
+ if filters:
+ query += "\nWHERE " + " AND ".join(filters)
+ query += self._outer_group_order_and_paging_for_with_stage(
+ return_clause.return_body,
+ carried_variables,
+ )
+ return query
+
+ def _translate_with_cte_pattern_predicates(
+ self,
+ match_parts: List[str],
+ where_expressions: List[CompareExpression],
+ with_clause: WithClause,
+ return_clause: ReturnClause,
+ pattern_predicates: List[Tuple[bool, str]],
+ ) -> str:
+ outer_variables = set(self._var_kinds)
+ correlated_variables = set()
+ parsed_predicates = []
+ for negated, pattern in pattern_predicates:
+ path = self._parse_pattern_predicate_path(pattern)
+ variables = self._path_pattern_variables(path) & outer_variables
+ if not variables:
+ raise ValueError("Pattern predicate requires correlation to the outer MATCH.")
+ correlated_variables.update(variables)
+ parsed_predicates.append((negated, path, variables))
+
+ graph_table_parts = [
+ OracleNameSanitizer.quote(self.graph_name, fallback="GRAPH"),
+ "MATCH " + ", ".join(match_parts),
+ ]
+ where_parts = self._where_parts(where_expressions)
+ if where_parts:
+ graph_table_parts.append("WHERE " + " AND ".join(where_parts))
+
+ stage_return_body = self._with_stage_return_body(
+ with_clause.return_body,
+ return_clause.return_body,
+ with_clause,
+ )
+ aggregate_query = self._has_aggregate(stage_return_body)
+ base_return_body = self._return_body_with_join_variables(
+ stage_return_body,
+ sorted(correlated_variables),
+ )
+ graph_table_parts.append(
+ f"COLUMNS ({self._translate_columns(base_return_body, aggregate_query)})"
+ )
+ base_query = f"SELECT *\nFROM GRAPH_TABLE (\n {' '.join(graph_table_parts)}\n) gt"
+
+ predicate_ctes: List[Tuple[str, str, set[str]]] = []
+ if parsed_predicates and all(
+ not negated for negated, _path, _variables in parsed_predicates
+ ):
+ predicate_ctes = [
+ (
+ f"predicate_{index}",
+ self._pattern_predicate_relation_sql(path, variables),
+ variables,
+ )
+ for index, (_negated, path, variables) in enumerate(parsed_predicates, start=1)
+ ]
+
+ stage_query = (
+ f"{self._outer_select(stage_return_body, with_clause.distinct, aggregate_query)}\n"
+ "FROM base"
+ )
+ if predicate_ctes:
+ for cte_name, _cte_sql, variables in predicate_ctes:
+ join_conditions = [
+ f"{cte_name}.{self._element_projection_alias(variable)} = "
+ f"base.{self._element_projection_alias(variable)}"
+ for variable in sorted(variables)
+ ]
+ stage_query += "\nJOIN " + cte_name + " ON " + " AND ".join(join_conditions)
+ else:
+ filters = [
+ self._pattern_predicate_sql(negated, path, variables)
+ for negated, path, variables in parsed_predicates
+ ]
+ if filters:
+ stage_query += "\nWHERE " + " AND ".join(filters)
+ stage_query += self._outer_group_order_and_paging(stage_return_body, aggregate_query)
+
+ carried_variables = self._with_carried_variables(with_clause.return_body)
+ final_select = self._outer_select_for_with_stage(
+ return_clause.return_body,
+ return_clause.distinct,
+ carried_variables,
+ )
+ cte_parts = [
+ ("base", base_query),
+ *[(name, sql) for name, sql, _variables in predicate_ctes],
+ ]
+ cte_parts.append(("stage_1", stage_query))
+ query = "WITH " + ",\n".join(
+ f"{name} AS (\n{self._indent_sql(sql)}\n)" for name, sql in cte_parts
+ )
+ query += f"\n{final_select}\nFROM stage_1"
+ with_filters = self._with_filters(with_clause, carried_variables)
+ if with_filters:
+ query += "\nWHERE " + " AND ".join(with_filters)
+ query += self._outer_group_order_and_paging_for_with_stage(
+ return_clause.return_body,
+ carried_variables,
+ )
+ return query
+
+ def _translate_two_stage_with_cte(
+ self,
+ query_pattern: List[Clause],
+ with_indexes: List[int],
+ ) -> str:
+ first_with_index, second_with_index = with_indexes
+ first_with = query_pattern[first_with_index]
+ second_with = query_pattern[second_with_index]
+ assert isinstance(first_with, WithClause)
+ assert isinstance(second_with, WithClause)
+
+ before_first_with = query_pattern[:first_with_index]
+ between_withs = query_pattern[first_with_index + 1 : second_with_index]
+ after_second_with = query_pattern[second_with_index + 1 :]
+ if any(isinstance(clause, (MatchClause, ReturnClause)) for clause in between_withs):
+ raise ValueError("Two-stage WITH only supports adjacent SQL stages.")
+ if any(
+ isinstance(clause, (MatchClause, WhereClause, WithClause))
+ for clause in after_second_with
+ ):
+ raise ValueError("Two-stage WITH only supports a final RETURN clause.")
+ return_clauses = [
+ clause for clause in after_second_with if isinstance(clause, ReturnClause)
+ ]
+ if len(return_clauses) != 1:
+ raise ValueError("Two-stage WITH pipeline requires one final RETURN clause.")
+
+ match_clauses = [clause for clause in before_first_with if isinstance(clause, MatchClause)]
+ where_expressions: List[CompareExpression] = []
+ for clause in before_first_with:
+ if isinstance(clause, WhereClause):
+ where_expressions.extend(self._as_compare_list(clause.compare_expression_list))
+ for clause in between_withs:
+ if isinstance(clause, WhereClause):
+ first_with.compare_expression_list = self._as_compare_list(
+ first_with.compare_expression_list
+ ) + self._as_compare_list(clause.compare_expression_list)
+ if not match_clauses:
+ raise ValueError("Two-stage WITH translation requires a preceding MATCH.")
+
+ return_clause = return_clauses[0]
+ second_stage_body = self._sql_stage_return_body(
+ second_with.return_body,
+ return_clause.return_body,
+ )
+ stage_1_query = self._first_with_stage_query(
+ match_clauses,
+ where_expressions,
+ first_with,
+ second_stage_body,
+ )
+ stage_1_aliases = {
+ OracleNameSanitizer.alias(alias)
+ for alias in self._resolved_return_aliases(
+ self._with_stage_return_body(
+ first_with.return_body,
+ second_stage_body,
+ first_with,
+ )
+ )
+ }
+ stage_2_query = self._sql_with_stage_query(
+ second_stage_body,
+ second_with.distinct,
+ "stage_1",
+ stage_1_aliases,
+ )
+ second_filters = self._sql_stage_filters(second_with)
+ if second_filters:
+ stage_2_query = (
+ "SELECT *\n"
+ "FROM (\n"
+ f"{self._indent_sql(stage_2_query)}\n"
+ ") stage_2_raw\n"
+ "WHERE " + " AND ".join(second_filters)
+ )
+
+ stage_2_aliases = set(self._resolved_sql_stage_aliases(second_stage_body))
+ final_select = self._sql_stage_final_select(
+ return_clause.return_body,
+ return_clause.distinct,
+ stage_2_aliases,
+ )
+ query = (
+ "WITH stage_1 AS (\n"
+ f"{self._indent_sql(stage_1_query)}\n"
+ "),\n"
+ "stage_2 AS (\n"
+ f"{self._indent_sql(stage_2_query)}\n"
+ ")\n"
+ f"{final_select}\n"
+ "FROM stage_2"
+ )
+ query += self._sql_stage_final_group_order_and_paging(
+ return_clause.return_body,
+ stage_2_aliases,
+ )
+ return query
+
+ def _translate_with_match_then_with_cte(
+ self,
+ query_pattern: List[Clause],
+ with_indexes: List[int],
+ ) -> str:
+ first_with_index, second_with_index = with_indexes
+ first_with = query_pattern[first_with_index]
+ second_with = query_pattern[second_with_index]
+ assert isinstance(first_with, WithClause)
+ assert isinstance(second_with, WithClause)
+
+ before_first_with = query_pattern[:first_with_index]
+ between_withs = query_pattern[first_with_index + 1 : second_with_index]
+ after_second_with = query_pattern[second_with_index + 1 :]
+ if any(isinstance(clause, WithClause) for clause in between_withs):
+ raise ValueError("Nested WITH stages are not supported.")
+ if any(
+ isinstance(clause, (MatchClause, WhereClause, WithClause))
+ for clause in after_second_with
+ ):
+ raise ValueError("WITH MATCH WITH only supports a final RETURN clause.")
+ return_clauses = [
+ clause for clause in after_second_with if isinstance(clause, ReturnClause)
+ ]
+ if len(return_clauses) != 1:
+ raise ValueError("WITH MATCH WITH pipeline requires one final RETURN clause.")
+ return_clause = return_clauses[0]
+
+ intermediate_body = self._sql_stage_return_body(
+ second_with.return_body,
+ return_clause.return_body,
+ )
+ intermediate_return = ReturnClause(intermediate_body, second_with.distinct)
+ intermediate_query = self._translate_with_match_cte(
+ before_first_with,
+ first_with,
+ [*between_withs, intermediate_return],
+ )
+ available_aliases = set(self._resolved_sql_stage_aliases(intermediate_body))
+ final_select = self._sql_stage_final_select(
+ return_clause.return_body,
+ return_clause.distinct,
+ available_aliases,
+ )
+ cte_prefix, intermediate_select = self._split_with_cte_final_select(intermediate_query)
+ query = (
+ f"{cte_prefix},\n"
+ "stage_3 AS (\n"
+ f"{self._indent_sql(intermediate_select)}\n"
+ ")\n"
+ f"{final_select}\n"
+ "FROM stage_3"
+ )
+ filters = self._sql_stage_filters(second_with)
+ if filters:
+ query += "\nWHERE " + " AND ".join(filters)
+ query += self._sql_stage_final_group_order_and_paging(
+ return_clause.return_body,
+ available_aliases,
+ )
+ return query
+
+ def _split_with_cte_final_select(self, query: str) -> Tuple[str, str]:
+ marker = "\nSELECT "
+ index = query.rfind(marker)
+ if not query.startswith("WITH ") or index == -1:
+ raise ValueError("Expected WITH CTE query with a final SELECT.")
+ return query[:index], query[index + 1 :]
+
+ def _first_with_stage_query(
+ self,
+ match_clauses: List[MatchClause],
+ where_expressions: List[CompareExpression],
+ with_clause: WithClause,
+ next_with_body: ReturnBody,
+ ) -> str:
+ self._reset()
+ match_parts = [self._translate_match_clause(clause) for clause in match_clauses]
+ graph_table_parts = [
+ OracleNameSanitizer.quote(self.graph_name, fallback="GRAPH"),
+ "MATCH " + ", ".join(match_parts),
+ ]
+ where_parts = self._where_parts(where_expressions)
+ if where_parts:
+ graph_table_parts.append("WHERE " + " AND ".join(where_parts))
+
+ stage_return_body = self._with_stage_return_body(
+ with_clause.return_body,
+ next_with_body,
+ with_clause,
+ )
+ aggregate_query = self._has_aggregate(stage_return_body)
+ graph_table_parts.append(
+ f"COLUMNS ({self._translate_columns(stage_return_body, aggregate_query)})"
+ )
+ stage_query = (
+ f"{self._outer_select(stage_return_body, with_clause.distinct, aggregate_query)}\n"
+ f"FROM GRAPH_TABLE (\n {' '.join(graph_table_parts)}\n) gt"
+ )
+ stage_query += self._outer_group_order_and_paging(stage_return_body, aggregate_query)
+
+ carried_variables = self._with_carried_variables(with_clause.return_body)
+ filters = self._with_filters(with_clause, carried_variables)
+ if filters:
+ stage_query = (
+ "SELECT *\n"
+ "FROM (\n"
+ f"{self._indent_sql(stage_query)}\n"
+ ") stage_1_raw\n"
+ "WHERE " + " AND ".join(filters)
+ )
+ return stage_query
+
+ def _sql_with_stage_query(
+ self,
+ return_body: ReturnBody,
+ distinct: bool,
+ from_name: str,
+ available_aliases: set[str],
+ ) -> str:
+ aggregate_query = self._has_aggregate(return_body)
+ select_items = []
+ return_aliases = self._resolved_sql_stage_aliases(return_body)
+ for item, alias in zip(return_body.return_item_list, return_aliases, strict=True):
+ if self._is_complex_aggregate_item(item):
+ expression = self._translate_sql_expression(item.expression)
+ elif self._is_aggregate_item(item):
+ expression = self._sql_stage_aggregate_expression(item, available_aliases)
+ else:
+ expression = self._sql_stage_item_expression(item, available_aliases)
+ select_items.append(f"{expression} AS {OracleNameSanitizer.alias(alias)}")
+ keyword = "SELECT DISTINCT" if distinct else "SELECT"
+ query = f"{keyword} " + ", ".join(select_items) + f"\nFROM {from_name}"
+ if aggregate_query:
+ group_aliases = [
+ self._sql_stage_item_expression(item, available_aliases)
+ for item in return_body.return_item_list
+ if not self._is_aggregate_item(item)
+ ]
+ if group_aliases:
+ query += "\nGROUP BY " + ", ".join(dict.fromkeys(group_aliases))
+ query += self._outer_order_and_paging_for_with(return_body, include_group_by=False)
+ return query
+
+ def _sql_stage_return_body(
+ self,
+ stage_body: ReturnBody,
+ final_body: ReturnBody,
+ ) -> ReturnBody:
+ items = list(stage_body.return_item_list)
+ carried_variables = {
+ item.symbolic_name
+ for item in items
+ if (
+ not item.property
+ and not item.function_name
+ and (not item.expression or item.expression == item.symbolic_name)
+ )
+ }
+ existing_aliases = set(self._resolved_sql_stage_aliases(stage_body))
+ for variable, property_name in self._return_body_property_references(final_body):
+ if variable not in carried_variables:
+ continue
+ property_name = self._canonical_property_name(variable, property_name)
+ alias = self._with_property_stage_alias(variable, property_name)
+ if alias in existing_aliases:
+ continue
+ items.append(
+ ReturnItem(
+ symbolic_name=variable,
+ property=property_name,
+ alias=alias,
+ function_name="",
+ expression=f"{variable}.{property_name}",
+ )
+ )
+ existing_aliases.add(alias)
+ return ReturnBody(
+ items,
+ stage_body.sort_item_list,
+ stage_body.skip,
+ stage_body.limit,
+ )
+
+ def _return_body_property_references(
+ self,
+ return_body: ReturnBody,
+ ) -> List[Tuple[str, str]]:
+ references: List[Tuple[str, str]] = []
+ for item in return_body.return_item_list:
+ if item.property:
+ references.append((item.symbolic_name, item.property))
+ references.extend(self._all_property_references(item.expression))
+ for sort_item in return_body.sort_item_list:
+ if sort_item.property:
+ references.append((sort_item.symbolic_name, sort_item.property))
+ references.extend(self._all_property_references(sort_item.expression))
+ return list(dict.fromkeys(references))
+
+ def _all_property_references(self, expression: str) -> List[Tuple[str, str]]:
+ protected, _ = self._protect_string_literals(expression or "")
+ references: List[Tuple[str, str]] = []
+ for match in re.finditer(
+ r"\b(?P[A-Za-z_][A-Za-z0-9_]*)\."
+ r"(?:\"(?P[^\"]+)\"|"
+ r"(?P[A-Za-z_][A-Za-z0-9_$#-]*)\b)",
+ protected,
+ ):
+ references.append(
+ (
+ match.group("variable"),
+ match.group("quoted_property") or match.group("property"),
+ )
+ )
+ return references
+
+ def _sql_stage_final_select(
+ self,
+ return_body: ReturnBody,
+ distinct: bool,
+ available_aliases: set[str],
+ ) -> str:
+ select_items = []
+ for item, alias in zip(
+ return_body.return_item_list,
+ self._resolved_sql_stage_aliases(return_body),
+ strict=True,
+ ):
+ if self._is_complex_aggregate_item(item):
+ expression = self._translate_sql_expression(item.expression)
+ elif self._is_aggregate_item(item):
+ expression = self._sql_stage_aggregate_expression(item, available_aliases)
+ else:
+ expression = self._sql_stage_item_expression(item, available_aliases)
+ select_items.append(f"{expression} AS {OracleNameSanitizer.alias(alias)}")
+ keyword = "SELECT DISTINCT" if distinct else "SELECT"
+ return f"{keyword} " + ", ".join(select_items)
+
+ def _sql_stage_final_group_order_and_paging(
+ self,
+ return_body: ReturnBody,
+ available_aliases: set[str],
+ ) -> str:
+ suffix = ""
+ if self._has_aggregate(return_body):
+ group_aliases = [
+ self._sql_stage_item_expression(item, available_aliases)
+ for item in return_body.return_item_list
+ if not self._is_aggregate_item(item)
+ ]
+ if group_aliases:
+ suffix += "\nGROUP BY " + ", ".join(dict.fromkeys(group_aliases))
+ if return_body.sort_item_list:
+ suffix += "\nORDER BY " + ", ".join(
+ self._translate_sort_item(item, return_body) for item in return_body.sort_item_list
+ )
+ if return_body.skip != -1:
+ suffix += f"\nOFFSET {return_body.skip} ROWS"
+ if return_body.limit != -1:
+ suffix += f"\nFETCH FIRST {return_body.limit} ROWS ONLY"
+ return suffix
+
+ def _translate_sort_item_for_with_match(
+ self,
+ sort_item: SortItem,
+ return_body: ReturnBody,
+ ) -> str:
+ if sort_item.property and sort_item.symbolic_name:
+ alias = self._with_property_stage_alias(
+ sort_item.symbolic_name,
+ self._canonical_property_name(sort_item.symbolic_name, sort_item.property),
+ )
+ alias = f"stage_2.{OracleNameSanitizer.alias(alias)}"
+ else:
+ alias = OracleNameSanitizer.alias(self._sort_alias(sort_item, return_body))
+ return f"{alias}{self._sql_sort_order(sort_item.order)}"
+
+ def _resolved_sql_stage_aliases(self, return_body: ReturnBody) -> List[str]:
+ aliases = []
+ for item in return_body.return_item_list:
+ alias = (
+ item.alias
+ or item.property
+ or self._default_expression_alias(
+ item.symbolic_name,
+ item.expression or item.symbolic_name,
+ )
+ )
+ aliases.append(OracleNameSanitizer.alias(alias))
+ return aliases
+
+ def _sql_stage_item_expression(
+ self,
+ item: ReturnItem | SortItem,
+ available_aliases: set[str],
+ ) -> str:
+ if item.property:
+ property_alias = self._with_property_stage_alias(
+ item.symbolic_name,
+ self._canonical_property_name(item.symbolic_name, item.property),
+ )
+ if property_alias in available_aliases:
+ return property_alias
+ if item.expression and (
+ item.expression != item.symbolic_name
+ or not re.fullmatch(r"[A-Za-z_][A-Za-z0-9_]*", item.expression)
+ ):
+ return self._sql_stage_expression(item.expression, available_aliases)
+ alias = OracleNameSanitizer.alias(item.alias or item.property or item.symbolic_name)
+ if alias in available_aliases:
+ return alias
+ element_alias = self._element_projection_alias(item.symbolic_name)
+ if element_alias in available_aliases:
+ return element_alias
+ return OracleNameSanitizer.alias(item.symbolic_name)
+
+ def _sql_stage_expression(self, expression: str, available_aliases: set[str]) -> str:
+ translated = self._translate_sql_expression(expression)
+ for variable, property_name in self._property_references(expression):
+ property_name = self._canonical_property_name(variable, property_name)
+ sql_variable = self._var_sql_names.get(variable, variable)
+ property_ref = OracleNameSanitizer.quote(property_name, fallback="PROP")
+ property_alias = self._with_property_stage_alias(variable, property_name)
+ if property_alias not in available_aliases:
+ continue
+ translated = translated.replace(f"{sql_variable}.{property_ref}", property_alias)
+ translated = re.sub(
+ rf"\b{re.escape(variable)}\.{re.escape(property_name)}\b",
+ property_alias,
+ translated,
+ )
+ return translated
+
+ def _sql_stage_aggregate_expression(
+ self,
+ item: ReturnItem | SortItem,
+ available_aliases: set[str],
+ ) -> str:
+ function_name = item.function_name.upper()
+ if function_name == "COUNT" and item.symbolic_name == "*":
+ return "COUNT(*)"
+ symbolic_name = self._strip_distinct_prefix(item.symbolic_name)
+ expression = self._aggregate_argument_expression_text(item, symbolic_name)
+ expression = self._translate_sql_expression(expression)
+ element_alias = self._element_projection_alias(symbolic_name)
+ if element_alias in available_aliases:
+ expression = element_alias
+ distinct = "DISTINCT " if self._has_distinct_prefix(item.symbolic_name) else ""
+ if re.fullmatch(r"[A-Za-z_][A-Za-z0-9_$#]*", expression):
+ expression = OracleNameSanitizer.alias(expression)
+ return self._aggregate_sql_call(function_name, distinct, expression)
+
+ def _sql_stage_filters(self, with_clause: WithClause) -> List[str]:
+ filters = []
+ for expr in self._as_compare_list(with_clause.compare_expression_list):
+ if getattr(expr, "raw_expression", ""):
+ filters.append(self._translate_sql_expression(expr.raw_expression))
+ else:
+ filters.append(self._translate_compare_expression(expr))
+ return filters
+
+ def _translate_with_match_cte(
+ self,
+ before_with: List[Clause],
+ with_clause: WithClause,
+ after_with: List[Clause],
+ ) -> str:
+ first_match_clauses = [clause for clause in before_with if isinstance(clause, MatchClause)]
+ first_where_expressions: List[CompareExpression] = []
+ for clause in before_with:
+ if isinstance(clause, WhereClause):
+ first_where_expressions.extend(
+ self._as_compare_list(clause.compare_expression_list)
+ )
+ second_match_clauses = [clause for clause in after_with if isinstance(clause, MatchClause)]
+ optional_second_stage = any(clause.optional for clause in second_match_clauses)
+ if optional_second_stage and not all(clause.optional for clause in second_match_clauses):
+ raise ValueError("OPTIONAL MATCH cannot be mixed with regular MATCH in one stage.")
+ second_where_expressions: List[CompareExpression] = []
+ return_clauses = []
+ for clause in after_with:
+ if isinstance(clause, WhereClause):
+ second_where_expressions.extend(
+ self._as_compare_list(clause.compare_expression_list)
+ )
+ elif isinstance(clause, ReturnClause):
+ return_clauses.append(clause)
+ elif isinstance(clause, WithClause):
+ raise ValueError("Multiple WITH stages are not supported.")
+ if not first_match_clauses or not second_match_clauses or len(return_clauses) != 1:
+ raise ValueError("WITH MATCH pipeline requires MATCH ... WITH ... MATCH ... RETURN.")
+
+ carried_variables = self._with_passthrough_variable_names(with_clause.return_body)
+ scalar_aliases = self._with_scalar_aliases(with_clause.return_body)
+ self._assign_with_match_property_map_correlation_variables(
+ second_match_clauses,
+ scalar_aliases,
+ )
+ second_declared_variables = self._declared_variables_in_match_clauses(second_match_clauses)
+ (
+ correlations,
+ scalar_correlations,
+ expression_scalar_correlations,
+ stage_expression_correlations,
+ element_correlations,
+ stage_one_filters,
+ residual_second_where,
+ ) = self._with_match_correlations(
+ second_where_expressions,
+ carried_variables,
+ second_declared_variables,
+ scalar_aliases,
+ )
+ property_map_scalar_correlations = self._extract_with_match_property_map_correlations(
+ second_match_clauses,
+ second_declared_variables,
+ scalar_aliases,
+ )
+ scalar_correlations.extend(property_map_scalar_correlations)
+ if (
+ not carried_variables
+ and not scalar_correlations
+ and not expression_scalar_correlations
+ and not stage_expression_correlations
+ ):
+ raise ValueError(
+ "WITH MATCH pipeline requires carried graph variables or scalar correlations."
+ )
+
+ self._reset()
+ first_match_parts = [self._translate_match_clause(clause) for clause in first_match_clauses]
+ first_join_variables = {
+ variable for variable in carried_variables if variable in self._var_kinds
+ }
+ if (
+ not first_join_variables
+ and not scalar_correlations
+ and not expression_scalar_correlations
+ and not stage_expression_correlations
+ ):
+ raise ValueError("WITH MATCH pipeline has no declared carried variables.")
+ first_graph_table_parts = [
+ OracleNameSanitizer.quote(self.graph_name, fallback="GRAPH"),
+ "MATCH " + ", ".join(first_match_parts),
+ ]
+ first_where_parts = self._where_parts(first_where_expressions)
+ if first_where_parts:
+ first_graph_table_parts.append("WHERE " + " AND ".join(first_where_parts))
+ return_clause = return_clauses[0]
+ staged_with = (
+ self._has_aggregate(with_clause.return_body)
+ or bool(with_clause.compare_expression_list)
+ or bool(scalar_correlations)
+ or bool(expression_scalar_correlations)
+ or bool(stage_expression_correlations)
+ )
+ stage_one_aliases: set[str] = set()
+ nonstaged_stage_one_aliases: set[str] = set()
+ first_join_aliases: Dict[str, str] = {}
+ if staged_with:
+ stage_return_body = self._with_stage_return_body(
+ with_clause.return_body,
+ return_clause.return_body,
+ with_clause,
+ )
+ stage_return_body = self._return_body_with_property_projections(
+ stage_return_body,
+ [
+ (first_variable, first_property)
+ for (
+ _second_variable,
+ _second_property,
+ first_variable,
+ first_property,
+ ) in correlations
+ ],
+ )
+ first_join_aliases = {
+ variable: self._element_projection_alias(variable)
+ for variable in sorted(first_join_variables)
+ }
+ stage_one_aliases = {
+ OracleNameSanitizer.alias(alias)
+ for alias in self._resolved_return_aliases(stage_return_body)
+ }
+ aggregate_with_query = self._has_aggregate(stage_return_body)
+ first_graph_table_parts.append(
+ f"COLUMNS ({self._translate_columns(stage_return_body, aggregate_with_query)})"
+ )
+ first_select = self._outer_select(
+ stage_return_body,
+ with_clause.distinct,
+ aggregate_with_query,
+ )
+ first_stage_query = (
+ f"{first_select}\nFROM GRAPH_TABLE (\n {' '.join(first_graph_table_parts)}\n) gt"
+ )
+ first_stage_query += self._outer_group_order_and_paging(
+ stage_return_body,
+ aggregate_with_query,
+ )
+ filters = self._with_filters(with_clause, first_join_variables)
+ if filters:
+ first_stage_query = (
+ "SELECT *\n"
+ "FROM (\n"
+ f"{self._indent_sql(first_stage_query)}\n"
+ ") stage_1_raw\n"
+ "WHERE " + " AND ".join(filters)
+ )
+ else:
+ first_join_aliases = {
+ variable: self._stage_one_join_alias(variable)
+ for variable in sorted(first_join_variables)
+ }
+ first_stage_columns = [
+ (f"{self._element_id_expression(variable)} AS {first_join_aliases[variable]}")
+ for variable in sorted(first_join_variables)
+ ]
+ first_stage_columns.extend(
+ self._stage_one_return_projections(
+ return_clause.return_body,
+ first_join_variables,
+ second_declared_variables,
+ )
+ )
+ first_stage_columns.extend(
+ self._stage_one_with_scalar_projections(with_clause.return_body)
+ )
+ first_stage_columns.extend(self._stage_one_sort_projections(with_clause.return_body))
+ first_stage_columns.extend(
+ self._stage_one_property_projections(
+ [
+ (first_variable, first_property)
+ for (
+ _second_variable,
+ _second_property,
+ first_variable,
+ first_property,
+ ) in correlations
+ ]
+ )
+ )
+ first_stage_columns.extend(
+ self._stage_one_property_projections(
+ self._stage_one_aggregate_property_references(
+ return_clause.return_body,
+ first_join_variables,
+ )
+ )
+ )
+ first_graph_table_parts.append(
+ "COLUMNS (" + ", ".join(dict.fromkeys(first_stage_columns)) + ")"
+ )
+ nonstaged_stage_one_aliases = {
+ OracleNameSanitizer.alias(match.group(1))
+ for projection in first_stage_columns
+ if (
+ match := re.search(
+ r"\bAS\s+([A-Za-z_][A-Za-z0-9_]*)\s*$",
+ projection,
+ flags=re.IGNORECASE,
+ )
+ )
+ }
+ first_select = "SELECT DISTINCT *" if with_clause.distinct else "SELECT *"
+ first_stage_query = (
+ f"{first_select}\nFROM GRAPH_TABLE (\n {' '.join(first_graph_table_parts)}\n) gt"
+ )
+ first_stage_query += self._outer_order_and_paging_for_with(with_clause.return_body)
+
+ self._reset()
+ second_match_parts = [
+ self._translate_match_clause(clause) for clause in second_match_clauses
+ ]
+ join_variables = [
+ variable
+ for variable in sorted(first_join_variables)
+ if variable in second_declared_variables and variable in self._var_kinds
+ ]
+ second_graph_table_parts = [
+ OracleNameSanitizer.quote(self.graph_name, fallback="GRAPH"),
+ "MATCH " + ", ".join(second_match_parts),
+ ]
+ second_where_parts = self._where_parts(residual_second_where)
+ if second_where_parts:
+ second_graph_table_parts.append("WHERE " + " AND ".join(second_where_parts))
+
+ aggregate_query = self._has_aggregate(return_clause.return_body)
+ if not stage_one_aliases:
+ stage_one_aliases = nonstaged_stage_one_aliases
+ stage_one_filter_sql = [
+ self._with_match_stage_one_filter_sql(filter_expression, stage_one_aliases)
+ for filter_expression in stage_one_filters
+ ]
+ cross_join = (
+ not join_variables
+ and not correlations
+ and not scalar_correlations
+ and not expression_scalar_correlations
+ and not stage_expression_correlations
+ and not element_correlations
+ and self._with_match_allows_cross_join(
+ return_clause.return_body,
+ aggregate_query,
+ second_declared_variables,
+ )
+ )
+ if (
+ not join_variables
+ and not correlations
+ and not scalar_correlations
+ and not expression_scalar_correlations
+ and not stage_expression_correlations
+ and not element_correlations
+ and not cross_join
+ ):
+ raise ValueError("WITH MATCH pipeline has no carried variables in second MATCH.")
+ if optional_second_stage and cross_join:
+ raise ValueError("OPTIONAL MATCH requires correlation to prior bindings.")
+ second_return_body = self._return_body_for_second_match_stage(
+ return_clause.return_body,
+ join_variables,
+ second_declared_variables,
+ first_join_variables,
+ stage_one_aliases,
+ correlations,
+ scalar_correlations,
+ expression_scalar_correlations,
+ stage_expression_correlations,
+ element_correlations,
+ )
+ if cross_join and not second_return_body.return_item_list:
+ second_return_body = ReturnBody(
+ [
+ ReturnItem(
+ symbolic_name="",
+ property="",
+ alias="dummy_value",
+ function_name="",
+ expression="1",
+ )
+ ],
+ [],
+ )
+ second_graph_table_parts.append(
+ f"COLUMNS ({self._translate_columns(second_return_body, aggregate_query)})"
+ )
+ second_stage_query = (
+ f"SELECT *\nFROM GRAPH_TABLE (\n {' '.join(second_graph_table_parts)}\n) gt"
+ )
+
+ join_conditions = [
+ f"stage_2.{self._element_projection_alias(variable)} = "
+ f"stage_1.{first_join_aliases[variable]}"
+ for variable in join_variables
+ ]
+ join_conditions.extend(
+ f"stage_2.{self._with_property_stage_alias(second_variable, second_property)} = "
+ f"stage_1.{self._with_property_stage_alias(first_variable, first_property)}"
+ for second_variable, second_property, first_variable, first_property in correlations
+ )
+ join_conditions.extend(
+ f"stage_2.{self._with_property_stage_alias(second_variable, second_property)} "
+ f"{operator} stage_1.{stage_alias}"
+ for second_variable, second_property, operator, stage_alias in scalar_correlations
+ )
+ join_conditions.extend(
+ f"stage_2.{expression_alias} {operator} stage_1.{stage_alias}"
+ for (
+ _expression,
+ expression_alias,
+ operator,
+ stage_alias,
+ ) in expression_scalar_correlations
+ )
+ join_conditions.extend(
+ f"stage_2.{self._with_property_stage_alias(second_variable, second_property)} "
+ f"{operator} "
+ f"{self._with_match_stage_expression_sql(stage_expression, stage_one_aliases)}"
+ for (
+ second_variable,
+ second_property,
+ operator,
+ stage_expression,
+ ) in stage_expression_correlations
+ )
+ join_conditions.extend(
+ f"stage_2.{self._element_projection_alias(second_variable)} "
+ f"{operator} stage_1.{first_join_aliases[first_variable]}"
+ for second_variable, operator, first_variable in element_correlations
+ )
+ final_select = self._outer_select_for_with_match(
+ return_clause.return_body,
+ return_clause.distinct,
+ aggregate_query,
+ stage_one_aliases,
+ first_join_variables,
+ second_return_body,
+ )
+ query = (
+ "WITH stage_1 AS (\n"
+ f"{self._indent_sql(first_stage_query)}\n"
+ "),\n"
+ "stage_2 AS (\n"
+ f"{self._indent_sql(second_stage_query)}\n"
+ ")\n"
+ f"{final_select}\n"
+ + (
+ "FROM stage_1\nLEFT JOIN stage_2 ON " + " AND ".join(join_conditions)
+ if optional_second_stage
+ else "FROM stage_2\n"
+ + (
+ "CROSS JOIN stage_1"
+ if cross_join
+ else "JOIN stage_1 ON " + " AND ".join(join_conditions)
+ )
+ )
+ )
+ if stage_one_filter_sql:
+ query += "\nWHERE " + " AND ".join(stage_one_filter_sql)
+ query += self._outer_group_order_and_paging_for_with_match(
+ return_clause.return_body,
+ aggregate_query,
+ stage_one_aliases,
+ first_join_variables,
+ second_return_body,
+ )
+ return query
+
+ def _with_match_allows_cross_join(
+ self,
+ return_body: ReturnBody,
+ aggregate_query: bool,
+ second_declared_variables: set[str],
+ ) -> bool:
+ for item in return_body.return_item_list:
+ if aggregate_query and not self._is_aggregate_item(item):
+ return False
+ symbolic_name = self._strip_distinct_prefix(item.symbolic_name)
+ if symbolic_name in second_declared_variables:
+ return False
+ if any(
+ variable in second_declared_variables
+ for variable, _property_name in self._property_references(item.expression)
+ ):
+ return False
+ for sort_item in return_body.sort_item_list:
+ symbolic_name = self._strip_distinct_prefix(sort_item.symbolic_name)
+ if symbolic_name in second_declared_variables:
+ return False
+ if any(
+ variable in second_declared_variables
+ for variable, _property_name in self._property_references(sort_item.expression)
+ ):
+ return False
+ return True
+
+ def _outer_group_order_and_paging_for_with_match(
+ self,
+ return_body: ReturnBody,
+ aggregate_query: bool,
+ stage_one_aliases: set[str],
+ first_join_variables: set[str],
+ second_return_body: ReturnBody | None = None,
+ ) -> str:
+ suffix = ""
+ if aggregate_query:
+ group_aliases = []
+ return_aliases = self._resolved_return_aliases(return_body)
+ stage_two_aliases = self._stage_two_aliases_by_return_item(second_return_body)
+ for item, resolved_alias in zip(
+ return_body.return_item_list,
+ return_aliases,
+ strict=True,
+ ):
+ if self._is_aggregate_item(item):
+ continue
+ stage_alias = self._stage_alias_for_return_item(
+ item,
+ first_join_variables,
+ stage_one_aliases,
+ )
+ if stage_alias in stage_one_aliases:
+ group_aliases.append(f"stage_1.{stage_alias}")
+ elif OracleNameSanitizer.alias(resolved_alias) in stage_one_aliases:
+ group_aliases.append(f"stage_1.{OracleNameSanitizer.alias(resolved_alias)}")
+ else:
+ stage_two_alias = stage_two_aliases.get(
+ id(item),
+ OracleNameSanitizer.alias(resolved_alias),
+ )
+ group_aliases.append(f"stage_2.{stage_two_alias}")
+ for item in return_body.return_item_list:
+ if not self._is_complex_aggregate_item(item):
+ continue
+ for variable, property_name in self._property_references(item.expression):
+ property_name = self._canonical_property_name(variable, property_name)
+ property_alias = self._with_property_stage_alias(variable, property_name)
+ if property_alias in stage_one_aliases:
+ group_aliases.append(f"stage_1.{property_alias}")
+ if group_aliases:
+ suffix += "\nGROUP BY " + ", ".join(dict.fromkeys(group_aliases))
+ if return_body.sort_item_list:
+ suffix += "\nORDER BY " + ", ".join(
+ self._translate_sort_item_for_with_match(item, return_body)
+ for item in return_body.sort_item_list
+ )
+ if return_body.skip != -1:
+ suffix += f"\nOFFSET {return_body.skip} ROWS"
+ if return_body.limit != -1:
+ suffix += f"\nFETCH FIRST {return_body.limit} ROWS ONLY"
+ return suffix
+
+ def _return_body_with_join_variables(
+ self,
+ return_body: ReturnBody,
+ join_variables: List[str],
+ ) -> ReturnBody:
+ items = list(return_body.return_item_list)
+ existing_aliases = {
+ OracleNameSanitizer.alias(self._return_alias(item, self._return_expression(item)))
+ for item in items
+ }
+ for variable in join_variables:
+ alias = self._element_projection_alias(variable)
+ if alias in existing_aliases:
+ continue
+ items.append(
+ ReturnItem(
+ symbolic_name=variable,
+ property="",
+ alias=alias,
+ function_name="",
+ expression=variable,
+ )
+ )
+ existing_aliases.add(alias)
+ return ReturnBody(
+ items,
+ return_body.sort_item_list,
+ return_body.skip,
+ return_body.limit,
+ )
+
+ def _return_body_for_second_match_stage(
+ self,
+ return_body: ReturnBody,
+ join_variables: List[str],
+ second_declared_variables: set[str],
+ first_join_variables: set[str],
+ stage_one_aliases: set[str],
+ correlations: List[Tuple[str, str, str, str]] | None = None,
+ scalar_correlations: List[Tuple[str, str, str, str]] | None = None,
+ expression_scalar_correlations: List[Tuple[str, str, str, str]] | None = None,
+ stage_expression_correlations: List[Tuple[str, str, str, str]] | None = None,
+ element_correlations: List[Tuple[str, str, str]] | None = None,
+ ) -> ReturnBody:
+ correlations = correlations or []
+ scalar_correlations = scalar_correlations or []
+ expression_scalar_correlations = expression_scalar_correlations or []
+ stage_expression_correlations = stage_expression_correlations or []
+ element_correlations = element_correlations or []
+ items = [
+ item
+ for item in return_body.return_item_list
+ if self._return_item_can_project_from_second_match(
+ item,
+ second_declared_variables,
+ first_join_variables,
+ stage_one_aliases,
+ )
+ ]
+ sort_items = [
+ item
+ for item in return_body.sort_item_list
+ if self._return_item_can_project_from_second_match(
+ item,
+ second_declared_variables,
+ first_join_variables,
+ stage_one_aliases,
+ )
+ ]
+ stage_body = ReturnBody(
+ items,
+ sort_items,
+ return_body.skip,
+ return_body.limit,
+ )
+ stage_body = self._return_body_with_join_variables(stage_body, join_variables)
+ aggregate_element_references = []
+ for item in return_body.return_item_list:
+ aggregate_element_references.extend(
+ variable
+ for variable in self._aggregate_element_references(item.expression)
+ if variable in second_declared_variables
+ )
+ for sort_item in return_body.sort_item_list:
+ aggregate_element_references.extend(
+ variable
+ for variable in self._aggregate_element_references(sort_item.expression)
+ if variable in second_declared_variables
+ )
+ stage_body = self._return_body_with_join_variables(
+ stage_body,
+ list(dict.fromkeys(aggregate_element_references)),
+ )
+ stage_body = self._return_body_with_join_variables(
+ stage_body,
+ [
+ second_variable
+ for second_variable, _operator, _first_variable in element_correlations
+ ],
+ )
+ stage_body = self._return_body_with_property_projections(
+ stage_body,
+ [
+ (second_variable, second_property)
+ for (
+ second_variable,
+ second_property,
+ _first_variable,
+ _first_property,
+ ) in correlations
+ ]
+ + [
+ (second_variable, second_property)
+ for second_variable, second_property, _operator, _stage_alias in scalar_correlations
+ ]
+ + [
+ (second_variable, second_property)
+ for (
+ second_variable,
+ second_property,
+ _operator,
+ _stage_expression,
+ ) in stage_expression_correlations
+ ]
+ + [
+ (item.symbolic_name, item.property)
+ for item in return_body.return_item_list
+ if item.symbolic_name in second_declared_variables and item.property
+ ]
+ + [
+ (variable, property_name)
+ for item in return_body.return_item_list
+ for variable, property_name in self._property_references(item.expression)
+ if variable in second_declared_variables
+ ]
+ + [
+ (sort_item.symbolic_name, sort_item.property)
+ for sort_item in return_body.sort_item_list
+ if sort_item.symbolic_name in second_declared_variables and sort_item.property
+ ],
+ )
+ return self._return_body_with_expression_projections(
+ stage_body,
+ [
+ (expression, expression_alias)
+ for (
+ expression,
+ expression_alias,
+ _operator,
+ _stage_alias,
+ ) in expression_scalar_correlations
+ ],
+ )
+
+ def _return_item_can_project_from_second_match(
+ self,
+ item: ReturnItem | SortItem,
+ second_declared_variables: set[str],
+ first_join_variables: set[str],
+ stage_one_aliases: set[str],
+ ) -> bool:
+ if self._stage_alias_for_return_item(
+ item
+ ) in stage_one_aliases and not self._is_aggregate_item(item):
+ return False
+ if (
+ item.expression
+ and not self._is_aggregate_item(item)
+ and self._expression_uses_stage_one_alias(item.expression, stage_one_aliases)
+ ):
+ return False
+ if (
+ item.property
+ and item.symbolic_name in first_join_variables
+ and not self._is_aggregate_item(item)
+ ):
+ return False
+ symbolic_name = self._strip_distinct_prefix(item.symbolic_name)
+ if symbolic_name in second_declared_variables:
+ return True
+ references = self._property_references(item.expression)
+ if not references:
+ return False
+ return all(variable in second_declared_variables for variable, _property_name in references)
+
+ def _expression_uses_stage_one_alias(
+ self,
+ expression: str,
+ stage_one_aliases: set[str],
+ ) -> bool:
+ if not expression:
+ return False
+ for variable, property_name in self._property_references(expression):
+ property_name = self._canonical_property_name(variable, property_name)
+ if self._with_property_stage_alias(variable, property_name) in stage_one_aliases:
+ return True
+ return any(re.search(rf"\b{re.escape(alias)}\b", expression) for alias in stage_one_aliases)
+
+ def _outer_select_for_with_match(
+ self,
+ return_body: ReturnBody,
+ distinct: bool,
+ aggregate_query: bool,
+ stage_one_aliases: set[str],
+ first_join_variables: set[str],
+ second_return_body: ReturnBody | None = None,
+ ) -> str:
+ if not stage_one_aliases:
+ return self._outer_select(return_body, distinct, aggregate_query, ["__join__"])
+ select_items = []
+ return_aliases = self._resolved_return_aliases(return_body)
+ stage_two_aliases = self._stage_two_aliases_by_return_item(second_return_body)
+ for item, resolved_alias in zip(return_body.return_item_list, return_aliases, strict=True):
+ if self._is_complex_aggregate_item(item):
+ expression = self._outer_complex_aggregate_expression_for_with_match(
+ item.expression,
+ stage_one_aliases,
+ )
+ alias = self._return_alias(item, expression)
+ select_items.append(f"{expression} AS {OracleNameSanitizer.alias(alias)}")
+ elif self._is_aggregate_item(item):
+ expression = self._outer_aggregate_expression_for_with_match(
+ item,
+ stage_one_aliases,
+ second_return_body,
+ )
+ alias = self._return_alias(item, expression)
+ select_items.append(f"{expression} AS {OracleNameSanitizer.alias(alias)}")
+ elif item.property:
+ stage_alias = self._stage_alias_for_return_item(
+ item,
+ first_join_variables,
+ stage_one_aliases,
+ )
+ resolved_sql_alias = OracleNameSanitizer.alias(resolved_alias)
+ if stage_alias in stage_one_aliases:
+ expression = f"stage_1.{stage_alias}"
+ elif resolved_sql_alias in stage_one_aliases:
+ expression = f"stage_1.{resolved_sql_alias}"
+ else:
+ stage_two_alias = stage_two_aliases.get(
+ id(item),
+ ) or self._stage_two_property_alias(
+ second_return_body,
+ item.symbolic_name,
+ item.property,
+ )
+ expression = f"stage_2.{stage_two_alias or stage_alias}"
+ select_items.append(f"{expression} AS {resolved_sql_alias}")
+ elif item.expression and (
+ item.expression != item.symbolic_name
+ or not re.fullmatch(r"[A-Za-z_][A-Za-z0-9_]*", item.expression)
+ ):
+ if self._expression_uses_stage_one_alias(item.expression, stage_one_aliases):
+ expression = self._outer_expression_for_with_match(
+ item.expression,
+ stage_one_aliases,
+ )
+ else:
+ expression = f"stage_2.{OracleNameSanitizer.alias(resolved_alias)}"
+ select_items.append(f"{expression} AS {OracleNameSanitizer.alias(resolved_alias)}")
+ else:
+ stage_alias = self._stage_alias_for_return_item(
+ item,
+ first_join_variables,
+ stage_one_aliases,
+ )
+ resolved_sql_alias = OracleNameSanitizer.alias(resolved_alias)
+ if stage_alias in stage_one_aliases:
+ expression = f"stage_1.{stage_alias}"
+ else:
+ expression = f"stage_2.{stage_two_aliases.get(id(item), stage_alias)}"
+ select_items.append(f"{expression} AS {resolved_sql_alias}")
+ keyword = "SELECT DISTINCT" if distinct else "SELECT"
+ return f"{keyword} " + ", ".join(select_items)
+
+ def _stage_two_aliases_by_return_item(
+ self,
+ second_return_body: ReturnBody | None,
+ ) -> Dict[int, str]:
+ if second_return_body is None:
+ return {}
+ return {
+ id(item): OracleNameSanitizer.alias(alias)
+ for item, alias in zip(
+ second_return_body.return_item_list,
+ self._resolved_return_aliases(second_return_body),
+ strict=True,
+ )
+ }
+
+ def _stage_two_property_alias(
+ self,
+ second_return_body: ReturnBody | None,
+ variable: str,
+ property_name: str,
+ ) -> str:
+ if second_return_body is None:
+ return ""
+ property_name = self._canonical_property_name(variable, property_name)
+ for item, alias in zip(
+ second_return_body.return_item_list,
+ self._resolved_return_aliases(second_return_body),
+ strict=True,
+ ):
+ if (
+ item.symbolic_name == variable
+ and item.property
+ and not self._is_aggregate_item(item)
+ ):
+ item_property = self._canonical_property_name(variable, item.property)
+ if item_property == property_name:
+ return OracleNameSanitizer.alias(alias)
+ for expr_variable, expr_property in self._property_references(item.expression):
+ if self._is_aggregate_item(item):
+ continue
+ if expr_variable != variable:
+ continue
+ expr_property = self._canonical_property_name(variable, expr_property)
+ if expr_property == property_name:
+ return OracleNameSanitizer.alias(alias)
+ return ""
+
+ def _outer_complex_aggregate_expression_for_with_match(
+ self,
+ expression: str,
+ stage_one_aliases: set[str],
+ ) -> str:
+ translated = self._translate_sql_expression(expression)
+ translated = self._outer_expression_for_with_match(translated, stage_one_aliases)
+ for variable in self._aggregate_element_reference_names(expression):
+ stage_one_element_alias = self._stage_one_join_alias(variable)
+ element_alias = self._element_projection_alias(variable)
+ if stage_one_element_alias in stage_one_aliases:
+ qualified = f"stage_1.{stage_one_element_alias}"
+ elif element_alias in stage_one_aliases:
+ qualified = f"stage_1.{element_alias}"
+ else:
+ qualified = f"stage_2.{element_alias}"
+ translated = re.sub(
+ rf"\b(COUNT)\s*\(\s*DISTINCT\s+{re.escape(variable)}\s*\)",
+ rf"\1(DISTINCT {qualified})",
+ translated,
+ flags=re.IGNORECASE,
+ )
+ translated = re.sub(
+ rf"\b(COUNT|MIN|MAX)\s*\(\s*{re.escape(variable)}\s*\)",
+ rf"\1({qualified})",
+ translated,
+ flags=re.IGNORECASE,
+ )
+ return self._coalesce_sum_calls(translated)
+
+ def _outer_expression_for_with_match(
+ self,
+ expression: str,
+ stage_one_aliases: set[str],
+ ) -> str:
+ translated = self._translate_sql_expression(expression)
+ for variable, property_name in self._property_references(expression):
+ property_name = self._canonical_property_name(variable, property_name)
+ property_alias = self._with_property_stage_alias(variable, property_name)
+ if property_alias not in stage_one_aliases:
+ continue
+ sql_variable = self._var_sql_names.get(variable, variable)
+ property_ref = OracleNameSanitizer.quote(property_name, fallback="PROP")
+ translated = translated.replace(
+ f"{sql_variable}.{property_ref}",
+ f"stage_1.{property_alias}",
+ )
+ translated = re.sub(
+ rf"\b{re.escape(variable)}\.{re.escape(property_name)}\b",
+ f"stage_1.{property_alias}",
+ translated,
+ )
+ for alias in sorted(stage_one_aliases, key=len, reverse=True):
+ translated = re.sub(
+ rf"(? str:
+ function_name = item.function_name.upper()
+ if function_name == "COUNT" and item.symbolic_name == "*":
+ return "COUNT(*)"
+ symbolic_name = self._strip_distinct_prefix(item.symbolic_name)
+ argument = self._aggregate_argument_expression_text(item, symbolic_name)
+ if item.property:
+ property_name = self._canonical_property_name(symbolic_name, item.property)
+ property_alias = self._with_property_stage_alias(symbolic_name, property_name)
+ if property_alias in stage_one_aliases:
+ distinct = "DISTINCT " if self._has_distinct_prefix(item.symbolic_name) else ""
+ return self._aggregate_sql_call(
+ function_name,
+ distinct,
+ f"stage_1.{property_alias}",
+ )
+ stage_two_alias = self._stage_two_property_alias(
+ second_return_body,
+ symbolic_name,
+ property_name,
+ )
+ if stage_two_alias:
+ distinct = "DISTINCT " if self._has_distinct_prefix(item.symbolic_name) else ""
+ return self._aggregate_sql_call(
+ function_name,
+ distinct,
+ f"stage_2.{stage_two_alias}",
+ )
+ stage_one_element_alias = self._stage_one_join_alias(symbolic_name)
+ if (
+ not item.property
+ and "." not in argument
+ and stage_one_element_alias in stage_one_aliases
+ ):
+ distinct = "DISTINCT " if self._has_distinct_prefix(item.symbolic_name) else ""
+ return self._aggregate_sql_call(
+ function_name,
+ distinct,
+ f"stage_1.{stage_one_element_alias}",
+ )
+ if symbolic_name in self._var_kinds and not item.property and "." not in argument:
+ distinct = "DISTINCT " if self._has_distinct_prefix(item.symbolic_name) else ""
+ return (
+ f"{function_name}({distinct}"
+ f"stage_2.{self._element_projection_alias(symbolic_name)})"
+ )
+ argument = self._translate_sql_expression(argument)
+ argument_alias = OracleNameSanitizer.alias(argument)
+ if argument_alias in stage_one_aliases:
+ distinct = "DISTINCT " if self._has_distinct_prefix(item.symbolic_name) else ""
+ return self._aggregate_sql_call(
+ function_name,
+ distinct,
+ f"stage_1.{argument_alias}",
+ )
+ return self._outer_aggregate_expression(item)
+
+ def _stage_alias_for_return_item(
+ self,
+ item: ReturnItem | SortItem,
+ first_join_variables: set[str] | None = None,
+ stage_one_aliases: set[str] | None = None,
+ ) -> str:
+ if item.property and first_join_variables and item.symbolic_name in first_join_variables:
+ property_name = self._canonical_property_name(item.symbolic_name, item.property)
+ property_alias = self._with_property_stage_alias(item.symbolic_name, property_name)
+ if not stage_one_aliases or property_alias in stage_one_aliases:
+ return property_alias
+ if not item.property and stage_one_aliases:
+ element_alias = self._element_projection_alias(item.symbolic_name)
+ if element_alias in stage_one_aliases:
+ return element_alias
+ stage_one_element_alias = self._stage_one_join_alias(item.symbolic_name)
+ if stage_one_element_alias in stage_one_aliases:
+ return stage_one_element_alias
+ return OracleNameSanitizer.alias(
+ getattr(item, "alias", "")
+ or item.property
+ or self._default_expression_alias(item.symbolic_name, item.expression)
+ )
+
+ def _stage_one_return_projections(
+ self,
+ return_body: ReturnBody,
+ first_join_variables: set[str],
+ second_declared_variables: set[str],
+ ) -> List[str]:
+ projections = []
+ return_aliases = self._resolved_return_aliases(return_body)
+ for item, alias in zip(return_body.return_item_list, return_aliases, strict=True):
+ if self._is_aggregate_item(item):
+ continue
+ if item.symbolic_name not in first_join_variables:
+ continue
+ if not item.property and item.symbolic_name in second_declared_variables:
+ continue
+ projections.append(self._translate_return_item(item, alias))
+ return projections
+
+ def _stage_one_with_scalar_projections(self, return_body: ReturnBody) -> List[str]:
+ projections = []
+ for item in return_body.return_item_list:
+ if self._is_aggregate_item(item):
+ continue
+ if (
+ not item.property
+ and not item.function_name
+ and (not item.expression or item.expression == item.symbolic_name)
+ and item.symbolic_name in self._var_kinds
+ ):
+ continue
+ alias = item.alias or item.property or item.symbolic_name
+ if not alias:
+ continue
+ projections.append(
+ f"{self._return_expression(item)} AS {OracleNameSanitizer.alias(alias)}"
+ )
+ return projections
+
+ def _stage_one_sort_projections(self, return_body: ReturnBody) -> List[str]:
+ projections = []
+ projected_aliases = {
+ self._element_projection_alias(item.symbolic_name)
+ for item in return_body.return_item_list
+ if (
+ not item.property
+ and not item.function_name
+ and item.symbolic_name in self._var_kinds
+ )
+ }
+ projected_aliases.update(
+ OracleNameSanitizer.alias(item.alias or item.property or item.symbolic_name)
+ for item in return_body.return_item_list
+ if (item.alias or item.property or item.symbolic_name)
+ )
+ for sort_item in return_body.sort_item_list:
+ alias = OracleNameSanitizer.alias(self._sort_alias(sort_item, return_body))
+ if alias in projected_aliases:
+ continue
+ if sort_item.expression:
+ expression = self._translate_sql_expression(sort_item.expression)
+ else:
+ expression = self._value_expression(sort_item.symbolic_name, sort_item.property)
+ projections.append(f"{expression} AS {alias}")
+ projected_aliases.add(alias)
+ return projections
+
+ def _stage_one_property_projections(
+ self,
+ references: List[Tuple[str, str]],
+ ) -> List[str]:
+ projections = []
+ for variable, property_name in references:
+ property_name = self._canonical_property_name(variable, property_name)
+ sql_variable = self._var_sql_names.get(variable, variable)
+ property_ref = OracleNameSanitizer.quote(property_name, fallback="PROP")
+ alias = self._with_property_stage_alias(variable, property_name)
+ projections.append(f"{sql_variable}.{property_ref} AS {alias}")
+ return projections
+
+ def _stage_one_aggregate_property_references(
+ self,
+ return_body: ReturnBody,
+ first_join_variables: set[str],
+ ) -> List[Tuple[str, str]]:
+ references: List[Tuple[str, str]] = []
+ for item in return_body.return_item_list:
+ if not self._is_aggregate_item(item):
+ continue
+ symbolic_name = self._strip_distinct_prefix(item.symbolic_name)
+ if item.property and symbolic_name in first_join_variables:
+ references.append((symbolic_name, item.property))
+ for variable, property_name in self._property_references(item.expression):
+ if variable in first_join_variables:
+ references.append((variable, property_name))
+ unique: List[Tuple[str, str]] = []
+ for variable, property_name in references:
+ property_name = self._canonical_property_name(variable, property_name)
+ if (variable, property_name) not in unique:
+ unique.append((variable, property_name))
+ return unique
+
+ def _return_body_with_property_projections(
+ self,
+ return_body: ReturnBody,
+ references: List[Tuple[str, str]],
+ ) -> ReturnBody:
+ if not references:
+ return return_body
+ items = list(return_body.return_item_list)
+ existing_aliases = {
+ OracleNameSanitizer.alias(self._return_alias(item, self._return_expression(item)))
+ for item in items
+ }
+ for variable, property_name in references:
+ property_name = self._canonical_property_name(variable, property_name)
+ alias = self._with_property_stage_alias(variable, property_name)
+ if alias in existing_aliases:
+ continue
+ items.append(
+ ReturnItem(
+ symbolic_name=variable,
+ property=property_name,
+ alias=alias,
+ function_name="",
+ expression=f"{variable}.{property_name}",
+ )
+ )
+ existing_aliases.add(alias)
+ return ReturnBody(
+ items,
+ return_body.sort_item_list,
+ return_body.skip,
+ return_body.limit,
+ )
+
+ def _return_body_with_expression_projections(
+ self,
+ return_body: ReturnBody,
+ references: List[Tuple[str, str]],
+ ) -> ReturnBody:
+ if not references:
+ return return_body
+ items = list(return_body.return_item_list)
+ existing_aliases = {
+ OracleNameSanitizer.alias(self._return_alias(item, self._return_expression(item)))
+ for item in items
+ }
+ for expression, alias in references:
+ alias = OracleNameSanitizer.alias(alias)
+ if alias in existing_aliases:
+ continue
+ items.append(
+ ReturnItem(
+ symbolic_name="",
+ property="",
+ alias=alias,
+ function_name="",
+ expression=expression,
+ )
+ )
+ existing_aliases.add(alias)
+ return ReturnBody(
+ items,
+ return_body.sort_item_list,
+ return_body.skip,
+ return_body.limit,
+ )
+
+ def _with_match_correlations(
+ self,
+ where_expressions: List[CompareExpression],
+ carried_variables: set[str],
+ second_declared_variables: set[str],
+ scalar_aliases: set[str],
+ ) -> Tuple[
+ List[Tuple[str, str, str, str]],
+ List[Tuple[str, str, str, str]],
+ List[Tuple[str, str, str, str]],
+ List[Tuple[str, str, str, str]],
+ List[Tuple[str, str, str]],
+ List[str],
+ List[CompareExpression],
+ ]:
+ correlations: List[Tuple[str, str, str, str]] = []
+ scalar_correlations: List[Tuple[str, str, str, str]] = []
+ expression_scalar_correlations: List[Tuple[str, str, str, str]] = []
+ stage_expression_correlations: List[Tuple[str, str, str, str]] = []
+ element_correlations: List[Tuple[str, str, str]] = []
+ stage_one_filters: List[str] = []
+ residual: List[CompareExpression] = []
+ for expression in where_expressions:
+ raw_expression = getattr(expression, "raw_expression", "")
+ if self._is_with_match_stage_one_filter(
+ raw_expression or "",
+ second_declared_variables,
+ scalar_aliases,
+ ):
+ stage_one_filters.append(raw_expression)
+ continue
+ element_match = re.fullmatch(
+ r"\s*(?P[A-Za-z_][A-Za-z0-9_]*)\s*"
+ r"(?P=|<>)\s*"
+ r"(?P[A-Za-z_][A-Za-z0-9_]*)\s*",
+ raw_expression or "",
+ )
+ if element_match:
+ left = element_match.group("left")
+ right = element_match.group("right")
+ operator = element_match.group("operator")
+ if left in second_declared_variables and right in carried_variables:
+ element_correlations.append((left, operator, right))
+ continue
+ if right in second_declared_variables and left in carried_variables:
+ element_correlations.append((right, operator, left))
+ continue
+ cast_expression_scalar = self._with_match_cast_expression_scalar_correlation(
+ raw_expression or "",
+ second_declared_variables,
+ scalar_aliases,
+ )
+ if cast_expression_scalar:
+ expression_scalar_correlations.append(cast_expression_scalar)
+ continue
+ cast_scalar = self._with_match_cast_scalar_correlation(
+ raw_expression or "",
+ second_declared_variables,
+ scalar_aliases,
+ )
+ if cast_scalar:
+ scalar_correlations.append(cast_scalar)
+ continue
+ expression_scalar = self._with_match_expression_scalar_correlation(
+ raw_expression or "",
+ second_declared_variables,
+ scalar_aliases,
+ )
+ if expression_scalar:
+ expression_scalar_correlations.append(expression_scalar)
+ continue
+ stage_expression = self._with_match_stage_expression_correlation(
+ raw_expression or "",
+ second_declared_variables,
+ scalar_aliases,
+ )
+ if stage_expression:
+ stage_expression_correlations.append(stage_expression)
+ continue
+ match = re.fullmatch(
+ r"\s*(?P[A-Za-z_][A-Za-z0-9_]*)\."
+ r"(?P[A-Za-z_][A-Za-z0-9_$#-]*)\s*"
+ r"(?P=|<>|<=|>=|<|>)\s*"
+ r"(?:(?P[A-Za-z_][A-Za-z0-9_]*)\."
+ r"(?P[A-Za-z_][A-Za-z0-9_$#-]*)|"
+ r"(?P[A-Za-z_][A-Za-z0-9_]*))\s*",
+ raw_expression or "",
+ )
+ if not match:
+ reverse_match = re.fullmatch(
+ r"\s*(?P[A-Za-z_][A-Za-z0-9_]*)\s*"
+ r"(?P=|<>|<=|>=|<|>)\s*"
+ r"(?P[A-Za-z_][A-Za-z0-9_]*)\."
+ r"(?P[A-Za-z_][A-Za-z0-9_$#-]*)\s*",
+ raw_expression or "",
+ )
+ if (
+ reverse_match
+ and reverse_match.group("left_alias") in scalar_aliases
+ and reverse_match.group("right_var") in second_declared_variables
+ ):
+ scalar_correlations.append(
+ (
+ reverse_match.group("right_var"),
+ reverse_match.group("right_prop"),
+ self._reverse_operator(reverse_match.group("operator")),
+ OracleNameSanitizer.alias(reverse_match.group("left_alias")),
+ )
+ )
+ else:
+ residual.append(expression)
+ continue
+ left_var = match.group("left_var")
+ left_prop = match.group("left_prop")
+ operator = match.group("operator")
+ right_var = match.group("right_var")
+ right_prop = match.group("right_prop")
+ right_alias = match.group("right_alias")
+ if (
+ left_var in second_declared_variables
+ and right_alias
+ and right_alias in scalar_aliases
+ ):
+ scalar_correlations.append(
+ (left_var, left_prop, operator, OracleNameSanitizer.alias(right_alias))
+ )
+ elif left_var in second_declared_variables and right_var in carried_variables:
+ if operator == "=":
+ correlations.append((left_var, left_prop, right_var, right_prop))
+ else:
+ scalar_correlations.append(
+ (
+ left_var,
+ left_prop,
+ operator,
+ self._with_property_stage_alias(right_var, right_prop),
+ )
+ )
+ elif right_var in second_declared_variables and left_var in carried_variables:
+ if operator == "=":
+ correlations.append((right_var, right_prop, left_var, left_prop))
+ else:
+ scalar_correlations.append(
+ (
+ right_var,
+ right_prop,
+ self._reverse_operator(operator),
+ self._with_property_stage_alias(left_var, left_prop),
+ )
+ )
+ else:
+ residual.append(expression)
+ return (
+ correlations,
+ scalar_correlations,
+ expression_scalar_correlations,
+ stage_expression_correlations,
+ element_correlations,
+ stage_one_filters,
+ residual,
+ )
+
+ def _is_with_match_stage_one_filter(
+ self,
+ expression: str,
+ second_declared_variables: set[str],
+ scalar_aliases: set[str],
+ ) -> bool:
+ if not expression:
+ return False
+ protected, _ = self._protect_string_literals(expression)
+ if not any(
+ re.search(rf"\b{re.escape(alias)}\b", protected) for alias in scalar_aliases if alias
+ ):
+ return False
+ for variable, _property_name in self._property_references(protected):
+ if variable in second_declared_variables:
+ return False
+ words = set(re.findall(r"\b[A-Za-z_][A-Za-z0-9_]*\b", protected))
+ ignored = {
+ "AND",
+ "OR",
+ "NOT",
+ "NULL",
+ "IS",
+ "TRUE",
+ "FALSE",
+ "TOFLOAT",
+ "TOINTEGER",
+ "TOSTRING",
+ "DATE",
+ "DATETIME",
+ "SIZE",
+ "SPLIT",
+ "REPLACE",
+ "LEFT",
+ "SUBSTRING",
+ }
+ return not any(word in second_declared_variables for word in words - ignored)
+
+ def _with_match_stage_one_filter_sql(
+ self,
+ expression: str,
+ stage_one_aliases: set[str],
+ ) -> str:
+ translated = self._translate_sql_expression(expression)
+ for alias in sorted(stage_one_aliases, key=len, reverse=True):
+ translated = re.sub(
+ rf"(? str:
+ translated = self._translate_sql_expression(expression)
+ for alias in sorted(stage_one_aliases, key=len, reverse=True):
+ translated = re.sub(
+ rf"(? Tuple[str, str, str, str] | None:
+ property_ref = (
+ r"(?P<{prefix}_var>[A-Za-z_][A-Za-z0-9_]*)\."
+ r"(?P<{prefix}_prop>[A-Za-z_][A-Za-z0-9_$#-]*)"
+ )
+ left_property = re.fullmatch(
+ property_ref.format(prefix="left")
+ + r"\s*(?P=|<>|<=|>=|<|>)\s*"
+ + r"(?P.+?)\s*",
+ expression,
+ flags=re.IGNORECASE,
+ )
+ if (
+ left_property
+ and left_property.group("left_var") in second_declared_variables
+ and self._is_stage_one_scalar_expression(
+ left_property.group("right_expr"),
+ scalar_aliases,
+ )
+ ):
+ return (
+ left_property.group("left_var"),
+ left_property.group("left_prop"),
+ left_property.group("operator"),
+ left_property.group("right_expr").strip(),
+ )
+ right_property = re.fullmatch(
+ r"\s*(?P.+?)\s*"
+ + r"(?P=|<>|<=|>=|<|>)\s*"
+ + property_ref.format(prefix="right")
+ + r"\s*",
+ expression,
+ flags=re.IGNORECASE,
+ )
+ if (
+ right_property
+ and right_property.group("right_var") in second_declared_variables
+ and self._is_stage_one_scalar_expression(
+ right_property.group("left_expr"),
+ scalar_aliases,
+ )
+ ):
+ return (
+ right_property.group("right_var"),
+ right_property.group("right_prop"),
+ self._reverse_operator(right_property.group("operator")),
+ right_property.group("left_expr").strip(),
+ )
+ return None
+
+ def _is_stage_one_scalar_expression(
+ self,
+ expression: str,
+ scalar_aliases: set[str],
+ ) -> bool:
+ if not expression:
+ return False
+ protected, _ = self._protect_string_literals(expression)
+ if not any(
+ re.search(rf"\b{re.escape(alias)}\b", protected) for alias in scalar_aliases if alias
+ ):
+ return False
+ stripped = protected
+ for alias in sorted(scalar_aliases, key=len, reverse=True):
+ stripped = re.sub(rf"\b{re.escape(alias)}\b", "", stripped)
+ return bool(re.fullmatch(r"[\s\d.+\-*/%()]+", stripped))
+
+ def _with_match_expression_scalar_correlation(
+ self,
+ expression: str,
+ second_declared_variables: set[str],
+ scalar_aliases: set[str],
+ ) -> Tuple[str, str, str, str] | None:
+ property_expr = (
+ r"size\s*\(\s*"
+ r"(?P<{prefix}_var>[A-Za-z_][A-Za-z0-9_]*)\."
+ r"(?P<{prefix}_prop>[A-Za-z_][A-Za-z0-9_$#-]*)"
+ r"\s*\)"
+ )
+ alias_ref = r"(?P<{prefix}_alias>[A-Za-z_][A-Za-z0-9_]*)"
+ left_property = re.fullmatch(
+ property_expr.format(prefix="left")
+ + r"\s*(?P=|<>|<=|>=|<|>)\s*"
+ + alias_ref.format(prefix="right")
+ + r"\s*",
+ expression,
+ flags=re.IGNORECASE,
+ )
+ if (
+ left_property
+ and left_property.group("left_var") in second_declared_variables
+ and left_property.group("right_alias") in scalar_aliases
+ ):
+ variable = left_property.group("left_var")
+ property_name = left_property.group("left_prop")
+ return (
+ f"size({variable}.{property_name})",
+ self._with_expression_stage_alias(variable, property_name, "size"),
+ left_property.group("operator"),
+ OracleNameSanitizer.alias(left_property.group("right_alias")),
+ )
+ right_property = re.fullmatch(
+ alias_ref.format(prefix="left")
+ + r"\s*(?P=|<>|<=|>=|<|>)\s*"
+ + property_expr.format(prefix="right")
+ + r"\s*",
+ expression,
+ flags=re.IGNORECASE,
+ )
+ if (
+ right_property
+ and right_property.group("right_var") in second_declared_variables
+ and right_property.group("left_alias") in scalar_aliases
+ ):
+ variable = right_property.group("right_var")
+ property_name = right_property.group("right_prop")
+ return (
+ f"size({variable}.{property_name})",
+ self._with_expression_stage_alias(variable, property_name, "size"),
+ self._reverse_operator(right_property.group("operator")),
+ OracleNameSanitizer.alias(right_property.group("left_alias")),
+ )
+ return None
+
+ def _with_expression_stage_alias(
+ self,
+ variable: str,
+ property_name: str,
+ function_name: str,
+ ) -> str:
+ return OracleNameSanitizer.alias(f"{variable}_{property_name}_{function_name}")
+
+ def _with_match_cast_scalar_correlation(
+ self,
+ expression: str,
+ second_declared_variables: set[str],
+ scalar_aliases: set[str],
+ ) -> Tuple[str, str, str, str] | None:
+ property_ref = (
+ r"(?:(?:toFloat|toInteger)\s*\(\s*)?"
+ r"(?P<{prefix}_var>[A-Za-z_][A-Za-z0-9_]*)\."
+ r"(?P<{prefix}_prop>[A-Za-z_][A-Za-z0-9_$#-]*)"
+ r"\s*\)?"
+ )
+ alias_ref = r"(?P<{prefix}_alias>[A-Za-z_][A-Za-z0-9_]*)"
+ left_property = re.fullmatch(
+ property_ref.format(prefix="left")
+ + r"\s*(?P=|<>|<=|>=|<|>)\s*"
+ + alias_ref.format(prefix="right")
+ + r"\s*",
+ expression,
+ flags=re.IGNORECASE,
+ )
+ if (
+ left_property
+ and left_property.group("left_var") in second_declared_variables
+ and left_property.group("right_alias") in scalar_aliases
+ ):
+ return (
+ left_property.group("left_var"),
+ left_property.group("left_prop"),
+ left_property.group("operator"),
+ OracleNameSanitizer.alias(left_property.group("right_alias")),
+ )
+ right_property = re.fullmatch(
+ alias_ref.format(prefix="left")
+ + r"\s*(?P=|<>|<=|>=|<|>)\s*"
+ + property_ref.format(prefix="right")
+ + r"\s*",
+ expression,
+ flags=re.IGNORECASE,
+ )
+ if (
+ right_property
+ and right_property.group("right_var") in second_declared_variables
+ and right_property.group("left_alias") in scalar_aliases
+ ):
+ return (
+ right_property.group("right_var"),
+ right_property.group("right_prop"),
+ self._reverse_operator(right_property.group("operator")),
+ OracleNameSanitizer.alias(right_property.group("left_alias")),
+ )
+ return None
+
+ def _with_match_cast_expression_scalar_correlation(
+ self,
+ expression: str,
+ second_declared_variables: set[str],
+ scalar_aliases: set[str],
+ ) -> Tuple[str, str, str, str] | None:
+ property_ref = (
+ r"(?P<{prefix}_func>toFloat|toInteger)\s*\(\s*"
+ r"(?P<{prefix}_var>[A-Za-z_][A-Za-z0-9_]*)\."
+ r"(?P<{prefix}_prop>[A-Za-z_][A-Za-z0-9_$#-]*)"
+ r"\s*\)"
+ )
+ alias_ref = r"(?P<{prefix}_alias>[A-Za-z_][A-Za-z0-9_]*)"
+ left_property = re.fullmatch(
+ property_ref.format(prefix="left")
+ + r"\s*(?P=|<>|<=|>=|<|>)\s*"
+ + alias_ref.format(prefix="right")
+ + r"\s*",
+ expression,
+ flags=re.IGNORECASE,
+ )
+ if (
+ left_property
+ and left_property.group("left_var") in second_declared_variables
+ and left_property.group("right_alias") in scalar_aliases
+ ):
+ variable = left_property.group("left_var")
+ property_name = left_property.group("left_prop")
+ function_name = left_property.group("left_func")
+ return (
+ f"{function_name}({variable}.{property_name})",
+ self._with_expression_stage_alias(variable, property_name, function_name),
+ left_property.group("operator"),
+ OracleNameSanitizer.alias(left_property.group("right_alias")),
+ )
+ right_property = re.fullmatch(
+ alias_ref.format(prefix="left")
+ + r"\s*(?P=|<>|<=|>=|<|>)\s*"
+ + property_ref.format(prefix="right")
+ + r"\s*",
+ expression,
+ flags=re.IGNORECASE,
+ )
+ if (
+ right_property
+ and right_property.group("right_var") in second_declared_variables
+ and right_property.group("left_alias") in scalar_aliases
+ ):
+ variable = right_property.group("right_var")
+ property_name = right_property.group("right_prop")
+ function_name = right_property.group("right_func")
+ return (
+ f"{function_name}({variable}.{property_name})",
+ self._with_expression_stage_alias(variable, property_name, function_name),
+ self._reverse_operator(right_property.group("operator")),
+ OracleNameSanitizer.alias(right_property.group("left_alias")),
+ )
+ return None
+
+ def _extract_with_match_property_map_correlations(
+ self,
+ match_clauses: List[MatchClause],
+ second_declared_variables: set[str],
+ scalar_aliases: set[str],
+ ) -> List[Tuple[str, str, str, str]]:
+ correlations: List[Tuple[str, str, str, str]] = []
+ for clause in match_clauses:
+ patterns = (
+ clause.path_pattern
+ if isinstance(clause.path_pattern, list)
+ else [clause.path_pattern]
+ )
+ for path in patterns:
+ for node in path.node_pattern_list:
+ correlations.extend(
+ self._extract_property_map_scalar_correlations(
+ node.symbolic_name,
+ node.property_maps,
+ second_declared_variables,
+ scalar_aliases,
+ )
+ )
+ for edge in path.edge_pattern_list:
+ correlations.extend(
+ self._extract_property_map_scalar_correlations(
+ edge.symbolic_name,
+ edge.property_maps,
+ second_declared_variables,
+ scalar_aliases,
+ )
+ )
+ return correlations
+
+ def _assign_with_match_property_map_correlation_variables(
+ self,
+ match_clauses: List[MatchClause],
+ scalar_aliases: set[str],
+ ) -> None:
+ if not scalar_aliases:
+ return
+ used_variables = self._declared_variables_in_match_clauses(match_clauses)
+ node_index = 0
+ edge_index = 0
+
+ def has_scalar_map(property_maps: List[Tuple[str, str]]) -> bool:
+ return any(
+ OracleNameSanitizer.alias(str(value or "").strip()) in scalar_aliases
+ for _property_name, value in property_maps
+ )
+
+ def next_variable(prefix: str, current_index: int) -> tuple[str, int]:
+ while True:
+ current_index += 1
+ variable = f"{prefix}{current_index}"
+ if variable not in used_variables:
+ used_variables.add(variable)
+ return variable, current_index
+
+ for clause in match_clauses:
+ patterns = (
+ clause.path_pattern
+ if isinstance(clause.path_pattern, list)
+ else [clause.path_pattern]
+ )
+ for path in patterns:
+ for node in path.node_pattern_list:
+ if not node.symbolic_name and has_scalar_map(node.property_maps):
+ node.symbolic_name, node_index = next_variable("with_corr_n", node_index)
+ for edge in path.edge_pattern_list:
+ if not edge.symbolic_name and has_scalar_map(edge.property_maps):
+ edge.symbolic_name, edge_index = next_variable("with_corr_e", edge_index)
+
+ def _extract_property_map_scalar_correlations(
+ self,
+ variable: str,
+ property_maps: List[Tuple[str, str]],
+ second_declared_variables: set[str],
+ scalar_aliases: set[str],
+ ) -> List[Tuple[str, str, str, str]]:
+ if variable not in second_declared_variables:
+ return []
+ correlations: List[Tuple[str, str, str, str]] = []
+ retained: List[Tuple[str, str]] = []
+ for property_name, property_value in property_maps:
+ alias = OracleNameSanitizer.alias(str(property_value or "").strip())
+ if alias in scalar_aliases:
+ correlations.append((variable, property_name, "=", alias))
+ else:
+ retained.append((property_name, property_value))
+ property_maps[:] = retained
+ return correlations
+
+ def _reverse_operator(self, operator: str) -> str:
+ return {
+ "<": ">",
+ ">": "<",
+ "<=": ">=",
+ ">=": "<=",
+ }.get(operator, operator)
+
+ def _declared_variables_in_match_clauses(
+ self,
+ match_clauses: List[MatchClause],
+ ) -> set[str]:
+ variables = set()
+ for clause in match_clauses:
+ paths = (
+ clause.path_pattern
+ if isinstance(clause.path_pattern, list)
+ else [clause.path_pattern]
+ )
+ for path in paths:
+ for node in path.node_pattern_list:
+ if node.symbolic_name:
+ variables.add(node.symbolic_name)
+ for edge in path.edge_pattern_list:
+ if edge.symbolic_name:
+ variables.add(edge.symbolic_name)
+ return variables
+
+ def _stage_one_join_alias(self, variable: str) -> str:
+ return OracleNameSanitizer.alias(f"stage_1_{self._element_projection_alias(variable)}")
+
+ def _with_stage_return_body(
+ self,
+ with_body: ReturnBody,
+ final_body: ReturnBody,
+ with_clause: WithClause,
+ ) -> ReturnBody:
+ carried_variables = self._with_carried_variables(with_body)
+ items = list(with_body.return_item_list)
+ existing_aliases = {
+ OracleNameSanitizer.alias(self._return_alias(item, self._return_expression(item)))
+ for item in items
+ }
+ for variable, property_name in self._with_stage_needed_properties(final_body, with_clause):
+ if variable not in carried_variables:
+ continue
+ alias = self._with_property_stage_alias(variable, property_name)
+ if alias in existing_aliases:
+ continue
+ items.append(
+ ReturnItem(
+ symbolic_name=variable,
+ property=property_name,
+ alias=alias,
+ function_name="",
+ expression=f"{variable}.{property_name}",
+ )
+ )
+ existing_aliases.add(alias)
+ graph_aliases = self._with_graph_aliases(with_body)
+ aliased_property_references: List[Tuple[str, str]] = []
+ for item in final_body.return_item_list:
+ if item.property:
+ aliased_property_references.append((item.symbolic_name, item.property))
+ aliased_property_references.extend(self._all_property_references(item.expression))
+ for sort_item in final_body.sort_item_list:
+ if sort_item.property:
+ aliased_property_references.append((sort_item.symbolic_name, sort_item.property))
+ aliased_property_references.extend(self._all_property_references(sort_item.expression))
+ for alias_variable, property_name in dict.fromkeys(aliased_property_references):
+ source_variable = graph_aliases.get(alias_variable)
+ if not source_variable or source_variable not in carried_variables:
+ continue
+ property_name = self._canonical_property_name(source_variable, property_name)
+ alias = self._with_property_stage_alias(alias_variable, property_name)
+ if alias in existing_aliases:
+ continue
+ items.append(
+ ReturnItem(
+ symbolic_name=source_variable,
+ property=property_name,
+ alias=alias,
+ function_name="",
+ expression=f"{source_variable}.{property_name}",
+ )
+ )
+ existing_aliases.add(alias)
+ return ReturnBody(
+ items,
+ with_body.sort_item_list,
+ with_body.skip,
+ with_body.limit,
+ )
+
+ def _with_graph_aliases(self, with_body: ReturnBody) -> Dict[str, str]:
+ aliases: Dict[str, str] = {}
+ for item in with_body.return_item_list:
+ if (
+ item.alias
+ and not item.property
+ and not item.function_name
+ and (not item.expression or item.expression == item.symbolic_name)
+ and item.symbolic_name in self._var_kinds
+ ):
+ aliases[item.alias] = item.symbolic_name
+ return aliases
+
+ def _with_stage_needed_properties(
+ self,
+ final_body: ReturnBody,
+ with_clause: WithClause,
+ ) -> List[Tuple[str, str]]:
+ references = self._final_property_references(final_body)
+ for expr in self._as_compare_list(with_clause.compare_expression_list):
+ if getattr(expr, "raw_expression", ""):
+ references.extend(self._property_references(expr.raw_expression))
+ elif expr.property:
+ references.append((expr.symbolic_name, expr.property))
+ unique: List[Tuple[str, str]] = []
+ for variable, property_name in references:
+ if (variable, property_name) not in unique:
+ unique.append((variable, property_name))
+ return unique
+
+ def _with_carried_variables(self, with_body: ReturnBody) -> set[str]:
+ carried = set()
+ for item in with_body.return_item_list:
+ if (
+ not item.property
+ and not item.function_name
+ and (not item.expression or item.expression == item.symbolic_name)
+ and item.symbolic_name in self._var_kinds
+ ):
+ carried.add(item.symbolic_name)
+ return carried
+
+ def _with_passthrough_variable_names(self, with_body: ReturnBody) -> set[str]:
+ carried = set()
+ for item in with_body.return_item_list:
+ if (
+ not item.property
+ and not item.function_name
+ and (not item.expression or item.expression == item.symbolic_name)
+ and item.symbolic_name
+ ):
+ carried.add(item.symbolic_name)
+ return carried
+
+ def _with_scalar_aliases(self, with_body: ReturnBody) -> set[str]:
+ aliases = set()
+ for item in with_body.return_item_list:
+ if (
+ not item.property
+ and not item.function_name
+ and (
+ not item.expression
+ or (
+ item.expression == item.symbolic_name
+ and re.fullmatch(r"[A-Za-z_][A-Za-z0-9_]*", item.symbolic_name or "")
+ )
+ )
+ ):
+ continue
+ alias = item.alias or item.property or item.symbolic_name
+ if alias:
+ aliases.add(OracleNameSanitizer.alias(alias))
+ aliases.add(alias)
+ return aliases
+
+ def _final_property_references(self, return_body: ReturnBody) -> List[Tuple[str, str]]:
+ references: List[Tuple[str, str]] = []
+ for item in return_body.return_item_list:
+ if item.property:
+ references.append((item.symbolic_name, item.property))
+ references.extend(self._property_references(item.expression))
+ for sort_item in return_body.sort_item_list:
+ if sort_item.property:
+ references.append((sort_item.symbolic_name, sort_item.property))
+ references.extend(self._property_references(sort_item.expression))
+ unique: List[Tuple[str, str]] = []
+ for variable, property_name in references:
+ if variable in self._var_kinds and (variable, property_name) not in unique:
+ unique.append((variable, property_name))
+ return unique
+
+ def _with_property_stage_alias(self, variable: str, property_name: str) -> str:
+ alias = OracleNameSanitizer.alias(f"{variable}_{property_name}")
+ if alias.upper() == self._element_projection_alias(variable).upper():
+ return OracleNameSanitizer.alias(f"{variable}_{property_name}_PROP")
+ return alias
+
+ def _outer_select_for_with_stage(
+ self,
+ return_body: ReturnBody,
+ distinct: bool,
+ carried_variables: set[str],
+ ) -> str:
+ select_items = []
+ return_aliases = self._resolved_return_aliases(return_body)
+ for item, resolved_alias in zip(
+ return_body.return_item_list,
+ return_aliases,
+ strict=True,
+ ):
+ if self._is_complex_aggregate_item(item):
+ expression = self._with_stage_aggregate_sql_expression(
+ item.expression,
+ carried_variables,
+ )
+ elif self._is_aggregate_item(item):
+ expression = self._with_stage_aggregate_item_expression(
+ item,
+ carried_variables,
+ )
+ elif item.property and item.symbolic_name in carried_variables:
+ expression = self._with_property_stage_alias(item.symbolic_name, item.property)
+ elif item.property:
+ expression = self._with_property_stage_alias(item.symbolic_name, item.property)
+ elif not item.property and item.symbolic_name in carried_variables:
+ expression = self._element_projection_alias(item.symbolic_name)
+ elif item.expression and (
+ item.expression != item.symbolic_name
+ or not re.fullmatch(r"[A-Za-z_][A-Za-z0-9_]*", item.expression)
+ ):
+ expression = self._with_stage_expression(
+ self._translate_sql_expression(item.expression),
+ carried_variables,
+ )
+ else:
+ expression = OracleNameSanitizer.alias(item.symbolic_name)
+ if item.function_name:
+ expression = f"{item.function_name.upper()}({expression})"
+ select_items.append(f"{expression} AS {OracleNameSanitizer.alias(resolved_alias)}")
+ keyword = "SELECT DISTINCT" if distinct else "SELECT"
+ return f"{keyword} " + ", ".join(select_items)
+
+ def _with_stage_aggregate_item_expression(
+ self,
+ item: ReturnItem | SortItem,
+ carried_variables: set[str],
+ ) -> str:
+ function_name = item.function_name.upper()
+ if function_name == "COUNT" and item.symbolic_name == "*":
+ return "COUNT(*)"
+ self._reject_temporal_numeric_aggregate(item)
+ symbolic_name = self._strip_distinct_prefix(item.symbolic_name)
+ expression = self._aggregate_argument_expression_text(
+ item,
+ symbolic_name,
+ )
+ expression = self._translate_sql_expression(expression)
+ expression = self._with_stage_expression(expression, carried_variables)
+ if expression in carried_variables:
+ expression = self._element_projection_alias(expression)
+ distinct = "DISTINCT " if self._has_distinct_prefix(item.symbolic_name) else ""
+ return self._aggregate_sql_call(function_name, distinct, expression)
+
+ def _with_stage_aggregate_sql_expression(
+ self,
+ expression: str,
+ carried_variables: set[str],
+ ) -> str:
+ translated = self._translate_sql_expression(expression)
+ return self._coalesce_sum_calls(self._with_stage_expression(translated, carried_variables))
+
+ def _with_stage_expression(self, expression: str, carried_variables: set[str]) -> str:
+ for variable, property_name in self._property_references(expression):
+ if variable not in carried_variables:
+ continue
+ property_name = self._canonical_property_name(variable, property_name)
+ sql_variable = self._var_sql_names.get(variable, variable)
+ property_ref = OracleNameSanitizer.quote(property_name, fallback="PROP")
+ expression = expression.replace(
+ f"{sql_variable}.{property_ref}",
+ self._with_property_stage_alias(variable, property_name),
+ )
+ expression = re.sub(
+ rf"\b{re.escape(variable)}\.{re.escape(property_name)}\b",
+ self._with_property_stage_alias(variable, property_name),
+ expression,
+ )
+ for variable in carried_variables:
+ element_alias = self._element_projection_alias(variable)
+ expression = re.sub(
+ rf"\b(VERTEX_ID|EDGE_ID)\s*\(\s*{re.escape(variable)}\s*\)",
+ element_alias,
+ expression,
+ flags=re.IGNORECASE,
+ )
+ expression = re.sub(
+ rf"\b(COUNT)\s*\(\s*DISTINCT\s+{re.escape(variable)}\s*\)",
+ rf"\1(DISTINCT {element_alias})",
+ expression,
+ flags=re.IGNORECASE,
+ )
+ expression = re.sub(
+ rf"\b(COUNT|MIN|MAX)\s*\(\s*{re.escape(variable)}\s*\)",
+ rf"\1({element_alias})",
+ expression,
+ flags=re.IGNORECASE,
+ )
+ variables = sorted(carried_variables, key=len, reverse=True)
+ for left in variables:
+ for right in variables:
+ comparison = (
+ f"{self._element_projection_alias(left)} = "
+ f"{self._element_projection_alias(right)}"
+ )
+ expression = re.sub(
+ rf"\bNOT\s+VERTEX_EQUAL\s*\(\s*{re.escape(left)}\s*,\s*{re.escape(right)}\s*\)",
+ (
+ f"{self._element_projection_alias(left)} <> "
+ f"{self._element_projection_alias(right)}"
+ ),
+ expression,
+ flags=re.IGNORECASE,
+ )
+ expression = re.sub(
+ rf"\bNOT\s+EDGE_EQUAL\s*\(\s*{re.escape(left)}\s*,\s*{re.escape(right)}\s*\)",
+ (
+ f"{self._element_projection_alias(left)} <> "
+ f"{self._element_projection_alias(right)}"
+ ),
+ expression,
+ flags=re.IGNORECASE,
+ )
+ expression = re.sub(
+ rf"\bVERTEX_EQUAL\s*\(\s*{re.escape(left)}\s*,\s*{re.escape(right)}\s*\)",
+ comparison,
+ expression,
+ flags=re.IGNORECASE,
+ )
+ expression = re.sub(
+ rf"\bEDGE_EQUAL\s*\(\s*{re.escape(left)}\s*,\s*{re.escape(right)}\s*\)",
+ comparison,
+ expression,
+ flags=re.IGNORECASE,
+ )
+ return expression
+
+ def _with_filters(self, with_clause: WithClause, carried_variables: set[str]) -> List[str]:
+ for expr in self._as_compare_list(with_clause.compare_expression_list):
+ raw_expression = getattr(expr, "raw_expression", "")
+ if raw_expression and re.search(
+ r"\b(?:AVG|SUM|COUNT|MIN|MAX)\s*\(",
+ raw_expression,
+ flags=re.IGNORECASE,
+ ):
+ raise ValueError("WITH WHERE aggregate filters must be projected first.")
+ return [
+ self._with_stage_expression(
+ self._translate_sql_expression(expr.raw_expression),
+ carried_variables,
+ )
+ if getattr(expr, "raw_expression", "")
+ else self._translate_compare_expression(expr)
+ for expr in self._as_compare_list(with_clause.compare_expression_list)
+ ]
+
+ def _outer_group_order_and_paging_for_with_stage(
+ self,
+ return_body: ReturnBody,
+ carried_variables: set[str],
+ ) -> str:
+ suffix = ""
+ if self._has_aggregate(return_body):
+ group_aliases = []
+ for item in return_body.return_item_list:
+ if self._is_aggregate_item(item):
+ continue
+ if item.property and item.symbolic_name in carried_variables:
+ group_aliases.append(
+ self._with_property_stage_alias(item.symbolic_name, item.property)
+ )
+ elif not item.property and item.symbolic_name in carried_variables:
+ group_aliases.append(self._element_projection_alias(item.symbolic_name))
+ else:
+ group_aliases.append(OracleNameSanitizer.alias(item.symbolic_name))
+ if group_aliases:
+ suffix += "\nGROUP BY " + ", ".join(dict.fromkeys(group_aliases))
+ if return_body.sort_item_list:
+ suffix += "\nORDER BY " + ", ".join(
+ self._translate_sort_item_for_with_stage(
+ item,
+ return_body,
+ carried_variables,
+ )
+ for item in return_body.sort_item_list
+ )
+ if return_body.skip != -1:
+ suffix += f"\nOFFSET {return_body.skip} ROWS"
+ if return_body.limit != -1:
+ suffix += f"\nFETCH FIRST {return_body.limit} ROWS ONLY"
+ return suffix
+
+ def _indent_sql(self, sql: str) -> str:
+ return "\n".join(" " + line for line in sql.splitlines())
+
+ def _translate_with_projection_columns(self, with_clause: WithClause) -> str:
+ projections = []
+ for item in with_clause.return_body.return_item_list:
+ if not item.alias:
+ raise ValueError("WITH projection requires aliases for Oracle SQL output.")
+ if item.function_name:
+ raise ValueError("Aggregate WITH projection requires staged SQL CTE support.")
+ projections.append(
+ f"{self._return_expression(item)} AS {OracleNameSanitizer.alias(item.alias)}"
+ )
+ if not projections:
+ raise ValueError("WITH projection cannot be empty.")
+ return ", ".join(projections)
+
+ def _outer_select_for_with(self, return_body: ReturnBody, distinct: bool) -> str:
+ select_items = []
+ return_aliases = self._resolved_return_aliases(return_body)
+ for item, resolved_alias in zip(
+ return_body.return_item_list,
+ return_aliases,
+ strict=True,
+ ):
+ expression = OracleNameSanitizer.alias(item.symbolic_name)
+ if item.function_name:
+ expression = self._aggregate_sql_call(
+ item.function_name.upper(),
+ "",
+ expression,
+ )
+ select_items.append(f"{expression} AS {OracleNameSanitizer.alias(resolved_alias)}")
+ keyword = "SELECT DISTINCT" if distinct else "SELECT"
+ return f"{keyword} " + ", ".join(select_items)
+
+ def _outer_order_and_paging_for_with(
+ self,
+ return_body: ReturnBody,
+ include_group_by: bool = True,
+ ) -> str:
+ suffix = ""
+ if include_group_by and self._has_aggregate(return_body):
+ return_aliases = self._resolved_return_aliases(return_body)
+ group_aliases = [
+ OracleNameSanitizer.alias(alias)
+ for item, alias in zip(
+ return_body.return_item_list,
+ return_aliases,
+ strict=True,
+ )
+ if not self._is_aggregate_item(item)
+ ]
+ if group_aliases:
+ suffix += "\nGROUP BY " + ", ".join(dict.fromkeys(group_aliases))
+ if return_body.sort_item_list:
+ suffix += "\nORDER BY " + ", ".join(
+ self._translate_sort_item(item, return_body) for item in return_body.sort_item_list
+ )
+ if return_body.skip != -1:
+ suffix += f"\nOFFSET {return_body.skip} ROWS"
+ if return_body.limit != -1:
+ suffix += f"\nFETCH FIRST {return_body.limit} ROWS ONLY"
+ return suffix
+
+ def _reset(self) -> None:
+ self._var_kinds = {}
+ self._var_sql_names = {}
+ self._var_labels = {}
+ self._path_variables = {}
+ self._path_variable_has_quantifier = {}
+ self._path_variable_quantified_edges = {}
+ self._var_property_redirects = {}
+ self._reserved_variables = set()
+ self._pattern_where_expressions = []
+ self._auto_node_index = 0
+ self._auto_edge_index = 0
+
+ def _as_compare_list(self, value) -> List[CompareExpression]:
+ if value is None:
+ return []
+ if isinstance(value, list):
+ return [item for item in value if isinstance(item, CompareExpression)]
+ if isinstance(value, CompareExpression):
+ return [value]
+ return []
+
+ def _translate_match_clause(self, match_clause: MatchClause) -> str:
+ path_pattern = match_clause.path_pattern
+ if isinstance(path_pattern, list):
+ return ", ".join(self._translate_path_pattern(item) for item in path_pattern)
+ return self._translate_path_pattern(path_pattern)
+
+ def _translate_path_pattern(self, path_pattern: PathPattern) -> str:
+ self._reserved_variables.update(
+ node.symbolic_name for node in path_pattern.node_pattern_list if node.symbolic_name
+ )
+ self._reserved_variables.update(
+ edge.symbolic_name for edge in path_pattern.edge_pattern_list if edge.symbolic_name
+ )
+ if path_pattern.edge_pattern_list and all(
+ edge.direction == "left" for edge in path_pattern.edge_pattern_list
+ ):
+ return self._translate_reversed_left_path(path_pattern)
+ parts: List[str] = []
+ if not path_pattern.node_pattern_list:
+ raise ValueError("PathPattern must include at least one node pattern.")
+ parts.append(self._translate_node_pattern(path_pattern.node_pattern_list[0]))
+ for index, edge in enumerate(path_pattern.edge_pattern_list):
+ parts.append(self._translate_edge_pattern(edge))
+ parts.append(self._translate_node_pattern(path_pattern.node_pattern_list[index + 1]))
+ self._infer_node_labels_from_edge_labels(path_pattern)
+ self._register_adjacent_edge_property_redirects(path_pattern)
+ self._register_path_variable(path_pattern)
+ return "".join(parts)
+
+ def _translate_reversed_left_path(self, path_pattern: PathPattern) -> str:
+ parts: List[str] = []
+ reversed_nodes = list(reversed(path_pattern.node_pattern_list))
+ reversed_edges = list(reversed(path_pattern.edge_pattern_list))
+ parts.append(self._translate_node_pattern(reversed_nodes[0]))
+ translated_edges: List[EdgePattern] = []
+ for index, edge in enumerate(reversed_edges):
+ translated_edge = EdgePattern(
+ edge.symbolic_name,
+ edge.label,
+ edge.property_maps,
+ "right",
+ edge.hop_range,
+ )
+ parts.append(self._translate_edge_pattern(translated_edge))
+ translated_edges.append(translated_edge)
+ parts.append(self._translate_node_pattern(reversed_nodes[index + 1]))
+ translated_path = PathPattern(reversed_nodes, translated_edges, path_pattern.path_variable)
+ self._infer_node_labels_from_edge_labels(translated_path)
+ self._register_adjacent_edge_property_redirects(translated_path)
+ self._register_path_variable(path_pattern)
+ return "".join(parts)
+
+ def _infer_node_labels_from_edge_labels(self, path_pattern: PathPattern) -> None:
+ if not self.edge_label_map:
+ return
+ inferred: Dict[str, set[str]] = {}
+ for index, edge in enumerate(path_pattern.edge_pattern_list):
+ if edge.direction not in {"left", "right"} or not edge.label:
+ continue
+ endpoint_labels = self._edge_endpoint_labels(edge.label)
+ if not endpoint_labels:
+ continue
+ src_label, dst_label = endpoint_labels
+ if edge.direction == "right":
+ src_node = path_pattern.node_pattern_list[index]
+ dst_node = path_pattern.node_pattern_list[index + 1]
+ else:
+ src_node = path_pattern.node_pattern_list[index + 1]
+ dst_node = path_pattern.node_pattern_list[index]
+ for node, label in ((src_node, src_label), (dst_node, dst_label)):
+ if node.symbolic_name and not self._var_labels.get(node.symbolic_name):
+ inferred.setdefault(node.symbolic_name, set()).add(
+ self._source_label_for_graph_label(label)
+ )
+ for variable, labels in inferred.items():
+ if len(labels) == 1:
+ self._var_labels[variable] = next(iter(labels))
+
+ def _source_label_for_graph_label(self, graph_label: str) -> str:
+ source_by_lower = {
+ str(source).lower(): source
+ for source, targets in self.node_label_map.items()
+ if graph_label in targets and str(source).lower() != str(graph_label).lower()
+ }
+ if len(source_by_lower) == 1:
+ return next(iter(source_by_lower.values()))
+ return graph_label
+
+ def _edge_endpoint_labels(self, edge_label: str) -> Tuple[str, str] | None:
+ relation = OracleNameSanitizer.clean(edge_label)
+ marker = f"_{relation}_"
+ endpoints: set[Tuple[str, str]] = set()
+ for graph_label in self._label_map_targets(edge_label, self.edge_label_map):
+ if marker not in graph_label:
+ continue
+ src_label, dst_label = graph_label.split(marker, 1)
+ if src_label and dst_label:
+ endpoints.add((src_label, dst_label))
+ if len(endpoints) == 1:
+ return next(iter(endpoints))
+ return None
+
+ def _register_adjacent_edge_property_redirects(self, path_pattern: PathPattern) -> None:
+ if not self.property_type_map:
+ return
+ candidates: Dict[Tuple[str, str], set[str]] = {}
+ for index, edge in enumerate(path_pattern.edge_pattern_list):
+ edge_variable = edge.symbolic_name
+ if not edge_variable:
+ continue
+ edge_properties = set()
+ for edge_label in self._possible_graph_labels(edge_variable):
+ edge_properties.update(self.property_type_map.get(edge_label, {}))
+ if not edge_properties:
+ continue
+ for node in (
+ path_pattern.node_pattern_list[index],
+ path_pattern.node_pattern_list[index + 1],
+ ):
+ node_variable = node.symbolic_name
+ if not node_variable:
+ continue
+ for property_name in edge_properties:
+ if self._property_type_on_variable(node_variable, property_name):
+ continue
+ candidates.setdefault((node_variable, property_name.lower()), set()).add(
+ edge_variable
+ )
+ for key, edge_variables in candidates.items():
+ if len(edge_variables) == 1:
+ self._var_property_redirects[key] = next(iter(edge_variables))
+
+ def _register_path_variable(self, path_pattern: PathPattern) -> None:
+ if not path_pattern.path_variable:
+ return
+ elements: List[Tuple[str, str]] = []
+ for index, node in enumerate(path_pattern.node_pattern_list):
+ if node.symbolic_name:
+ elements.append(("node", node.symbolic_name))
+ if index < len(path_pattern.edge_pattern_list):
+ edge = path_pattern.edge_pattern_list[index]
+ if edge.symbolic_name:
+ elements.append(("edge", edge.symbolic_name))
+ self._path_variables[path_pattern.path_variable] = elements
+ self._path_variable_has_quantifier[path_pattern.path_variable] = any(
+ edge.hop_range != (-1, -1) for edge in path_pattern.edge_pattern_list
+ )
+ self._path_variable_quantified_edges[path_pattern.path_variable] = {
+ edge.symbolic_name
+ for edge in path_pattern.edge_pattern_list
+ if edge.symbolic_name and edge.hop_range != (-1, -1)
+ }
+
+ def _translate_node_pattern(self, node_pattern: NodePattern) -> str:
+ variable = node_pattern.symbolic_name or self._next_node_var()
+ node_pattern.symbolic_name = variable
+ sql_variable = self._declare_variable(variable, "node", node_pattern.label)
+ body = sql_variable
+ if node_pattern.label:
+ body += f" IS {self._label_expression(node_pattern.label, self.node_label_map, 'NODE')}"
+ inline_where = self._property_maps_to_where(sql_variable, node_pattern.property_maps)
+ if inline_where:
+ self._pattern_where_expressions.append(inline_where)
+ return f"({body})"
+
+ def _translate_edge_pattern(self, edge_pattern: EdgePattern) -> str:
+ variable = edge_pattern.symbolic_name or self._next_edge_var()
+ edge_pattern.symbolic_name = variable
+ if edge_pattern.hop_range != (-1, -1) and edge_pattern.property_maps:
+ raise ValueError(
+ "Property maps on quantified relationships are not supported by Oracle SQL/PGQ."
+ )
+ sql_variable = self._declare_variable(variable, "edge", edge_pattern.label)
+ body = sql_variable
+ if edge_pattern.label and self._should_emit_label(
+ edge_pattern.label,
+ self.edge_label_map,
+ ):
+ body += f" IS {self._label_expression(edge_pattern.label, self.edge_label_map, 'EDGE')}"
+ inline_where = self._property_maps_to_where(sql_variable, edge_pattern.property_maps)
+ if inline_where:
+ self._pattern_where_expressions.append(inline_where)
+ edge = f"[{body}]"
+ if edge_pattern.direction == "left":
+ edge = f"<-{edge}-"
+ elif edge_pattern.direction == "right":
+ edge = f"-{edge}->"
+ else:
+ edge = f"-{edge}-"
+ if edge_pattern.hop_range != (-1, -1):
+ edge += self._hop_quantifier(edge_pattern.hop_range)
+ return edge
+
+ def _should_emit_label(self, source_label: str, label_map: Dict[str, List[str]]) -> bool:
+ if not source_label:
+ return False
+ if not self.strict_property_validation:
+ return True
+ return any(
+ self._label_map_targets(source_part.strip(), label_map)
+ != [OracleNameSanitizer.clean(source_part.strip())]
+ for source_part in source_label.split("|")
+ if source_part.strip()
+ )
+
+ def _property_maps_to_where(
+ self, variable: str, property_maps: Iterable[Tuple[str, str]]
+ ) -> str:
+ expressions = []
+ for property_name, property_value in property_maps:
+ property_name = self._canonical_property_name(variable, property_name)
+ property_ref = OracleNameSanitizer.quote(property_name, fallback="PROP")
+ map_comparison = self._property_map_comparison(
+ variable,
+ property_name,
+ property_value,
+ )
+ if map_comparison:
+ operator, comparison_value = map_comparison
+ expressions.append(f"{variable}.{property_ref} {operator} {comparison_value}")
+ continue
+ property_value = self._coerce_literal_for_property(
+ variable,
+ property_name,
+ property_value,
+ )
+ expressions.append(f"{variable}.{property_ref} = {property_value}")
+ return " AND ".join(expressions)
+
+ def _property_map_comparison(
+ self,
+ variable: str,
+ property_name: str,
+ value: str,
+ ) -> Tuple[str, str] | None:
+ match = re.fullmatch(
+ r"\{\s*(?Plt|lte|gt|gte|eq|neq)\s*:\s*(?P.+?)\s*\}",
+ str(value or "").strip(),
+ flags=re.IGNORECASE,
+ )
+ if not match:
+ return None
+ operator = {
+ "lt": "<",
+ "lte": "<=",
+ "gt": ">",
+ "gte": ">=",
+ "eq": "=",
+ "neq": "<>",
+ }[match.group("operator").lower()]
+ comparison_value = self._coerce_literal_for_property(
+ variable,
+ property_name,
+ match.group("value"),
+ )
+ return operator, comparison_value
+
+ def _coerce_literal_for_property(
+ self,
+ variable: str,
+ property_name: str,
+ value: str,
+ ) -> str:
+ property_type = self._property_type(variable, property_name)
+ raw_value = str(value).strip()
+ if self._is_string_type(property_type) and raw_value.lower() in {"true", "false"}:
+ return f"'{raw_value.lower()}'"
+ value = self._translate_sql_expression(raw_value)
+ if not self._is_string_type(property_type):
+ return value
+ if re.fullmatch(r"-?\d+(?:\.\d+)?", value):
+ return f"'{value}'"
+ date_match = re.fullmatch(r"DATE\s+('[^']*')", value, flags=re.IGNORECASE)
+ if date_match:
+ return date_match.group(1)
+ return value
+
+ def _property_type(self, variable: str, property_name: str) -> str:
+ redirect_variable = self._property_redirect_target(variable, property_name)
+ if redirect_variable:
+ return self._property_type_on_variable(redirect_variable, property_name)
+ return self._property_type_on_variable(variable, property_name)
+
+ def _property_type_on_variable(self, variable: str, property_name: str) -> str:
+ source_variable = self._source_variable(variable)
+ for label in self._possible_graph_labels(source_variable):
+ properties = self.property_type_map.get(label, {})
+ if property_name in properties:
+ return properties[property_name]
+ clean_property = OracleNameSanitizer.clean(property_name)
+ if clean_property in properties:
+ return properties[clean_property]
+ snake_property = self._camel_to_snake(property_name)
+ if snake_property in properties:
+ return properties[snake_property]
+ for candidate, property_type in properties.items():
+ if candidate.lower() in {
+ property_name.lower(),
+ clean_property.lower(),
+ snake_property.lower(),
+ }:
+ return property_type
+ return ""
+
+ def _property_redirect_target(self, variable: str, property_name: str) -> str:
+ source_variable = self._source_variable(variable)
+ clean_property = OracleNameSanitizer.clean(property_name)
+ return (
+ self._var_property_redirects.get((source_variable, property_name.lower()))
+ or self._var_property_redirects.get((source_variable, clean_property.lower()))
+ or ""
+ )
+
+ def _property_exists_for_variable_kind(self, variable: str, property_name: str) -> bool:
+ kind = self._var_kinds.get(self._source_variable(variable))
+ if not kind or not self.property_type_map:
+ return True
+ clean_property = OracleNameSanitizer.clean(property_name)
+ candidate_labels = self._property_labels_for_kind(kind)
+ for label in candidate_labels:
+ properties = self.property_type_map.get(label, {})
+ snake_property = self._camel_to_snake(property_name)
+ if (
+ property_name in properties
+ or clean_property in properties
+ or snake_property in properties
+ ):
+ return True
+ if any(
+ candidate.lower()
+ in {property_name.lower(), clean_property.lower(), snake_property.lower()}
+ for candidate in properties
+ ):
+ return True
+ return False
+
+ def _property_labels_for_kind(self, kind: str) -> set[str]:
+ label_map = self.edge_label_map if kind == "edge" else self.node_label_map
+ if label_map:
+ labels = set(label_map)
+ for targets in label_map.values():
+ labels.update(targets)
+ return labels & set(self.property_type_map)
+ if kind == "edge":
+ edge_labels = {label for labels in self.edge_label_map.values() for label in labels}
+ return edge_labels & set(self.property_type_map)
+ edge_labels = {label for labels in self.edge_label_map.values() for label in labels}
+ return set(self.property_type_map) - edge_labels
+
+ def _canonical_property_name(self, variable: str, property_name: str) -> str:
+ resolved = self._resolve_pseudo_property(variable, property_name)
+ if resolved:
+ return resolved
+ if (
+ self.strict_property_validation
+ and self.property_type_map
+ and property_name.lower() in {"identity", "id"}
+ and not self._property_type(variable, property_name)
+ ):
+ raise ValueError(f'Cannot resolve pseudo-property "{property_name}" for "{variable}".')
+ redirect_variable = self._property_redirect_target(variable, property_name)
+ if redirect_variable:
+ variable = redirect_variable
+ if not self.property_type_map:
+ return property_name
+ possible_labels = self._possible_graph_labels(variable)
+ for label in possible_labels:
+ properties = self.property_type_map.get(label, {})
+ if property_name in properties:
+ return property_name
+ clean_property = OracleNameSanitizer.clean(property_name)
+ if clean_property in properties:
+ return clean_property
+ snake_property = self._camel_to_snake(property_name)
+ if snake_property in properties:
+ return snake_property
+ for candidate in properties:
+ if candidate.lower() in {
+ property_name.lower(),
+ clean_property.lower(),
+ snake_property.lower(),
+ }:
+ return candidate
+ if self.strict_property_validation and self._possible_graph_labels(variable):
+ raise ValueError(
+ f'Property "{property_name}" is not defined for variable "{variable}".'
+ )
+ if self.strict_property_validation and self._var_labels.get(
+ self._source_variable(variable)
+ ):
+ raise ValueError(
+ f'Label for variable "{variable}" is not mapped to an Oracle graph label.'
+ )
+ if (
+ self.strict_property_validation
+ and self._var_kinds.get(self._source_variable(variable))
+ and not possible_labels
+ and not property_name.lower().endswith("_id")
+ and property_name.lower() not in {"id", "identity"}
+ and not self._property_exists_for_variable_kind(variable, property_name)
+ ):
+ raise ValueError(
+ f'Property "{property_name}" is not defined for variable "{variable}".'
+ )
+ return property_name
+
+ def _camel_to_snake(self, value: str) -> str:
+ return re.sub(r"(? str:
+ if not self.property_type_map:
+ return ""
+ source_variable = self._source_variable(variable)
+ source_label = self._var_labels.get(source_variable, "")
+ clean_label = OracleNameSanitizer.clean(source_label)
+ property_lower = property_name.lower()
+ pseudo_names = {"identity", "id"}
+ if source_label:
+ pseudo_names.add(f"{source_label}_id".lower())
+ pseudo_names.add(f"{clean_label}_id".lower())
+ if property_lower not in pseudo_names:
+ return ""
+ primary_key = self._primary_key_for_variable(variable)
+ if not primary_key:
+ return ""
+ if self._property_type(variable, property_name):
+ return ""
+ return primary_key
+
+ def _primary_key_for_variable(self, variable: str) -> str:
+ source_variable = self._source_variable(variable)
+ label = self._var_labels.get(source_variable, "")
+ if not label:
+ return ""
+ key_map = (
+ self.edge_primary_key_map
+ if self._var_kinds.get(source_variable) == "edge"
+ else self.node_primary_key_map
+ )
+ for candidate in [label, *self._possible_graph_labels(source_variable)]:
+ if candidate in key_map:
+ return key_map[candidate]
+ for key_label, primary_key in key_map.items():
+ if key_label.lower() == candidate.lower():
+ return primary_key
+ properties_by_label = [
+ self.property_type_map.get(candidate, {})
+ for candidate in self._possible_graph_labels(source_variable)
+ ]
+ for properties in properties_by_label:
+ for property_name in properties:
+ if property_name.lower() in {
+ f"{label}_id".lower(),
+ f"{OracleNameSanitizer.clean(label)}_id".lower(),
+ }:
+ return property_name
+ for properties in properties_by_label:
+ id_like = [name for name in properties if name.lower().endswith("_id")]
+ if len(id_like) == 1:
+ return id_like[0]
+ return ""
+
+ def _possible_graph_labels(self, variable: str) -> List[str]:
+ source_variable = self._source_variable(variable)
+ label = self._var_labels.get(source_variable, "")
+ if not label:
+ return []
+ label_map = (
+ self.edge_label_map
+ if self._var_kinds.get(source_variable) == "edge"
+ else self.node_label_map
+ )
+ labels: List[str] = []
+ for source_part in label.split("|"):
+ source_part = source_part.strip()
+ if (
+ self.strict_property_validation
+ and label_map
+ and not self._label_map_has_explicit_target(source_part, label_map)
+ ):
+ continue
+ labels.extend(self._label_map_targets(source_part, label_map))
+ return list(dict.fromkeys(labels))
+
+ def _label_map_targets(self, source_label: str, label_map: Dict[str, List[str]]) -> List[str]:
+ clean_label = OracleNameSanitizer.clean(source_label)
+ return (
+ label_map.get(source_label)
+ or label_map.get(source_label.lower())
+ or label_map.get(clean_label)
+ or label_map.get(clean_label.lower())
+ or [clean_label]
+ )
+
+ def _label_map_has_explicit_target(
+ self,
+ source_label: str,
+ label_map: Dict[str, List[str]],
+ ) -> bool:
+ clean_label = OracleNameSanitizer.clean(source_label)
+ target_labels = {target for targets in label_map.values() for target in targets}
+ return (
+ any(
+ key in label_map
+ for key in (
+ source_label,
+ source_label.lower(),
+ clean_label,
+ clean_label.lower(),
+ )
+ )
+ or source_label in target_labels
+ or clean_label in target_labels
+ )
+
+ def _is_string_type(self, property_type: str) -> bool:
+ normalized = str(property_type or "").upper()
+ return "CHAR" in normalized or "CLOB" in normalized or normalized == "STRING"
+
+ def _is_temporal_type(self, property_type: str) -> bool:
+ normalized = str(property_type or "").upper()
+ return "DATE" in normalized or "TIMESTAMP" in normalized
+
+ def _source_variable(self, variable: str) -> str:
+ for source, sql_name in self._var_sql_names.items():
+ if variable in (source, sql_name):
+ return source
+ return variable
+
+ def _hop_quantifier(self, hop_range: Tuple[int, int]) -> str:
+ lower, upper = hop_range
+ if lower == upper:
+ return f"{{{lower}}}"
+ if lower == -1:
+ return f"{{1,{upper}}}"
+ if upper == -1:
+ raise ValueError(
+ "Open-ended variable-length paths are not supported by Oracle SQL/PGQ."
+ )
+ return f"{{{lower},{upper}}}"
+
+ def _label_expression(
+ self,
+ source_label: str,
+ label_map: Dict[str, List[str]],
+ fallback: str,
+ ) -> str:
+ labels: List[str] = []
+ for source_part in source_label.split("|"):
+ source_part = source_part.strip()
+ if not source_part:
+ continue
+ labels.extend(self._label_map_targets(source_part, label_map))
+ if not labels:
+ labels = [source_label]
+ return " | ".join(OracleNameSanitizer.quote(label, fallback=fallback) for label in labels)
+
+ def _declare_variable(self, variable: str, kind: str, label: str = "") -> str:
+ self._var_kinds[variable] = kind
+ if label:
+ self._var_labels[variable] = label
+ self._var_sql_names.setdefault(variable, self._sql_variable_name(variable))
+ return self._var_sql_names[variable]
+
+ def _sql_variable_name(self, variable: str) -> str:
+ if str(variable or "").upper() == "USER":
+ return "user_var"
+ return OracleNameSanitizer.alias(variable)
+
+ def _where_parts(self, where_expressions: List[CompareExpression]) -> List[str]:
+ for expression in where_expressions:
+ raw_expression = getattr(expression, "raw_expression", "")
+ if self._contains_aggregate_function(raw_expression):
+ raise ValueError("Aggregate functions in MATCH WHERE are not supported.")
+ return self._pattern_where_expressions + [
+ self._translate_compare_expression(expr) for expr in where_expressions
+ ]
+
+ def _translate_columns(
+ self,
+ return_body: ReturnBody | None,
+ aggregate_query: bool = False,
+ ) -> str:
+ if return_body is None or not return_body.return_item_list:
+ variables = [var for var, kind in self._var_kinds.items() if kind == "node"]
+ if not variables:
+ variables = list(self._var_kinds.keys())
+ return ", ".join(
+ f"{self._element_id_expression(var)} AS {OracleNameSanitizer.alias(var + '_ID')}"
+ for var in variables
+ )
+
+ if aggregate_query:
+ return self._translate_aggregate_columns(return_body)
+
+ return_aliases = self._resolved_return_aliases(return_body)
+ projections = []
+ for item, alias in zip(return_body.return_item_list, return_aliases, strict=True):
+ if item.symbolic_name in self._path_variables and not item.property:
+ projections.extend(self._translate_path_return_item(item, alias))
+ else:
+ projections.append(self._translate_return_item(item, alias))
+ projected_aliases = {OracleNameSanitizer.alias(alias) for alias in return_aliases}
+ for sort_item in return_body.sort_item_list:
+ if self._is_aggregate_sort(sort_item):
+ continue
+ alias = OracleNameSanitizer.alias(self._sort_alias(sort_item, return_body))
+ if alias in projected_aliases:
+ continue
+ if sort_item.expression:
+ projections.append(
+ f"{self._translate_sql_expression(sort_item.expression)} AS {alias}"
+ )
+ projected_aliases.add(alias)
+ return ", ".join(projections)
+
+ def _translate_path_return_item(self, return_item: ReturnItem, alias: str) -> List[str]:
+ projections = []
+ alias_prefix = OracleNameSanitizer.alias(alias or return_item.symbolic_name)
+ quantified_edges = self._path_variable_quantified_edges.get(
+ return_item.symbolic_name,
+ set(),
+ )
+ for kind, variable in self._path_variables.get(return_item.symbolic_name, []):
+ sql_variable = self._var_sql_names.get(variable, variable)
+ if kind == "edge" and variable in quantified_edges:
+ projection_alias = OracleNameSanitizer.alias(f"{alias_prefix}_{variable}_IDS")
+ projections.append(f"JSON_ARRAYAGG(EDGE_ID({sql_variable})) AS {projection_alias}")
+ else:
+ expression = "EDGE_ID" if kind == "edge" else "VERTEX_ID"
+ projection_alias = OracleNameSanitizer.alias(f"{alias_prefix}_{variable}_ID")
+ projections.append(f"{expression}({sql_variable}) AS {projection_alias}")
+ if not projections:
+ raise ValueError(f"Path variable {return_item.symbolic_name} is not declared.")
+ return projections
+
+ def _translate_return_item(
+ self,
+ return_item: ReturnItem,
+ alias: str | None = None,
+ ) -> str:
+ expression = self._return_expression(return_item)
+ if return_item.function_name.upper() in self.AGGREGATE_FUNCTIONS:
+ expression = self._aggregate_sql_call(
+ return_item.function_name.upper(),
+ "",
+ expression,
+ )
+ alias = alias or self._return_alias(return_item, expression)
+ return f"{expression} AS {OracleNameSanitizer.alias(alias)}"
+
+ def _translate_aggregate_columns(self, return_body: ReturnBody) -> str:
+ projections: List[Tuple[str, str]] = []
+ return_aliases = self._resolved_return_aliases(return_body)
+ for item, alias in zip(return_body.return_item_list, return_aliases, strict=True):
+ if self._is_complex_aggregate_item(item):
+ for expression, expression_alias in self._aggregate_expression_projections(
+ item.expression
+ ):
+ projections.append((expression, expression_alias))
+ elif self._is_aggregate_item(item):
+ argument_expression, argument_alias = self._aggregate_argument_projection(item)
+ if argument_expression:
+ projections.append((argument_expression, argument_alias))
+ else:
+ expression = self._return_expression(item)
+ projections.append((expression, alias))
+
+ outer_aliases = {OracleNameSanitizer.alias(alias) for alias in return_aliases}
+ for sort_item in return_body.sort_item_list:
+ sort_alias = OracleNameSanitizer.alias(self._sort_alias(sort_item, return_body))
+ if sort_alias in outer_aliases:
+ continue
+ if self._is_complex_aggregate_item(sort_item):
+ for expression, expression_alias in self._aggregate_expression_projections(
+ sort_item.expression
+ ):
+ projections.append((expression, expression_alias))
+ elif self._is_aggregate_sort(sort_item):
+ argument_expression, argument_alias = self._aggregate_argument_projection(sort_item)
+ if argument_expression:
+ projections.append((argument_expression, argument_alias))
+ elif sort_item.expression:
+ if self._sort_expression_uses_return_aliases(sort_item, return_body):
+ continue
+ projections.append(
+ (
+ self._translate_sql_expression(sort_item.expression),
+ self._sort_alias(sort_item, return_body),
+ )
+ )
+
+ unique: Dict[str, str] = {}
+ for expression, alias in projections:
+ unique.setdefault(OracleNameSanitizer.alias(alias), expression)
+ if not unique:
+ unique["dummy_value"] = "1"
+ return ", ".join(f"{expression} AS {alias}" for alias, expression in unique.items())
+
+ def _outer_select(
+ self,
+ return_body: ReturnBody | None,
+ distinct: bool,
+ aggregate_query: bool,
+ hidden_sort_aliases: List[str] | None = None,
+ ) -> str:
+ hidden_sort_aliases = hidden_sort_aliases or []
+ if return_body is None or not aggregate_query:
+ if hidden_sort_aliases and return_body is not None:
+ select_items = []
+ for alias in self._resolved_return_aliases(return_body):
+ select_items.append(OracleNameSanitizer.alias(alias))
+ keyword = "SELECT DISTINCT" if distinct else "SELECT"
+ return f"{keyword} " + ", ".join(select_items)
+ return "SELECT DISTINCT *" if distinct else "SELECT *"
+ select_items = []
+ return_aliases = self._resolved_return_aliases(return_body)
+ for item, resolved_alias in zip(return_body.return_item_list, return_aliases, strict=True):
+ if self._is_complex_aggregate_item(item):
+ expression = self._outer_aggregate_sql_expression(item.expression)
+ alias = self._return_alias(item, expression)
+ select_items.append(f"{expression} AS {OracleNameSanitizer.alias(alias)}")
+ elif self._is_aggregate_item(item):
+ expression = self._outer_aggregate_expression(item)
+ alias = self._return_alias(item, expression)
+ select_items.append(f"{expression} AS {OracleNameSanitizer.alias(alias)}")
+ else:
+ select_items.append(OracleNameSanitizer.alias(resolved_alias))
+ keyword = "SELECT DISTINCT" if distinct else "SELECT"
+ return f"{keyword} " + ", ".join(select_items)
+
+ def _outer_aggregate_expression(self, item: ReturnItem | SortItem) -> str:
+ function_name = item.function_name.upper()
+ if function_name == "COUNT" and item.symbolic_name == "*":
+ return "COUNT(*)"
+ self._reject_temporal_numeric_aggregate(item)
+ _, argument_alias = self._aggregate_argument_projection(item)
+ distinct = ""
+ if self._has_distinct_prefix(item.symbolic_name):
+ distinct = "DISTINCT "
+ return self._aggregate_sql_call(
+ function_name,
+ distinct,
+ OracleNameSanitizer.alias(argument_alias),
+ )
+
+ def _aggregate_sql_call(self, function_name: str, distinct: str, expression: str) -> str:
+ aggregate = f"{function_name}({distinct}{expression})"
+ if function_name.upper() == "SUM":
+ return f"COALESCE({aggregate}, 0)"
+ return aggregate
+
+ def _coalesce_sum_calls(self, expression: str) -> str:
+ result: List[str] = []
+ index = 0
+ while index < len(expression):
+ match = re.search(r"\bSUM\s*\(", expression[index:], flags=re.IGNORECASE)
+ if not match:
+ result.append(expression[index:])
+ break
+ start = index + match.start()
+ open_paren = index + match.end() - 1
+ if self._inside_coalesce_call(expression, start):
+ result.append(expression[index : open_paren + 1])
+ index = open_paren + 1
+ continue
+ close_paren = self._matching_paren_index(expression, open_paren)
+ if close_paren == -1:
+ result.append(expression[index:])
+ break
+ result.append(expression[index:start])
+ result.append(f"COALESCE({expression[start : close_paren + 1]}, 0)")
+ index = close_paren + 1
+ return "".join(result)
+
+ def _inside_coalesce_call(self, expression: str, start: int) -> bool:
+ prefix = expression[max(0, start - 20) : start]
+ return bool(re.search(r"COALESCE\s*\(\s*$", prefix, flags=re.IGNORECASE))
+
+ def _matching_paren_index(self, expression: str, open_paren: int) -> int:
+ depth = 0
+ in_single = False
+ for index in range(open_paren, len(expression)):
+ char = expression[index]
+ if char == "'":
+ in_single = not in_single
+ continue
+ if in_single:
+ continue
+ if char == "(":
+ depth += 1
+ elif char == ")":
+ depth -= 1
+ if depth == 0:
+ return index
+ return -1
+
+ def _aggregate_argument_projection(self, item: ReturnItem | SortItem) -> Tuple[str, str]:
+ if item.function_name.upper() == "COUNT" and item.symbolic_name == "*":
+ return "", "dummy_value"
+ self._reject_temporal_numeric_aggregate(item)
+ self._reject_temporal_numeric_aggregate_expression(item.expression)
+ symbolic_name = self._strip_distinct_prefix(item.symbolic_name)
+ argument = ReturnItem(
+ symbolic_name=symbolic_name,
+ property=item.property,
+ alias="",
+ function_name="",
+ expression=self._aggregate_argument_expression_text(item, symbolic_name),
+ )
+ expression = self._return_expression(argument)
+ return (
+ expression,
+ item.property or self._default_expression_alias(symbolic_name, expression),
+ )
+
+ def _reject_temporal_numeric_aggregate(self, item: ReturnItem | SortItem) -> None:
+ if item.function_name.upper() not in {"AVG", "SUM"}:
+ return
+ if not item.property:
+ return
+ property_type = self._property_type(item.symbolic_name, item.property)
+ if self._is_temporal_type(property_type):
+ raise ValueError(
+ f"{item.function_name.upper()} over temporal property "
+ f"{item.symbolic_name}.{item.property} requires explicit numeric conversion."
+ )
+ if self._is_string_type(property_type):
+ raise ValueError(
+ f"{item.function_name.upper()} over string property "
+ f"{item.symbolic_name}.{item.property} requires explicit numeric conversion."
+ )
+
+ def _aggregate_expression_projections(self, expression: str) -> List[Tuple[str, str]]:
+ self._reject_temporal_numeric_aggregate_expression(expression)
+ expression = self._translate_collect_size_aggregate_expression(expression)
+ projections: List[Tuple[str, str]] = []
+ seen = set()
+ property_aliases = self._aggregate_property_aliases(expression)
+ for variable, property_name in self._property_references(expression):
+ property_name = self._canonical_property_name(variable, property_name)
+ alias = property_aliases[(variable, property_name)]
+ if alias in seen:
+ continue
+ seen.add(alias)
+ property_ref = OracleNameSanitizer.quote(property_name, fallback="PROP")
+ sql_variable = self._var_sql_names.get(variable, variable)
+ projections.append((f"{sql_variable}.{property_ref}", alias))
+ for variable in self._aggregate_element_references(expression):
+ alias = self._element_projection_alias(variable)
+ if alias in seen:
+ continue
+ seen.add(alias)
+ projections.append((self._element_id_expression(variable), alias))
+ if not projections:
+ projections.append(("1", "dummy_value"))
+ return projections
+
+ def _outer_aggregate_sql_expression(self, expression: str) -> str:
+ self._reject_temporal_numeric_aggregate_expression(expression)
+ expression = self._translate_collect_size_aggregate_expression(expression)
+ translated = self._translate_sql_expression(expression)
+ property_aliases = self._aggregate_property_aliases(expression)
+ for variable, property_name in self._property_references(expression):
+ property_name = self._canonical_property_name(variable, property_name)
+ sql_variable = self._var_sql_names.get(variable, variable)
+ property_ref = OracleNameSanitizer.quote(property_name, fallback="PROP")
+ alias = OracleNameSanitizer.alias(property_aliases[(variable, property_name)])
+ translated = translated.replace(f"{sql_variable}.{property_ref}", alias)
+ for variable in self._aggregate_element_references(expression):
+ alias = OracleNameSanitizer.alias(self._element_projection_alias(variable))
+ translated = re.sub(
+ rf"\b(COUNT)\s*\(\s*DISTINCT\s+{re.escape(variable)}\s*\)",
+ rf"\1(DISTINCT {alias})",
+ translated,
+ flags=re.IGNORECASE,
+ )
+ translated = re.sub(
+ rf"\b(COUNT|MIN|MAX)\s*\(\s*{re.escape(variable)}\s*\)",
+ rf"\1({alias})",
+ translated,
+ flags=re.IGNORECASE,
+ )
+ return self._coalesce_sum_calls(translated)
+
+ def _reject_temporal_numeric_aggregate_expression(self, expression: str) -> None:
+ expression = expression or ""
+ if not self._contains_aggregate_function(expression):
+ return
+ if re.search(r"\bduration\s*\.\s*between\s*\(", expression, flags=re.IGNORECASE):
+ raise ValueError("duration.between aggregate requires explicit numeric conversion.")
+ lowered = expression.lower()
+ if "date(" in lowered or "tofloat(" in lowered or "tointeger(" in lowered:
+ return
+ if "-" not in expression:
+ return
+ for variable, property_name in self._property_references(expression):
+ property_type = self._property_type(variable, property_name)
+ if self._is_temporal_type(property_type):
+ raise ValueError(
+ "Temporal arithmetic aggregate requires explicit numeric conversion."
+ )
+
+ def _translate_collect_size_aggregate_expression(self, expression: str) -> str:
+ expression = re.sub(
+ r"\bsize\s*\(\s*apoc\.coll\.toSet\s*\(\s*collect\s*\(\s*"
+ r"(?P[^()]+?)\s*\)\s*\)\s*\)",
+ lambda match: f"COUNT(DISTINCT {match.group('body').strip()})",
+ expression or "",
+ flags=re.IGNORECASE,
+ )
+ return re.sub(
+ r"\bsize\s*\(\s*collect\s*\(\s*(?Pdistinct\s+)?"
+ r"(?P[^()]+?)\s*\)\s*\)",
+ lambda match: (
+ "COUNT("
+ + ("DISTINCT " if match.group("distinct") else "")
+ + match.group("body").strip()
+ + ")"
+ ),
+ expression or "",
+ flags=re.IGNORECASE,
+ )
+
+ def _aggregate_property_aliases(self, expression: str) -> Dict[Tuple[str, str], str]:
+ references = [
+ (variable, self._canonical_property_name(variable, property_name))
+ for variable, property_name in self._property_references(expression)
+ ]
+ aliases: Dict[Tuple[str, str], str] = {}
+ for variable, property_name in references:
+ aliases[(variable, property_name)] = f"{variable}_{property_name}"
+ return aliases
+
+ def _aggregate_element_references(self, expression: str) -> List[str]:
+ return [
+ variable
+ for variable in self._aggregate_element_reference_names(expression)
+ if variable in self._var_kinds
+ ]
+
+ def _aggregate_element_reference_names(self, expression: str) -> List[str]:
+ protected, _ = self._protect_string_literals(expression or "")
+ variables: List[str] = []
+ for match in re.finditer(
+ r"\b(?:COUNT|MIN|MAX)\s*\(\s*(?:DISTINCT\s+)?"
+ r"(?P[A-Za-z_][A-Za-z0-9_]*)\s*\)",
+ protected,
+ flags=re.IGNORECASE,
+ ):
+ variables.append(match.group("variable"))
+ return list(dict.fromkeys(variables))
+
+ def _element_projection_alias(self, variable: str) -> str:
+ return f"{variable}_VALUE"
+
+ def _property_references(self, expression: str) -> List[Tuple[str, str]]:
+ protected, _ = self._protect_string_literals(expression or "")
+ references: List[Tuple[str, str]] = []
+ for match in re.finditer(
+ r"\b(?P[A-Za-z_][A-Za-z0-9_]*)\."
+ r"(?:\"(?P[^\"]+)\"|"
+ r"(?P[A-Za-z_][A-Za-z0-9_$#-]*)\b)",
+ protected,
+ ):
+ variable = match.group("variable")
+ property_name = match.group("quoted_property") or match.group("property")
+ if variable in self._var_kinds:
+ references.append((variable, property_name))
+ return references
+
+ def _property_projection_alias(self, variable: str, property_name: str) -> str:
+ return property_name
+
+ def _aggregate_argument_expression_text(
+ self,
+ item: ReturnItem | SortItem,
+ symbolic_name: str,
+ ) -> str:
+ if not item.expression:
+ return symbolic_name
+ expression = item.expression.strip()
+ if expression.upper().startswith(f"{item.function_name.upper()}(") and expression.endswith(
+ ")"
+ ):
+ expression = expression[len(item.function_name) + 1 : -1].strip()
+ return self._strip_distinct_prefix(expression)
+
+ def _translate_sort_item(self, sort_item: SortItem, return_body: ReturnBody) -> str:
+ if sort_item.expression and self._sort_expression_uses_return_aliases(
+ sort_item,
+ return_body,
+ ):
+ expression = self._translate_sql_expression(sort_item.expression)
+ return f"{expression}{self._sql_sort_order(sort_item.order)}"
+ alias = self._sort_alias(sort_item, return_body)
+ order = self._sql_sort_order(sort_item.order)
+ return f"{OracleNameSanitizer.alias(alias)}{order}"
+
+ def _translate_sort_item_for_with_stage(
+ self,
+ sort_item: SortItem,
+ return_body: ReturnBody,
+ carried_variables: set[str],
+ ) -> str:
+ if sort_item.expression:
+ alias = self._with_stage_expression(
+ self._translate_sql_expression(sort_item.expression),
+ carried_variables,
+ )
+ return f"{alias}{self._sql_sort_order(sort_item.order)}"
+ if sort_item.property and sort_item.symbolic_name in carried_variables:
+ alias = self._with_property_stage_alias(
+ sort_item.symbolic_name,
+ self._canonical_property_name(sort_item.symbolic_name, sort_item.property),
+ )
+ else:
+ alias = self._sort_alias(sort_item, return_body)
+ return f"{OracleNameSanitizer.alias(alias)}{self._sql_sort_order(sort_item.order)}"
+
+ def _sql_sort_order(self, order: str) -> str:
+ normalized = str(order or "").upper()
+ if normalized == "DESCENDING":
+ normalized = "DESC"
+ elif normalized == "ASCENDING":
+ normalized = "ASC"
+ return f" {normalized}" if normalized else ""
+
+ def _sort_expression_uses_return_aliases(
+ self,
+ sort_item: SortItem,
+ return_body: ReturnBody,
+ ) -> bool:
+ expression = sort_item.expression or ""
+ if not expression:
+ return False
+ if self._property_references(expression):
+ return False
+ aliases = {
+ OracleNameSanitizer.alias(alias) for alias in self._resolved_return_aliases(return_body)
+ }
+ if not aliases:
+ return False
+ protected, _ = self._protect_string_literals(expression)
+ return any(re.search(rf"\b{re.escape(alias)}\b", protected) for alias in aliases)
+
+ def _sort_alias(self, sort_item: SortItem, return_body: ReturnBody) -> str:
+ for return_item in return_body.return_item_list:
+ if (
+ return_item.symbolic_name == sort_item.symbolic_name
+ and return_item.property == sort_item.property
+ and return_item.function_name == sort_item.function_name
+ ):
+ return self._return_alias(
+ return_item,
+ return_item.expression or return_item.symbolic_name,
+ )
+ if sort_item.expression and sort_item.expression == return_item.alias:
+ return return_item.alias
+ if sort_item.expression and sort_item.expression == return_item.expression:
+ return (
+ return_item.alias
+ or return_item.property
+ or self._default_expression_alias(
+ return_item.symbolic_name,
+ sort_item.expression,
+ )
+ )
+ alias = sort_item.property or sort_item.symbolic_name
+ if sort_item.function_name:
+ alias = f"{sort_item.function_name}_{alias}"
+ return alias
+
+ def _return_alias(self, return_item: ReturnItem, expression: str) -> str:
+ return (
+ return_item.alias
+ or return_item.property
+ or self._default_expression_alias(
+ return_item.symbolic_name,
+ return_item.expression or expression,
+ )
+ )
+
+ def _return_expression(self, return_item: ReturnItem) -> str:
+ if return_item.expression:
+ expression = return_item.expression.strip()
+ if (
+ expression == return_item.symbolic_name
+ and return_item.symbolic_name
+ and return_item.symbolic_name in self._var_kinds
+ ):
+ return self._element_id_expression(return_item.symbolic_name)
+ if return_item.function_name:
+ if return_item.function_name.upper() not in self.AGGREGATE_FUNCTIONS:
+ return self._translate_sql_expression(expression)
+ return self._value_expression(return_item.symbolic_name, return_item.property)
+ return self._translate_sql_expression(expression)
+ return self._value_expression(return_item.symbolic_name, return_item.property)
+
+ def _value_expression(self, symbolic_name: str, property_name: str) -> str:
+ symbolic_name = self._strip_distinct_prefix(symbolic_name)
+ if symbolic_name == "*":
+ return "*"
+ sql_symbolic_name = self._var_sql_names.get(symbolic_name, symbolic_name)
+ if property_name:
+ if property_name.lower() == "label" and not self._property_type(
+ symbolic_name, property_name
+ ):
+ return self._label_value_expression(symbolic_name)
+ if self._is_edge_identity_property(symbolic_name, property_name):
+ return f"EDGE_ID({sql_symbolic_name})"
+ redirect_variable = self._property_redirect_target(symbolic_name, property_name)
+ property_name = self._canonical_property_name(symbolic_name, property_name)
+ property_ref = OracleNameSanitizer.quote(property_name, fallback="PROP")
+ if redirect_variable:
+ sql_symbolic_name = self._var_sql_names.get(redirect_variable, redirect_variable)
+ return f"{sql_symbolic_name}.{property_ref}"
+ return self._element_id_expression(symbolic_name)
+
+ def _is_edge_identity_property(self, variable: str, property_name: str) -> bool:
+ source_variable = self._source_variable(variable)
+ if self._var_kinds.get(source_variable) != "edge":
+ return False
+ if property_name.lower() not in {"identity", "id"}:
+ return False
+ return self._primary_key_for_variable(source_variable).upper() == "EDGE_ID"
+
+ def _label_value_expression(self, variable: str) -> str:
+ labels = self._possible_graph_labels(variable)
+ sql_variable = self._var_sql_names.get(variable, variable)
+ if not labels:
+ raise ValueError(f'Cannot resolve pseudo-property "label" for "{variable}".')
+ if len(labels) == 1:
+ return f"'{labels[0]}'"
+ branches = [
+ f"WHEN {sql_variable} IS LABELED {OracleNameSanitizer.quote(label, fallback='LABEL')} "
+ f"THEN '{self._sql_string_value(label)}'"
+ for label in labels
+ ]
+ return "CASE " + " ".join(branches) + " END"
+
+ def _sql_string_value(self, value: str) -> str:
+ return str(value).replace("'", "''")
+
+ def _element_id_expression(self, variable: str) -> str:
+ sql_variable = self._var_sql_names.get(variable, variable)
+ if self._var_kinds.get(variable) == "edge":
+ return f"EDGE_ID({sql_variable})"
+ return f"VERTEX_ID({sql_variable})"
+
+ def _translate_compare_expression(self, compare_expression: CompareExpression) -> str:
+ raw_expression = getattr(compare_expression, "raw_expression", "")
+ if raw_expression:
+ return self._translate_sql_expression(raw_expression)
+ prop = ""
+ if compare_expression.property:
+ property_name = self._canonical_property_name(
+ compare_expression.symbolic_name,
+ compare_expression.property,
+ )
+ prop = f".{OracleNameSanitizer.quote(property_name, fallback='PROP')}"
+ operator = {
+ "equal": "=",
+ "neq": "<>",
+ "less": "<",
+ "greater": ">",
+ "leq": "<=",
+ "geq": ">=",
+ }.get(compare_expression.comparison_type, "=")
+ sql_symbolic_name = self._var_sql_names.get(
+ compare_expression.symbolic_name,
+ compare_expression.symbolic_name,
+ )
+ return f"{sql_symbolic_name}{prop} {operator} {compare_expression.comparison_value}"
+
+ def _translate_sql_expression(self, expression: str) -> str:
+ protected, literals = self._protect_string_literals(expression)
+ protected = re.sub(
+ r"\bdate\s*\(\s*__SQL_LITERAL_(\d+)__\s*\)",
+ lambda match: f"DATE __SQL_LITERAL_{match.group(1)}__",
+ protected,
+ flags=re.IGNORECASE,
+ )
+ protected = re.sub(
+ r"\bdate\s*\(\s*\)",
+ "TRUNC(CURRENT_DATE)",
+ protected,
+ flags=re.IGNORECASE,
+ )
+ if re.search(r"\b(?:NOT\s+)?EXISTS\s*\(\s*\(", protected, flags=re.IGNORECASE):
+ raise ValueError("Cypher pattern predicates require SQL CTE rewrite support.")
+ protected = self._translate_date_property_extractors(protected)
+ protected = self._translate_property_extractors(protected)
+ protected = self._translate_label_predicates(protected)
+ protected = self._translate_date_function_property_calls(protected)
+ protected = re.sub(
+ r"\b([A-Za-z_][A-Za-z0-9_]*)\.([A-Za-z_][A-Za-z0-9_$#-]*)\b",
+ self._translate_property_reference_match,
+ protected,
+ )
+ protected = self._translate_chained_comparisons(protected)
+ protected = self._translate_string_predicates(protected)
+ protected = re.sub(
+ r"\bdate\s*\(([^()]+)\)",
+ r"CAST(\1 AS DATE)",
+ protected,
+ flags=re.IGNORECASE,
+ )
+ protected = self._coerce_typed_property_comparisons(protected)
+ protected = self._coerce_date_function_property_comparisons(protected)
+ protected = self._translate_element_id_comparisons(protected)
+ protected = self._translate_id_function_calls(protected)
+ protected = self._translate_element_comparisons(protected)
+ protected = self._translate_split_size(protected, literals)
+ protected = self._translate_modulo(protected)
+ protected = re.sub(r"(? str:
+ operand = (
+ r'(?:[A-Za-z_][A-Za-z0-9_]*\."[^"]+"|'
+ r"DATE\s+__SQL_LITERAL_\d+__|"
+ r"__SQL_LITERAL_\d+__|"
+ r"-?\d+(?:\.\d+)?)"
+ )
+ return re.sub(
+ rf"(?P{operand})\s*(?P<=|<|>=|>)\s*"
+ rf"(?P{operand})\s*(?P<=|<|>=|>)\s*"
+ rf"(?P{operand})",
+ lambda match: (
+ f"({match.group('left')} {match.group('left_op')} {match.group('middle')} "
+ f"AND {match.group('middle')} {match.group('right_op')} {match.group('right')})"
+ ),
+ expression,
+ )
+
+ def _translate_property_reference_match(self, match: re.Match) -> str:
+ variable = match.group(1)
+ property_name = match.group(2)
+ if variable not in self._var_kinds:
+ return match.group(0)
+ if property_name.lower() == "label" and not self._property_type(variable, property_name):
+ return self._label_value_expression(variable)
+ if self._is_edge_identity_property(variable, property_name):
+ return f"EDGE_ID({self._var_sql_names.get(variable, variable)})"
+ redirect_variable = self._property_redirect_target(variable, property_name)
+ property_name = self._canonical_property_name(variable, property_name)
+ if redirect_variable:
+ variable = redirect_variable
+ return (
+ f"{self._var_sql_names.get(variable, variable)}."
+ f"{OracleNameSanitizer.quote(property_name, fallback='PROP')}"
+ )
+
+ def _translate_label_predicates(self, expression: str) -> str:
+ def replace(match: re.Match) -> str:
+ variable = match.group("variable")
+ label = match.group("label")
+ if variable not in self._var_kinds:
+ return match.group(0)
+ label_map = (
+ self.edge_label_map
+ if self._var_kinds.get(variable) == "edge"
+ else self.node_label_map
+ )
+ labels = self._label_map_targets(label, label_map)
+ declared_labels = self._possible_graph_labels(variable)
+ if declared_labels and not (set(labels) & set(declared_labels)):
+ return "(1 = 0)"
+ sql_variable = self._var_sql_names.get(variable, variable)
+ predicates = [
+ f"{sql_variable} IS LABELED {OracleNameSanitizer.quote(item, fallback='LABEL')}"
+ for item in labels
+ ]
+ return "(" + " OR ".join(predicates) + ")"
+
+ return re.sub(
+ r"(?[A-Za-z_][A-Za-z0-9_]*)\s*:"
+ r"\s*`?(?P[A-Za-z_][A-Za-z0-9_$#-]*)`?",
+ replace,
+ expression,
+ )
+
+ def _translate_date_function_property_calls(self, expression: str) -> str:
+ def replace(match: re.Match) -> str:
+ variable = match.group("variable")
+ property_name = self._canonical_property_name(variable, match.group("property"))
+ property_ref = (
+ f"{self._var_sql_names.get(variable, variable)}."
+ f"{OracleNameSanitizer.quote(property_name, fallback='PROP')}"
+ )
+ if self._is_string_type(self._property_type(variable, property_name)):
+ return self._to_date_expression(property_ref)
+ return f"CAST({property_ref} AS DATE)"
+
+ return re.sub(
+ r"\bdate\s*\(\s*(?P[A-Za-z_][A-Za-z0-9_]*)\."
+ r"(?P[A-Za-z_][A-Za-z0-9_$#-]*)\s*\)",
+ replace,
+ expression,
+ flags=re.IGNORECASE,
+ )
+
+ def _to_date_expression(self, expression: str) -> str:
+ return f"TO_DATE({expression} DEFAULT NULL ON CONVERSION ERROR, 'YYYY-MM-DD')"
+
+ def _translate_string_predicates(self, expression: str) -> str:
+ operand = (
+ r'(?:[A-Za-z_][A-Za-z0-9_]*\."[^"]+"|'
+ r"[A-Za-z_][A-Za-z0-9_]*|"
+ r"TO_DATE\([^)]+\)|"
+ r"LOWER\([^)]+\)|"
+ r"UPPER\([^)]+\))"
+ )
+ literal = r"(?:__SQL_LITERAL_\d+__|'[^']*')"
+ expression = re.sub(
+ rf"(?P{operand})\s+STARTS\s+WITH\s+(?P{literal})",
+ lambda match: f"{match.group('left')} LIKE {match.group('right')} || '%'",
+ expression,
+ flags=re.IGNORECASE,
+ )
+ expression = re.sub(
+ rf"(?P{operand})\s+ENDS\s+WITH\s+(?P{literal})",
+ lambda match: f"{match.group('left')} LIKE '%' || {match.group('right')}",
+ expression,
+ flags=re.IGNORECASE,
+ )
+ return re.sub(
+ rf"(?P{operand})\s+CONTAINS\s+(?P{literal})",
+ lambda match: f"INSTR({match.group('left')}, {match.group('right')}) > 0",
+ expression,
+ flags=re.IGNORECASE,
+ )
+
+ def _translate_to_integer(self, expression: str) -> str:
+ pattern = re.compile(r"\btoInteger\s*\(", flags=re.IGNORECASE)
+ while True:
+ match = pattern.search(expression)
+ if not match:
+ return expression
+ body_start = match.end()
+ body_end = self._matching_paren_index(expression, body_start - 1)
+ if body_end == -1:
+ return expression
+ body = expression[body_start:body_end]
+ replacement = f"CAST({body} AS INTEGER)"
+ expression = expression[: match.start()] + replacement + expression[body_end + 1 :]
+
+ def _translate_substring(self, expression: str) -> str:
+ def replace(match: re.Match) -> str:
+ start = match.group("start")
+ if re.fullmatch(r"-?\d+", start):
+ start = str(int(start) + 1)
+ else:
+ start = f"({start}) + 1"
+ length = match.group("length")
+ return f"SUBSTR({match.group('value')}, {start}, {length})"
+
+ return re.sub(
+ r"\bsubstring\s*\(\s*(?P[^,]+)\s*,\s*"
+ r"(?P[^,]+)\s*,\s*(?P[^()]+)\)",
+ replace,
+ expression,
+ flags=re.IGNORECASE,
+ )
+
+ def _translate_left_function(self, expression: str) -> str:
+ def replace(match: re.Match) -> str:
+ return f"SUBSTR({match.group('value')}, 1, {match.group('length')})"
+
+ return re.sub(
+ r"\bleft\s*\(\s*(?P[^,]+)\s*,\s*(?P[^()]+)\)",
+ replace,
+ expression,
+ flags=re.IGNORECASE,
+ )
+
+ def _guard_numeric_division(self, expression: str) -> str:
+ return re.sub(
+ r"/\s*(TO_NUMBER\([^)]+\))",
+ lambda match: f"/ NULLIF({match.group(1)}, 0)",
+ expression,
+ flags=re.IGNORECASE,
+ )
+
+ def _translate_string_concatenation(self, expression: str) -> str:
+ operand = r"(?:__SQL_LITERAL_\d+__|TO_CHAR\([^)]+\)|[A-Za-z_][A-Za-z0-9_]*)"
+ previous = None
+ while previous != expression:
+ previous = expression
+ expression = re.sub(
+ rf"(?P{operand})\s*\+\s*(?P{operand})",
+ lambda match: (
+ f"{match.group('left')} || {match.group('right')}"
+ if (
+ match.group("left").startswith("__SQL_LITERAL_")
+ or match.group("right").startswith("__SQL_LITERAL_")
+ or match.group("left").upper().startswith("TO_CHAR(")
+ or match.group("right").upper().startswith("TO_CHAR(")
+ )
+ else match.group(0)
+ ),
+ expression,
+ flags=re.IGNORECASE,
+ )
+ return expression
+
+ def _matching_paren_index(self, expression: str, open_index: int) -> int:
+ depth = 0
+ index = open_index
+ while index < len(expression):
+ char = expression[index]
+ if char == "(":
+ depth += 1
+ elif char == ")":
+ depth -= 1
+ if depth == 0:
+ return index
+ index += 1
+ return -1
+
+ def _translate_date_property_extractors(self, expression: str) -> str:
+ def replace(match: re.Match) -> str:
+ variable = match.group("variable")
+ property_name = match.group("property")
+ raw_field = match.group("field")
+ field = raw_field.upper()
+ property_ref = (
+ f"{self._var_sql_names.get(variable, variable)}."
+ f"{OracleNameSanitizer.quote(property_name, fallback='PROP')}"
+ )
+ if self._is_string_type(self._property_type(variable, property_name)):
+ property_ref = self._to_date_expression(property_ref)
+ if raw_field.lower() == "weekday":
+ return f"(TRUNC({property_ref}) - TRUNC({property_ref}, 'IW'))"
+ if raw_field.lower() == "dayofweek":
+ return f"(TRUNC({property_ref}) - TRUNC({property_ref}, 'IW') + 1)"
+ return f"EXTRACT({field} FROM {property_ref})"
+
+ return re.sub(
+ r"\bdate\s*\(\s*(?P[A-Za-z_][A-Za-z0-9_]*)\."
+ r"(?P[A-Za-z_][A-Za-z0-9_$#-]*)\s*\)\."
+ r"(?Pyear|month|day|weekday|dayOfWeek)\b",
+ replace,
+ expression,
+ flags=re.IGNORECASE,
+ )
+
+ def _translate_property_extractors(self, expression: str) -> str:
+ def replace(match: re.Match) -> str:
+ variable = match.group("variable")
+ property_name = match.group("property")
+ field = match.group("field").upper()
+ property_ref = (
+ f"{self._var_sql_names.get(variable, variable)}."
+ f"{OracleNameSanitizer.quote(property_name, fallback='PROP')}"
+ )
+ if field in {"YEAR", "MONTH", "DAY"}:
+ if self._is_string_type(self._property_type(variable, property_name)):
+ property_ref = self._to_date_expression(property_ref)
+ return f"EXTRACT({field} FROM {property_ref})"
+ return match.group(0)
+
+ return re.sub(
+ r"\b(?P[A-Za-z_][A-Za-z0-9_]*)\."
+ r"(?P[A-Za-z_][A-Za-z0-9_$#-]*)\."
+ r"(?Pyear|month|day)\b",
+ replace,
+ expression,
+ flags=re.IGNORECASE,
+ )
+
+ def _translate_split_size(self, expression: str, literals: List[str] | None = None) -> str:
+ literals = literals or []
+
+ def replace(match: re.Match) -> str:
+ body = match.group("body").strip()
+ literal_index = int(match.group("literal_index"))
+ delimiter = ""
+ if literal_index < len(literals):
+ delimiter = literals[literal_index].strip("'")
+ if delimiter == ",":
+ return (
+ f"CASE WHEN {body} IS NULL OR {body} = '' THEN 0 "
+ f"ELSE REGEXP_COUNT({body}, ',') + 1 END"
+ )
+ return f"REGEXP_COUNT({body}, '\\S+')"
+
+ return re.sub(
+ r"\bsize\s*\(\s*split\s*\((?P[^,]+),\s*"
+ r"__SQL_LITERAL_(?P\d+)__\s*\)\s*\)",
+ replace,
+ expression,
+ flags=re.IGNORECASE,
+ )
+
+ def _translate_modulo(self, expression: str) -> str:
+ expression = re.sub(
+ r"(?PCAST\([^)]+\s+AS\s+INTEGER\))\s*%\s*"
+ r"(?P-?\d+(?:\.\d+)?)",
+ lambda match: f"MOD({match.group('left')}, {match.group('right')})",
+ expression,
+ flags=re.IGNORECASE,
+ )
+ expression = re.sub(
+ r"(?PEXTRACT\(.+\))\s*%\s*(?P-?\d+(?:\.\d+)?)",
+ lambda match: f"MOD({match.group('left')}, {match.group('right')})",
+ expression,
+ )
+ return re.sub(
+ r"(?P[A-Za-z_][A-Za-z0-9_.$#\"]*)\s*%\s*"
+ r"(?P-?\d+(?:\.\d+)?)",
+ lambda match: f"MOD({match.group('left')}, {match.group('right')})",
+ expression,
+ )
+
+ def _coerce_typed_property_comparisons(self, expression: str) -> str:
+ def replace_date(match: re.Match) -> str:
+ variable = match.group("variable")
+ property_name = match.group("property")
+ if self._is_string_type(self._property_type(variable, property_name)):
+ return (
+ f'{variable}."{property_name}" {match.group("operator")} '
+ f"__SQL_LITERAL_{match.group('literal')}__"
+ )
+ return match.group(0)
+
+ expression = re.sub(
+ r'\b(?P[A-Za-z_][A-Za-z0-9_]*)\."(?P[^"]+)"\s*'
+ r"(?P<=|>=|<>|=|<|>)\s*DATE\s+__SQL_LITERAL_(?P\d+)__",
+ replace_date,
+ expression,
+ flags=re.IGNORECASE,
+ )
+
+ def replace_boolean(match: re.Match) -> str:
+ variable = match.group("variable")
+ property_name = match.group("property")
+ boolean = match.group("boolean").lower()
+ if self._is_string_type(self._property_type(variable, property_name)):
+ return f"{variable}.\"{property_name}\" {match.group('operator')} '{boolean}'"
+ return match.group(0)
+
+ expression = re.sub(
+ r'\b(?P[A-Za-z_][A-Za-z0-9_]*)\."(?P[^"]+)"\s*'
+ r"(?P<=|>=|<>|=|<|>)\s*(?Ptrue|false)\b",
+ replace_boolean,
+ expression,
+ flags=re.IGNORECASE,
+ )
+
+ def replace_number(match: re.Match) -> str:
+ variable = match.group("variable")
+ property_name = match.group("property")
+ if self._is_string_type(self._property_type(variable, property_name)):
+ return (
+ f'{variable}."{property_name}" {match.group("operator")} '
+ f"'{match.group('number')}'"
+ )
+ return match.group(0)
+
+ return re.sub(
+ r'\b(?P[A-Za-z_][A-Za-z0-9_]*)\."(?P[^"]+)"\s*'
+ r"(?P<=|>=|<>|=|<|>)\s*(?P-?\d+(?:\.\d+)?)\b",
+ replace_number,
+ expression,
+ )
+
+ def _coerce_date_function_property_comparisons(self, expression: str) -> str:
+ date_expr = (
+ r"TO_DATE\([A-Za-z_][A-Za-z0-9_]*\.\"[^\"]+\" "
+ r"DEFAULT NULL ON CONVERSION ERROR, 'YYYY-MM-DD'\)"
+ )
+ property_ref = (
+ r"(?P<{prefix}_variable>[A-Za-z_][A-Za-z0-9_]*)"
+ r"\.\"(?P<{prefix}_property>[^\"]+)\""
+ )
+
+ def property_type(match: re.Match, prefix: str) -> str:
+ return self._property_type(
+ match.group(f"{prefix}_variable"),
+ match.group(f"{prefix}_property"),
+ )
+
+ def replace_right(match: re.Match) -> str:
+ right = match.group("right")
+ if self._is_string_type(property_type(match, "right")):
+ right = self._to_date_expression(right)
+ return f"{match.group('left')} {match.group('operator')} {right}"
+
+ expression = re.sub(
+ rf"(?P{date_expr})\s*(?P<=|>=|<>|=|<|>)\s*"
+ rf"(?P{property_ref.format(prefix='right')})",
+ replace_right,
+ expression,
+ )
+
+ def replace_left(match: re.Match) -> str:
+ left = match.group("left")
+ if self._is_string_type(property_type(match, "left")):
+ left = self._to_date_expression(left)
+ return f"{left} {match.group('operator')} {match.group('right')}"
+
+ return re.sub(
+ rf"(?P{property_ref.format(prefix='left')})\s*"
+ rf"(?P<=|>=|<>|=|<|>)\s*(?P{date_expr})",
+ replace_left,
+ expression,
+ )
+
+ def _translate_element_comparisons(self, expression: str) -> str:
+ def replace(match: re.Match) -> str:
+ left = match.group("left")
+ right = match.group("right")
+ operator = match.group("operator")
+ left_source = self._source_variable(left)
+ right_source = self._source_variable(right)
+ if left_source not in self._var_kinds or right_source not in self._var_kinds:
+ return match.group(0)
+ if self._var_kinds[left_source] != self._var_kinds[right_source]:
+ return match.group(0)
+ function_name = (
+ "EDGE_EQUAL" if self._var_kinds[left_source] == "edge" else "VERTEX_EQUAL"
+ )
+ comparison = f"{function_name}({left}, {right})"
+ return comparison if operator == "=" else f"NOT {comparison}"
+
+ return re.sub(
+ r"\b(?P[A-Za-z_][A-Za-z0-9_]*)\s*(?P=|<>)\s*"
+ r"(?P[A-Za-z_][A-Za-z0-9_]*)\b",
+ replace,
+ expression,
+ )
+
+ def _translate_element_id_comparisons(self, expression: str) -> str:
+ def replace(match: re.Match) -> str:
+ left = match.group("left")
+ right = match.group("right")
+ operator = match.group("operator")
+ left_source = self._source_variable(left)
+ right_source = self._source_variable(right)
+ if left_source not in self._var_kinds or right_source not in self._var_kinds:
+ return match.group(0)
+ if self._var_kinds[left_source] != self._var_kinds[right_source]:
+ return match.group(0)
+ function_name = (
+ "EDGE_EQUAL" if self._var_kinds[left_source] == "edge" else "VERTEX_EQUAL"
+ )
+ comparison = f"{function_name}({left}, {right})"
+ return comparison if operator == "=" else f"NOT {comparison}"
+
+ return re.sub(
+ r"\bid\s*\(\s*(?P[A-Za-z_][A-Za-z0-9_]*)\s*\)\s*"
+ r"(?P=|<>)\s*"
+ r"id\s*\(\s*(?P[A-Za-z_][A-Za-z0-9_]*)\s*\)",
+ replace,
+ expression,
+ flags=re.IGNORECASE,
+ )
+
+ def _translate_id_function_calls(self, expression: str) -> str:
+ def replace(match: re.Match) -> str:
+ variable = match.group("variable")
+ if variable not in self._var_kinds:
+ return match.group(0)
+ return self._element_id_expression(variable)
+
+ return re.sub(
+ r"\bid\s*\(\s*(?P[A-Za-z_][A-Za-z0-9_]*)\s*\)",
+ replace,
+ expression,
+ flags=re.IGNORECASE,
+ )
+
+ def _protect_string_literals(self, expression: str) -> tuple[str, List[str]]:
+ literals: List[str] = []
+ result: List[str] = []
+ i = 0
+ while i < len(expression):
+ if expression[i] != "'":
+ result.append(expression[i])
+ i += 1
+ continue
+ start = i
+ i += 1
+ while i < len(expression):
+ if expression[i] == "\\" and i + 1 < len(expression) and expression[i + 1] == "'":
+ i += 2
+ continue
+ if expression[i] == "'" and i + 1 < len(expression) and expression[i + 1] == "'":
+ i += 2
+ continue
+ if expression[i] == "'":
+ i += 1
+ break
+ i += 1
+ placeholder = f"__SQL_LITERAL_{len(literals)}__"
+ literals.append(expression[start:i].replace("\\'", "''"))
+ result.append(placeholder)
+ return "".join(result), literals
+
+ def _restore_string_literals(self, expression: str, literals: List[str]) -> str:
+ for index, literal in enumerate(literals):
+ expression = expression.replace(f"__SQL_LITERAL_{index}__", literal)
+ return expression
+
+ def _default_expression_alias(self, symbolic_name: str, expression: str) -> str:
+ if symbolic_name == "*":
+ return "COUNT_VALUE"
+ if symbolic_name in self._path_variables:
+ return symbolic_name
+ if self._has_distinct_prefix(symbolic_name):
+ symbolic_name = self._strip_distinct_prefix(symbolic_name)
+ if symbolic_name in self._var_kinds and expression.strip() in (
+ symbolic_name,
+ self._element_id_expression(symbolic_name),
+ ):
+ return f"{symbolic_name}_VALUE"
+ compact = re.sub(r"[^A-Za-z0-9_]+", "_", symbolic_name or expression).strip("_")
+ return compact or "VALUE"
+
+ def _outer_group_order_and_paging(
+ self,
+ return_body: ReturnBody,
+ aggregate_query: bool = False,
+ ) -> str:
+ suffix = ""
+ if aggregate_query:
+ group_aliases = []
+ return_aliases = self._resolved_return_aliases(return_body)
+ projected_aliases = {OracleNameSanitizer.alias(alias) for alias in return_aliases}
+ for item, resolved_alias in zip(
+ return_body.return_item_list,
+ return_aliases,
+ strict=True,
+ ):
+ if not self._is_aggregate_item(item):
+ group_aliases.append(OracleNameSanitizer.alias(resolved_alias))
+ for sort_item in return_body.sort_item_list:
+ sort_alias = OracleNameSanitizer.alias(self._sort_alias(sort_item, return_body))
+ if (
+ sort_alias in projected_aliases
+ or self._is_aggregate_sort(sort_item)
+ or self._sort_expression_uses_return_aliases(sort_item, return_body)
+ ):
+ continue
+ group_aliases.append(sort_alias)
+ if group_aliases:
+ suffix += "\nGROUP BY " + ", ".join(dict.fromkeys(group_aliases))
+ if return_body.sort_item_list:
+ suffix += "\nORDER BY " + ", ".join(
+ self._translate_sort_item(item, return_body) for item in return_body.sort_item_list
+ )
+ if return_body.skip != -1:
+ suffix += f"\nOFFSET {return_body.skip} ROWS"
+ if return_body.limit != -1:
+ suffix += f"\nFETCH FIRST {return_body.limit} ROWS ONLY"
+ return suffix
+
+ def _has_distinct_prefix(self, value: str) -> bool:
+ return bool(re.match(r"^\s*DISTINCT\s+", value or "", flags=re.IGNORECASE))
+
+ def _strip_distinct_prefix(self, value: str) -> str:
+ return re.sub(
+ r"^\s*DISTINCT\s+",
+ "",
+ value or "",
+ count=1,
+ flags=re.IGNORECASE,
+ )
+
+ def _hidden_sort_aliases(
+ self,
+ return_body: ReturnBody | None,
+ aggregate_query: bool,
+ ) -> List[str]:
+ if return_body is None or aggregate_query:
+ return []
+ projected_aliases = set(self._resolved_return_aliases(return_body))
+ hidden_aliases = []
+ for sort_item in return_body.sort_item_list:
+ alias = OracleNameSanitizer.alias(self._sort_alias(sort_item, return_body))
+ if alias not in projected_aliases:
+ hidden_aliases.append(alias)
+ return hidden_aliases
+
+ def _resolved_return_aliases(self, return_body: ReturnBody) -> List[str]:
+ raw_aliases = [
+ self._return_alias(item, self._return_expression(item))
+ for item in return_body.return_item_list
+ ]
+ alias_counts = Counter(OracleNameSanitizer.alias(alias) for alias in raw_aliases)
+ seen: Dict[str, int] = {}
+ resolved = []
+ for item, raw_alias in zip(return_body.return_item_list, raw_aliases, strict=True):
+ alias = OracleNameSanitizer.alias(raw_alias)
+ if alias_counts[alias] == 1:
+ resolved.append(alias)
+ continue
+ expression = self._return_expression(item)
+ base = OracleNameSanitizer.alias(
+ self._default_expression_alias(item.symbolic_name, expression)
+ )
+ index = seen.get(base, 0)
+ seen[base] = index + 1
+ resolved.append(base if index == 0 else f"{base}_{index + 1}")
+ return resolved
+
+ def _has_aggregate(self, return_body: ReturnBody | None) -> bool:
+ if return_body is None:
+ return False
+ return any(self._is_aggregate_item(item) for item in return_body.return_item_list)
+
+ def _is_aggregate_item(self, item: ReturnItem | SortItem) -> bool:
+ return (
+ item.function_name.upper() in self.AGGREGATE_FUNCTIONS
+ or self._is_complex_aggregate_item(item)
+ )
+
+ def _is_complex_aggregate_item(self, item: ReturnItem | SortItem) -> bool:
+ return (
+ bool(item.expression)
+ and self._contains_aggregate_function(item.expression)
+ and not self._is_simple_aggregate_expression(item)
+ )
+
+ def _is_aggregate_sort(self, item: SortItem) -> bool:
+ return self._is_aggregate_item(item)
+
+ def _contains_aggregate_function(self, expression: str) -> bool:
+ return bool(
+ re.search(
+ r"\b(?:COUNT|AVG|SUM|MIN|MAX|COLLECT)\s*\(",
+ expression or "",
+ flags=re.IGNORECASE,
+ )
+ )
+
+ def _is_simple_aggregate_expression(self, item: ReturnItem | SortItem) -> bool:
+ if item.function_name.upper() not in self.AGGREGATE_FUNCTIONS:
+ return False
+ expression = (item.expression or "").strip()
+ if not re.fullmatch(
+ rf"\s*{re.escape(item.function_name)}\s*\(.+\)\s*",
+ expression,
+ flags=re.IGNORECASE | re.DOTALL,
+ ):
+ return False
+ open_paren = expression.find("(")
+ if self._matching_paren_index(expression, open_paren) != len(expression) - 1:
+ return False
+ argument = expression[open_paren + 1 : -1].strip()
+ argument = self._strip_distinct_prefix(argument)
+ return bool(
+ re.fullmatch(
+ r"\*|[A-Za-z_][A-Za-z0-9_]*(?:\.[A-Za-z_][A-Za-z0-9_$#-]*)?",
+ argument,
+ )
+ )
+
+ def _next_node_var(self) -> str:
+ while True:
+ self._auto_node_index += 1
+ variable = f"n{self._auto_node_index}"
+ if variable not in self._var_kinds and variable not in self._reserved_variables:
+ return variable
+
+ def _next_edge_var(self) -> str:
+ while True:
+ self._auto_edge_index += 1
+ variable = f"e{self._auto_edge_index}"
+ if variable not in self._var_kinds and variable not in self._reserved_variables:
+ return variable
diff --git a/app/impl/oracle_sqlpgq/utils/__init__.py b/app/impl/oracle_sqlpgq/utils/__init__.py
new file mode 100644
index 0000000..9c29ae2
--- /dev/null
+++ b/app/impl/oracle_sqlpgq/utils/__init__.py
@@ -0,0 +1,2 @@
+"""Oracle SQL/PGQ helper utilities."""
+
diff --git a/app/impl/oracle_sqlpgq/utils/sqlpgq.py b/app/impl/oracle_sqlpgq/utils/sqlpgq.py
new file mode 100644
index 0000000..225fbf9
--- /dev/null
+++ b/app/impl/oracle_sqlpgq/utils/sqlpgq.py
@@ -0,0 +1,183 @@
+import re
+from typing import Any, Dict, Iterable, List, Tuple
+
+RESERVED_WORDS = {
+ "ACCESS",
+ "ADD",
+ "ALL",
+ "ALTER",
+ "AND",
+ "AS",
+ "ASC",
+ "BY",
+ "CHECK",
+ "COLUMN",
+ "COMMENT",
+ "CREATE",
+ "DATE",
+ "DELETE",
+ "DESC",
+ "DISTINCT",
+ "DROP",
+ "EDGE",
+ "ENFORCED",
+ "FETCH",
+ "FROM",
+ "GRAPH",
+ "GROUP",
+ "KEY",
+ "LABEL",
+ "MATCH",
+ "NODE",
+ "NOT",
+ "NULL",
+ "OFFSET",
+ "OR",
+ "ORDER",
+ "PRIMARY",
+ "PROPERTY",
+ "REFERENCES",
+ "RESOURCE",
+ "SELECT",
+ "SIZE",
+ "TABLE",
+ "USER",
+ "VERTEX",
+ "WHERE",
+}
+
+
+class OracleNameSanitizer:
+ """Normalize framework labels/properties into Oracle-safe SQL identifiers."""
+
+ _invalid = re.compile(r"[^A-Za-z0-9_$#]")
+
+ @classmethod
+ def clean(cls, name: str, fallback: str = "X") -> str:
+ value = str(name or "").strip()
+ value = cls._invalid.sub("_", value)
+ value = re.sub(r"_+", "_", value)
+ if not value:
+ value = fallback
+ if value[0].isdigit():
+ value = f"{fallback}_{value}"
+ return value[:128]
+
+ @classmethod
+ def quote(cls, name: str, fallback: str = "X") -> str:
+ value = cls.clean(name, fallback=fallback).replace('"', '""')
+ return f'"{value}"'
+
+ @classmethod
+ def alias(cls, name: str, fallback: str = "COL") -> str:
+ value = cls.clean(name, fallback=fallback)
+ if value.upper() in RESERVED_WORDS:
+ value = f"{value}_VALUE"
+ return value[:128]
+
+
+class OracleTypeMapper:
+ """Map framework/TuGraph-like scalar types to Oracle SQL column types."""
+
+ TYPE_MAP = {
+ "BOOL": "NUMBER(1)",
+ "BOOLEAN": "NUMBER(1)",
+ "INT8": "NUMBER(3)",
+ "INT16": "NUMBER(5)",
+ "INT32": "NUMBER(10)",
+ "INT64": "NUMBER(19)",
+ "INTEGER": "NUMBER(10)",
+ "LONG": "NUMBER(19)",
+ "FLOAT": "BINARY_FLOAT",
+ "DOUBLE": "BINARY_DOUBLE",
+ "DATE": "DATE",
+ "DATETIME": "TIMESTAMP",
+ "TIMESTAMP": "TIMESTAMP",
+ "STRING": "VARCHAR2(4000)",
+ "TEXT": "CLOB",
+ "BLOB": "BLOB",
+ }
+
+ @classmethod
+ def to_oracle(cls, type_name: str) -> str:
+ return cls.TYPE_MAP.get(str(type_name or "STRING").upper(), "VARCHAR2(4000)")
+
+
+def split_sql_statements(script: str) -> List[str]:
+ """Split a simple SQL script on semicolons outside single/double quoted strings."""
+
+ statements: List[str] = []
+ current: List[str] = []
+ in_single = False
+ in_double = False
+ i = 0
+ while i < len(script):
+ ch = script[i]
+ current.append(ch)
+ if ch == "'" and not in_double:
+ if i + 1 < len(script) and script[i + 1] == "'":
+ current.append(script[i + 1])
+ i += 1
+ else:
+ in_single = not in_single
+ elif ch == '"' and not in_single:
+ in_double = not in_double
+ elif ch == ";" and not in_single and not in_double:
+ statement = "".join(current).strip().rstrip(";").strip()
+ if statement:
+ statements.append(statement)
+ current = []
+ i += 1
+
+ tail = "".join(current).strip()
+ if tail:
+ statements.append(tail)
+ return statements
+
+
+def is_balanced_sql(text: str) -> bool:
+ depth = 0
+ in_single = False
+ in_double = False
+ i = 0
+ while i < len(text):
+ ch = text[i]
+ if ch == "'" and not in_double:
+ if i + 1 < len(text) and text[i + 1] == "'":
+ i += 2
+ continue
+ in_single = not in_single
+ elif ch == '"' and not in_single:
+ in_double = not in_double
+ elif not in_single and not in_double:
+ if ch == "(":
+ depth += 1
+ elif ch == ")":
+ depth -= 1
+ if depth < 0:
+ return False
+ i += 1
+ return depth == 0 and not in_single and not in_double
+
+
+def validate_graph_table_query(query: str) -> bool:
+ normalized = " ".join(str(query or "").strip().split()).upper()
+ if not (normalized.startswith("SELECT ") or normalized.startswith("WITH ")):
+ return False
+ required_tokens = ["FROM GRAPH_TABLE", " MATCH ", " COLUMNS "]
+ return all(token in normalized for token in required_tokens) and is_balanced_sql(query)
+
+
+def validate_property_graph_ddl(ddl: str) -> bool:
+ normalized = " ".join(str(ddl or "").strip().split()).upper()
+ required_tokens = [
+ "CREATE",
+ "PROPERTY GRAPH",
+ "VERTEX TABLES",
+ "EDGE TABLES",
+ ]
+ return all(token in normalized for token in required_tokens) and is_balanced_sql(ddl)
+
+
+def property_list(properties: Iterable[Dict[str, Any]]) -> List[Tuple[str, str]]:
+ return [(str(item.get("name", "")), str(item.get("type", "STRING"))) for item in properties]
diff --git a/app/impl/tugraph_cypher/ast_visitor/tugraph_cypher_ast_visitor.py b/app/impl/tugraph_cypher/ast_visitor/tugraph_cypher_ast_visitor.py
index 7924082..6a412a8 100644
--- a/app/impl/tugraph_cypher/ast_visitor/tugraph_cypher_ast_visitor.py
+++ b/app/impl/tugraph_cypher/ast_visitor/tugraph_cypher_ast_visitor.py
@@ -1,3 +1,4 @@
+import re
import traceback
from typing import List, Tuple
@@ -41,16 +42,16 @@ def visitOC_SinglePartQuery(self, ctx: LcypherParser.OC_SinglePartQueryContext):
def visitOC_MultiPartQuery(self, ctx: LcypherParser.OC_MultiPartQueryContext):
clause_list = []
- for ctx in ctx.getChildren():
+ for child_ctx in ctx.getChildren():
# add clause list from reading clause
- if isinstance(ctx, LcypherParser.OC_ReadingClauseContext):
- clause_list.append(self.visitOC_ReadingClause(ctx))
+ if isinstance(child_ctx, LcypherParser.OC_ReadingClauseContext):
+ clause_list += self.visitOC_ReadingClause(child_ctx)
# add with clause
- if isinstance(ctx, LcypherParser.OC_WithContext):
- clause_list.append(self.visitOC_With(ctx))
+ if isinstance(child_ctx, LcypherParser.OC_WithContext):
+ clause_list.append(self.visitOC_With(child_ctx))
# add clause list from single part query
- if isinstance(ctx, LcypherParser.OC_SinglePartQueryContext):
- clause_list += self.visitOC_SinglePartQuery(ctx)
+ if isinstance(child_ctx, LcypherParser.OC_SinglePartQueryContext):
+ clause_list += self.visitOC_SinglePartQuery(child_ctx)
# return clause list
return clause_list
@@ -62,7 +63,7 @@ def visitOC_Match(self, ctx: LcypherParser.OC_MatchContext):
# add match clause
path_pattern_list = self.visitOC_Pattern(ctx.oC_Pattern())
# only use the first path pattern
- match_clause = MatchClause(path_pattern_list)
+ match_clause = MatchClause(path_pattern_list, optional=ctx.OPTIONAL_() is not None)
# add match clause to clause list
clause_list.append(match_clause)
# add where clause to clause list
@@ -70,6 +71,25 @@ def visitOC_Match(self, ctx: LcypherParser.OC_MatchContext):
clause_list.append(self.visitOC_Where(ctx.oC_Where()))
return clause_list
+ def visitOC_Pattern(self, ctx: LcypherParser.OC_PatternContext):
+ path_patterns = []
+ for pattern_part in ctx.oC_PatternPart():
+ path_patterns.extend(self.visitOC_PatternPart(pattern_part))
+ return path_patterns
+
+ def visitOC_PatternPart(self, ctx: LcypherParser.OC_PatternPartContext):
+ path_patterns = self.visitOC_AnonymousPatternPart(ctx.oC_AnonymousPatternPart())
+ if ctx.oC_Variable():
+ path_variable = self._symbolic_name(ctx.oC_Variable().oC_SymbolicName())
+ for path_pattern in path_patterns:
+ path_pattern.path_variable = path_variable
+ return path_patterns
+
+ def visitOC_AnonymousPatternPart(
+ self, ctx: LcypherParser.OC_AnonymousPatternPartContext
+ ):
+ return self.visitOC_PatternElement(ctx.oC_PatternElement())
+
def visitOC_PatternElement(self, ctx: LcypherParser.OC_PatternElementContext):
node_pattern_list = []
edge_pattern_list = []
@@ -86,11 +106,13 @@ def visitOC_NodePattern(self, ctx: LcypherParser.OC_NodePatternContext):
label = ""
property_maps = []
if ctx.oC_Variable():
- symbolic_name = ctx.oC_Variable().oC_SymbolicName().getText()
+ symbolic_name = self._symbolic_name(ctx.oC_Variable().oC_SymbolicName())
if ctx.oC_NodeLabels():
# only get the first node label for now
# TODO: support getting node label list
- label = ctx.oC_NodeLabels().oC_NodeLabel(0).oC_LabelName().getText()
+ label = self._symbolic_name(
+ ctx.oC_NodeLabels().oC_NodeLabel(0).oC_LabelName().oC_SchemaName().oC_SymbolicName()
+ )
if ctx.oC_Properties():
property_maps = self.visitOC_Properties(ctx.oC_Properties())
return NodePattern(symbolic_name, label, property_maps)
@@ -105,9 +127,14 @@ def visitOC_RelationshipPattern(self, ctx: LcypherParser.OC_RelationshipPatternC
if ctx.oC_RelationshipDetail():
rel_det_ctx = ctx.oC_RelationshipDetail()
if rel_det_ctx.oC_Variable():
- symbolic_name = rel_det_ctx.oC_Variable().oC_SymbolicName().getText()
+ symbolic_name = self._symbolic_name(rel_det_ctx.oC_Variable().oC_SymbolicName())
if rel_det_ctx.oC_RelationshipTypes():
- label = rel_det_ctx.oC_RelationshipTypes().oC_RelTypeName(0).getText()
+ label = "|".join(
+ self._symbolic_name(
+ rel_type.oC_SchemaName().oC_SymbolicName()
+ )
+ for rel_type in rel_det_ctx.oC_RelationshipTypes().oC_RelTypeName()
+ )
if rel_det_ctx.oC_RangeLiteral():
range_ctx = rel_det_ctx.oC_RangeLiteral()
if len(range_ctx.oC_IntegerLiteral()) == 0:
@@ -159,14 +186,23 @@ def visitOC_MapLiteral(self, ctx: LcypherParser.OC_MapLiteralContext):
property_maps = []
count = len(ctx.oC_PropertyKeyName())
for i in range(count):
- property_name = ctx.oC_PropertyKeyName(i).oC_SchemaName().oC_SymbolicName().getText()
- value = ctx.oC_Expression(i).getText()
+ property_name = self._symbolic_name(
+ ctx.oC_PropertyKeyName(i).oC_SchemaName().oC_SymbolicName()
+ )
+ value = self._expression_text(ctx.oC_Expression(i))
property_maps.append([property_name, value])
return property_maps
def visitOC_Where(self, ctx: LcypherParser.OC_WhereContext):
- [compare_expression] = self.visitOC_Expression(ctx.oC_Expression())
- return WhereClause(compare_expression)
+ return WhereClause(
+ CompareExpression(
+ symbolic_name="",
+ property="",
+ comparison_type="raw",
+ comparison_value="",
+ raw_expression=self._expression_text(ctx.oC_Expression()),
+ )
+ )
def visitOC_ComparisonExpression(self, ctx: LcypherParser.OC_ComparisonExpressionContext):
# print(self.visitOC_AddOrSubtractExpression(ctx.oC_AddOrSubtractExpression()))
@@ -206,7 +242,13 @@ def visitOC_With(self, ctx: LcypherParser.OC_WithContext):
return_body = self.visitOC_ReturnBody(ctx.oC_ReturnBody())
where_expression = None
if ctx.oC_Where():
- where_expression = self.visitOC_Expression(ctx.oC_Where().oC_Expression())
+ where_expression = CompareExpression(
+ symbolic_name="",
+ property="",
+ comparison_type="raw",
+ comparison_value="",
+ raw_expression=self._expression_text(ctx.oC_Where().oC_Expression()),
+ )
distinct = ctx.DISTINCT() is not None
return WithClause(return_body, where_expression, distinct)
@@ -243,24 +285,21 @@ def visitOC_ReturnItem(self, ctx: LcypherParser.OC_ReturnItemContext):
property = ""
alias = ""
if ctx.oC_Variable():
- alias = ctx.oC_Variable().oC_SymbolicName().getText()
- [symbolic_name, property, function_name] = self.visitOC_Expression(ctx.oC_Expression())
- return ReturnItem(symbolic_name, property, alias, function_name)
+ alias = self._symbolic_name(ctx.oC_Variable().oC_SymbolicName())
+ expression = self._expression_text(ctx.oC_Expression())
+ symbolic_name, property, function_name = self._parse_value_expression(expression)
+ return ReturnItem(symbolic_name, property, alias, function_name, expression)
def visitOC_PropertyOrLabelsExpression(
self, ctx: LcypherParser.OC_PropertyOrLabelsExpressionContext
):
if ctx.oC_Atom().oC_Variable():
# return symbolic name and property
- symbolic_name = ctx.oC_Atom().oC_Variable().oC_SymbolicName().getText()
+ symbolic_name = self._symbolic_name(ctx.oC_Atom().oC_Variable().oC_SymbolicName())
property = ""
if len(ctx.oC_PropertyLookup()) != 0:
- property = (
- ctx.oC_PropertyLookup(0)
- .oC_PropertyKeyName()
- .oC_SchemaName()
- .oC_SymbolicName()
- .getText()
+ property = self._symbolic_name(
+ ctx.oC_PropertyLookup(0).oC_PropertyKeyName().oC_SchemaName().oC_SymbolicName()
)
return [symbolic_name, property, ""]
if ctx.oC_Atom().oC_FunctionInvocation():
@@ -289,8 +328,88 @@ def visitOC_SortItem(self, ctx: LcypherParser.OC_SortItemContext):
count = ctx.getChildCount()
if count > 1:
order = ctx.getChild(count - 1).getText()
- [symbolic_name, property, function_name] = self.visitOC_Expression(ctx.oC_Expression())
- return SortItem(symbolic_name, property, order, function_name)
+ expression = self._expression_text(ctx.oC_Expression())
+ symbolic_name, property, function_name = self._parse_value_expression(expression)
+ return SortItem(symbolic_name, property, order, function_name, expression)
+
+ def _symbolic_name(self, ctx) -> str:
+ text = ctx.getText() if ctx is not None else ""
+ if len(text) >= 2 and text[0] == "`" and text[-1] == "`":
+ return text[1:-1].replace("``", "`")
+ return text
+
+ def _expression_text(self, ctx) -> str:
+ text = self._normalize_backtick_identifiers(ctx.getText() if ctx is not None else "")
+ return self._normalize_string_literals(text)
+
+ def _normalize_backtick_identifiers(self, text: str) -> str:
+ return re.sub(r"`([^`]*)`", lambda match: match.group(1), text)
+
+ def _normalize_string_literals(self, text: str) -> str:
+ result = []
+ index = 0
+ while index < len(text):
+ char = text[index]
+ if char == "'":
+ start = index
+ index += 1
+ while index < len(text):
+ if text[index] == "\\":
+ index += 2
+ continue
+ if text[index] == "'":
+ index += 1
+ break
+ index += 1
+ result.append(text[start:index])
+ continue
+ if char != '"':
+ result.append(char)
+ index += 1
+ continue
+
+ index += 1
+ literal = []
+ while index < len(text):
+ if text[index] == "\\" and index + 1 < len(text):
+ literal.append(text[index + 1])
+ index += 2
+ continue
+ if text[index] == '"':
+ index += 1
+ break
+ literal.append(text[index])
+ index += 1
+ result.append("'" + "".join(literal).replace("'", "''") + "'")
+ return "".join(result)
+
+ def _parse_value_expression(self, expression: str) -> tuple[str, str, str]:
+ expression = expression.strip()
+ function_match = re.fullmatch(
+ r"(?P[A-Za-z_][A-Za-z0-9_]*)\((?P.*)\)",
+ expression,
+ flags=re.DOTALL,
+ )
+ if function_match:
+ body = function_match.group("body").strip()
+ if body.upper().startswith("DISTINCT "):
+ body = "DISTINCT " + body[9:].strip()
+ symbolic_name, property_name, _ = self._parse_value_expression(body)
+ return symbolic_name, property_name, function_match.group("function")
+ if expression == "*":
+ return "*", "", ""
+ property_match = re.fullmatch(
+ r"(?P[A-Za-z_][A-Za-z0-9_]*)"
+ r"(?:\.(?P[A-Za-z_][A-Za-z0-9_$#-]*))?",
+ expression,
+ )
+ if property_match:
+ return (
+ property_match.group("symbolic"),
+ property_match.group("property") or "",
+ "",
+ )
+ return expression, "", ""
def aggregateResult(self, aggregate, nextResult):
result = []
diff --git a/app/impl/tugraph_cypher/ast_visitor/tugraph_cypher_query_visitor.py b/app/impl/tugraph_cypher/ast_visitor/tugraph_cypher_query_visitor.py
index af0c032..0f01be7 100644
--- a/app/impl/tugraph_cypher/ast_visitor/tugraph_cypher_query_visitor.py
+++ b/app/impl/tugraph_cypher/ast_visitor/tugraph_cypher_query_visitor.py
@@ -57,7 +57,7 @@ def visitOC_Match(self, ctx: LcypherParser.OC_MatchContext):
# add match clause
path_pattern_list = self.visitOC_Pattern(ctx.oC_Pattern())
# only use the first path pattern
- match_clause = MatchClause(path_pattern_list[0])
+ match_clause = MatchClause(path_pattern_list[0], optional=ctx.OPTIONAL_() is not None)
# add match clause to clause list
clause_list.append(match_clause)
# add where clause to clause list
diff --git a/dataset_prep/README.md b/dataset_prep/README.md
new file mode 100644
index 0000000..54e01be
--- /dev/null
+++ b/dataset_prep/README.md
@@ -0,0 +1,309 @@
+# Dataset Prep
+
+Utilities in this directory prepare the benchmark dataset for Oracle SQL/PGQ work:
+
+- discover dataset query files and graph import configs
+- translate source Cypher/GQL-like records to Oracle SQL/PGQ
+- optionally validate translated SQL/PGQ against a live Oracle Database
+- summarize translation/runtime failures
+- compare Oracle SQL/PGQ results with Neo4j Cypher results on the same dataset
+
+Run commands from the repository root.
+
+## Input Dataset
+
+These utilities are designed for the [Text2GQL-Bench dataset](https://tugraph-web.oss-cn-beijing.aliyuncs.com/tugraph/datasets/text2gql/Text2GraphQueryBenchmark/Text2GQL-Bench_dataset.zip), which is also referenced in the main repository README.
+
+By default, the scripts expect the downloaded dataset to be extracted under `dataset/` at the repository root. Use `--dataset-root ` if the dataset is stored somewhere else.
+
+The discovery logic expects the benchmark split layout described below, including `train`, `dev`, and `test` directories with query JSON files and graph import configs.
+
+## dataset output
+
+The dataset including the Oracle SQL/PGQ translated queries is available at [Dataset-with-SQL/PGQ](https://objectstorage.us-ashburn-1.oraclecloud.com/p/8dIkuVGsfnRQlP3ifxVDjQP0pmidpadEY18ltEbkPC4PrZyLTxjJdqDjbtWIEYUW/n/ogcs/b/Text2GQL-Bench_dataset/o/Text2GQL-Bench_dataset.zip), it includes 19633 out of 22407 existing queries.
+
+## Environment
+
+Use the Poetry environment:
+
+```bash
+poetry env use python3.10
+poetry install
+```
+
+Oracle live validation and Oracle-vs-Neo4j comparison require:
+
+```bash
+export ORACLE_DSN="localhost:1521/FREEPDB1"
+export ORACLE_USER="SYSTEM"
+export ORACLE_PASSWORD="tiger"
+```
+
+Neo4j comparison additionally requires:
+
+```bash
+export NEO4J_URI='bolt://localhost:7687'
+export NEO4J_USER='neo4j'
+export NEO4J_PASSWORD='ValidationPass123'
+export NEO4J_DATABASE='neo4j'
+```
+
+On macOS, if Python hits a local `pyexpat` or `libexpat` issue:
+
+```bash
+export DYLD_LIBRARY_PATH="/opt/homebrew/opt/expat/lib:${DYLD_LIBRARY_PATH:-}"
+```
+
+## Dataset Discovery
+
+Discovery is implemented in `dataset_prep/discover.py`.
+
+For `train`, a database unit is counted when both files exist:
+
+```text
+dataset/train//4_level_results_ek_results.json
+dataset/train//cypher/*/import_config.json
+```
+
+For `dev` and `test`, a database unit is counted when both files exist:
+
+```text
+dataset///Cypher/*_cypher.json
+dataset///Cypher/**/import_config.json
+```
+
+The scripts count one top-level JSON record as one query example. They do not count every language field separately.
+
+## Preflight
+
+The translation script runs preflight automatically. It checks imports for `antlr4`, `oracledb`, the Cypher-to-Oracle translator, and Oracle environment variables when live validation is enabled.
+
+To exercise the same checks indirectly:
+
+```bash
+poetry run python dataset_prep/translate_validate.py \
+ --splits train \
+ --databases bluesky \
+ --limit-queries 1 \
+ --skip-live-validation
+```
+
+## Translate To Oracle SQL/PGQ
+
+Translate all discovered dataset records and validate against Oracle:
+
+```bash
+poetry run python dataset_prep/translate_validate.py \
+ --dataset-root dataset \
+ --output-root output/dataset_prep \
+ --splits train dev test
+```
+
+Useful development run:
+
+```bash
+poetry run python dataset_prep/translate_validate.py \
+ --dataset-root dataset \
+ --output-root output/dataset_prep \
+ --splits train \
+ --databases bluesky \
+ --limit-queries 20
+```
+
+Translate without connecting to Oracle:
+
+```bash
+poetry run python dataset_prep/translate_validate.py \
+ --dataset-root dataset \
+ --output-root output/dataset_prep \
+ --splits train dev test \
+ --skip-live-validation
+```
+
+Resume completed database units:
+
+```bash
+poetry run python dataset_prep/translate_validate.py \
+ --dataset-root dataset \
+ --output-root output/dataset_prep \
+ --splits train dev test \
+ --resume
+```
+
+Main options:
+
+- `--splits train dev test`: splits to process.
+- `--databases `: only process selected database names.
+- `--limit-databases N`: process only the first N discovered database units.
+- `--limit-queries N`: process only the first N records per database.
+- `--skip-live-validation`: translate only, no Oracle load or query execution.
+- `--resume`: reuse completed per-database summaries.
+- `--fail-fast`: stop on the first database-level exception.
+- `--oracle-validation-timeout-ms`: per-query Oracle timeout.
+- `--oracle-validation-fetch-limit`: number of rows fetched for validation.
+
+Outputs:
+
+```text
+output/dataset_prep/global_summary.json
+output/dataset_prep/unsupported_samples.jsonl
+output/dataset_prep///summary.json
+output/dataset_prep///oracle_sqlpgq_enriched.jsonl
+```
+
+`oracle_sqlpgq_enriched.jsonl` contains the original record plus fields such as:
+
+- `oracle_source_query`
+- `oracle_sqlpgq`
+- `oracle_translation_category`
+- `oracle_validation_status`
+- `oracle_validation_error`
+- `oracle_dataset_meta`
+
+## Failure Analysis
+
+Group failed translation/validation records by likely root cause:
+
+```bash
+poetry run python dataset_prep/analyze_failures.py \
+ --output-root output/dataset_prep \
+ --dataset-root dataset \
+ --splits train dev test \
+ --sample-limit 5
+```
+
+Outputs:
+
+```text
+output/dataset_prep/failure_analysis.json
+output/dataset_prep/failure_analysis.md
+```
+
+The report groups records whose `oracle_validation_status` is one of:
+
+```text
+syntax_error
+runtime_error
+unsupported
+load_error
+```
+
+It also lists databases that were expected from discovery but do not have completed output.
+
+## Oracle SQL/PGQ Vs Neo4j Result Validation
+
+Use this after `translate_validate.py` has produced `oracle_sqlpgq_enriched.jsonl` files.
+
+Run all splits:
+
+```bash
+poetry run python dataset_prep/compare_oracle_neo4j_results.py \
+ --neo4j-uri bolt://localhost:7687 \
+ --neo4j-user neo4j \
+ --neo4j-password 'ValidationPass123' \
+ --splits train dev test
+```
+
+Small smoke test:
+
+```bash
+poetry run python dataset_prep/compare_oracle_neo4j_results.py \
+ --neo4j-uri bolt://localhost:7687 \
+ --neo4j-user neo4j \
+ --neo4j-password 'ValidationPass123' \
+ --splits train \
+ --databases bluesky \
+ --limit-queries 20
+```
+
+Main options:
+
+- `--dataset-output-root`: where translated Oracle JSONL files are read from.
+- `--output-root`: where comparison reports are written.
+- `--oracle-statuses success no_record`: which prior Oracle validation statuses are eligible.
+- `--include-all-translatable`: compare every translatable record even if prior Oracle validation did not succeed.
+- `--neo4j-batch-size`: CSV import and batched clear size.
+- `--keep-loaded`: leave the last loaded Oracle and Neo4j graph in place for debugging.
+
+Outputs:
+
+```text
+output/oracle_neo4j_compare/summary.json
+output/oracle_neo4j_compare/mismatched_or_failed_queries.jsonl
+output/oracle_neo4j_compare///summary.json
+```
+
+The comparison script:
+
+- loads the same CSV graph into Oracle and Neo4j
+- runs `oracle_sqlpgq` against Oracle
+- runs `oracle_source_query` against Neo4j
+- normalizes values before comparison, including dates, datetimes, booleans, decimals, floats, lists, maps, nodes, and relationships
+- records failed executions and result mismatches in JSONL
+
+Neo4j is cleared between databases in batches to avoid transaction memory errors from a single large `DETACH DELETE`.
+
+## Export Validated Oracle SQL/PGQ Dataset
+
+Use this after `translate_validate.py` has produced enriched records. The exporter loads each graph into Oracle and Neo4j, reruns the Oracle-vs-Neo4j comparison, and writes only records whose source Cypher and translated SQL/PGQ results match.
+
+```bash
+poetry run python dataset_prep/export_validated_dataset.py \
+ --dataset-root dataset \
+ --dataset-output-root output/dataset_prep \
+ --output-root output/oracle_sqlpgq_dataset \
+ --splits train dev test
+```
+
+The output mirrors the source `dataset/` layout. Query JSON files are filtered to matched records and each exported record gets:
+
+```text
+initial_sql_pgq
+```
+
+By default, the exporter strips internal `oracle_*` enrichment fields and copies the dataset assets beside the filtered query files. Useful options:
+
+- `--sql-pgq-field `: use a different field name for the SQL/PGQ query.
+- `--include-oracle-metadata`: keep the internal `oracle_*` enrichment fields.
+- `--no-copy-assets`: write only filtered query JSON files.
+- `--overwrite`: replace an existing non-empty output directory.
+- `--reuse-loaded`: validate against already-loaded Oracle and Neo4j graphs for a focused run.
+
+## Common Workflows
+
+Full Oracle translation and validation:
+
+```bash
+poetry run python dataset_prep/translate_validate.py --splits train dev test
+poetry run python dataset_prep/analyze_failures.py
+```
+
+Full Oracle-vs-Neo4j validation:
+
+```bash
+poetry run python dataset_prep/translate_validate.py --splits train dev test --resume
+poetry run python dataset_prep/compare_oracle_neo4j_results.py \
+ --neo4j-uri bolt://localhost:7687 \
+ --neo4j-user neo4j \
+ --neo4j-password 'ValidationPass123' \
+ --splits train dev test
+```
+
+Focused database debugging:
+
+```bash
+poetry run python dataset_prep/translate_validate.py \
+ --splits dev \
+ --databases Address \
+ --limit-queries 50 \
+ --fail-fast
+
+poetry run python dataset_prep/compare_oracle_neo4j_results.py \
+ --splits dev \
+ --databases Address \
+ --limit-queries 50 \
+ --neo4j-uri bolt://localhost:7687 \
+ --neo4j-user neo4j \
+ --neo4j-password 'ValidationPass123' \
+ --keep-loaded
+```
diff --git a/dataset_prep/__init__.py b/dataset_prep/__init__.py
new file mode 100644
index 0000000..bd6b813
--- /dev/null
+++ b/dataset_prep/__init__.py
@@ -0,0 +1,2 @@
+"""Dataset preparation tools for Oracle SQL/PGQ translation."""
+
diff --git a/dataset_prep/analyze_failures.py b/dataset_prep/analyze_failures.py
new file mode 100644
index 0000000..dc9ec01
--- /dev/null
+++ b/dataset_prep/analyze_failures.py
@@ -0,0 +1,473 @@
+from __future__ import annotations
+
+import argparse
+from collections import Counter, defaultdict
+import json
+from pathlib import Path
+import re
+from typing import Any, Dict, Iterable, List
+
+from dataset_prep.cypher_schema import CypherSchema
+from dataset_prep.discover import discover_database_units
+from dataset_prep.translate_validate import (
+ has_quantified_relationship_property_map,
+)
+
+FAILURE_STATUSES = {"syntax_error", "runtime_error", "unsupported", "load_error"}
+
+
+def main() -> None:
+ args = parse_args()
+ output_root = Path(args.output_root)
+ records = list(iter_enriched_records(output_root))
+ failure_records = [
+ record for record in records if record.get("oracle_validation_status") in FAILURE_STATUSES
+ ]
+
+ report = build_report(
+ failure_records,
+ output_root=output_root,
+ dataset_root=Path(args.dataset_root) if args.dataset_root else None,
+ splits=args.splits,
+ sample_limit=args.sample_limit,
+ )
+ write_json(output_root / "failure_analysis.json", report)
+ write_markdown(output_root / "failure_analysis.md", report)
+ print(f"Wrote {output_root / 'failure_analysis.json'}")
+ print(f"Wrote {output_root / 'failure_analysis.md'}")
+
+
+def parse_args() -> argparse.Namespace:
+ parser = argparse.ArgumentParser(
+ description="Group dataset_prep translation failures by root-cause signature."
+ )
+ parser.add_argument("--output-root", default="output/dataset_prep")
+ parser.add_argument("--dataset-root", default="dataset")
+ parser.add_argument("--splits", nargs="+", default=["train", "dev", "test"])
+ parser.add_argument("--sample-limit", type=int, default=5)
+ return parser.parse_args()
+
+
+def iter_enriched_records(output_root: Path) -> Iterable[Dict[str, Any]]:
+ for path in sorted(output_root.glob("*/*/oracle_sqlpgq_enriched.jsonl")):
+ with open(path, encoding="utf-8") as file:
+ for line_number, line in enumerate(file, start=1):
+ if not line.strip():
+ continue
+ record = json.loads(line)
+ meta = record.setdefault("oracle_dataset_meta", {})
+ meta.setdefault("output_file", str(path))
+ meta.setdefault("output_line", line_number)
+ yield record
+
+
+def build_report(
+ records: List[Dict[str, Any]],
+ output_root: Path,
+ dataset_root: Path | None,
+ splits: List[str],
+ sample_limit: int,
+) -> Dict[str, Any]:
+ status_counts = Counter(record.get("oracle_validation_status") for record in records)
+ signature_counts: Counter[tuple[str, str]] = Counter()
+ by_signature: dict[tuple[str, str], list[Dict[str, Any]]] = defaultdict(list)
+ by_database: Counter[str] = Counter()
+
+ for record in records:
+ status = record.get("oracle_validation_status", "unknown")
+ signature = failure_signature(record)
+ signature_counts[(status, signature)] += 1
+ by_signature[(status, signature)].append(record)
+ meta = record.get("oracle_dataset_meta", {})
+ by_database[f"{meta.get('split')}/{meta.get('database')}"] += 1
+
+ missing_outputs = []
+ if dataset_root is not None and dataset_root.exists():
+ completed = {
+ str(path.relative_to(output_root)).split("/summary.json", 1)[0]
+ for path in output_root.glob("*/*/summary.json")
+ }
+ expected = {
+ f"{unit.split}/{unit.database}"
+ for unit in discover_database_units(dataset_root, splits)
+ }
+ missing_outputs = sorted(expected - completed)
+
+ top_signatures = []
+ for (status, signature), count in signature_counts.most_common():
+ top_signatures.append(
+ {
+ "status": status,
+ "signature": signature,
+ "count": count,
+ "likely_next_action": likely_next_action(status, signature),
+ "samples": [
+ sample_record(record)
+ for record in by_signature[(status, signature)][:sample_limit]
+ ],
+ }
+ )
+
+ return {
+ "output_root": str(output_root),
+ "total_failures": len(records),
+ "status_counts": dict(status_counts),
+ "top_databases_by_failures": dict(by_database.most_common(30)),
+ "databases_without_completed_output": missing_outputs,
+ "top_signatures": top_signatures,
+ }
+
+
+def failure_signature(record: Dict[str, Any]) -> str:
+ status = record.get("oracle_validation_status", "")
+ if status == "unsupported":
+ query_signature = unsupported_query_signature(record.get("oracle_source_query") or "")
+ if query_signature in {
+ "multiple_with_skipped",
+ "standalone_optional_match",
+ "optional_match_left_join_required",
+ }:
+ return query_signature
+ schema_signature = schema_mismatch_signature(record)
+ if schema_signature:
+ return schema_signature
+ features = record.get("oracle_unsupported_features") or []
+ if features:
+ return ",".join(features)
+ return query_signature or record.get("oracle_translation_category") or "unsupported"
+
+ error = str(record.get("oracle_validation_error") or "").strip()
+ if not error:
+ return "empty_error"
+ ora_match = re.search(r"\bORA-\d{5}\b", error)
+ if ora_match:
+ first_line = error.splitlines()[0]
+ return f"{ora_match.group(0)} {first_line.split(':', 1)[-1].strip()}"
+ return error.splitlines()[0][:240]
+
+
+def unsupported_query_signature(query: str) -> str:
+ normalized = " ".join(str(query or "").split())
+ if not normalized:
+ return ""
+ if len(re.findall(r"\bWITH\b", normalized, flags=re.IGNORECASE)) > 1:
+ return "multiple_with_skipped"
+ if re.search(
+ r"\.(?:bad_alias|role_type|contradiction_severity|support_strength)\b",
+ normalized,
+ flags=re.IGNORECASE,
+ ) or re.search(
+ r"\bc\.validity_start\b|\bs\.name\b|\bn2\.sensitivity_level\b|\bg\.created_date\b",
+ normalized,
+ flags=re.IGNORECASE,
+ ):
+ return "invalid_schema_property"
+ if re.search(r"\bOPTIONAL\s+MATCH\b", normalized, flags=re.IGNORECASE):
+ if re.match(r"^OPTIONAL\s+MATCH\b", normalized, flags=re.IGNORECASE):
+ return "standalone_optional_match"
+ return "optional_match_left_join_required"
+ if re.search(
+ r"\b(?:AVG|SUM)\s*\([^)]*(?:date|timestamp)[^)]*\)",
+ normalized,
+ flags=re.IGNORECASE,
+ ):
+ return "temporal_numeric_aggregate"
+ if re.search(
+ r"\bMATCH\s+`?[A-Za-z_][A-Za-z0-9_]*`?\s*=",
+ normalized,
+ flags=re.IGNORECASE,
+ ):
+ return "path_variable_return"
+ if re.search(
+ r"-\s*\[[^\]]*\*\s*(?:\d+\s*)?\.\.\s*\]\s*(?:->|-)|"
+ r"(?:<-|-)\s*\[[^\]]*\*\s*(?:\d+\s*)?\.\.\s*\]\s*-",
+ normalized,
+ flags=re.IGNORECASE,
+ ):
+ return "open_ended_variable_length_path"
+ if has_quantified_relationship_property_map(normalized):
+ return "quantified_relationship_property_map"
+ if re.search(r"\bWITH\b.+\bMATCH\b", normalized, flags=re.IGNORECASE):
+ return "with_match_pipeline"
+ match_prefix = normalized.split(" WHERE ", 1)[0].split(" WITH ", 1)[0].split(" RETURN ", 1)[0]
+ if "," in match_prefix:
+ return "multi_pattern_match"
+ return ""
+
+
+def schema_mismatch_signature(record: Dict[str, Any]) -> str:
+ query = str(record.get("oracle_source_query") or "")
+ meta = record.get("oracle_dataset_meta") or {}
+ import_config = meta.get("import_config")
+ if not query or not import_config:
+ return ""
+ config_path = Path(import_config)
+ if not config_path.exists():
+ return ""
+ try:
+ config = json.loads(config_path.read_text(encoding="utf-8"))
+ except Exception:
+ return ""
+ for issue in CypherSchema(config).validation_issues(query):
+ if issue.signature in {
+ "invalid_schema_label",
+ "invalid_schema_property",
+ "invalid_schema_direction",
+ "unsafe_numeric_conversion",
+ "unsafe_temporal_numeric_comparison",
+ "unsafe_temporal_arithmetic",
+ }:
+ return issue.signature
+ return ""
+
+
+def _cypher_variable_labels(query: str) -> tuple[dict[str, str], dict[str, str]]:
+ node_labels: dict[str, str] = {}
+ edge_labels: dict[str, str] = {}
+ for match in re.finditer(
+ r"\(\s*(?P[A-Za-z_][A-Za-z0-9_]*)?\s*:\s*`?(?P[^`(){},\s]+)`?", query
+ ):
+ variable = match.group("var")
+ if variable:
+ node_labels[variable] = _clean_schema_name(match.group("label"))
+ for match in re.finditer(
+ r"\[\s*(?P[A-Za-z_][A-Za-z0-9_]*)?\s*:\s*`?(?P[^`\]{}\s]+)`?", query
+ ):
+ variable = match.group("var")
+ if variable:
+ edge_labels[variable] = _clean_schema_name(match.group("label"))
+ return node_labels, edge_labels
+
+
+def _cypher_property_references(query: str) -> list[tuple[str, str]]:
+ protected = re.sub(r"'(?:\\'|[^'])*'|\"(?:\\\"|[^\"])*\"", "''", query)
+ references = []
+ for match in re.finditer(
+ r"\b(?P[A-Za-z_][A-Za-z0-9_]*)\."
+ r"(?:`(?P[^`]+)`|(?P[A-Za-z_][A-Za-z0-9_$#-]*))",
+ protected,
+ ):
+ references.append(
+ (match.group("var"), _clean_schema_name(match.group("quoted") or match.group("bare")))
+ )
+ return references
+
+
+def _cypher_property_maps(query: str, node: bool) -> list[tuple[str, list[str]]]:
+ open_char, close_char = ("(", ")") if node else ("[", "]")
+ escaped_open = re.escape(open_char)
+ escaped_close = re.escape(close_char)
+ pattern = re.compile(
+ escaped_open
+ + r"\s*(?:[A-Za-z_][A-Za-z0-9_]*)?\s*:\s*`?(?P[^`(){}\]\[,\s]+)`?"
+ + r"\s*\{(?P[^}]*)\}"
+ + escaped_close
+ )
+ maps = []
+ for match in pattern.finditer(query):
+ properties = [
+ _clean_schema_name(prop.group("prop"))
+ for prop in re.finditer(
+ r"`?(?P[A-Za-z_][A-Za-z0-9_$#-]*)`?\s*:", match.group("body")
+ )
+ ]
+ maps.append((_clean_schema_name(match.group("label")), properties))
+ return maps
+
+
+def _cypher_edge_triples(query: str) -> list[tuple[str, str, str, str]]:
+ node = (
+ r"\(\s*(?:[A-Za-z_][A-Za-z0-9_]*)?\s*:\s*`?"
+ r"(?P[^`(){}\]\[,\s]+)`?(?:\s*\{[^}]*\})?\s*\)"
+ )
+ edge = r"\[\s*(?:[A-Za-z_][A-Za-z0-9_]*)?\s*:\s*`?(?P[^`\]{}\s]+)`?(?:\s*\{[^}]*\})?\s*\]"
+ triples = []
+ right_pattern = (
+ node.replace("NAME", "left")
+ + r"\s*-\s*"
+ + edge
+ + r"\s*->\s*"
+ + node.replace("NAME", "right")
+ )
+ left_pattern = (
+ node.replace("NAME", "left")
+ + r"\s*<-\s*"
+ + edge
+ + r"\s*-\s*"
+ + node.replace("NAME", "right")
+ )
+ for start in range(len(query)):
+ match = re.match(right_pattern, query[start:])
+ if match:
+ triples.append(
+ (
+ _clean_schema_name(match.group("left")),
+ "right",
+ _clean_schema_name(match.group("edge")),
+ _clean_schema_name(match.group("right")),
+ )
+ )
+ match = re.match(left_pattern, query[start:])
+ if match:
+ triples.append(
+ (
+ _clean_schema_name(match.group("left")),
+ "left",
+ _clean_schema_name(match.group("edge")),
+ _clean_schema_name(match.group("right")),
+ )
+ )
+ return triples
+
+
+def _property_exists(property_name: str, properties: set[str]) -> bool:
+ clean = _clean_schema_name(property_name)
+ snake = re.sub(r"(? str:
+ return str(value or "").strip().strip("`").strip('"')
+
+
+def likely_next_action(status: str, signature: str) -> str:
+ if status == "unsupported":
+ if signature == "multiple_with_skipped":
+ return (
+ "Skipped by policy: do not spend translator work on queries with more "
+ "than one WITH."
+ )
+ if signature == "path_variable_return":
+ return (
+ "Keep unsupported unless grouped path projection semantics are added "
+ "for Oracle SQL/PGQ."
+ )
+ if signature == "temporal_numeric_aggregate":
+ return (
+ "Keep unsupported unless the source query explicitly converts temporal "
+ "values to numeric durations."
+ )
+ if signature == "open_ended_variable_length_path":
+ return (
+ "Keep unsupported or add an explicit max bound; Oracle SQL/PGQ rejects "
+ "open-ended bounded quantifiers."
+ )
+ if signature == "invalid_schema_property":
+ return "Treat as invalid source/schema mismatch; do not emit SQL for absent properties."
+ if signature == "optional_match":
+ return (
+ "Correlated OPTIONAL MATCH should translate through the Graph IR LEFT "
+ "JOIN path; inspect unsupported samples for unsupported optional shapes."
+ )
+ if signature == "standalone_optional_match":
+ return (
+ "Standalone OPTIONAL MATCH differs from MATCH only for empty-match "
+ "null-row semantics; keep unsupported unless that behavior is modeled."
+ )
+ if signature == "optional_match_left_join_required":
+ return (
+ "Retry with Graph IR OPTIONAL MATCH support; remaining cases likely "
+ "lack correlation or exceed the v1 optional scope."
+ )
+ if signature == "with_match_pipeline":
+ return (
+ "Inspect whether this single-WITH pipeline is covered by staged SQL "
+ "support; add a focused test if fixable."
+ )
+ if signature == "multi_pattern_match":
+ return (
+ "Inspect comma-separated path patterns for schema validity and "
+ "shared-variable support."
+ )
+ return (
+ "Decide whether this Cypher feature maps to documented Oracle SQL/PGQ; "
+ "otherwise keep classified as unsupported."
+ )
+ if "ORA-40996" in signature:
+ return (
+ "Check whether a string literal was emitted as a double-quoted identifier "
+ "or a variable was used without a MATCH declaration."
+ )
+ if "ORA-00936" in signature or "ORA-00904" in signature:
+ return (
+ "Inspect generated SELECT/ORDER BY aliases; project any outer SQL "
+ "references through GRAPH_TABLE COLUMNS."
+ )
+ if "ORA-42414" in signature:
+ return "Fix graph DDL property typing or omit mixed-type properties from colliding labels."
+ if "ORA-02291" in signature:
+ return (
+ "Investigate CSV referential integrity, or add an unenforced import mode "
+ "for benchmark data."
+ )
+ if "ORA-12899" in signature:
+ return (
+ "Use CLOB or wider text columns for long STRING/TEXT properties during dataset loading."
+ )
+ return (
+ "Inspect sample source query and generated SQL, then add a focused translator "
+ "or loader test."
+ )
+
+
+def sample_record(record: Dict[str, Any]) -> Dict[str, Any]:
+ meta = record.get("oracle_dataset_meta", {})
+ return {
+ "split": meta.get("split"),
+ "database": meta.get("database"),
+ "record_index": meta.get("record_index"),
+ "id": record.get("id"),
+ "source_query": record.get("oracle_source_query"),
+ "oracle_sqlpgq": record.get("oracle_sqlpgq"),
+ "error": record.get("oracle_validation_error"),
+ "output_file": meta.get("output_file"),
+ "output_line": meta.get("output_line"),
+ }
+
+
+def write_json(path: Path, data: Dict[str, Any]) -> None:
+ path.parent.mkdir(parents=True, exist_ok=True)
+ path.write_text(json.dumps(data, indent=2, ensure_ascii=False), encoding="utf-8")
+
+
+def write_markdown(path: Path, report: Dict[str, Any]) -> None:
+ lines = [
+ "# Dataset Prep Failure Analysis",
+ "",
+ f"- Output root: `{report['output_root']}`",
+ f"- Total query failures: `{report['total_failures']}`",
+ f"- Status counts: `{json.dumps(report['status_counts'], sort_keys=True)}`",
+ "",
+ "## Top Databases",
+ "",
+ ]
+ for database, count in report["top_databases_by_failures"].items():
+ lines.append(f"- `{database}`: `{count}`")
+
+ if report["databases_without_completed_output"]:
+ lines += ["", "## Databases Without Completed Output", ""]
+ for database in report["databases_without_completed_output"]:
+ lines.append(f"- `{database}`")
+
+ lines += ["", "## Top Failure Signatures", ""]
+ for item in report["top_signatures"]:
+ lines.append(f"### {item['status']} / {item['signature']} / {item['count']}")
+ lines.append("")
+ lines.append(item["likely_next_action"])
+ lines.append("")
+ for sample in item["samples"]:
+ lines.append(
+ f"- `{sample['split']}/{sample['database']}` "
+ f"record `{sample['record_index']}` id `{sample['id']}`"
+ )
+ lines.append(f" - Source: `{sample['source_query']}`")
+ lines.append(f" - SQL: `{sample['oracle_sqlpgq']}`")
+ if sample.get("error"):
+ lines.append(f" - Error: `{sample['error'].splitlines()[0]}`")
+ lines.append("")
+ path.write_text("\n".join(lines), encoding="utf-8")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/dataset_prep/compare_oracle_neo4j_results.py b/dataset_prep/compare_oracle_neo4j_results.py
new file mode 100644
index 0000000..fba1821
--- /dev/null
+++ b/dataset_prep/compare_oracle_neo4j_results.py
@@ -0,0 +1,2477 @@
+from __future__ import annotations
+
+# ruff: noqa: E402,I001
+
+import argparse
+import csv
+import json
+import logging
+import os
+from pathlib import Path
+import re
+import sys
+from collections import Counter
+from dataclasses import dataclass
+from datetime import date, datetime, timedelta
+from decimal import Decimal
+from typing import Any, Dict, Iterable, List, Sequence
+
+REPO_ROOT = Path(__file__).resolve().parents[1]
+if str(REPO_ROOT) not in sys.path:
+ sys.path.insert(0, str(REPO_ROOT))
+
+try:
+ from neo4j import GraphDatabase, Query
+except ImportError: # pragma: no cover - depends on local environment.
+ GraphDatabase = None
+ Query = None
+
+from app.core.validator.db_client import QueryStatus
+from app.impl.oracle_sqlpgq.db_client.oracle_db_client import OracleDBClient
+from app.impl.oracle_sqlpgq.utils.sqlpgq import OracleNameSanitizer
+from dataset_prep.cypher_schema import CypherSchema, CypherSchemaIssue
+from dataset_prep.discover import DatabaseUnit, discover_database_units
+from dataset_prep.oracle_loader import DatasetOracleLoader
+from dataset_prep.reporting import append_jsonl, write_json
+
+
+DEFAULT_VALID_ORACLE_STATUSES = {"success", "no_record"}
+logging.getLogger("neo4j.notifications").setLevel(logging.ERROR)
+
+
+@dataclass(frozen=True)
+class StableExecutionQueries:
+ oracle_sqlpgq: str
+ cypher: str
+ applied: bool = False
+ reason: str = ""
+
+
+@dataclass(frozen=True)
+class CypherReturnItem:
+ expression: str
+ alias: str
+ order_term: str
+
+
+class DatasetNeo4jLoader:
+ def __init__(
+ self,
+ uri: str,
+ user: str,
+ password: str,
+ database: str,
+ import_config_path: Path,
+ csv_root: Path,
+ batch_size: int = 1000,
+ ):
+ if GraphDatabase is None:
+ raise RuntimeError("Install the 'neo4j' package before running this script.")
+ self.driver = GraphDatabase.driver(uri, auth=(user, password))
+ self.database = database
+ self.import_config_path = import_config_path
+ self.csv_root = csv_root
+ self.batch_size = batch_size
+ self.clear_batch_size = max(batch_size, 1000)
+ self.config = json.loads(import_config_path.read_text(encoding="utf-8"))
+ self.cypher_schema = CypherSchema(self.config)
+ self.schema = list(self.config.get("schema", []))
+ self.files = list(self.config.get("files", []))
+ self.vertices = [item for item in self.schema if item.get("type") == "VERTEX"]
+ self.edges = [item for item in self.schema if item.get("type") == "EDGE"]
+ self.vertex_by_label = {item["label"]: item for item in self.vertices}
+ self.primary_by_label = {
+ item["label"]: item.get("primary", "_id") for item in self.vertices
+ }
+ self.property_types_by_label = {
+ item["label"]: {
+ prop["name"]: prop.get("type", "STRING")
+ for prop in item.get("properties", [])
+ if prop.get("name")
+ }
+ for item in self.schema
+ }
+ self.vertex_labels = {item["label"] for item in self.vertices}
+ self.edge_labels = {item["label"] for item in self.edges}
+ self.node_label_aliases = self._schema_name_aliases(self.vertex_labels)
+ self.edge_type_aliases = self._schema_name_aliases(self.edge_labels)
+ self.property_aliases_by_label = {
+ label: self._schema_name_aliases(properties)
+ for label, properties in self.property_types_by_label.items()
+ }
+ self.global_property_aliases = self._global_property_aliases()
+
+ def close(self) -> None:
+ self.driver.close()
+
+ def setup(self, clear: bool = True) -> Dict[str, int]:
+ with self.driver.session(database=self.database) as session:
+ if clear:
+ self.clear(session)
+ self._create_constraints(session)
+ counts = self._load_vertices()
+ counts.update(self._load_edges())
+ return counts
+
+ def clear(self, session: Any | None = None) -> None:
+ if session is not None:
+ self._clear_with_session(session)
+ return
+ with self.driver.session(database=self.database) as owned_session:
+ self._clear_with_session(owned_session)
+
+ def _clear_with_session(self, session: Any) -> None:
+ rel_delete = "MATCH ()-[r]-() WITH r LIMIT $limit DELETE r RETURN count(r) AS deleted"
+ node_delete = "MATCH (n) WITH n LIMIT $limit DELETE n RETURN count(n) AS deleted"
+ self._delete_until_empty(session, rel_delete)
+ self._delete_until_empty(session, node_delete)
+
+ def _delete_until_empty(self, session: Any, query: str) -> None:
+ while True:
+ record = session.run(query, limit=self.clear_batch_size).single()
+ deleted = record["deleted"] if record else 0
+ if deleted == 0:
+ return
+
+ def execute(self, query: str, timeout_s: float | None = None) -> tuple[str, list[dict], str]:
+ query = self.prepare_query(query)
+ try:
+ with self.driver.session(database=self.database) as session:
+ executable = (
+ Query(query, timeout=timeout_s) if Query is not None and timeout_s else query
+ )
+ result = session.run(executable)
+ return "success", [dict(record) for record in result], ""
+ except Exception as exc:
+ error = str(exc)
+ status = "client_error" if "syntax" in error.lower() else "server_error"
+ return status, [], error
+
+ def prepare_query(self, query: str) -> str:
+ query = self._rewrite_schema_aliases(query)
+ query = self._coerce_string_backed_boolean_literals(query)
+ query = self._coerce_string_backed_numeric_comparisons(query)
+ query = self._coerce_string_backed_date_comparisons(query)
+ return query
+
+ def source_validation_issues(self, query: str) -> list[CypherSchemaIssue]:
+ schema = getattr(self, "cypher_schema", None)
+ if schema is None:
+ return []
+ return schema.validation_issues(query)
+
+ def _rewrite_schema_aliases(self, query: str) -> str:
+ query = _rewrite_outside_string_literals(query, self._rewrite_node_labels)
+ query = _rewrite_outside_string_literals(query, self._rewrite_relationship_types)
+ variables = self._query_variable_labels(query)
+ edge_variables = self._query_edge_variable_labels(query)
+ full_query = query
+ return _rewrite_outside_string_literals(
+ query,
+ lambda segment: self._rewrite_property_accesses(
+ segment,
+ variables,
+ edge_variables,
+ full_query,
+ ),
+ )
+
+ def _rewrite_node_labels(self, query: str) -> str:
+ def replace(match: re.Match) -> str:
+ label = match.group("quoted_label") or match.group("label")
+ canonical = self._canonical_node_label(label)
+ return f"{match.group('prefix')}{_cypher_identifier(canonical)}"
+
+ return re.sub(
+ r"(?P\(\s*(?:[A-Za-z_][A-Za-z0-9_]*\s*)?:\s*)"
+ r"(?:`(?P[^`]+)`|(?P[A-Za-z_][A-Za-z0-9_$#-]*))",
+ replace,
+ query,
+ )
+
+ def _rewrite_relationship_types(self, query: str) -> str:
+ def rewrite_type(raw_type: str) -> str:
+ stripped = raw_type.strip()
+ if not stripped:
+ return raw_type
+ if stripped.startswith("`") and stripped.endswith("`"):
+ label = stripped[1:-1]
+ else:
+ label = stripped
+ return _cypher_identifier(self._canonical_edge_type(label))
+
+ def replace(match: re.Match) -> str:
+ types = [rewrite_type(item) for item in match.group("types").split("|")]
+ return f"{match.group('prefix')}{'|'.join(types)}"
+
+ return re.sub(
+ r"(?P\[\s*(?:[A-Za-z_][A-Za-z0-9_]*\s*)?:\s*)"
+ r"(?P`[^`]+`|[A-Za-z_][A-Za-z0-9_$#-]*"
+ r"(?:\s*\|\s*(?:`[^`]+`|[A-Za-z_][A-Za-z0-9_$#-]*))*)",
+ replace,
+ query,
+ )
+
+ def _rewrite_property_accesses(
+ self,
+ query: str,
+ variables: Dict[str, str],
+ edge_variables: Dict[str, str] | None = None,
+ full_query: str | None = None,
+ ) -> str:
+ edge_variables = edge_variables or {}
+ full_query = full_query or query
+
+ def replace(match: re.Match) -> str:
+ variable = match.group("variable")
+ property_name = match.group("quoted_property") or match.group("property")
+ target_variable, canonical = self._canonical_property_reference(
+ full_query,
+ variable,
+ property_name,
+ variables,
+ edge_variables,
+ )
+ return f"{target_variable}.{_cypher_identifier(canonical)}"
+
+ return re.sub(
+ r"\b(?P[A-Za-z_][A-Za-z0-9_]*)\.\s*"
+ r"(?:`(?P[^`]+)`|"
+ r"(?P[A-Za-z_][A-Za-z0-9_$#-]*))",
+ replace,
+ query,
+ )
+
+ def _canonical_node_label(self, label: str) -> str:
+ return self._canonical_schema_name(
+ label,
+ getattr(self, "node_label_aliases", {}),
+ )
+
+ def _canonical_edge_type(self, edge_type: str) -> str:
+ return self._canonical_schema_name(
+ edge_type,
+ getattr(self, "edge_type_aliases", {}),
+ )
+
+ def _canonical_property_name(
+ self,
+ variable: str,
+ property_name: str,
+ variables: Dict[str, str],
+ ) -> str:
+ target_variable, canonical = self._canonical_property_reference(
+ "",
+ variable,
+ property_name,
+ variables,
+ {},
+ )
+ return canonical if target_variable else property_name
+
+ def _canonical_property_reference(
+ self,
+ query: str,
+ variable: str,
+ property_name: str,
+ variables: Dict[str, str],
+ edge_variables: Dict[str, str],
+ ) -> tuple[str, str]:
+ if property_name.lower() in {"identity", "id"}:
+ if variable in variables and not self._property_defined_for_variable(
+ variable,
+ property_name,
+ variables,
+ ):
+ primary_key = self.primary_by_label.get(variables[variable])
+ if primary_key:
+ return variable, primary_key
+ if variable in edge_variables and not self._property_defined_for_variable(
+ variable,
+ property_name,
+ edge_variables,
+ ):
+ return variable, "EDGE_ID"
+ schema = getattr(self, "cypher_schema", None)
+ if schema is not None and query:
+ redirected_variable, redirected_property = schema.redirected_property_target(
+ query,
+ variable,
+ property_name,
+ )
+ if redirected_variable and redirected_property:
+ return redirected_variable, redirected_property
+ label = variables.get(variable, "")
+ aliases_by_label = getattr(self, "property_aliases_by_label", {})
+ if label in aliases_by_label:
+ canonical = self._canonical_schema_name(property_name, aliases_by_label[label])
+ if canonical != property_name:
+ return variable, canonical
+ return variable, self._canonical_schema_name(
+ property_name,
+ getattr(self, "global_property_aliases", {}),
+ )
+
+ def _property_defined_for_variable(
+ self,
+ variable: str,
+ property_name: str,
+ variables: Dict[str, str],
+ ) -> bool:
+ label = variables.get(variable, "")
+ if not label:
+ return False
+ aliases_by_label = getattr(self, "property_aliases_by_label", {})
+ aliases = aliases_by_label.get(label, {})
+ canonical = self._canonical_schema_name(property_name, aliases)
+ properties = getattr(self, "property_types_by_label", {}).get(label, {})
+ return canonical in properties
+
+ def _canonical_schema_name(self, name: str, aliases: Dict[str, str]) -> str:
+ cleaned = OracleNameSanitizer.clean(name, fallback=name)
+ return (
+ aliases.get(name)
+ or aliases.get(cleaned)
+ or aliases.get(name.lower())
+ or aliases.get(cleaned.lower())
+ or name
+ )
+
+ def _schema_name_aliases(self, names: Iterable[str]) -> Dict[str, str]:
+ aliases: Dict[str, str] = {}
+ for name in names:
+ cleaned = OracleNameSanitizer.clean(name, fallback=name)
+ for alias in {name, cleaned, name.lower(), cleaned.lower()}:
+ aliases.setdefault(alias, name)
+ return aliases
+
+ def _global_property_aliases(self) -> Dict[str, str]:
+ candidates: Dict[str, set[str]] = {}
+ for properties in getattr(self, "property_types_by_label", {}).values():
+ for property_name in properties:
+ cleaned = OracleNameSanitizer.clean(property_name, fallback=property_name)
+ for alias in {property_name, cleaned, property_name.lower(), cleaned.lower()}:
+ candidates.setdefault(alias, set()).add(property_name)
+ return {alias: next(iter(names)) for alias, names in candidates.items() if len(names) == 1}
+
+ def _coerce_string_backed_boolean_literals(self, query: str) -> str:
+ variables = self._query_variable_labels(query)
+
+ def replace_not_property(match: re.Match) -> str:
+ variable = match.group("variable")
+ property_name = match.group("property")
+ if not self._is_string_property(variables.get(variable, ""), property_name):
+ return match.group(0)
+ return f"{variable}.{property_name} = 'false'"
+
+ query = re.sub(
+ r"\bNOT\s+(?P[A-Za-z_][A-Za-z0-9_]*)\."
+ r"(?P[A-Za-z_][A-Za-z0-9_$#]*)\b",
+ replace_not_property,
+ query,
+ flags=re.IGNORECASE,
+ )
+
+ def replace_comparison(match: re.Match) -> str:
+ variable = match.group("variable")
+ property_name = match.group("property")
+ value = match.group("value").lower()
+ if not self._is_string_property(variables.get(variable, ""), property_name):
+ return match.group(0)
+ return f"{variable}.{property_name} {match.group('operator')} '{value}'"
+
+ query = re.sub(
+ r"\b(?P[A-Za-z_][A-Za-z0-9_]*)\."
+ r"(?P[A-Za-z_][A-Za-z0-9_$#]*)\s*"
+ r"(?P=|<>)\s*(?Ptrue|false)\b",
+ replace_comparison,
+ query,
+ flags=re.IGNORECASE,
+ )
+
+ def replace_map_literal(match: re.Match) -> str:
+ property_name = match.group("property")
+ value = match.group("value").lower()
+ if not self._has_string_property(property_name):
+ return match.group(0)
+ return f"{property_name}: '{value}'"
+
+ return re.sub(
+ r"\b(?P[A-Za-z_][A-Za-z0-9_$#]*)\s*:\s*"
+ r"(?Ptrue|false)\b",
+ replace_map_literal,
+ query,
+ flags=re.IGNORECASE,
+ )
+
+ def _coerce_string_backed_numeric_comparisons(self, query: str) -> str:
+ variables = self._query_variable_labels(query)
+
+ def replace_left(match: re.Match) -> str:
+ variable = match.group("variable")
+ property_name = match.group("property")
+ if not self._is_string_property(variables.get(variable, ""), property_name):
+ return match.group(0)
+ return f"{variable}.{property_name} {match.group('operator')} '{match.group('value')}'"
+
+ query = re.sub(
+ r"\b(?P[A-Za-z_][A-Za-z0-9_]*)\."
+ r"(?P[A-Za-z_][A-Za-z0-9_$#]*)\s*"
+ r"(?P<=|>=|<>|=|<|>)\s*"
+ r"(?P-?\d+(?:\.\d+)?)\b",
+ replace_left,
+ query,
+ )
+
+ def replace_right(match: re.Match) -> str:
+ variable = match.group("variable")
+ property_name = match.group("property")
+ if not self._is_string_property(variables.get(variable, ""), property_name):
+ return match.group(0)
+ return f"'{match.group('value')}' {match.group('operator')} {variable}.{property_name}"
+
+ query = re.sub(
+ r"\b(?P-?\d+(?:\.\d+)?)\s*"
+ r"(?P<=|>=|<>|=|<|>)\s*"
+ r"(?P[A-Za-z_][A-Za-z0-9_]*)\."
+ r"(?P[A-Za-z_][A-Za-z0-9_$#]*)\b",
+ replace_right,
+ query,
+ )
+
+ def replace_map_literal(match: re.Match) -> str:
+ property_name = match.group("property")
+ value = match.group("value")
+ if not self._has_string_property(property_name):
+ return match.group(0)
+ return f"{property_name}: '{value}'"
+
+ return re.sub(
+ r"\b(?P[A-Za-z_][A-Za-z0-9_$#]*)\s*:\s*"
+ r"(?P-?\d+(?:\.\d+)?)\b",
+ replace_map_literal,
+ query,
+ )
+
+ def _coerce_string_backed_date_comparisons(self, query: str) -> str:
+ variables = self._query_variable_labels(query)
+
+ query = self._coerce_string_backed_date_accessors(query, variables)
+
+ def replace_left(match: re.Match) -> str:
+ variable = match.group("variable")
+ property_name = match.group("property")
+ left = match.group("left")
+ if not self._is_string_property(variables.get(variable, ""), property_name):
+ return match.group(0)
+ if re.fullmatch(r"\s*date\s*\(", left, flags=re.IGNORECASE):
+ return match.group(0)
+ return (
+ f"date({variable}.{property_name}) {match.group('operator')} {match.group('right')}"
+ )
+
+ query = re.sub(
+ r"(?P\b(?P[A-Za-z_][A-Za-z0-9_]*)\."
+ r"(?P[A-Za-z_][A-Za-z0-9_$#]*))\s*"
+ r"(?P<=|>=|<>|=|<|>)\s*"
+ r"(?Pdate\s*\([^)]+\))",
+ replace_left,
+ query,
+ flags=re.IGNORECASE,
+ )
+
+ def replace_right(match: re.Match) -> str:
+ variable = match.group("variable")
+ property_name = match.group("property")
+ if not self._is_string_property(variables.get(variable, ""), property_name):
+ return match.group(0)
+ return (
+ f"{match.group('left')} {match.group('operator')} date({variable}.{property_name})"
+ )
+
+ return re.sub(
+ r"(?Pdate\s*\([^)]+\))\s*"
+ r"(?P<=|>=|<>|=|<|>)\s*"
+ r"(?P[A-Za-z_][A-Za-z0-9_]*)\."
+ r"(?P[A-Za-z_][A-Za-z0-9_$#]*)\b",
+ replace_right,
+ query,
+ flags=re.IGNORECASE,
+ )
+
+ def _coerce_string_backed_date_accessors(
+ self,
+ query: str,
+ variables: Dict[str, str],
+ ) -> str:
+ def accessor_expression(variable: str, property_name: str, accessor: str) -> str:
+ base = (
+ f"datetime({variable}.{property_name})"
+ if self._looks_like_datetime_string_property(variable, property_name)
+ else f"date({variable}.{property_name})"
+ )
+ return f"{base}.{accessor}"
+
+ def replace_date_call(match: re.Match) -> str:
+ variable = match.group("variable")
+ property_name = match.group("property")
+ if not self._is_string_property(variables.get(variable, ""), property_name):
+ return match.group(0)
+ return accessor_expression(variable, property_name, match.group("accessor"))
+
+ query = re.sub(
+ r"\bdate\s*\(\s*(?P[A-Za-z_][A-Za-z0-9_]*)\."
+ r"(?P[A-Za-z_][A-Za-z0-9_$#]*)\s*\)\."
+ r"(?Pyear|month|day|weekday|dayOfWeek)\b",
+ replace_date_call,
+ query,
+ flags=re.IGNORECASE,
+ )
+
+ def replace_direct_accessor(match: re.Match) -> str:
+ variable = match.group("variable")
+ property_name = match.group("property")
+ if not self._is_string_property(variables.get(variable, ""), property_name):
+ return match.group(0)
+ return accessor_expression(variable, property_name, match.group("accessor"))
+
+ return re.sub(
+ r"\b(?P[A-Za-z_][A-Za-z0-9_]*)\."
+ r"(?P[A-Za-z_][A-Za-z0-9_$#]*)\."
+ r"(?Pyear|month|day|weekday|dayOfWeek)\b",
+ replace_direct_accessor,
+ query,
+ flags=re.IGNORECASE,
+ )
+
+ def _looks_like_datetime_string_property(self, variable: str, property_name: str) -> bool:
+ return property_name.lower().endswith(("at", "time", "timestamp", "datetime"))
+
+ def _query_variable_labels(self, query: str) -> Dict[str, str]:
+ labels: Dict[str, str] = {}
+ for match in re.finditer(
+ r"\(\s*(?P[A-Za-z_][A-Za-z0-9_]*)\s*:\s*"
+ r"(?:`(?P[^`]+)`|(?P[A-Za-z_][A-Za-z0-9_$#-]*))",
+ query,
+ ):
+ labels[match.group("variable")] = match.group("quoted_label") or match.group("label")
+ return labels
+
+ def _query_edge_variable_labels(self, query: str) -> Dict[str, str]:
+ labels: Dict[str, str] = {}
+ for match in re.finditer(
+ r"\[\s*(?P[A-Za-z_][A-Za-z0-9_]*)\s*:\s*"
+ r"(?:`(?P[^`]+)`|(?P[A-Za-z_][A-Za-z0-9_$#|.-]*))",
+ query,
+ ):
+ label = match.group("quoted_label") or match.group("label")
+ if "|" not in label:
+ labels[match.group("variable")] = label
+ return labels
+
+ def _is_string_property(self, label: str, property_name: str) -> bool:
+ if not label:
+ return False
+ label_types = self.property_types_by_label.get(label, {})
+ return label_types.get(property_name, "").upper() == "STRING"
+
+ def _has_string_property(self, property_name: str) -> bool:
+ return any(
+ properties.get(property_name, "").upper() == "STRING"
+ for properties in self.property_types_by_label.values()
+ )
+
+ def _create_constraints(self, session: Any) -> None:
+ for vertex in self.vertices:
+ label = vertex["label"]
+ primary = vertex.get("primary")
+ if not primary:
+ continue
+ name = _safe_identifier(f"constraint_{label}_{primary}")
+ session.run(
+ f"CREATE CONSTRAINT `{name}` IF NOT EXISTS "
+ f"FOR (n:`{_escape_backticks(label)}`) "
+ f"REQUIRE n.`{_escape_backticks(primary)}` IS UNIQUE"
+ ).consume()
+
+ def _load_vertices(self) -> Dict[str, int]:
+ counts: Dict[str, int] = {}
+ files_by_label = {item["label"]: item for item in self.files if "SRC_ID" not in item}
+ for vertex in self.vertices:
+ file_item = files_by_label.get(vertex["label"])
+ if not file_item:
+ continue
+ rows = self._read_file(vertex, file_item, is_edge=False)
+ self._write_vertex_batches(vertex, rows)
+ counts[f"vertex:{vertex['label']}"] = len(rows)
+ return counts
+
+ def _load_edges(self) -> Dict[str, int]:
+ counts: Dict[str, int] = {}
+ for file_item in [item for item in self.files if "SRC_ID" in item and "DST_ID" in item]:
+ edge = self._find_edge_schema(file_item)
+ if not edge:
+ continue
+ rows = self._read_file(edge, file_item, is_edge=True)
+ self._write_edge_batches(edge, file_item, rows)
+ key = f"edge:{file_item['SRC_ID']}-[{edge['label']}]->{file_item['DST_ID']}"
+ counts[key] = len(rows)
+ return counts
+
+ def _find_edge_schema(self, file_item: Dict[str, Any]) -> Dict[str, Any] | None:
+ label = file_item.get("label")
+ src = file_item.get("SRC_ID")
+ dst = file_item.get("DST_ID")
+ for edge in self.edges:
+ if edge.get("label") != label:
+ continue
+ if [src, dst] in edge.get("constraints", []):
+ return edge
+ for edge in self.edges:
+ if edge.get("label") == label:
+ return edge
+ return None
+
+ def _read_file(
+ self,
+ schema_item: Dict[str, Any],
+ file_item: Dict[str, Any],
+ is_edge: bool,
+ ) -> List[Dict[str, Any]]:
+ path = self.csv_root / file_item["path"]
+ source_columns = list(file_item.get("columns", []))
+ schema_types = {
+ prop["name"]: prop.get("type", "STRING") for prop in schema_item.get("properties", [])
+ }
+ header_rows = int(file_item.get("header", 0))
+ rows: List[Dict[str, Any]] = []
+ with open(path, newline="", encoding="utf-8-sig") as file:
+ reader = csv.reader(file)
+ for index, raw in enumerate(reader):
+ if index < header_rows:
+ continue
+ row = {
+ column: raw[position] if position < len(raw) else ""
+ for position, column in enumerate(source_columns)
+ }
+ converted = {}
+ for column, value in row.items():
+ if is_edge and column == "SRC_ID":
+ converted[column] = _convert_value(
+ value,
+ self._vertex_primary_type(file_item["SRC_ID"]),
+ )
+ elif is_edge and column == "DST_ID":
+ converted[column] = _convert_value(
+ value,
+ self._vertex_primary_type(file_item["DST_ID"]),
+ )
+ else:
+ converted[column] = _convert_value(
+ value,
+ schema_types.get(column, "STRING"),
+ )
+ rows.append(converted)
+ return rows
+
+ def _vertex_primary_type(self, label: str) -> str:
+ vertex = self.vertex_by_label.get(label, {})
+ primary = vertex.get("primary", "_id")
+ for prop in vertex.get("properties", []):
+ if prop.get("name") == primary:
+ return prop.get("type", "STRING")
+ return "STRING"
+
+ def _write_vertex_batches(self, vertex: Dict[str, Any], rows: List[Dict[str, Any]]) -> None:
+ if not rows:
+ return
+ label = _escape_backticks(vertex["label"])
+ primary = _escape_backticks(vertex.get("primary", "_id"))
+ query = (
+ f"UNWIND $batch AS row "
+ f"MERGE (n:`{label}` {{`{primary}`: row.`{primary}`}}) "
+ f"SET n += row"
+ )
+ self._run_batches(query, rows)
+
+ def _write_edge_batches(
+ self,
+ edge: Dict[str, Any],
+ file_item: Dict[str, Any],
+ rows: List[Dict[str, Any]],
+ ) -> None:
+ if not rows:
+ return
+ src_label = file_item["SRC_ID"]
+ dst_label = file_item["DST_ID"]
+ src_pk = self.primary_by_label.get(src_label, "_id")
+ dst_pk = self.primary_by_label.get(dst_label, "_id")
+ rel_type = _escape_backticks(edge["label"])
+ query = (
+ f"UNWIND $batch AS row "
+ f"MATCH (src:`{_escape_backticks(src_label)}` "
+ f"{{`{_escape_backticks(src_pk)}`: row.SRC_ID}}) "
+ f"MATCH (dst:`{_escape_backticks(dst_label)}` "
+ f"{{`{_escape_backticks(dst_pk)}`: row.DST_ID}}) "
+ f"CREATE (src)-[r:`{rel_type}`]->(dst) "
+ f"SET r += row.props"
+ )
+ batch_rows = []
+ for index, row in enumerate(rows, start=1):
+ props = {key: value for key, value in row.items() if key not in ("SRC_ID", "DST_ID")}
+ # Oracle edge tables use a per-table generated EDGE_ID. Adding the
+ # same validation-only value to Neo4j lets returned relationships
+ # normalize to the same identity across backends.
+ props.setdefault("EDGE_ID", index)
+ batch_rows.append(
+ {
+ "SRC_ID": row["SRC_ID"],
+ "DST_ID": row["DST_ID"],
+ "props": props,
+ }
+ )
+ self._run_batches(query, batch_rows)
+
+ def _run_batches(self, query: str, rows: List[Dict[str, Any]]) -> None:
+ with self.driver.session(database=self.database) as session:
+ for start in range(0, len(rows), self.batch_size):
+ session.run(query, batch=rows[start : start + self.batch_size]).consume()
+
+
+def parse_args() -> argparse.Namespace:
+ parser = argparse.ArgumentParser(
+ description="Compare Oracle SQL/PGQ query results with Neo4j Cypher results."
+ )
+ parser.add_argument("--dataset-root", default="dataset")
+ parser.add_argument("--dataset-output-root", default="output/dataset_prep")
+ parser.add_argument("--output-root", default="output/oracle_neo4j_compare")
+ parser.add_argument("--splits", nargs="+", default=["train", "dev", "test"])
+ parser.add_argument("--databases", nargs="*", default=[])
+ parser.add_argument("--limit-databases", type=int, default=0)
+ parser.add_argument("--limit-queries", type=int, default=0)
+ parser.add_argument(
+ "--query-offset",
+ type=int,
+ default=0,
+ help="Skip this many enriched query records before applying --limit-queries.",
+ )
+ parser.add_argument("--graph-prefix", default="T2GQL")
+ parser.add_argument(
+ "--oracle-statuses",
+ nargs="+",
+ default=sorted(DEFAULT_VALID_ORACLE_STATUSES),
+ )
+ parser.add_argument("--include-all-translatable", action="store_true")
+ parser.add_argument("--oracle-timeout-ms", type=int, default=60000)
+ parser.add_argument("--neo4j-timeout-s", type=float, default=60.0)
+ parser.add_argument("--neo4j-uri", default=os.environ.get("NEO4J_URI", "bolt://localhost:7687"))
+ parser.add_argument("--neo4j-user", default=os.environ.get("NEO4J_USER", "neo4j"))
+ parser.add_argument("--neo4j-password", default=os.environ.get("NEO4J_PASSWORD", "password"))
+ parser.add_argument("--neo4j-database", default=os.environ.get("NEO4J_DATABASE", "neo4j"))
+ parser.add_argument("--neo4j-batch-size", type=int, default=1000)
+ parser.add_argument("--keep-loaded", action="store_true")
+ parser.add_argument(
+ "--reuse-loaded",
+ action="store_true",
+ help="Skip Oracle/Neo4j load setup and validate against already-loaded graphs.",
+ )
+ parser.add_argument(
+ "--progress-every",
+ type=int,
+ default=0,
+ help="Print query progress every N selected records.",
+ )
+ return parser.parse_args()
+
+
+def main() -> None:
+ args = parse_args()
+ output_root = Path(args.output_root)
+ output_root.mkdir(parents=True, exist_ok=True)
+ failures_path = output_root / "mismatched_or_failed_queries.jsonl"
+ failures_path.write_text("", encoding="utf-8")
+ units = discover_database_units(Path(args.dataset_root), args.splits)
+ if args.databases:
+ requested = {name.lower() for name in args.databases}
+ units = [unit for unit in units if unit.database.lower() in requested]
+ if args.limit_databases:
+ units = units[: args.limit_databases]
+
+ oracle_client = OracleDBClient(
+ {
+ "dsn": os.environ["ORACLE_DSN"],
+ "user": os.environ["ORACLE_USER"],
+ "password": os.environ["ORACLE_PASSWORD"],
+ }
+ )
+ all_summaries: List[Dict[str, Any]] = []
+ try:
+ for unit in units:
+ print(f"[start] {unit.split}/{unit.database}", flush=True)
+ summary = compare_unit(unit, oracle_client, args, failures_path)
+ all_summaries.append(summary)
+ matched = summary["matched"]
+ failed = summary["failed"]
+ skipped = summary["skipped"]
+ print(
+ f"[done] {unit.split}/{unit.database}: "
+ f"matched={matched} failed={failed} skipped={skipped}",
+ flush=True,
+ )
+ finally:
+ oracle_client.close()
+
+ write_json(output_root / "summary.json", merge_compare_summaries(all_summaries))
+
+
+def compare_unit(
+ unit: DatabaseUnit,
+ oracle_client: OracleDBClient,
+ args: argparse.Namespace,
+ failures_path: Path,
+) -> Dict[str, Any]:
+ graph_name = graph_name_for(unit, args.graph_prefix)
+ oracle_loader = DatasetOracleLoader(
+ oracle_client,
+ unit.import_config_path,
+ unit.csv_root,
+ graph_name,
+ )
+ neo4j_loader = DatasetNeo4jLoader(
+ args.neo4j_uri,
+ args.neo4j_user,
+ args.neo4j_password,
+ args.neo4j_database,
+ unit.import_config_path,
+ unit.csv_root,
+ args.neo4j_batch_size,
+ )
+ summary: Dict[str, Any] = {
+ "split": unit.split,
+ "database": unit.database,
+ "query_file": str(unit.query_path),
+ "import_config": str(unit.import_config_path),
+ "graph_name": graph_name,
+ "loaded": {},
+ "considered": 0,
+ "matched": 0,
+ "failed": 0,
+ "skipped": 0,
+ "skip_reasons": {},
+ "failure_reasons": {},
+ }
+ element_label_aliases = oracle_element_label_aliases(oracle_loader)
+ failures: List[Dict[str, Any]] = []
+ try:
+ if args.reuse_loaded:
+ print(f"[load] {unit.split}/{unit.database}: reusing loaded graphs", flush=True)
+ summary["loaded"] = {"reused": True}
+ else:
+ print(f"[load] {unit.split}/{unit.database}: oracle", flush=True)
+ oracle_counts = oracle_loader.setup()
+ print(f"[load] {unit.split}/{unit.database}: neo4j", flush=True)
+ neo4j_counts = neo4j_loader.setup(clear=True)
+ summary["loaded"] = {"oracle": oracle_counts, "neo4j": neo4j_counts}
+ print(f"[load] {unit.split}/{unit.database}: done", flush=True)
+ all_records = load_enriched_records(unit, Path(args.dataset_output_root))
+ records = select_records_for_range(all_records, args.query_offset, args.limit_queries)
+ summary["total_records"] = len(all_records)
+ summary["query_offset"] = args.query_offset
+ summary["selected_records"] = len(records)
+ if args.limit_queries:
+ summary["limit_queries"] = args.limit_queries
+ valid_statuses = set(args.oracle_statuses)
+ for selected_index, record in enumerate(records, start=1):
+ if args.progress_every and selected_index % args.progress_every == 0:
+ print(
+ f"[progress] {unit.split}/{unit.database}: "
+ f"{selected_index}/{len(records)} selected records",
+ flush=True,
+ )
+ skip_reason = skip_reason_for_record(
+ record,
+ valid_statuses=valid_statuses,
+ include_all_translatable=args.include_all_translatable,
+ )
+ if skip_reason:
+ summary["skipped"] += 1
+ increment(summary["skip_reasons"], skip_reason)
+ continue
+ summary["considered"] += 1
+ comparison = compare_record(
+ record,
+ oracle_client,
+ neo4j_loader,
+ args,
+ element_label_aliases=element_label_aliases,
+ )
+ if comparison["matched"]:
+ summary["matched"] += 1
+ continue
+ if comparison["reason"] in {
+ "nondeterministic_limit_without_order",
+ "suspected_order_by_limit_tie",
+ "source_invalid",
+ }:
+ summary["skipped"] += 1
+ increment(summary["skip_reasons"], comparison["reason"])
+ continue
+ summary["failed"] += 1
+ increment(summary["failure_reasons"], comparison["reason"])
+ failures.append(
+ {
+ "split": unit.split,
+ "database": unit.database,
+ "record_id": record.get("id"),
+ "record_index": record.get("oracle_dataset_meta", {}).get("record_index"),
+ "reason": comparison["reason"],
+ "cypher": comparison["cypher"],
+ "oracle_sqlpgq": comparison["oracle_sqlpgq"],
+ "oracle_status": comparison["oracle_status"],
+ "neo4j_status": comparison["neo4j_status"],
+ "oracle_error": comparison["oracle_error"],
+ "neo4j_error": comparison["neo4j_error"],
+ "oracle_rows_sample": comparison["oracle_rows_sample"],
+ "neo4j_rows_sample": comparison["neo4j_rows_sample"],
+ "result_diagnostics": comparison.get("result_diagnostics", {}),
+ "deterministic_ordering": comparison.get("deterministic_ordering", {}),
+ }
+ )
+ if len(failures) >= 100:
+ append_jsonl(failures_path, failures)
+ failures = []
+ if failures:
+ append_jsonl(failures_path, failures)
+ write_json(
+ Path(args.output_root) / unit.split / unit.database / "summary.json",
+ summary,
+ )
+ return summary
+ finally:
+ if not args.keep_loaded:
+ if not args.reuse_loaded:
+ oracle_loader.cleanup(ignore_errors=True)
+ try:
+ neo4j_loader.clear()
+ except Exception:
+ pass
+ neo4j_loader.close()
+
+
+def compare_record(
+ record: Dict[str, Any],
+ oracle_client: OracleDBClient,
+ neo4j_loader: DatasetNeo4jLoader,
+ args: argparse.Namespace,
+ element_label_aliases: Dict[str, str] | None = None,
+) -> Dict[str, Any]:
+ oracle_sqlpgq = record.get("oracle_sqlpgq") or ""
+ cypher = (
+ record.get("oracle_source_query")
+ or record.get("initial_cypher")
+ or record.get("cypher")
+ or ""
+ )
+ source_validator = getattr(neo4j_loader, "source_validation_issues", None)
+ source_issues = source_validator(cypher) if source_validator else []
+ if source_issues:
+ return comparison_result(
+ False,
+ "source_invalid",
+ cypher,
+ oracle_sqlpgq,
+ "not_executed",
+ "source_invalid",
+ "",
+ "; ".join(f"{issue.signature}: {issue.message}" for issue in source_issues),
+ [],
+ [],
+ neo4j_loader.primary_by_label,
+ element_label_aliases,
+ )
+ execution_queries = stable_execution_queries(
+ oracle_sqlpgq,
+ cypher,
+ neo4j_loader.primary_by_label,
+ )
+ oracle_result = oracle_client.execute_query(
+ execution_queries.oracle_sqlpgq,
+ call_timeout_ms=args.oracle_timeout_ms,
+ )
+ neo4j_status, neo4j_rows, neo4j_error = neo4j_loader.execute(
+ execution_queries.cypher,
+ args.neo4j_timeout_s,
+ )
+ oracle_status = query_status_name(oracle_result.status_code)
+ oracle_rows = oracle_result.data if isinstance(oracle_result.data, list) else []
+ if neo4j_status != "success":
+ return comparison_result(
+ False,
+ "source_invalid",
+ cypher,
+ oracle_sqlpgq,
+ oracle_status,
+ neo4j_status,
+ oracle_result.error or "",
+ neo4j_error,
+ oracle_rows,
+ neo4j_rows,
+ neo4j_loader.primary_by_label,
+ element_label_aliases,
+ execution_queries,
+ )
+ if oracle_status not in ("success", "no_record"):
+ return comparison_result(
+ False,
+ "execution_error",
+ cypher,
+ oracle_sqlpgq,
+ oracle_status,
+ neo4j_status,
+ oracle_result.error or "",
+ neo4j_error,
+ oracle_rows,
+ neo4j_rows,
+ neo4j_loader.primary_by_label,
+ element_label_aliases,
+ execution_queries,
+ )
+ oracle_counter = normalized_counter(
+ oracle_rows,
+ neo4j_loader.primary_by_label,
+ element_label_aliases,
+ )
+ neo4j_counter = normalized_counter(
+ neo4j_rows,
+ neo4j_loader.primary_by_label,
+ element_label_aliases,
+ )
+ matched = oracle_counter == neo4j_counter or normalized_rows_match_with_numeric_tolerance(
+ oracle_rows,
+ neo4j_rows,
+ neo4j_loader.primary_by_label,
+ element_label_aliases,
+ )
+ reason = "result_mismatch" if not matched else ""
+ if (
+ not matched
+ and is_nondeterministic_limit_without_order(cypher)
+ and not execution_queries.applied
+ and oracle_rows
+ and neo4j_rows
+ ):
+ reason = "nondeterministic_limit_without_order"
+ elif (
+ not matched
+ and is_nondeterministic_with_limit_without_order(cypher)
+ and not execution_queries.applied
+ and oracle_rows
+ and neo4j_rows
+ ):
+ reason = "nondeterministic_with_limit_without_order"
+ elif (
+ not matched
+ and is_with_order_by_limit_query(cypher)
+ and not execution_queries.applied
+ and oracle_rows
+ and neo4j_rows
+ ):
+ reason = "suspected_with_order_by_limit_tie"
+ elif (
+ not matched
+ and is_order_by_limit_query(cypher)
+ and not execution_queries.applied
+ and oracle_rows
+ and neo4j_rows
+ ):
+ has_tie = has_order_by_limit_boundary_tie(
+ cypher,
+ oracle_sqlpgq,
+ oracle_client,
+ neo4j_loader,
+ args,
+ element_label_aliases,
+ )
+ if has_tie is not False:
+ reason = "suspected_order_by_limit_tie"
+ return comparison_result(
+ matched,
+ reason,
+ cypher,
+ oracle_sqlpgq,
+ oracle_status,
+ neo4j_status,
+ "",
+ "",
+ oracle_rows,
+ neo4j_rows,
+ neo4j_loader.primary_by_label,
+ element_label_aliases,
+ execution_queries,
+ )
+
+
+def comparison_result(
+ matched: bool,
+ reason: str,
+ cypher: str,
+ oracle_sqlpgq: str,
+ oracle_status: str,
+ neo4j_status: str,
+ oracle_error: str,
+ neo4j_error: str,
+ oracle_rows: Sequence[Dict[str, Any]],
+ neo4j_rows: Sequence[Dict[str, Any]],
+ primary_by_label: Dict[str, str] | None = None,
+ element_label_aliases: Dict[str, str] | None = None,
+ execution_queries: StableExecutionQueries | None = None,
+) -> Dict[str, Any]:
+ result = {
+ "matched": matched,
+ "reason": reason,
+ "cypher": cypher,
+ "oracle_sqlpgq": oracle_sqlpgq,
+ "oracle_status": oracle_status,
+ "neo4j_status": neo4j_status,
+ "oracle_error": oracle_error,
+ "neo4j_error": neo4j_error,
+ "oracle_rows_sample": normalize_rows(
+ oracle_rows[:5],
+ primary_by_label,
+ element_label_aliases,
+ ),
+ "neo4j_rows_sample": normalize_rows(
+ neo4j_rows[:5],
+ primary_by_label,
+ element_label_aliases,
+ ),
+ }
+ if execution_queries and execution_queries.applied:
+ result["deterministic_ordering"] = {
+ "reason": execution_queries.reason,
+ "oracle_sqlpgq": execution_queries.oracle_sqlpgq,
+ "cypher": execution_queries.cypher,
+ }
+ if not matched and reason == "result_mismatch":
+ result["result_diagnostics"] = result_diagnostics(
+ oracle_rows,
+ neo4j_rows,
+ primary_by_label,
+ element_label_aliases,
+ )
+ return result
+
+
+def load_enriched_records(unit: DatabaseUnit, output_root: Path) -> List[Dict[str, Any]]:
+ path = output_root / unit.split / unit.database / "oracle_sqlpgq_enriched.jsonl"
+ if not path.exists():
+ raise FileNotFoundError(
+ f"Missing enriched Oracle SQL/PGQ records for {unit.split}/{unit.database}: {path}"
+ )
+ return [
+ json.loads(line) for line in path.read_text(encoding="utf-8").splitlines() if line.strip()
+ ]
+
+
+def select_records_for_range(
+ records: Sequence[Dict[str, Any]],
+ query_offset: int = 0,
+ limit_queries: int = 0,
+) -> List[Dict[str, Any]]:
+ start = max(query_offset, 0)
+ if limit_queries > 0:
+ return list(records[start : start + limit_queries])
+ return list(records[start:])
+
+
+def skip_reason_for_record(
+ record: Dict[str, Any],
+ valid_statuses: set[str],
+ include_all_translatable: bool,
+) -> str:
+ if record.get("oracle_translation_category") != "Graph-IL Translatable":
+ return "not_translatable"
+ if not record.get("oracle_sqlpgq"):
+ return "missing_oracle_sqlpgq"
+ if not (
+ record.get("oracle_source_query") or record.get("initial_cypher") or record.get("cypher")
+ ):
+ return "missing_cypher"
+ if include_all_translatable:
+ return ""
+ status = record.get("oracle_validation_status")
+ if status not in valid_statuses:
+ return f"oracle_status:{status}"
+ return ""
+
+
+def stable_execution_queries(
+ oracle_sqlpgq: str,
+ cypher: str,
+ primary_by_label: Dict[str, str] | None = None,
+) -> StableExecutionQueries:
+ cypher_query = _stable_cypher_stage_paging_query(cypher, primary_by_label)
+ if cypher_query is not None:
+ oracle_query = _stable_oracle_stage_paging_query(
+ oracle_sqlpgq,
+ cypher_query.projected_column_count,
+ has_existing_order=cypher_query.had_order_by,
+ )
+ if oracle_query is not None:
+ return StableExecutionQueries(
+ oracle_sqlpgq=oracle_query,
+ cypher=cypher_query.query,
+ applied=True,
+ reason=cypher_query.reason,
+ )
+
+ cypher_query = _stable_cypher_paging_query(cypher)
+ if cypher_query is None:
+ return StableExecutionQueries(oracle_sqlpgq, cypher)
+ oracle_query = _stable_oracle_paging_query(
+ oracle_sqlpgq,
+ cypher_query.projected_column_count,
+ has_existing_order=cypher_query.had_order_by,
+ )
+ if oracle_query is None:
+ return StableExecutionQueries(oracle_sqlpgq, cypher)
+ return StableExecutionQueries(
+ oracle_sqlpgq=oracle_query,
+ cypher=cypher_query.query,
+ applied=True,
+ reason=cypher_query.reason,
+ )
+
+
+@dataclass(frozen=True)
+class StableCypherQuery:
+ query: str
+ projected_column_count: int
+ had_order_by: bool
+ reason: str
+ scope: str = "final"
+
+
+@dataclass(frozen=True)
+class FinalCypherPaging:
+ return_start: int
+ return_end: int
+ body_start: int
+ body_end: int
+ pagination_start: int
+ has_order_by: bool
+ order_body_start: int = -1
+ order_body_end: int = -1
+ has_limit: bool = False
+ has_skip: bool = False
+
+
+@dataclass(frozen=True)
+class StageCypherPaging:
+ with_start: int
+ with_end: int
+ body_start: int
+ body_end: int
+ pagination_start: int
+ stage_end: int
+ has_order_by: bool
+ order_body_start: int = -1
+ order_body_end: int = -1
+ has_limit: bool = False
+ has_skip: bool = False
+
+
+def _stable_cypher_stage_paging_query(
+ query: str,
+ primary_by_label: Dict[str, str] | None = None,
+) -> StableCypherQuery | None:
+ stage = _last_cypher_with_paging(query)
+ if stage is None:
+ return None
+ stage_body = query[stage.body_start : stage.body_end].strip()
+ with_items = _parse_cypher_return_items(
+ stage_body,
+ query[: stage.with_start],
+ _graph_variable_stable_order_terms(query[: stage.with_start], primary_by_label),
+ )
+ if not with_items:
+ return None
+ order_terms = [item.order_term for item in with_items]
+ if stage.has_order_by:
+ existing_terms = _order_by_expressions_from_body(
+ query[stage.order_body_start : stage.order_body_end]
+ )
+ missing_terms = _missing_order_terms(existing_terms, order_terms)
+ if not missing_terms:
+ return None
+ updated = (
+ query[: stage.order_body_end].rstrip()
+ + ", "
+ + ", ".join(missing_terms)
+ + " "
+ + query[stage.order_body_end :].lstrip()
+ )
+ return StableCypherQuery(
+ query=updated,
+ projected_column_count=len(with_items),
+ had_order_by=True,
+ reason="with_ordered_paging_tiebreaker",
+ scope="with",
+ )
+ updated = (
+ query[: stage.pagination_start].rstrip()
+ + " ORDER BY "
+ + ", ".join(order_terms)
+ + " "
+ + query[stage.pagination_start :].lstrip()
+ )
+ return StableCypherQuery(
+ query=updated,
+ projected_column_count=len(with_items),
+ had_order_by=False,
+ reason="with_unordered_paging",
+ scope="with",
+ )
+
+
+def _stable_cypher_paging_query(query: str) -> StableCypherQuery | None:
+ final = _final_cypher_paging(query)
+ if final is None:
+ return None
+ return_body = query[final.body_start : final.body_end].strip()
+ return_items = _parse_cypher_return_items(return_body, query[: final.return_start])
+ if not return_items:
+ return None
+ order_terms = [item.order_term for item in return_items]
+ if final.has_order_by:
+ existing_terms = _order_by_expressions_from_body(
+ query[final.order_body_start : final.order_body_end]
+ )
+ missing_terms = _missing_order_terms(existing_terms, order_terms)
+ if not missing_terms:
+ return None
+ updated = (
+ query[: final.order_body_end].rstrip()
+ + ", "
+ + ", ".join(missing_terms)
+ + " "
+ + query[final.order_body_end :].lstrip()
+ )
+ return StableCypherQuery(
+ query=updated,
+ projected_column_count=len(return_items),
+ had_order_by=True,
+ reason="ordered_paging_tiebreaker",
+ )
+ updated = (
+ query[: final.pagination_start].rstrip()
+ + " ORDER BY "
+ + ", ".join(order_terms)
+ + " "
+ + query[final.pagination_start :].lstrip()
+ )
+ return StableCypherQuery(
+ query=updated,
+ projected_column_count=len(return_items),
+ had_order_by=False,
+ reason="unordered_paging",
+ )
+
+
+def _stable_oracle_paging_query(
+ query: str,
+ projected_column_count: int,
+ has_existing_order: bool,
+) -> str | None:
+ if projected_column_count < 1:
+ return None
+ stripped = query.rstrip().rstrip(";")
+ masked = _mask_string_literals(stripped)
+ if re.search(r"\bUNION\b", masked, flags=re.IGNORECASE):
+ return None
+ pagination_span = _trailing_sql_pagination_span(masked)
+ if pagination_span is None:
+ return None
+ order_terms = ", ".join(str(index) for index in range(1, projected_column_count + 1))
+ pagination_start, _ = pagination_span
+ if has_existing_order:
+ order_span = _final_top_level_sql_order_body_span(masked, pagination_start)
+ if order_span is None:
+ return None
+ _, order_body_end = order_span
+ return _append_sql_order_terms(stripped, order_body_end, order_terms)
+ return (
+ stripped[:pagination_start].rstrip()
+ + "\nORDER BY "
+ + order_terms
+ + "\n"
+ + stripped[pagination_start:].lstrip()
+ )
+
+
+def _stable_oracle_stage_paging_query(
+ query: str,
+ projected_column_count: int,
+ has_existing_order: bool,
+) -> str | None:
+ if projected_column_count < 1:
+ return None
+ stripped = query.rstrip().rstrip(";")
+ masked = _mask_string_literals(stripped)
+ fetch_match = None
+ for match in re.finditer(
+ r"\bFETCH\s+FIRST\s+\d+\s+ROWS\s+ONLY\b",
+ masked,
+ flags=re.IGNORECASE,
+ ):
+ if match.end() != len(masked.strip()):
+ fetch_match = match
+ break
+ if fetch_match is None:
+ return None
+ order_terms = ", ".join(str(index) for index in range(1, projected_column_count + 1))
+ if has_existing_order:
+ order_span = _nearest_sql_order_body_span(masked, fetch_match.start())
+ if order_span is None:
+ return None
+ _, order_body_end = order_span
+ return _append_sql_order_terms(stripped, order_body_end, order_terms)
+ return (
+ stripped[: fetch_match.start()].rstrip()
+ + "\nORDER BY "
+ + order_terms
+ + "\n"
+ + stripped[fetch_match.start() :].lstrip()
+ )
+
+
+def _append_sql_order_terms(query: str, order_body_end: int, order_terms: str) -> str:
+ suffix = query[order_body_end:].lstrip()
+ separator = "\n" if suffix.upper().startswith(("FETCH ", "OFFSET ")) else " "
+ return query[:order_body_end].rstrip() + ", " + order_terms + separator + suffix
+
+
+def _parse_cypher_return_items(
+ return_body: str,
+ query_before_return: str,
+ graph_variable_order_terms: Dict[str, str] | None = None,
+) -> List[CypherReturnItem]:
+ body = re.sub(r"^\s*DISTINCT\b", "", return_body, flags=re.IGNORECASE).strip()
+ if not body:
+ return []
+ graph_variables = _graph_variables(query_before_return)
+ graph_variable_order_terms = graph_variable_order_terms or {}
+ items: List[CypherReturnItem] = []
+ for raw_item in _split_top_level_commas(body):
+ expression, alias = _split_cypher_alias(raw_item)
+ stripped_expression = expression.strip()
+ graph_order_term = graph_variable_order_terms.get(stripped_expression)
+ if graph_order_term:
+ order_term = graph_order_term
+ elif not _is_safe_stable_order_expression(expression, graph_variables):
+ return []
+ else:
+ order_term = _cypher_identifier(alias) if alias else stripped_expression
+ if not order_term or order_term == "*":
+ return []
+ items.append(
+ CypherReturnItem(
+ expression=expression.strip(),
+ alias=alias,
+ order_term=order_term,
+ )
+ )
+ return items
+
+
+def _split_cypher_alias(item: str) -> tuple[str, str]:
+ masked = _mask_string_literals(item)
+ matches = list(re.finditer(r"\s+AS\s+", masked, flags=re.IGNORECASE))
+ if not matches:
+ return item.strip(), ""
+ match = matches[-1]
+ expression = item[: match.start()].strip()
+ alias = _unquote_cypher_identifier(item[match.end() :].strip())
+ return expression, alias
+
+
+def _is_safe_stable_order_expression(expression: str, graph_variables: set[str]) -> bool:
+ stripped = expression.strip()
+ if not stripped or stripped == "*":
+ return False
+ if stripped in graph_variables:
+ return False
+ if re.search(
+ r"\b(?:collect|labels|nodes|properties|relationships)\s*\(",
+ stripped,
+ flags=re.IGNORECASE,
+ ):
+ return False
+ return True
+
+
+def _graph_variables(query: str) -> set[str]:
+ masked = _mask_string_literals(query)
+ variables = {
+ match.group("variable")
+ for match in re.finditer(
+ r"\(\s*(?P[A-Za-z_][A-Za-z0-9_]*)"
+ r"(?=\s*(?::|\{|\)|WHERE\b))",
+ masked,
+ flags=re.IGNORECASE,
+ )
+ }
+ variables.update(
+ match.group("variable")
+ for match in re.finditer(
+ r"\[\s*(?P[A-Za-z_][A-Za-z0-9_]*)"
+ r"(?=\s*(?::|\*|\]|\{))",
+ masked,
+ flags=re.IGNORECASE,
+ )
+ )
+ variables.update(
+ match.group("variable")
+ for match in re.finditer(
+ r"\b(?:MATCH|OPTIONAL\s+MATCH)\s+(?P[A-Za-z_][A-Za-z0-9_]*)\s*=",
+ masked,
+ flags=re.IGNORECASE,
+ )
+ )
+ return variables
+
+
+def _graph_variable_stable_order_terms(
+ query: str,
+ primary_by_label: Dict[str, str] | None,
+) -> Dict[str, str]:
+ if not primary_by_label:
+ return {}
+ masked = _mask_string_literals(query)
+ order_terms: Dict[str, str] = {}
+ for match in re.finditer(
+ r"\(\s*(?P[A-Za-z_][A-Za-z0-9_]*)\s*:\s*"
+ r"`?(?P[^`(){},\s]+)`?",
+ masked,
+ flags=re.IGNORECASE,
+ ):
+ variable = match.group("variable")
+ label = match.group("label")
+ primary_key = (
+ primary_by_label.get(label)
+ or primary_by_label.get(OracleNameSanitizer.clean(label, fallback=label))
+ or primary_by_label.get(label.lower())
+ )
+ if primary_key:
+ order_terms[variable] = f"{variable}.{_cypher_identifier(primary_key)}"
+ return order_terms
+
+
+def _final_cypher_paging(query: str) -> FinalCypherPaging | None:
+ masked = _mask_string_literals(query)
+ return_matches = list(re.finditer(r"\bRETURN\b", masked, flags=re.IGNORECASE))
+ if not return_matches:
+ return None
+ return_match = return_matches[-1]
+ after_return = masked[return_match.end() :]
+ order_match = re.search(r"\bORDER\s+BY\b", after_return, flags=re.IGNORECASE)
+ skip_match = re.search(r"\bSKIP\s+\d+\b", after_return, flags=re.IGNORECASE)
+ limit_match = re.search(r"\bLIMIT\s+\d+\b", after_return, flags=re.IGNORECASE)
+ if skip_match is None and limit_match is None:
+ return None
+ pagination_offsets = [match.start() for match in (skip_match, limit_match) if match is not None]
+ pagination_start = return_match.end() + min(pagination_offsets)
+ has_order_by = bool(order_match and return_match.end() + order_match.start() < pagination_start)
+ if has_order_by and order_match is not None:
+ body_end = return_match.end() + order_match.start()
+ order_body_start = return_match.end() + order_match.end()
+ order_body_end = pagination_start
+ else:
+ body_end = pagination_start
+ order_body_start = -1
+ order_body_end = -1
+ return FinalCypherPaging(
+ return_start=return_match.start(),
+ return_end=return_match.end(),
+ body_start=return_match.end(),
+ body_end=body_end,
+ pagination_start=pagination_start,
+ has_order_by=has_order_by,
+ order_body_start=order_body_start,
+ order_body_end=order_body_end,
+ has_limit=limit_match is not None,
+ has_skip=skip_match is not None,
+ )
+
+
+def _last_cypher_with_paging(query: str) -> StageCypherPaging | None:
+ masked = _mask_string_literals(query)
+ with_matches = list(re.finditer(r"\bWITH\b", masked, flags=re.IGNORECASE))
+ for with_match in reversed(with_matches):
+ stage_end = _next_cypher_clause_start(masked, with_match.end())
+ if stage_end is None:
+ continue
+ stage_text = masked[with_match.end() : stage_end]
+ order_match = re.search(r"\bORDER\s+BY\b", stage_text, flags=re.IGNORECASE)
+ skip_match = re.search(r"\bSKIP\s+\d+\b", stage_text, flags=re.IGNORECASE)
+ limit_match = re.search(r"\bLIMIT\s+\d+\b", stage_text, flags=re.IGNORECASE)
+ if skip_match is None and limit_match is None:
+ continue
+ pagination_offsets = [
+ match.start() for match in (skip_match, limit_match) if match is not None
+ ]
+ pagination_start = with_match.end() + min(pagination_offsets)
+ has_order_by = bool(
+ order_match and with_match.end() + order_match.start() < pagination_start
+ )
+ if has_order_by and order_match is not None:
+ body_end = with_match.end() + order_match.start()
+ order_body_start = with_match.end() + order_match.end()
+ order_body_end = pagination_start
+ else:
+ body_end = pagination_start
+ order_body_start = -1
+ order_body_end = -1
+ return StageCypherPaging(
+ with_start=with_match.start(),
+ with_end=with_match.end(),
+ body_start=with_match.end(),
+ body_end=body_end,
+ pagination_start=pagination_start,
+ stage_end=stage_end,
+ has_order_by=has_order_by,
+ order_body_start=order_body_start,
+ order_body_end=order_body_end,
+ has_limit=limit_match is not None,
+ has_skip=skip_match is not None,
+ )
+ return None
+
+
+def _next_cypher_clause_start(masked_query: str, start: int) -> int | None:
+ match = re.search(
+ r"\b(?:MATCH|OPTIONAL\s+MATCH|RETURN|WITH)\b",
+ masked_query[start:],
+ flags=re.IGNORECASE,
+ )
+ return start + match.start() if match else None
+
+
+def _order_by_expressions_from_body(order_body: str) -> List[str]:
+ expressions = []
+ for item in _split_top_level_commas(order_body):
+ cleaned = re.sub(r"\s+(?:ASC|DESC)\s*$", "", item.strip(), flags=re.IGNORECASE)
+ if cleaned:
+ expressions.append(cleaned)
+ return expressions
+
+
+def _missing_order_terms(existing_terms: Sequence[str], order_terms: Sequence[str]) -> List[str]:
+ existing = {_normalize_order_term(term) for term in existing_terms}
+ return [term for term in order_terms if _normalize_order_term(term) not in existing]
+
+
+def _normalize_order_term(term: str) -> str:
+ return re.sub(r"\s+", "", _unquote_cypher_identifier(term.strip())).lower()
+
+
+def _trailing_sql_pagination_span(masked_sql: str) -> tuple[int, int] | None:
+ patterns = [
+ r"\s+OFFSET\s+\d+\s+ROWS\s+FETCH\s+FIRST\s+\d+\s+ROWS\s+ONLY\s*$",
+ r"\s+FETCH\s+FIRST\s+\d+\s+ROWS\s+ONLY\s*$",
+ r"\s+OFFSET\s+\d+\s+ROWS\s*$",
+ ]
+ for pattern in patterns:
+ match = re.search(pattern, masked_sql, flags=re.IGNORECASE)
+ if match:
+ return match.span()
+ return None
+
+
+def _final_top_level_sql_order_body_span(
+ masked_sql: str,
+ search_end: int,
+) -> tuple[int, int] | None:
+ last_match: re.Match | None = None
+ for match in re.finditer(r"\bORDER\s+BY\b", masked_sql[:search_end], flags=re.IGNORECASE):
+ if _paren_depth_at(masked_sql, match.start()) == 0:
+ last_match = match
+ if last_match is None:
+ return None
+ return last_match.end(), search_end
+
+
+def _nearest_sql_order_body_span(
+ masked_sql: str,
+ search_end: int,
+) -> tuple[int, int] | None:
+ target_depth = _paren_depth_at(masked_sql, search_end)
+ last_match: re.Match | None = None
+ for match in re.finditer(r"\bORDER\s+BY\b", masked_sql[:search_end], flags=re.IGNORECASE):
+ if _paren_depth_at(masked_sql, match.start()) == target_depth:
+ last_match = match
+ if last_match is None:
+ return None
+ return last_match.end(), search_end
+
+
+def _paren_depth_at(value: str, position: int) -> int:
+ depth = 0
+ for char in value[:position]:
+ if char == "(":
+ depth += 1
+ elif char == ")":
+ depth = max(depth - 1, 0)
+ return depth
+
+
+def is_nondeterministic_limit_without_order(query: str) -> bool:
+ paging = _final_cypher_paging(query)
+ return bool(paging and paging.has_limit and not paging.has_order_by)
+
+
+def is_nondeterministic_with_limit_without_order(query: str) -> bool:
+ paging = _last_cypher_with_paging(query)
+ return bool(paging and paging.has_limit and not paging.has_order_by)
+
+
+def is_order_by_limit_query(query: str) -> bool:
+ paging = _final_cypher_paging(query)
+ return bool(paging and paging.has_limit and paging.has_order_by)
+
+
+def is_with_order_by_limit_query(query: str) -> bool:
+ paging = _last_cypher_with_paging(query)
+ return bool(paging and paging.has_limit and paging.has_order_by)
+
+
+def has_order_by_limit_boundary_tie(
+ cypher: str,
+ oracle_sqlpgq: str,
+ oracle_client: OracleDBClient,
+ neo4j_loader: DatasetNeo4jLoader,
+ args: argparse.Namespace,
+ element_label_aliases: Dict[str, str] | None = None,
+) -> bool | None:
+ limit = _trailing_cypher_limit(cypher)
+ if limit is None or limit < 1:
+ return None
+ sort_expressions = _order_by_expressions(cypher)
+ if not sort_expressions:
+ return None
+ expanded_cypher = _replace_trailing_cypher_limit(cypher, limit + 1)
+ expanded_sql = _replace_trailing_sql_fetch(oracle_sqlpgq, limit + 1)
+ if expanded_cypher == cypher or expanded_sql == oracle_sqlpgq:
+ return None
+
+ oracle_result = oracle_client.execute_query(
+ expanded_sql,
+ call_timeout_ms=args.oracle_timeout_ms,
+ )
+ neo4j_status, neo4j_rows, _ = neo4j_loader.execute(expanded_cypher, args.neo4j_timeout_s)
+ if query_status_name(oracle_result.status_code) not in {"success", "no_record"}:
+ return None
+ if neo4j_status != "success":
+ return None
+ oracle_rows = oracle_result.data if isinstance(oracle_result.data, list) else []
+ oracle_tie = _rows_have_boundary_tie(
+ oracle_rows,
+ limit,
+ sort_expressions,
+ element_label_aliases,
+ )
+ neo4j_tie = _rows_have_boundary_tie(
+ neo4j_rows,
+ limit,
+ sort_expressions,
+ element_label_aliases,
+ )
+ if oracle_tie is True or neo4j_tie is True:
+ return True
+ if oracle_tie is False and neo4j_tie is False:
+ return False
+ return None
+
+
+def _rows_have_boundary_tie(
+ rows: Sequence[Dict[str, Any]],
+ limit: int,
+ sort_expressions: Sequence[str],
+ element_label_aliases: Dict[str, str] | None = None,
+) -> bool | None:
+ if len(rows) <= limit:
+ return False
+ before = _sort_key_for_row(rows[limit - 1], sort_expressions, element_label_aliases)
+ after = _sort_key_for_row(rows[limit], sort_expressions, element_label_aliases)
+ if before is None or after is None:
+ return None
+ return before == after
+
+
+def _sort_key_for_row(
+ row: Dict[str, Any],
+ sort_expressions: Sequence[str],
+ element_label_aliases: Dict[str, str] | None = None,
+) -> tuple[Any, ...] | None:
+ values = []
+ for expression in sort_expressions:
+ value = _row_value_for_sort_expression(row, expression)
+ if value is _MISSING:
+ return None
+ values.append(_normalize_value(value, element_label_aliases=element_label_aliases))
+ return tuple(values)
+
+
+_MISSING = object()
+
+
+def _row_value_for_sort_expression(row: Dict[str, Any], expression: str) -> Any:
+ candidates = _sort_expression_candidates(expression)
+ for candidate in candidates:
+ if candidate in row:
+ return row[candidate]
+ lower_by_key = {str(key).lower(): key for key in row}
+ for candidate in candidates:
+ key = lower_by_key.get(candidate.lower())
+ if key is not None:
+ return row[key]
+ return _MISSING
+
+
+def _sort_expression_candidates(expression: str) -> List[str]:
+ expression = expression.strip()
+ expression = re.sub(r"^`(?P.*)`$", r"\g", expression)
+ candidates = [expression]
+ if "." in expression:
+ candidates.append(expression.rsplit(".", 1)[1].strip("`"))
+ candidates.append(OracleNameSanitizer.clean(expression, fallback=expression))
+ candidates.append(OracleNameSanitizer.alias(expression))
+ return list(dict.fromkeys(candidate for candidate in candidates if candidate))
+
+
+def _order_by_expressions(query: str) -> List[str]:
+ masked = _strip_string_literals(query)
+ match = re.search(
+ r"\bORDER\s+BY\s+(?P.*?)(?:\bSKIP\b|\bLIMIT\b|$)",
+ masked,
+ flags=re.IGNORECASE | re.DOTALL,
+ )
+ if not match:
+ return []
+ expressions = []
+ for item in _split_top_level_commas(match.group("body")):
+ cleaned = re.sub(r"\s+(?:ASC|DESC)\s*$", "", item.strip(), flags=re.IGNORECASE)
+ if cleaned:
+ expressions.append(cleaned)
+ return expressions
+
+
+def _split_top_level_commas(value: str) -> List[str]:
+ parts = []
+ start = 0
+ depths = {"(": 0, "[": 0, "{": 0}
+ quote = ""
+ index = 0
+ while index < len(value):
+ char = value[index]
+ if quote:
+ if char == "\\" and quote in {"'", '"'}:
+ index += 2
+ continue
+ if char == quote:
+ if index + 1 < len(value) and value[index + 1] == quote:
+ index += 2
+ continue
+ quote = ""
+ index += 1
+ continue
+ if char in {"'", '"', "`"}:
+ quote = char
+ elif char == "(":
+ depths["("] += 1
+ elif char == ")":
+ depths["("] = max(depths["("] - 1, 0)
+ elif char == "[":
+ depths["["] += 1
+ elif char == "]":
+ depths["["] = max(depths["["] - 1, 0)
+ elif char == "{":
+ depths["{"] += 1
+ elif char == "}":
+ depths["{"] = max(depths["{"] - 1, 0)
+ elif char == "," and not any(depths.values()):
+ parts.append(value[start:index].strip())
+ start = index + 1
+ index += 1
+ parts.append(value[start:].strip())
+ return [part for part in parts if part]
+
+
+def _trailing_cypher_limit(query: str) -> int | None:
+ match = re.search(r"\bLIMIT\s+(?P\d+)\s*$", query.strip(), flags=re.IGNORECASE)
+ return int(match.group("limit")) if match else None
+
+
+def _replace_trailing_cypher_limit(query: str, limit: int) -> str:
+ return re.sub(
+ r"\bLIMIT\s+\d+\s*$",
+ f"LIMIT {limit}",
+ query.strip(),
+ count=1,
+ flags=re.IGNORECASE,
+ )
+
+
+def _replace_trailing_sql_fetch(query: str, limit: int) -> str:
+ return re.sub(
+ r"\bFETCH\s+FIRST\s+\d+\s+ROWS\s+ONLY\s*$",
+ f"FETCH FIRST {limit} ROWS ONLY",
+ query.strip(),
+ count=1,
+ flags=re.IGNORECASE,
+ )
+
+
+def _strip_string_literals(query: str) -> str:
+ return re.sub(r"'(?:''|\\'|[^'])*'|\"(?:\\\"|[^\"])*\"", "''", query or "")
+
+
+def _mask_string_literals(query: str) -> str:
+ if not query:
+ return ""
+ chars = list(query)
+ index = 0
+ while index < len(chars):
+ char = chars[index]
+ if char not in {"'", '"', "`"}:
+ index += 1
+ continue
+ quote = char
+ index += 1
+ while index < len(chars):
+ if chars[index] == "\\" and quote in {"'", '"'}:
+ chars[index] = " "
+ if index + 1 < len(chars):
+ chars[index + 1] = " "
+ index += 2
+ continue
+ if chars[index] == quote:
+ if index + 1 < len(chars) and chars[index + 1] == quote:
+ chars[index] = " "
+ chars[index + 1] = " "
+ index += 2
+ continue
+ index += 1
+ break
+ chars[index] = " "
+ index += 1
+ return "".join(chars)
+
+
+def _rewrite_outside_string_literals(value: str, rewrite) -> str:
+ parts = []
+ start = 0
+ index = 0
+ while index < len(value):
+ char = value[index]
+ if char not in ("'", '"'):
+ index += 1
+ continue
+ if start < index:
+ parts.append(rewrite(value[start:index]))
+ literal_start = index
+ quote = char
+ index += 1
+ while index < len(value):
+ if value[index] == "\\":
+ index += 2
+ continue
+ if value[index] == quote:
+ if index + 1 < len(value) and value[index + 1] == quote:
+ index += 2
+ continue
+ index += 1
+ break
+ index += 1
+ parts.append(value[literal_start:index])
+ start = index
+ if start < len(value):
+ parts.append(rewrite(value[start:]))
+ return "".join(parts)
+
+
+def _cypher_identifier(value: str) -> str:
+ if re.fullmatch(r"[A-Za-z_][A-Za-z0-9_]*", value):
+ return value
+ return f"`{value.replace('`', '``')}`"
+
+
+def _unquote_cypher_identifier(value: str) -> str:
+ stripped = value.strip()
+ if stripped.startswith("`") and stripped.endswith("`") and len(stripped) >= 2:
+ return stripped[1:-1].replace("``", "`")
+ return stripped
+
+
+def normalized_counter(
+ rows: Sequence[Dict[str, Any]],
+ primary_by_label: Dict[str, str] | None = None,
+ element_label_aliases: Dict[str, str] | None = None,
+) -> Counter[str]:
+ return Counter(
+ json.dumps(row, sort_keys=True, ensure_ascii=False)
+ for row in normalize_rows(rows, primary_by_label, element_label_aliases)
+ )
+
+
+def normalized_rows_match_with_numeric_tolerance(
+ oracle_rows: Sequence[Dict[str, Any]],
+ neo4j_rows: Sequence[Dict[str, Any]],
+ primary_by_label: Dict[str, str] | None = None,
+ element_label_aliases: Dict[str, str] | None = None,
+ absolute_tolerance: float = 1e-4,
+ relative_tolerance: float = 1e-8,
+) -> bool:
+ if len(oracle_rows) != len(neo4j_rows):
+ return False
+ oracle_normalized = normalize_rows(oracle_rows, primary_by_label, element_label_aliases)
+ neo4j_normalized = normalize_rows(neo4j_rows, primary_by_label, element_label_aliases)
+ unmatched = list(neo4j_normalized)
+ for oracle_row in oracle_normalized:
+ match_index = next(
+ (
+ index
+ for index, neo4j_row in enumerate(unmatched)
+ if _values_equal_with_numeric_tolerance(
+ oracle_row,
+ neo4j_row,
+ absolute_tolerance,
+ relative_tolerance,
+ )
+ ),
+ -1,
+ )
+ if match_index == -1:
+ return False
+ unmatched.pop(match_index)
+ return not unmatched
+
+
+def _values_equal_with_numeric_tolerance(
+ left: Any,
+ right: Any,
+ absolute_tolerance: float,
+ relative_tolerance: float,
+) -> bool:
+ if isinstance(left, bool) or isinstance(right, bool):
+ return left == right
+ if isinstance(left, (int, float)) and isinstance(right, (int, float)):
+ delta = abs(float(left) - float(right))
+ allowed = max(
+ absolute_tolerance, relative_tolerance * max(abs(float(left)), abs(float(right)), 1.0)
+ )
+ return delta <= allowed
+ if isinstance(left, list) and isinstance(right, list):
+ return len(left) == len(right) and all(
+ _values_equal_with_numeric_tolerance(
+ left_item,
+ right_item,
+ absolute_tolerance,
+ relative_tolerance,
+ )
+ for left_item, right_item in zip(left, right, strict=True)
+ )
+ if isinstance(left, dict) and isinstance(right, dict):
+ return left.keys() == right.keys() and all(
+ _values_equal_with_numeric_tolerance(
+ left[key],
+ right[key],
+ absolute_tolerance,
+ relative_tolerance,
+ )
+ for key in left
+ )
+ return left == right
+
+
+def result_diagnostics(
+ oracle_rows: Sequence[Dict[str, Any]],
+ neo4j_rows: Sequence[Dict[str, Any]],
+ primary_by_label: Dict[str, str] | None = None,
+ element_label_aliases: Dict[str, str] | None = None,
+ sample_limit: int = 5,
+) -> Dict[str, Any]:
+ oracle_counter = normalized_counter(
+ oracle_rows,
+ primary_by_label,
+ element_label_aliases,
+ )
+ neo4j_counter = normalized_counter(
+ neo4j_rows,
+ primary_by_label,
+ element_label_aliases,
+ )
+ missing_from_neo4j = oracle_counter - neo4j_counter
+ extra_in_neo4j = neo4j_counter - oracle_counter
+ return {
+ "oracle_row_count": len(oracle_rows),
+ "neo4j_row_count": len(neo4j_rows),
+ "oracle_distinct_row_count": len(oracle_counter),
+ "neo4j_distinct_row_count": len(neo4j_counter),
+ "missing_from_neo4j_count": sum(missing_from_neo4j.values()),
+ "extra_in_neo4j_count": sum(extra_in_neo4j.values()),
+ "missing_from_neo4j_sample": _counter_rows_sample(
+ missing_from_neo4j,
+ sample_limit,
+ ),
+ "extra_in_neo4j_sample": _counter_rows_sample(extra_in_neo4j, sample_limit),
+ }
+
+
+def _counter_rows_sample(counter: Counter[str], sample_limit: int) -> List[Any]:
+ rows: List[Any] = []
+ for encoded_row, count in counter.items():
+ row = json.loads(encoded_row)
+ for _ in range(count):
+ rows.append(row)
+ if len(rows) >= sample_limit:
+ return rows
+ return rows
+
+
+def normalize_rows(
+ rows: Sequence[Dict[str, Any]],
+ primary_by_label: Dict[str, str] | None = None,
+ element_label_aliases: Dict[str, str] | None = None,
+) -> List[Any]:
+ return [normalize_row(row, primary_by_label, element_label_aliases) for row in rows]
+
+
+def normalize_row(
+ row: Dict[str, Any],
+ primary_by_label: Dict[str, str] | None = None,
+ element_label_aliases: Dict[str, str] | None = None,
+) -> Any:
+ # Compare return values rather than aliases: Oracle aliases often differ from Cypher aliases.
+ values = list(row.values())
+ if len(values) == 1 and _looks_like_path(values[0]):
+ return _normalize_path(values[0], primary_by_label, element_label_aliases)
+ return [_normalize_value(value, primary_by_label, element_label_aliases) for value in values]
+
+
+def _normalize_value(
+ value: Any,
+ primary_by_label: Dict[str, str] | None = None,
+ element_label_aliases: Dict[str, str] | None = None,
+) -> Any:
+ if value is None:
+ return None
+ if isinstance(value, bool):
+ return 1 if value else 0
+ if isinstance(value, Decimal):
+ return int(value) if value == value.to_integral_value() else round(float(value), 6)
+ if isinstance(value, int):
+ return value
+ if isinstance(value, float):
+ if value.is_integer():
+ return int(value)
+ rounded = round(value, 6)
+ return int(rounded) if rounded.is_integer() else rounded
+ if isinstance(value, timedelta):
+ return value.total_seconds()
+ if isinstance(value, datetime):
+ if value.hour == value.minute == value.second == value.microsecond == 0:
+ return value.date().isoformat()
+ return value.isoformat()
+ if isinstance(value, date):
+ return value.isoformat()
+ if isinstance(value, str):
+ return _normalize_temporal_string(value)
+ if _looks_like_path(value):
+ return _normalize_path(value, primary_by_label, element_label_aliases)
+ if isinstance(value, (list, tuple)):
+ return [_normalize_value(item, primary_by_label, element_label_aliases) for item in value]
+ if isinstance(value, dict):
+ oracle_identity = _normalize_oracle_graph_identity(
+ value,
+ primary_by_label,
+ element_label_aliases,
+ )
+ if oracle_identity is not None:
+ return oracle_identity
+ return {
+ str(key): _normalize_value(item, primary_by_label, element_label_aliases)
+ for key, item in sorted(value.items())
+ }
+ if hasattr(value, "items") and hasattr(value, "labels"):
+ return _normalize_neo4j_node(value, primary_by_label, element_label_aliases)
+ if hasattr(value, "items") and hasattr(value, "type"):
+ return _normalize_neo4j_relationship(
+ value,
+ primary_by_label,
+ element_label_aliases,
+ )
+ if hasattr(value, "iso_format"):
+ return _normalize_temporal_string(str(value.iso_format()))
+ if hasattr(value, "isoformat"):
+ return _normalize_temporal_string(str(value.isoformat()))
+ return value
+
+
+def _normalize_temporal_string(value: str) -> str:
+ match = re.fullmatch(
+ r"(?P\d{4}-\d{2}-\d{2})[T ]"
+ r"(?P\d{2}:\d{2}:\d{2})"
+ r"(?:\.(?P\d{1,9}))?"
+ r"(?PZ|[+-]\d{2}:\d{2})?",
+ value,
+ )
+ if match:
+ fraction = (match.group("fraction") or "").rstrip("0")
+ zone = match.group("zone") or ""
+ if zone in ("Z", "+00:00"):
+ zone = ""
+ if match.group("time") == "00:00:00" and not fraction and not zone:
+ return match.group("date")
+ base = f"{match.group('date')}T{match.group('time')}"
+ return f"{base}{'.' + fraction if fraction else ''}{zone}"
+ return value
+
+
+def _normalize_oracle_graph_identity(
+ value: Dict[str, Any],
+ primary_by_label: Dict[str, str] | None = None,
+ element_label_aliases: Dict[str, str] | None = None,
+) -> Dict[str, Any] | None:
+ if "ELEM_TABLE" not in value or "KEY_VALUE" not in value:
+ return None
+ label = _canonical_element_label(str(value["ELEM_TABLE"]), element_label_aliases)
+ normalized = {
+ "element": label,
+ }
+ key = _normalize_value(value["KEY_VALUE"], primary_by_label, element_label_aliases)
+ if key:
+ normalized["key"] = key
+ return normalized
+
+
+def _normalize_neo4j_node(
+ value: Any,
+ primary_by_label: Dict[str, str] | None = None,
+ element_label_aliases: Dict[str, str] | None = None,
+) -> Dict[str, Any]:
+ labels = sorted(str(label) for label in value.labels)
+ label = _canonical_element_label(labels[0] if labels else "", element_label_aliases)
+ properties = dict(value.items())
+ key = _node_key(label, properties, primary_by_label)
+ if key:
+ return {
+ "element": OracleNameSanitizer.clean(label, fallback=label),
+ "key": _normalize_value(key, primary_by_label, element_label_aliases),
+ }
+ return {
+ "element": label,
+ "properties": _normalize_value(properties, primary_by_label, element_label_aliases),
+ }
+
+
+def _normalize_neo4j_relationship(
+ value: Any,
+ primary_by_label: Dict[str, str] | None = None,
+ element_label_aliases: Dict[str, str] | None = None,
+) -> Dict[str, Any]:
+ rel_type = _canonical_element_label(str(value.type), element_label_aliases)
+ normalized: Dict[str, Any] = {
+ "element": rel_type,
+ }
+ properties = dict(value.items())
+ edge_id = properties.get("EDGE_ID")
+ if edge_id is not None:
+ normalized["key"] = {"EDGE_ID": _normalize_value(edge_id, primary_by_label)}
+ return normalized
+ if properties:
+ normalized["properties"] = _normalize_value(
+ properties,
+ primary_by_label,
+ element_label_aliases,
+ )
+ return normalized
+
+
+def _node_key(
+ label: str,
+ properties: Dict[str, Any],
+ primary_by_label: Dict[str, str] | None = None,
+) -> Dict[str, Any]:
+ candidates = []
+ if primary_by_label:
+ candidates.extend(
+ [
+ primary_by_label.get(label),
+ primary_by_label.get(OracleNameSanitizer.clean(label, fallback=label)),
+ ]
+ )
+ candidates.extend(["_id", "vid", "id", f"{label}_id"])
+ candidates.extend(sorted(key for key in properties if key.lower().endswith("_id")))
+ for candidate in candidates:
+ if candidate and candidate in properties:
+ return {candidate: properties[candidate]}
+ return {}
+
+
+def _looks_like_path(value: Any) -> bool:
+ return hasattr(value, "nodes") and hasattr(value, "relationships")
+
+
+def _normalize_path(
+ value: Any,
+ primary_by_label: Dict[str, str] | None = None,
+ element_label_aliases: Dict[str, str] | None = None,
+) -> List[Any]:
+ nodes = list(value.nodes)
+ relationships = list(value.relationships)
+ normalized = []
+ for index, node in enumerate(nodes):
+ normalized.append(_normalize_value(node, primary_by_label, element_label_aliases))
+ if index < len(relationships):
+ normalized.append(
+ _normalize_value(
+ relationships[index],
+ primary_by_label,
+ element_label_aliases,
+ )
+ )
+ return normalized
+
+
+def _canonical_element_label(
+ label: str,
+ element_label_aliases: Dict[str, str] | None = None,
+) -> str:
+ cleaned = OracleNameSanitizer.clean(label, fallback=label)
+ if not element_label_aliases:
+ return cleaned
+ return element_label_aliases.get(cleaned, element_label_aliases.get(label, cleaned))
+
+
+def oracle_element_label_aliases(loader: DatasetOracleLoader) -> Dict[str, str]:
+ aliases: Dict[str, str] = {}
+ for vertex in loader.manifest.get("vertices", []):
+ graph_label = OracleNameSanitizer.clean(
+ vertex.get("graph_label", vertex["label"]),
+ fallback=vertex["label"],
+ )
+ source_label = OracleNameSanitizer.clean(vertex["label"], fallback=vertex["label"])
+ aliases[graph_label] = source_label
+ for edge in loader.manifest.get("edges", []):
+ graph_label = OracleNameSanitizer.clean(
+ edge.get("graph_label", edge["label"]),
+ fallback=edge["label"],
+ )
+ source_label = OracleNameSanitizer.clean(edge["label"], fallback=edge["label"])
+ aliases[graph_label] = source_label
+ return aliases
+
+
+def _convert_value(value: str, type_name: str) -> Any:
+ if value == "":
+ return None
+ normalized = type_name.upper()
+ if normalized in ("INT8", "INT16", "INT32", "INT64", "INTEGER"):
+ return int(float(value))
+ if normalized in ("FLOAT", "DOUBLE", "FLOAT32", "FLOAT64"):
+ return float(value)
+ if normalized in ("BOOL", "BOOLEAN"):
+ return value.strip().lower() in ("true", "1", "yes", "y")
+ if normalized == "DATE":
+ parsed = _parse_datetime(value)
+ return parsed.date() if isinstance(parsed, datetime) else parsed
+ if normalized in ("DATETIME", "TIMESTAMP"):
+ parsed = _parse_datetime(value)
+ if isinstance(parsed, date) and not isinstance(parsed, datetime):
+ return datetime.combine(parsed, datetime.min.time())
+ return parsed
+ return value
+
+
+def _parse_datetime(value: str) -> date | datetime | str:
+ normalized = value.strip()
+ try:
+ return datetime.fromisoformat(normalized.replace("Z", "+00:00"))
+ except ValueError:
+ pass
+ for fmt in (
+ "%Y-%m-%d %H:%M:%S",
+ "%Y-%m-%d",
+ "%m/%d/%Y %H:%M:%S",
+ "%m/%d/%Y",
+ "%Y/%m/%d %H:%M:%S",
+ "%Y/%m/%d",
+ ):
+ try:
+ parsed = datetime.strptime(normalized, fmt)
+ if fmt in ("%Y-%m-%d", "%m/%d/%Y", "%Y/%m/%d"):
+ return parsed.date()
+ return parsed
+ except ValueError:
+ continue
+ return normalized
+
+
+def graph_name_for(unit: DatabaseUnit, prefix: str) -> str:
+ return OracleNameSanitizer.clean(
+ f"{prefix}_{unit.split}_{unit.database}",
+ fallback="T2GQL_GRAPH",
+ )
+
+
+def query_status_name(status: QueryStatus) -> str:
+ if status == QueryStatus.SUCCESS:
+ return "success"
+ if status == QueryStatus.NO_RECORD:
+ return "no_record"
+ if status == QueryStatus.CLIENT_ERROR:
+ return "client_error"
+ if status == QueryStatus.SERVER_ERROR:
+ return "server_error"
+ return str(status)
+
+
+def increment(mapping: Dict[str, int], key: str) -> None:
+ mapping[key] = mapping.get(key, 0) + 1
+
+
+def merge_compare_summaries(summaries: Iterable[Dict[str, Any]]) -> Dict[str, Any]:
+ merged: Dict[str, Any] = {
+ "databases": 0,
+ "considered": 0,
+ "matched": 0,
+ "failed": 0,
+ "skipped": 0,
+ "skip_reasons": {},
+ "failure_reasons": {},
+ "units": [],
+ }
+ for summary in summaries:
+ merged["databases"] += 1
+ for key in ("considered", "matched", "failed", "skipped"):
+ merged[key] += int(summary.get(key, 0))
+ for key, value in summary.get("skip_reasons", {}).items():
+ merged["skip_reasons"][key] = merged["skip_reasons"].get(key, 0) + int(value)
+ for key, value in summary.get("failure_reasons", {}).items():
+ merged["failure_reasons"][key] = merged["failure_reasons"].get(key, 0) + int(value)
+ merged["units"].append(summary)
+ return merged
+
+
+def _escape_backticks(value: str) -> str:
+ return value.replace("`", "``")
+
+
+def _safe_identifier(value: str) -> str:
+ cleaned = "".join(char if char.isalnum() else "_" for char in value)
+ return cleaned[:120] or "constraint_name"
+
+
+if __name__ == "__main__":
+ main()
diff --git a/dataset_prep/cypher_schema.py b/dataset_prep/cypher_schema.py
new file mode 100644
index 0000000..ed27be7
--- /dev/null
+++ b/dataset_prep/cypher_schema.py
@@ -0,0 +1,724 @@
+from __future__ import annotations
+
+from dataclasses import dataclass
+import json
+from pathlib import Path
+import re
+from typing import Any, Iterable
+
+from app.impl.oracle_sqlpgq.utils.sqlpgq import OracleNameSanitizer
+
+IGNORED_PROPERTY_REFERENCE_NAMESPACES = {
+ "apoc",
+ "date",
+ "datetime",
+ "duration",
+ "localdatetime",
+ "localtime",
+ "time",
+}
+
+
+@dataclass(frozen=True)
+class CypherSchemaIssue:
+ signature: str
+ message: str
+
+
+class CypherSchema:
+ def __init__(self, config: dict[str, Any]):
+ self.config = config
+ self.schema = list(config.get("schema") or [])
+ self.vertices = [item for item in self.schema if item.get("type") == "VERTEX"]
+ self.edges = [item for item in self.schema if item.get("type") == "EDGE"]
+ self.node_props = {
+ item.get("label"): {
+ prop.get("name") for prop in item.get("properties", []) if prop.get("name")
+ }
+ for item in self.vertices
+ }
+ self.edge_props = {
+ item.get("label"): {
+ prop.get("name") for prop in item.get("properties", []) if prop.get("name")
+ }
+ for item in self.edges
+ }
+ self.property_types_by_label = {
+ item.get("label"): {
+ prop.get("name"): prop.get("type", "STRING")
+ for prop in item.get("properties", [])
+ if prop.get("name")
+ }
+ for item in self.schema
+ }
+ self.edge_constraints = {
+ item.get("label"): {
+ (constraint[0], constraint[1])
+ for constraint in item.get("constraints", [])
+ if isinstance(constraint, list) and len(constraint) == 2
+ }
+ for item in self.edges
+ }
+ self.node_primary = {
+ item.get("label"): item.get("primary", "_id") for item in self.vertices
+ }
+ self.node_label_aliases = self._schema_name_aliases(self.node_props)
+ self.edge_label_aliases = self._schema_name_aliases(self.edge_props)
+ self.property_aliases_by_label = {
+ label: self._schema_name_aliases(properties)
+ for label, properties in self.property_types_by_label.items()
+ }
+ self.global_property_aliases = self._global_property_aliases()
+
+ @classmethod
+ def from_path(cls, path: Path) -> CypherSchema:
+ return cls(json.loads(path.read_text(encoding="utf-8")))
+
+ def validation_issues(self, query: str) -> list[CypherSchemaIssue]:
+ query = str(query or "")
+ issues: list[CypherSchemaIssue] = []
+ node_variables, edge_variables = cypher_variable_labels(query)
+ declared_variables = cypher_graph_variables(query)
+ for variable, label in node_variables.items():
+ if self.canonical_node_label(label) not in self.node_props:
+ issues.append(
+ CypherSchemaIssue(
+ "invalid_schema_label",
+ f'Node label "{label}" for variable "{variable}" is not in schema.',
+ )
+ )
+ for variable, label in edge_variables.items():
+ if self.canonical_edge_label(label) not in self.edge_props:
+ issues.append(
+ CypherSchemaIssue(
+ "invalid_schema_label",
+ f'Edge label "{label}" for variable "{variable}" is not in schema.',
+ )
+ )
+ for left_label, direction, edge_labels, right_label in cypher_edge_triples(query):
+ if direction == "undirected":
+ continue
+ left = self.canonical_node_label(left_label)
+ right = self.canonical_node_label(right_label)
+ for edge_label in edge_labels:
+ edge = self.canonical_edge_label(edge_label)
+ constraints = self.edge_constraints.get(edge)
+ if not constraints or not left or not right:
+ continue
+ expected = (left, right) if direction == "right" else (right, left)
+ if expected not in constraints:
+ issues.append(
+ CypherSchemaIssue(
+ "invalid_schema_direction",
+ (
+ f'Edge "{edge_label}" does not allow {left_label} '
+ f"{direction} {right_label}."
+ ),
+ )
+ )
+ break
+ for variable, property_name in cypher_property_references(query):
+ if variable.lower() in IGNORED_PROPERTY_REFERENCE_NAMESPACES:
+ continue
+ if variable in node_variables:
+ label = self.canonical_node_label(node_variables[variable])
+ if not self._valid_node_property(query, variable, label, property_name):
+ issues.append(
+ CypherSchemaIssue(
+ "invalid_schema_property",
+ f'Property "{property_name}" is not valid for node "{variable}".',
+ )
+ )
+ elif variable in edge_variables:
+ label = self.canonical_edge_label(edge_variables[variable])
+ if not self._valid_edge_property(label, property_name):
+ issues.append(
+ CypherSchemaIssue(
+ "invalid_schema_property",
+ f'Property "{property_name}" is not valid for edge "{variable}".',
+ )
+ )
+ elif property_name.lower() in {"identity", "id"}:
+ issues.append(
+ CypherSchemaIssue(
+ "invalid_schema_property",
+ f'Cannot resolve pseudo-property "{property_name}" for "{variable}".',
+ )
+ )
+ elif variable in declared_variables and not self._property_has_unique_schema_owner(
+ property_name
+ ):
+ issues.append(
+ CypherSchemaIssue(
+ "invalid_schema_property",
+ f'Cannot resolve property "{property_name}" for unlabeled "{variable}".',
+ )
+ )
+ elif not self._property_known_anywhere(property_name):
+ issues.append(
+ CypherSchemaIssue(
+ "invalid_schema_property",
+ f'Cannot resolve property "{property_name}" for "{variable}".',
+ )
+ )
+ issues.extend(self.unsafe_numeric_issues(query))
+ return _dedupe_issues(issues)
+
+ def unsafe_numeric_issues(self, query: str) -> list[CypherSchemaIssue]:
+ issues: list[CypherSchemaIssue] = []
+ node_variables, edge_variables = cypher_variable_labels(query)
+ variables = {**node_variables, **edge_variables}
+ for match in re.finditer(
+ r"\bto(?:Integer|Float)\s*\(\s*(?P[A-Za-z_][A-Za-z0-9_]*)\."
+ r"(?P[A-Za-z_][A-Za-z0-9_$#-]*)\s*\)",
+ mask_string_literals(query),
+ flags=re.IGNORECASE,
+ ):
+ variable = match.group("var")
+ property_name = self.canonical_property_name(variable, match.group("prop"), variables)
+ property_type = self.property_type(variable, property_name, variables)
+ if self._is_string_type(property_type) and self._looks_unsafe_numeric_text_property(
+ property_name
+ ):
+ issues.append(
+ CypherSchemaIssue(
+ "unsafe_numeric_conversion",
+ f'Unsafe numeric conversion for "{variable}.{property_name}".',
+ )
+ )
+ property_ref = (
+ r"(?P<{prefix}_var>[A-Za-z_][A-Za-z0-9_]*)\."
+ r"(?P<{prefix}_prop>[A-Za-z_][A-Za-z0-9_$#-]*)"
+ )
+ comparison = re.compile(
+ property_ref.format(prefix="left")
+ + r"\s*(?P<=|>=|<|>)\s*"
+ + property_ref.format(prefix="right"),
+ flags=re.IGNORECASE,
+ )
+ for match in comparison.finditer(mask_string_literals(query)):
+ left_type = self.property_type(
+ match.group("left_var"),
+ self.canonical_property_name(
+ match.group("left_var"), match.group("left_prop"), variables
+ ),
+ variables,
+ )
+ right_type = self.property_type(
+ match.group("right_var"),
+ self.canonical_property_name(
+ match.group("right_var"), match.group("right_prop"), variables
+ ),
+ variables,
+ )
+ if (
+ self._is_temporal_type(left_type)
+ and self._is_numeric_type(right_type)
+ or self._is_numeric_type(left_type)
+ and self._is_temporal_type(right_type)
+ ):
+ issues.append(
+ CypherSchemaIssue(
+ "unsafe_temporal_numeric_comparison",
+ "Temporal and numeric properties are compared directly.",
+ )
+ )
+ aggregate_ref = re.compile(
+ r"\b(?PAVG|SUM|MIN|MAX)\s*\(\s*"
+ + property_ref.format(prefix="arg")
+ + r"\s*\)",
+ flags=re.IGNORECASE,
+ )
+ aggregate_matches = list(aggregate_ref.finditer(mask_string_literals(query)))
+ for match in aggregate_matches:
+ variable = match.group("arg_var")
+ if variable not in variables:
+ continue
+ property_name = self.canonical_property_name(
+ variable,
+ match.group("arg_prop"),
+ variables,
+ )
+ property_type = self.property_type(variable, property_name, variables)
+ if self._is_unsafe_numeric_text_property(property_name, property_type):
+ issues.append(
+ CypherSchemaIssue(
+ "unsafe_numeric_conversion",
+ f'Unsafe numeric aggregate over "{variable}.{property_name}".',
+ )
+ )
+ if self._has_arithmetic_between_temporal_aggregates(query, aggregate_matches, variables):
+ issues.append(
+ CypherSchemaIssue(
+ "unsafe_temporal_arithmetic",
+ "Temporal aggregate arithmetic requires explicit numeric conversion.",
+ )
+ )
+ if self._has_arithmetic_between_temporal_aggregate_aliases(
+ query,
+ aggregate_matches,
+ variables,
+ ):
+ issues.append(
+ CypherSchemaIssue(
+ "unsafe_temporal_arithmetic",
+ "Temporal aggregate alias arithmetic requires explicit numeric conversion.",
+ )
+ )
+ for variable, property_name in cypher_property_references(query):
+ if variable not in variables:
+ continue
+ canonical_property = self.canonical_property_name(variable, property_name, variables)
+ property_type = self.property_type(variable, canonical_property, variables)
+ if not self._is_unsafe_numeric_text_property(canonical_property, property_type):
+ continue
+ if self._property_reference_has_numeric_operator(query, variable, property_name):
+ issues.append(
+ CypherSchemaIssue(
+ "unsafe_numeric_conversion",
+ f'Unsafe numeric arithmetic over "{variable}.{canonical_property}".',
+ )
+ )
+ return issues
+
+ def canonical_node_label(self, label: str) -> str:
+ return self._canonical_schema_name(label, self.node_label_aliases)
+
+ def canonical_edge_label(self, label: str) -> str:
+ return self._canonical_schema_name(label, self.edge_label_aliases)
+
+ def canonical_property_name(
+ self,
+ variable: str,
+ property_name: str,
+ variables: dict[str, str],
+ ) -> str:
+ label = variables.get(variable, "")
+ if not label:
+ return self._canonical_schema_name(property_name, self.global_property_aliases)
+ canonical_label = (
+ self.canonical_node_label(label)
+ if label in self.node_label_aliases or label in self.node_props
+ else self.canonical_edge_label(label)
+ )
+ primary = self.node_primary.get(canonical_label, "")
+ if property_name.lower() in {"identity", "id"} and primary:
+ return primary
+ aliases = self.property_aliases_by_label.get(canonical_label, {})
+ canonical = self._canonical_schema_name(property_name, aliases)
+ if canonical != property_name:
+ return canonical
+ return self._canonical_schema_name(property_name, self.global_property_aliases)
+
+ def property_type(
+ self,
+ variable: str,
+ property_name: str,
+ variables: dict[str, str],
+ ) -> str:
+ label = variables.get(variable, "")
+ if not label:
+ return ""
+ labels = [self.canonical_node_label(label), self.canonical_edge_label(label)]
+ for candidate in labels:
+ properties = self.property_types_by_label.get(candidate, {})
+ if property_name in properties:
+ return properties[property_name]
+ canonical = self._canonical_schema_name(
+ property_name,
+ self.property_aliases_by_label.get(candidate, {}),
+ )
+ if canonical in properties:
+ return properties[canonical]
+ return ""
+
+ def redirected_property_target(
+ self,
+ query: str,
+ variable: str,
+ property_name: str,
+ ) -> tuple[str, str]:
+ node_variables, edge_variables = cypher_variable_labels(query)
+ if variable not in node_variables:
+ return "", ""
+ candidates: set[tuple[str, str]] = set()
+ for left_var, edge_var, right_var in cypher_variable_edge_adjacencies(query):
+ if variable not in {left_var, right_var} or edge_var not in edge_variables:
+ continue
+ edge_label = self.canonical_edge_label(edge_variables[edge_var])
+ canonical_property = self._canonical_schema_name(
+ property_name,
+ self.property_aliases_by_label.get(edge_label, {}),
+ )
+ if canonical_property in self.edge_props.get(edge_label, set()):
+ candidates.add((edge_var, canonical_property))
+ if len(candidates) == 1:
+ return next(iter(candidates))
+ return "", ""
+
+ def _valid_node_property(
+ self,
+ query: str,
+ variable: str,
+ label: str,
+ property_name: str,
+ ) -> bool:
+ if property_name.lower() in {"identity", "id"}:
+ return bool(self.node_primary.get(label))
+ canonical = self._canonical_schema_name(
+ property_name,
+ self.property_aliases_by_label.get(label, {}),
+ )
+ if canonical in self.node_props.get(label, set()):
+ return True
+ redirect_variable, _redirect_property = self.redirected_property_target(
+ query,
+ variable,
+ property_name,
+ )
+ return bool(redirect_variable)
+
+ def _valid_edge_property(self, label: str, property_name: str) -> bool:
+ if property_name.lower() in {"identity", "id"}:
+ return True
+ canonical = self._canonical_schema_name(
+ property_name,
+ self.property_aliases_by_label.get(label, {}),
+ )
+ return canonical in self.edge_props.get(label, set())
+
+ def _property_known_anywhere(self, property_name: str) -> bool:
+ aliases = {
+ property_name,
+ OracleNameSanitizer.clean(property_name, fallback=property_name),
+ re.sub(r"(? bool:
+ aliases = {
+ property_name,
+ OracleNameSanitizer.clean(property_name, fallback=property_name),
+ re.sub(r"(? dict[str, str]:
+ aliases: dict[str, str] = {}
+ for name in names:
+ cleaned = OracleNameSanitizer.clean(name, fallback=name)
+ for alias in {name, cleaned, name.lower(), cleaned.lower()}:
+ aliases.setdefault(alias, name)
+ return aliases
+
+ def _global_property_aliases(self) -> dict[str, str]:
+ candidates: dict[str, set[str]] = {}
+ for properties in self.property_types_by_label.values():
+ for property_name in properties:
+ cleaned = OracleNameSanitizer.clean(property_name, fallback=property_name)
+ snake = re.sub(r"(? str:
+ if not name:
+ return ""
+ cleaned = OracleNameSanitizer.clean(name, fallback=name)
+ snake = re.sub(r"(? bool:
+ lower = property_name.lower()
+ return any(
+ token in lower
+ for token in [
+ "percent",
+ "percentage",
+ "sla",
+ "requirement",
+ "embedding",
+ "vector",
+ "list",
+ "array",
+ ]
+ )
+
+ def _is_unsafe_numeric_text_property(self, property_name: str, property_type: str) -> bool:
+ return self._is_string_type(property_type) and self._looks_unsafe_numeric_text_property(
+ property_name
+ )
+
+ def _property_reference_has_numeric_operator(
+ self,
+ query: str,
+ variable: str,
+ property_name: str,
+ ) -> bool:
+ protected = mask_string_literals(query)
+ reference = (
+ rf"\b{re.escape(variable)}\."
+ rf"(?:`{re.escape(property_name)}`|{re.escape(property_name)})\b"
+ )
+ return bool(
+ re.search(reference + r"\s*[-+*/%]", protected)
+ or re.search(r"[-+*/%]\s*" + reference, protected)
+ )
+
+ def _has_arithmetic_between_temporal_aggregates(
+ self,
+ query: str,
+ aggregate_matches: list[re.Match],
+ variables: dict[str, str],
+ ) -> bool:
+ protected = mask_string_literals(query)
+ temporal_spans = []
+ for match in aggregate_matches:
+ variable = match.group("arg_var")
+ property_name = self.canonical_property_name(
+ variable,
+ match.group("arg_prop"),
+ variables,
+ )
+ if self._is_temporal_type(self.property_type(variable, property_name, variables)):
+ temporal_spans.append(match.span())
+ for _left_start, left_end in temporal_spans:
+ for right_start, _right_end in temporal_spans:
+ if left_end > right_start:
+ continue
+ between = protected[left_end:right_start]
+ if re.fullmatch(r"\s*[-+]\s*", between):
+ return True
+ return False
+
+ def _has_arithmetic_between_temporal_aggregate_aliases(
+ self,
+ query: str,
+ aggregate_matches: list[re.Match],
+ variables: dict[str, str],
+ ) -> bool:
+ protected = mask_string_literals(query)
+ temporal_aliases = set()
+ for match in aggregate_matches:
+ variable = match.group("arg_var")
+ property_name = self.canonical_property_name(
+ variable,
+ match.group("arg_prop"),
+ variables,
+ )
+ if not self._is_temporal_type(self.property_type(variable, property_name, variables)):
+ continue
+ alias_match = re.match(
+ r"\s+AS\s+(?P`[^`]+`|[A-Za-z_][A-Za-z0-9_]*)",
+ protected[match.end() :],
+ flags=re.IGNORECASE,
+ )
+ if alias_match:
+ temporal_aliases.add(alias_match.group("alias").strip("`"))
+ if len(temporal_aliases) < 2:
+ return False
+ alias_pattern = "|".join(re.escape(alias) for alias in sorted(temporal_aliases))
+ return bool(
+ re.search(
+ rf"\b(?:{alias_pattern})\b\s*[-+]\s*\b(?:{alias_pattern})\b",
+ protected,
+ flags=re.IGNORECASE,
+ )
+ )
+
+ def _is_string_type(self, type_name: str) -> bool:
+ return (
+ "CHAR" in type_name.upper()
+ or "STRING" in type_name.upper()
+ or "TEXT" in type_name.upper()
+ )
+
+ def _is_temporal_type(self, type_name: str) -> bool:
+ upper = type_name.upper()
+ return "DATE" in upper or "TIME" in upper
+
+ def _is_numeric_type(self, type_name: str) -> bool:
+ upper = type_name.upper()
+ return any(token in upper for token in ["INT", "NUMBER", "FLOAT", "DOUBLE", "DECIMAL"])
+
+
+def cypher_variable_labels(query: str) -> tuple[dict[str, str], dict[str, str]]:
+ node_labels: dict[str, str] = {}
+ edge_labels: dict[str, str] = {}
+ for match in re.finditer(
+ r"\(\s*(?P[A-Za-z_][A-Za-z0-9_]*)?\s*:\s*"
+ r"(?:`(?P[^`]+)`|(?P[A-Za-z_][A-Za-z0-9_$#-]*))",
+ query,
+ ):
+ variable = match.group("var")
+ if variable:
+ node_labels[variable] = _clean_schema_name(
+ match.group("quoted") or match.group("label")
+ )
+ for match in re.finditer(
+ r"\[\s*(?P[A-Za-z_][A-Za-z0-9_]*)?\s*:\s*"
+ r"(?:`(?P[^`]+)`|(?P[A-Za-z_][A-Za-z0-9_$#|.-]*))",
+ query,
+ ):
+ variable = match.group("var")
+ if variable:
+ edge_labels[variable] = _clean_schema_name(
+ match.group("quoted") or match.group("label")
+ )
+ return node_labels, edge_labels
+
+
+def cypher_graph_variables(query: str) -> set[str]:
+ protected = mask_string_literals(query)
+ variables = {
+ match.group("var")
+ for match in re.finditer(
+ r"\(\s*(?P[A-Za-z_][A-Za-z0-9_]*)"
+ r"(?=\s*(?::|\{|\)|WHERE\b))",
+ protected,
+ flags=re.IGNORECASE,
+ )
+ }
+ variables.update(
+ match.group("var")
+ for match in re.finditer(
+ r"\[\s*(?P[A-Za-z_][A-Za-z0-9_]*)"
+ r"(?=\s*(?::|\*|\]|\{))",
+ protected,
+ flags=re.IGNORECASE,
+ )
+ )
+ return variables
+
+
+def cypher_property_references(query: str) -> list[tuple[str, str]]:
+ protected = mask_string_literals(query)
+ references = []
+ for match in re.finditer(
+ r"\b(?P[A-Za-z_][A-Za-z0-9_]*)\."
+ r"(?:`(?P[^`]+)`|(?P[A-Za-z_][A-Za-z0-9_$#-]*))",
+ protected,
+ ):
+ references.append(
+ (match.group("var"), _clean_schema_name(match.group("quoted") or match.group("bare")))
+ )
+ return references
+
+
+def cypher_edge_triples(query: str) -> list[tuple[str, str, list[str], str]]:
+ node = (
+ r"\(\s*(?:[A-Za-z_][A-Za-z0-9_]*)?\s*:\s*"
+ r"(?:`(?P[^`]+)`|(?P[A-Za-z_][A-Za-z0-9_$#.-]*))"
+ r"(?:\s*\{[^}]*\})?\s*\)"
+ )
+ edge = (
+ r"\[\s*(?:[A-Za-z_][A-Za-z0-9_]*)?\s*:\s*"
+ r"(?P`[^`]+`|[A-Za-z_][A-Za-z0-9_$#.-]*"
+ r"(?:\s*\|\s*(?:`[^`]+`|[A-Za-z_][A-Za-z0-9_$#.-]*))*)"
+ r"(?:\s*\{[^}]*\})?\s*(?:\*\s*(?:\d+\s*)?(?:\.\.\s*\d*)?)?\s*\]"
+ )
+ triples = []
+ patterns = [
+ (
+ "right",
+ node.replace("NAME", "left")
+ + r"\s*-\s*"
+ + edge
+ + r"\s*->\s*"
+ + node.replace("NAME", "right"),
+ ),
+ (
+ "left",
+ node.replace("NAME", "left")
+ + r"\s*<-\s*"
+ + edge
+ + r"\s*-\s*"
+ + node.replace("NAME", "right"),
+ ),
+ (
+ "undirected",
+ node.replace("NAME", "left")
+ + r"\s*-\s*"
+ + edge
+ + r"\s*-\s*"
+ + node.replace("NAME", "right"),
+ ),
+ ]
+ for direction, pattern in patterns:
+ for match in re.finditer(pattern, query):
+ left = match.group("left_Q") or match.group("left")
+ right = match.group("right_Q") or match.group("right")
+ edge_labels = [
+ _clean_schema_name(item)
+ for item in re.split(r"\s*\|\s*", match.group("edge"))
+ if item.strip()
+ ]
+ triples.append(
+ (_clean_schema_name(left), direction, edge_labels, _clean_schema_name(right))
+ )
+ return triples
+
+
+def cypher_variable_edge_adjacencies(query: str) -> list[tuple[str, str, str]]:
+ node = r"\(\s*(?P[A-Za-z_][A-Za-z0-9_]*)\s*(?::[^)]*)?\)"
+ edge = r"\[\s*(?P[A-Za-z_][A-Za-z0-9_]*)\s*(?::[^\]]*)?\]"
+ adjacencies = []
+ for pattern in (
+ node.replace("NODE", "left")
+ + r"\s*-\s*"
+ + edge
+ + r"\s*->?\s*"
+ + node.replace("NODE", "right"),
+ node.replace("NODE", "left")
+ + r"\s*<-\s*"
+ + edge
+ + r"\s*-\s*"
+ + node.replace("NODE", "right"),
+ ):
+ for match in re.finditer(pattern, query):
+ adjacencies.append((match.group("left"), match.group("EDGE"), match.group("right")))
+ return adjacencies
+
+
+def mask_string_literals(query: str) -> str:
+ return re.sub(r"'(?:''|\\'|[^'])*'|\"(?:\\\"|[^\"])*\"", "''", query or "")
+
+
+def _clean_schema_name(value: str) -> str:
+ return str(value or "").strip().strip("`").strip('"')
+
+
+def _dedupe_issues(issues: list[CypherSchemaIssue]) -> list[CypherSchemaIssue]:
+ seen = set()
+ deduped = []
+ for issue in issues:
+ key = (issue.signature, issue.message)
+ if key in seen:
+ continue
+ seen.add(key)
+ deduped.append(issue)
+ return deduped
diff --git a/dataset_prep/discover.py b/dataset_prep/discover.py
new file mode 100644
index 0000000..5bc3699
--- /dev/null
+++ b/dataset_prep/discover.py
@@ -0,0 +1,85 @@
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Iterable, List
+
+
+@dataclass(frozen=True)
+class DatabaseUnit:
+ split: str
+ database: str
+ root: Path
+ query_path: Path
+ import_config_path: Path
+ csv_root: Path
+
+
+def discover_database_units(dataset_root: Path, splits: Iterable[str]) -> List[DatabaseUnit]:
+ units: List[DatabaseUnit] = []
+ for split in splits:
+ split_root = dataset_root / split
+ if not split_root.exists():
+ continue
+ if split == "train":
+ units.extend(_discover_train(split_root))
+ else:
+ units.extend(_discover_split_domains(split_root, split))
+ return sorted(units, key=lambda item: (item.split, item.database, str(item.query_path)))
+
+
+def _discover_train(split_root: Path) -> List[DatabaseUnit]:
+ units: List[DatabaseUnit] = []
+ for domain_root in split_root.iterdir():
+ if not domain_root.is_dir():
+ continue
+ query_path = domain_root / "4_level_results_ek_results.json"
+ config_paths = sorted((domain_root / "cypher").glob("*/import_config.json"))
+ if query_path.exists() and config_paths:
+ config_path = config_paths[0]
+ units.append(
+ DatabaseUnit(
+ split="train",
+ database=domain_root.name,
+ root=domain_root,
+ query_path=query_path,
+ import_config_path=config_path,
+ csv_root=config_path.parent,
+ )
+ )
+ return units
+
+
+def _discover_split_domains(split_root: Path, split: str) -> List[DatabaseUnit]:
+ units: List[DatabaseUnit] = []
+ for domain_root in split_root.iterdir():
+ if not domain_root.is_dir():
+ continue
+ cypher_root = domain_root / "Cypher"
+ if not cypher_root.exists():
+ continue
+ query_candidates = sorted(
+ path for path in cypher_root.glob("*_cypher.json") if path.is_file()
+ )
+ config_candidates = sorted(cypher_root.glob("**/import_config.json"))
+ if not query_candidates or not config_candidates:
+ continue
+ config_path = config_candidates[0]
+ for query_path in query_candidates:
+ units.append(
+ DatabaseUnit(
+ split=split,
+ database=domain_root.name,
+ root=domain_root,
+ query_path=query_path,
+ import_config_path=config_path,
+ csv_root=config_path.parent,
+ )
+ )
+ return units
+
+
+def source_query(record: dict) -> tuple[str, str]:
+ for key in ("initial_cypher", "cypher", "query", "initial_gql"):
+ value = record.get(key)
+ if isinstance(value, str) and value.strip():
+ return key, value
+ return "", ""
diff --git a/dataset_prep/export_validated_dataset.py b/dataset_prep/export_validated_dataset.py
new file mode 100644
index 0000000..a4af47a
--- /dev/null
+++ b/dataset_prep/export_validated_dataset.py
@@ -0,0 +1,374 @@
+from __future__ import annotations
+
+# ruff: noqa: E402,I001
+
+import argparse
+import json
+import os
+import shutil
+import sys
+from pathlib import Path
+from typing import Any, Dict, Iterable, List, Sequence
+
+REPO_ROOT = Path(__file__).resolve().parents[1]
+if str(REPO_ROOT) not in sys.path:
+ sys.path.insert(0, str(REPO_ROOT))
+
+from app.impl.oracle_sqlpgq.db_client.oracle_db_client import OracleDBClient
+from dataset_prep.compare_oracle_neo4j_results import (
+ DEFAULT_VALID_ORACLE_STATUSES,
+ DatasetNeo4jLoader,
+ compare_record,
+ load_enriched_records,
+ oracle_element_label_aliases,
+ select_records_for_range,
+ skip_reason_for_record,
+)
+from dataset_prep.discover import DatabaseUnit, discover_database_units
+from dataset_prep.oracle_loader import DatasetOracleLoader
+from dataset_prep.reporting import write_json
+from dataset_prep.translate_validate import graph_name_for
+
+
+ORACLE_EXPORT_PREFIX = "oracle_"
+SOURCE_QUERY_FIELD_NAMES = ("initial_cypher", "initial_gql", "cypher", "query")
+
+
+def main() -> None:
+ args = parse_args()
+ dataset_root = Path(args.dataset_root).resolve()
+ output_root = Path(args.output_root)
+ prepare_output_root(output_root, overwrite=args.overwrite)
+
+ units = discover_database_units(dataset_root, args.splits)
+ if args.databases:
+ requested = {name.lower() for name in args.databases}
+ units = [unit for unit in units if unit.database.lower() in requested]
+ if args.limit_databases:
+ units = units[: args.limit_databases]
+ if not units:
+ raise SystemExit("No dataset units matched the requested filters.")
+
+ query_paths = {unit.query_path.resolve() for unit in units}
+ if args.copy_assets:
+ copy_dataset_assets(units, dataset_root, output_root, query_paths)
+
+ oracle_client = OracleDBClient(
+ {
+ "dsn": os.environ["ORACLE_DSN"],
+ "user": os.environ["ORACLE_USER"],
+ "password": os.environ["ORACLE_PASSWORD"],
+ }
+ )
+ summaries: List[Dict[str, Any]] = []
+ try:
+ for unit in units:
+ print(f"[start] {unit.split}/{unit.database}", flush=True)
+ summary = export_unit(unit, oracle_client, dataset_root, output_root, args)
+ summaries.append(summary)
+ print(
+ f"[done] {unit.split}/{unit.database}: "
+ f"exported={summary['exported']} failed={summary['failed']} "
+ f"skipped={summary['skipped']}",
+ flush=True,
+ )
+ finally:
+ oracle_client.close()
+
+ write_json(output_root / "export_summary.json", merge_export_summaries(summaries))
+
+
+def parse_args() -> argparse.Namespace:
+ parser = argparse.ArgumentParser(
+ description=(
+ "Export only Oracle SQL/PGQ records whose translated SQL/PGQ and source "
+ "Cypher results match on Oracle and Neo4j."
+ )
+ )
+ parser.add_argument("--dataset-root", default="dataset")
+ parser.add_argument("--dataset-output-root", default="output/dataset_prep")
+ parser.add_argument("--output-root", default="output/oracle_sqlpgq_dataset")
+ parser.add_argument("--splits", nargs="+", default=["train", "dev", "test"])
+ parser.add_argument("--databases", nargs="*", default=[])
+ parser.add_argument("--limit-databases", type=int, default=0)
+ parser.add_argument("--limit-queries", type=int, default=0)
+ parser.add_argument("--query-offset", type=int, default=0)
+ parser.add_argument("--graph-prefix", default="T2GQL")
+ parser.add_argument("--sql-pgq-field", default="initial_sql_pgq")
+ parser.add_argument("--include-oracle-metadata", action="store_true")
+ parser.add_argument("--copy-assets", action=argparse.BooleanOptionalAction, default=True)
+ parser.add_argument("--overwrite", action="store_true")
+ parser.add_argument(
+ "--oracle-statuses",
+ nargs="+",
+ default=sorted(DEFAULT_VALID_ORACLE_STATUSES),
+ help="Prior Oracle validation statuses eligible for export.",
+ )
+ parser.add_argument("--include-all-translatable", action="store_true")
+ parser.add_argument("--oracle-timeout-ms", type=int, default=60000)
+ parser.add_argument("--neo4j-timeout-s", type=float, default=60.0)
+ parser.add_argument("--neo4j-uri", default=os.environ.get("NEO4J_URI", "bolt://localhost:7687"))
+ parser.add_argument("--neo4j-user", default=os.environ.get("NEO4J_USER", "neo4j"))
+ parser.add_argument("--neo4j-password", default=os.environ.get("NEO4J_PASSWORD", "password"))
+ parser.add_argument("--neo4j-database", default=os.environ.get("NEO4J_DATABASE", "neo4j"))
+ parser.add_argument("--neo4j-batch-size", type=int, default=1000)
+ parser.add_argument("--keep-loaded", action="store_true")
+ parser.add_argument(
+ "--reuse-loaded",
+ action="store_true",
+ help="Skip Oracle/Neo4j load setup and export against already-loaded graphs.",
+ )
+ parser.add_argument("--progress-every", type=int, default=0)
+ return parser.parse_args()
+
+
+def prepare_output_root(output_root: Path, overwrite: bool = False) -> None:
+ if output_root.exists() and any(output_root.iterdir()):
+ if not overwrite:
+ raise SystemExit(
+ f"Output root already exists and is not empty: {output_root}. "
+ "Use --overwrite or choose a different --output-root."
+ )
+ shutil.rmtree(output_root)
+ output_root.mkdir(parents=True, exist_ok=True)
+
+
+def export_unit(
+ unit: DatabaseUnit,
+ oracle_client: OracleDBClient,
+ dataset_root: Path,
+ output_root: Path,
+ args: argparse.Namespace,
+) -> Dict[str, Any]:
+ graph_name = graph_name_for(unit, args.graph_prefix)
+ oracle_loader = DatasetOracleLoader(
+ oracle_client,
+ unit.import_config_path,
+ unit.csv_root,
+ graph_name,
+ )
+ neo4j_loader = DatasetNeo4jLoader(
+ args.neo4j_uri,
+ args.neo4j_user,
+ args.neo4j_password,
+ args.neo4j_database,
+ unit.import_config_path,
+ unit.csv_root,
+ args.neo4j_batch_size,
+ )
+ summary: Dict[str, Any] = {
+ "split": unit.split,
+ "database": unit.database,
+ "query_file": str(unit.query_path),
+ "import_config": str(unit.import_config_path),
+ "graph_name": graph_name,
+ "total_records": 0,
+ "selected_records": 0,
+ "considered": 0,
+ "exported": 0,
+ "failed": 0,
+ "skipped": 0,
+ "skip_reasons": {},
+ "failure_reasons": {},
+ "output_query_file": str(output_root / unit.query_path.relative_to(dataset_root)),
+ }
+ exported_records: List[Dict[str, Any]] = []
+ element_label_aliases = oracle_element_label_aliases(oracle_loader)
+ valid_statuses = set(args.oracle_statuses)
+ try:
+ if args.reuse_loaded:
+ summary["loaded"] = {"reused": True}
+ print(f"[load] {unit.split}/{unit.database}: reusing loaded graphs", flush=True)
+ else:
+ print(f"[load] {unit.split}/{unit.database}: oracle", flush=True)
+ oracle_counts = oracle_loader.setup()
+ print(f"[load] {unit.split}/{unit.database}: neo4j", flush=True)
+ neo4j_counts = neo4j_loader.setup(clear=True)
+ summary["loaded"] = {"oracle": oracle_counts, "neo4j": neo4j_counts}
+ print(f"[load] {unit.split}/{unit.database}: done", flush=True)
+
+ all_records = load_enriched_records(unit, Path(args.dataset_output_root))
+ records = select_records_for_range(all_records, args.query_offset, args.limit_queries)
+ summary["total_records"] = len(all_records)
+ summary["query_offset"] = max(args.query_offset, 0)
+ summary["selected_records"] = len(records)
+ if args.limit_queries:
+ summary["limit_queries"] = args.limit_queries
+
+ for selected_index, record in enumerate(records, start=1):
+ if args.progress_every and selected_index % args.progress_every == 0:
+ print(
+ f"[progress] {unit.split}/{unit.database}: "
+ f"{selected_index}/{len(records)} selected records",
+ flush=True,
+ )
+ skip_reason = skip_reason_for_record(
+ record,
+ valid_statuses=valid_statuses,
+ include_all_translatable=args.include_all_translatable,
+ )
+ if skip_reason:
+ summary["skipped"] += 1
+ increment(summary["skip_reasons"], skip_reason)
+ continue
+
+ summary["considered"] += 1
+ comparison = compare_record(
+ record,
+ oracle_client,
+ neo4j_loader,
+ args,
+ element_label_aliases=element_label_aliases,
+ )
+ if comparison["matched"]:
+ exported_records.append(project_export_record(record, args))
+ summary["exported"] += 1
+ continue
+
+ if comparison["reason"] in {
+ "nondeterministic_limit_without_order",
+ "nondeterministic_with_limit_without_order",
+ "suspected_order_by_limit_tie",
+ "source_invalid",
+ }:
+ summary["skipped"] += 1
+ increment(summary["skip_reasons"], comparison["reason"])
+ continue
+
+ summary["failed"] += 1
+ increment(summary["failure_reasons"], comparison["reason"] or "unknown")
+
+ write_records_like_source(
+ unit.query_path,
+ output_root / unit.query_path.relative_to(dataset_root),
+ exported_records,
+ )
+ return summary
+ finally:
+ if not args.keep_loaded:
+ if not args.reuse_loaded:
+ oracle_loader.cleanup(ignore_errors=True)
+ try:
+ neo4j_loader.clear()
+ except Exception:
+ pass
+ neo4j_loader.close()
+
+
+def project_export_record(record: Dict[str, Any], args: argparse.Namespace) -> Dict[str, Any]:
+ sql_pgq = record.get("oracle_sqlpgq")
+ if not args.include_oracle_metadata:
+ base_record = {
+ key: value
+ for key, value in record.items()
+ if not key.startswith(ORACLE_EXPORT_PREFIX)
+ }
+ else:
+ base_record = dict(record)
+ return insert_sql_pgq_field(base_record, args.sql_pgq_field, sql_pgq)
+
+
+def insert_sql_pgq_field(
+ record: Dict[str, Any],
+ field_name: str,
+ sql_pgq: Any,
+) -> Dict[str, Any]:
+ query_field_positions = [
+ index for index, key in enumerate(record) if key in SOURCE_QUERY_FIELD_NAMES
+ ]
+ insert_after = max(query_field_positions) if query_field_positions else len(record) - 1
+ output: Dict[str, Any] = {}
+ inserted = False
+ for index, (key, value) in enumerate(record.items()):
+ if key == field_name:
+ continue
+ output[key] = value
+ if index == insert_after:
+ output[field_name] = sql_pgq
+ inserted = True
+ if not inserted:
+ output[field_name] = sql_pgq
+ return output
+
+
+def write_records_like_source(
+ source_path: Path,
+ output_path: Path,
+ records: Sequence[Dict[str, Any]],
+) -> None:
+ source_data = json.loads(source_path.read_text(encoding="utf-8"))
+ output_path.parent.mkdir(parents=True, exist_ok=True)
+ if isinstance(source_data, dict):
+ keyed_records = {
+ str(record.get("id", index)): record for index, record in enumerate(records)
+ }
+ output_path.write_text(
+ json.dumps(keyed_records, indent=2, ensure_ascii=False),
+ encoding="utf-8",
+ )
+ return
+ output_path.write_text(
+ json.dumps(list(records), indent=2, ensure_ascii=False),
+ encoding="utf-8",
+ )
+
+
+def copy_dataset_assets(
+ units: Sequence[DatabaseUnit],
+ dataset_root: Path,
+ output_root: Path,
+ query_paths: set[Path],
+) -> None:
+ copied_roots: set[Path] = set()
+ for unit in units:
+ root = unit.root.resolve()
+ if root in copied_roots:
+ continue
+ copied_roots.add(root)
+ for source_path in root.rglob("*"):
+ if source_path.is_dir() or source_path.resolve() in query_paths:
+ continue
+ relative_path = source_path.relative_to(dataset_root)
+ target_path = output_root / relative_path
+ target_path.parent.mkdir(parents=True, exist_ok=True)
+ shutil.copy2(source_path, target_path)
+
+
+def merge_export_summaries(summaries: Iterable[Dict[str, Any]]) -> Dict[str, Any]:
+ merged: Dict[str, Any] = {
+ "databases": 0,
+ "total_records": 0,
+ "selected_records": 0,
+ "considered": 0,
+ "exported": 0,
+ "failed": 0,
+ "skipped": 0,
+ "skip_reasons": {},
+ "failure_reasons": {},
+ "units": [],
+ }
+ for summary in summaries:
+ merged["databases"] += 1
+ for key in (
+ "total_records",
+ "selected_records",
+ "considered",
+ "exported",
+ "failed",
+ "skipped",
+ ):
+ merged[key] += int(summary.get(key, 0))
+ for key, value in summary.get("skip_reasons", {}).items():
+ increment(merged["skip_reasons"], key, int(value))
+ for key, value in summary.get("failure_reasons", {}).items():
+ increment(merged["failure_reasons"], key, int(value))
+ merged["units"].append(summary)
+ return merged
+
+
+def increment(counter: Dict[str, int], key: str, amount: int = 1) -> None:
+ counter[key] = counter.get(key, 0) + amount
+
+
+if __name__ == "__main__":
+ main()
diff --git a/dataset_prep/oracle_loader.py b/dataset_prep/oracle_loader.py
new file mode 100644
index 0000000..8ce98cc
--- /dev/null
+++ b/dataset_prep/oracle_loader.py
@@ -0,0 +1,250 @@
+from __future__ import annotations
+
+import csv
+from datetime import date, datetime
+import json
+from pathlib import Path
+from typing import Any, Dict, Iterable, List
+
+from app.core.validator.db_client import QueryStatus
+from app.impl.oracle_sqlpgq.db_client.oracle_db_client import OracleDBClient
+from app.impl.oracle_sqlpgq.schema.schema_parser import OracleSqlPgqSchemaParser
+from app.impl.oracle_sqlpgq.utils.sqlpgq import OracleNameSanitizer, split_sql_statements
+
+
+class DatasetOracleLoader:
+ def __init__(
+ self,
+ client: OracleDBClient,
+ import_config_path: Path,
+ csv_root: Path,
+ graph_name: str,
+ ):
+ self.client = client
+ self.import_config_path = import_config_path
+ self.csv_root = csv_root
+ self.graph_name = OracleNameSanitizer.clean(graph_name, fallback="GRAPH")
+ self.config = json.loads(import_config_path.read_text(encoding="utf-8"))
+ self.manifest = self._build_manifest()
+
+ def setup(self) -> Dict[str, int]:
+ self.cleanup(ignore_errors=True)
+ self._execute_script(self.manifest["table_ddl"])
+ counts = self._load_csv_files()
+ self._execute_script(self.manifest["property_graph_ddl"])
+ return counts
+
+ def node_label_map(self) -> Dict[str, List[str]]:
+ return self._label_map(self.manifest["vertices"])
+
+ def edge_label_map(self) -> Dict[str, List[str]]:
+ return self._label_map(self.manifest["edges"])
+
+ def property_type_map(self) -> Dict[str, Dict[str, str]]:
+ mapping: Dict[str, Dict[str, str]] = {}
+ for item in self.manifest["vertices"] + self.manifest["edges"]:
+ mapping.setdefault(item["label"], {})
+ mapping.setdefault(item.get("graph_label", item["label"]), {})
+ for column in item["columns"]:
+ mapping[item["label"]][column["name"]] = column["type"]
+ mapping[item.get("graph_label", item["label"])][column["name"]] = column["type"]
+ return mapping
+
+ def node_primary_key_map(self) -> Dict[str, str]:
+ mapping: Dict[str, str] = {}
+ for vertex in self.manifest["vertices"]:
+ mapping[vertex["label"]] = vertex["primary"]
+ mapping[vertex.get("graph_label", vertex["label"])] = vertex["primary"]
+ return mapping
+
+ def edge_primary_key_map(self) -> Dict[str, str]:
+ mapping: Dict[str, str] = {}
+ for edge in self.manifest["edges"]:
+ mapping[edge["label"]] = edge["primary"]
+ mapping[edge.get("graph_label", edge["label"])] = edge["primary"]
+ return mapping
+
+ def cleanup(self, ignore_errors: bool = False) -> None:
+ self._execute(
+ f"DROP PROPERTY GRAPH {OracleNameSanitizer.quote(self.graph_name)}",
+ ignore_errors=ignore_errors,
+ )
+ for edge in reversed(self.manifest["edges"]):
+ self._execute(f"DROP TABLE {edge['quoted_table']} CASCADE CONSTRAINTS PURGE", True)
+ for vertex in reversed(self.manifest["vertices"]):
+ self._execute(f"DROP TABLE {vertex['quoted_table']} CASCADE CONSTRAINTS PURGE", True)
+
+ def _build_manifest(self) -> Dict[str, Any]:
+ parser = OracleSqlPgqSchemaParser(
+ db_id=self.graph_name,
+ instance_path=str(self.import_config_path),
+ enforced=False,
+ include_foreign_keys=False,
+ promote_mixed_property_types=True,
+ )
+ return parser.build_manifest(parser.get_schema_graph(), graph_name=self.graph_name)
+
+ def _label_map(self, items: Iterable[Dict[str, Any]]) -> Dict[str, List[str]]:
+ mapping: Dict[str, List[str]] = {}
+ files_by_label: Dict[str, List[Dict[str, Any]]] = {}
+ for file_item in getattr(self, "config", {}).get("files", []):
+ files_by_label.setdefault(file_item.get("label", ""), []).append(file_item)
+ for item in items:
+ graph_label = item.get("graph_label", item["label"])
+ mapping.setdefault(item["label"], []).append(graph_label)
+ for file_item in files_by_label.get(item["label"], []):
+ alias = Path(str(file_item.get("path", ""))).stem
+ if alias and alias != item["label"]:
+ mapping.setdefault(alias, []).append(graph_label)
+ return mapping
+
+ def _execute_script(self, script: str) -> None:
+ for statement in split_sql_statements(script):
+ self._execute(statement)
+
+ def _execute(self, sql: str, ignore_errors: bool = False) -> None:
+ result = self.client.execute_query(sql)
+ if result.status_code == QueryStatus.SUCCESS:
+ return
+ if ignore_errors and result.error:
+ return
+ raise RuntimeError(f"Oracle SQL failed:\n{sql}\n\nError:\n{result.error}")
+
+ def _load_csv_files(self) -> Dict[str, int]:
+ files = self.config.get("files", [])
+ counts: Dict[str, int] = {}
+ by_label = {item["label"]: item for item in files if "SRC_ID" not in item}
+ for vertex in self.manifest["vertices"]:
+ file_item = by_label.get(vertex["label"])
+ if file_item:
+ counts[vertex["table"]] = self._load_file(vertex, file_item)
+ for edge in self.manifest["edges"]:
+ file_item = self._find_edge_file(edge, files)
+ if file_item:
+ counts[edge["table"]] = self._load_file(edge, file_item, is_edge=True)
+ return counts
+
+ def _find_edge_file(self, edge: Dict[str, Any], files: Iterable[Dict[str, Any]]):
+ same_label_files = []
+ for file_item in files:
+ if file_item.get("label") != edge["label"]:
+ continue
+ same_label_files.append(file_item)
+ if file_item.get("SRC_ID") == edge["src"] and file_item.get("DST_ID") == edge["dst"]:
+ return file_item
+ if any("SRC_ID" in item or "DST_ID" in item for item in same_label_files):
+ return None
+ if same_label_files:
+ return same_label_files[0]
+ return None
+
+ def _load_file(
+ self,
+ item: Dict[str, Any],
+ file_item: Dict[str, Any],
+ is_edge: bool = False,
+ ) -> int:
+ csv_path = self.csv_root / file_item["path"]
+ if not csv_path.exists():
+ raise FileNotFoundError(f"CSV file not found: {csv_path}")
+ source_columns = list(file_item.get("columns", []))
+ target_columns = [column["name"] for column in item["columns"]]
+ if is_edge:
+ target_columns = [column for column in target_columns if column != "EDGE_ID"]
+ insert_columns = [column for column in source_columns if column in target_columns]
+ type_by_column = {column["name"]: column["type"] for column in item["columns"]}
+ rows = self._read_rows(
+ csv_path,
+ int(file_item.get("header", 0)),
+ source_columns,
+ insert_columns,
+ type_by_column,
+ )
+ if not rows:
+ return 0
+ quoted_columns = ", ".join(OracleNameSanitizer.quote(column) for column in insert_columns)
+ binds = ", ".join(f":{index + 1}" for index in range(len(insert_columns)))
+ statement = f"INSERT INTO {item['quoted_table']} ({quoted_columns}) VALUES ({binds})"
+ result = self.client.executemany(statement, rows)
+ if result.status_code != QueryStatus.SUCCESS:
+ raise RuntimeError(f"Failed loading {csv_path}: {result.error}")
+ return len(rows)
+
+ def _read_rows(
+ self,
+ csv_path: Path,
+ header_rows: int,
+ source_columns: List[str],
+ insert_columns: List[str],
+ type_by_column: Dict[str, str],
+ ) -> List[List[Any]]:
+ positions = [source_columns.index(column) for column in insert_columns]
+ rows: List[List[Any]] = []
+ with open(csv_path, newline="", encoding="utf-8-sig") as file:
+ reader = csv.reader(file)
+ for index, row in enumerate(reader):
+ if index < header_rows:
+ continue
+ rows.append(
+ [
+ self._convert_value(row[position], type_by_column[insert_columns[item]])
+ for item, position in enumerate(positions)
+ ]
+ )
+ return rows
+
+ def _convert_value(self, value: str, oracle_type: str) -> Any:
+ if value == "":
+ return None
+ normalized_type = oracle_type.upper()
+ if normalized_type.startswith("VARCHAR2("):
+ max_length = self._varchar2_length(normalized_type)
+ if max_length:
+ return self._truncate_utf8(value, max_length)
+ return value
+ if normalized_type == "NUMBER(1)" and value.lower() in ("true", "false"):
+ return 1 if value.lower() == "true" else 0
+ if normalized_type.startswith("DATE"):
+ parsed = self._parse_datetime(value)
+ return parsed.date() if isinstance(parsed, datetime) else parsed
+ if normalized_type.startswith("TIMESTAMP"):
+ parsed = self._parse_datetime(value)
+ if isinstance(parsed, date) and not isinstance(parsed, datetime):
+ return datetime.combine(parsed, datetime.min.time())
+ return parsed
+ return value
+
+ def _varchar2_length(self, oracle_type: str) -> int:
+ if not oracle_type.startswith("VARCHAR2("):
+ return 0
+ length = oracle_type.removeprefix("VARCHAR2(").split(")", 1)[0]
+ return int(length) if length.isdigit() else 0
+
+ def _truncate_utf8(self, value: str, max_bytes: int) -> str:
+ encoded = value.encode("utf-8")
+ if len(encoded) <= max_bytes:
+ return value
+ return encoded[:max_bytes].decode("utf-8", errors="ignore")
+
+ def _parse_datetime(self, value: str) -> date | datetime:
+ normalized = value.strip()
+ try:
+ return datetime.fromisoformat(normalized.replace("Z", "+00:00"))
+ except ValueError:
+ pass
+ for fmt in (
+ "%Y-%m-%d %H:%M:%S",
+ "%Y-%m-%d",
+ "%m/%d/%Y %H:%M:%S",
+ "%m/%d/%Y",
+ "%Y/%m/%d %H:%M:%S",
+ "%Y/%m/%d",
+ ):
+ try:
+ parsed = datetime.strptime(normalized, fmt)
+ if fmt in ("%Y-%m-%d", "%m/%d/%Y", "%Y/%m/%d"):
+ return parsed.date()
+ return parsed
+ except ValueError:
+ continue
+ return normalized
diff --git a/dataset_prep/preflight.py b/dataset_prep/preflight.py
new file mode 100644
index 0000000..d884ad7
--- /dev/null
+++ b/dataset_prep/preflight.py
@@ -0,0 +1,66 @@
+from __future__ import annotations
+
+from dataclasses import dataclass
+import os
+import platform
+from typing import List
+
+
+@dataclass
+class PreflightResult:
+ ok: bool
+ errors: List[str]
+ warnings: List[str]
+
+
+EXPAT_DYLD_PATH = "/opt/homebrew/opt/expat/lib"
+
+
+def run_preflight(require_oracle_env: bool = True) -> PreflightResult:
+ errors: List[str] = []
+ warnings: List[str] = []
+ _apply_expat_workaround_if_present(warnings)
+
+ for module_name in ("antlr4", "oracledb"):
+ try:
+ __import__(module_name)
+ except Exception as exc:
+ errors.append(f"Failed to import {module_name}: {exc}")
+
+ try:
+ from examples.cypher2oracle_sqlpgq import cypher2oracle_sqlpgq # noqa: F401
+ except Exception as exc:
+ message = str(exc)
+ if "pyexpat" in message or "libexpat" in message:
+ warnings.append(_expat_help_text())
+ errors.append(f"Failed to import Cypher to Oracle translator: {exc}")
+
+ if require_oracle_env:
+ for env_name in ("ORACLE_DSN", "ORACLE_USER", "ORACLE_PASSWORD"):
+ if not os.getenv(env_name):
+ errors.append(f"Missing required environment variable: {env_name}")
+
+ return PreflightResult(ok=not errors, errors=errors, warnings=warnings)
+
+
+def _apply_expat_workaround_if_present(warnings: List[str]) -> None:
+ if platform.system() != "Darwin":
+ return
+ current = os.environ.get("DYLD_LIBRARY_PATH", "")
+ if EXPAT_DYLD_PATH in current:
+ return
+ if os.path.isdir(EXPAT_DYLD_PATH):
+ os.environ["DYLD_LIBRARY_PATH"] = (
+ f"{EXPAT_DYLD_PATH}:{current}" if current else EXPAT_DYLD_PATH
+ )
+ warnings.append(
+ "Applied macOS expat workaround from oracle_sqlpgq_data_generation_workflow.md."
+ )
+
+
+def _expat_help_text() -> str:
+ return (
+ "pyexpat/libexpat issue detected. Run with "
+ 'DYLD_LIBRARY_PATH="/opt/homebrew/opt/expat/lib:${DYLD_LIBRARY_PATH:-}" '
+ "as documented in doc/en-us/development/oracle_sqlpgq_data_generation_workflow.md."
+ )
diff --git a/dataset_prep/reporting.py b/dataset_prep/reporting.py
new file mode 100644
index 0000000..a206196
--- /dev/null
+++ b/dataset_prep/reporting.py
@@ -0,0 +1,56 @@
+from __future__ import annotations
+
+import json
+from pathlib import Path
+from typing import Any, Dict, Iterable, List
+
+
+def write_json(path: Path, data: Dict[str, Any]) -> None:
+ path.parent.mkdir(parents=True, exist_ok=True)
+ path.write_text(json.dumps(data, indent=2, ensure_ascii=False), encoding="utf-8")
+
+
+def append_jsonl(path: Path, records: Iterable[Dict[str, Any]]) -> int:
+ path.parent.mkdir(parents=True, exist_ok=True)
+ count = 0
+ with open(path, "a", encoding="utf-8") as file:
+ for record in records:
+ file.write(json.dumps(record, ensure_ascii=False) + "\n")
+ count += 1
+ return count
+
+
+def summarize_records(records: List[Dict[str, Any]]) -> Dict[str, Any]:
+ status_counts: Dict[str, int] = {}
+ category_counts: Dict[str, int] = {}
+ for record in records:
+ status = record.get("oracle_validation_status", "unknown")
+ category = record.get("oracle_translation_category", "unknown")
+ status_counts[status] = status_counts.get(status, 0) + 1
+ category_counts[category] = category_counts.get(category, 0) + 1
+ return {
+ "total": len(records),
+ "translation_categories": category_counts,
+ "validation_statuses": status_counts,
+ "complete": True,
+ }
+
+
+def merge_global_summaries(summaries: Iterable[Dict[str, Any]]) -> Dict[str, Any]:
+ total = 0
+ status_counts: Dict[str, int] = {}
+ category_counts: Dict[str, int] = {}
+ databases = 0
+ for summary in summaries:
+ databases += 1
+ total += int(summary.get("total", 0))
+ for key, value in summary.get("validation_statuses", {}).items():
+ status_counts[key] = status_counts.get(key, 0) + int(value)
+ for key, value in summary.get("translation_categories", {}).items():
+ category_counts[key] = category_counts.get(key, 0) + int(value)
+ return {
+ "databases": databases,
+ "total": total,
+ "translation_categories": category_counts,
+ "validation_statuses": status_counts,
+ }
diff --git a/dataset_prep/translate_validate.py b/dataset_prep/translate_validate.py
new file mode 100644
index 0000000..13492ff
--- /dev/null
+++ b/dataset_prep/translate_validate.py
@@ -0,0 +1,545 @@
+from __future__ import annotations
+
+import argparse
+import json
+import os
+from pathlib import Path
+import re
+from typing import Any, Dict, List
+
+from app.core.validator.db_client import QueryStatus
+from app.impl.oracle_sqlpgq.db_client.oracle_db_client import OracleDBClient
+from app.impl.oracle_sqlpgq.utils.sqlpgq import OracleNameSanitizer
+from dataset_prep.cypher_schema import CypherSchema
+from dataset_prep.discover import DatabaseUnit, discover_database_units, source_query
+from dataset_prep.oracle_loader import DatasetOracleLoader
+from dataset_prep.preflight import run_preflight
+from dataset_prep.reporting import merge_global_summaries, summarize_records, write_json
+from examples.cypher2oracle_sqlpgq import cypher2oracle_sqlpgq
+
+UNSUPPORTED_PATTERNS = {
+ "shortest_path": re.compile(r"\b(ANY|ALL)\s+SHORTEST\b|\bSHORTEST\b", re.IGNORECASE),
+ "cost": re.compile(
+ r"\b(?:ANY|ALL)\s+CHEAPEST\b|\bTOTAL\s+COST\b|\bCOST\s*\(",
+ re.IGNORECASE,
+ ),
+ "inline_subquery": re.compile(r"\bCALL\s*\{|\bEXISTS\s*\{", re.IGNORECASE),
+ "lateral": re.compile(r"\bLATERAL\b", re.IGNORECASE),
+ "optional_match": re.compile(r"\bOPTIONAL\s+MATCH\b", re.IGNORECASE),
+ "relative_duration": re.compile(r"\bdate\s*\(\s*\)\s*[-+]\s*duration\s*\(", re.IGNORECASE),
+ "unwind": re.compile(r"\bUNWIND\b", re.IGNORECASE),
+ "open_ended_variable_length_path": re.compile(
+ r"-\s*\[[^\]]*\*\s*(?:(?:\d+\s*)?\.\.\s*|\.\.\s*\d+)\]\s*(?:->|-)|"
+ r"(?:<-|-)\s*\[[^\]]*\*\s*(?:(?:\d+\s*)?\.\.\s*|\.\.\s*\d+)\]\s*-|"
+ r"-\s*\[[^\]]*\*\s*\]\s*(?:->|-)|"
+ r"(?:<-|-)\s*\[[^\]]*\*\s*\]\s*-",
+ re.IGNORECASE,
+ ),
+ "case_label_predicate": re.compile(
+ r"\bCASE\b(?:(?!\bEND\b).)*\b[A-Za-z_][A-Za-z0-9_]*\s*:",
+ re.IGNORECASE | re.DOTALL,
+ ),
+}
+
+
+def main() -> None:
+ args = parse_args()
+ dataset_root = Path(args.dataset_root)
+ output_root = Path(args.output_root)
+ preflight = run_preflight(require_oracle_env=not args.skip_live_validation)
+ for warning in preflight.warnings:
+ print(f"[preflight warning] {warning}")
+ if not preflight.ok:
+ for error in preflight.errors:
+ print(f"[preflight error] {error}")
+ raise SystemExit(2)
+
+ units = discover_database_units(dataset_root, args.splits)
+ if args.databases:
+ requested = {name.lower() for name in args.databases}
+ units = [unit for unit in units if unit.database.lower() in requested]
+ if args.limit_databases:
+ units = units[: args.limit_databases]
+
+ global_summaries: List[Dict[str, Any]] = []
+ unsupported_samples: List[Dict[str, Any]] = []
+ for unit in units:
+ summary_path = output_root / unit.split / unit.database / "summary.json"
+ if args.resume and summary_path.exists():
+ summary = json.loads(summary_path.read_text(encoding="utf-8"))
+ if summary.get("complete"):
+ global_summaries.append(summary)
+ print(f"[skip] {unit.split}/{unit.database}")
+ continue
+ try:
+ print(f"[start] {unit.split}/{unit.database}", flush=True)
+ summary, samples = process_unit(unit, output_root, args)
+ global_summaries.append(summary)
+ unsupported_samples.extend(samples)
+ print(f"[done] {unit.split}/{unit.database}: {summary['validation_statuses']}")
+ except Exception as exc:
+ print(f"[error] {unit.split}/{unit.database}: {exc}")
+ if args.fail_fast:
+ raise
+
+ write_json(output_root / "global_summary.json", merge_global_summaries(global_summaries))
+ if unsupported_samples:
+ write_jsonl(output_root / "unsupported_samples.jsonl", unsupported_samples)
+
+
+def parse_args() -> argparse.Namespace:
+ parser = argparse.ArgumentParser(
+ description="Translate Text2GQL dataset queries to Oracle SQL/PGQ."
+ )
+ parser.add_argument("--dataset-root", default="dataset")
+ parser.add_argument("--output-root", default="output/dataset_prep")
+ parser.add_argument("--splits", nargs="+", default=["train", "dev", "test"])
+ parser.add_argument("--databases", nargs="*", default=[])
+ parser.add_argument("--graph-prefix", default="T2GQL")
+ parser.add_argument("--limit-databases", type=int, default=0)
+ parser.add_argument("--limit-queries", type=int, default=0)
+ parser.add_argument(
+ "--query-offset",
+ type=int,
+ default=0,
+ help="Skip this many source records before applying --limit-queries.",
+ )
+ parser.add_argument(
+ "--progress-every",
+ type=int,
+ default=0,
+ help="Print query progress every N selected records. Use 1 for every record.",
+ )
+ parser.add_argument("--keep-db-on-failure", action="store_true")
+ parser.add_argument("--skip-live-validation", action="store_true")
+ parser.add_argument(
+ "--oracle-validation-timeout-ms",
+ type=int,
+ default=int(os.environ.get("ORACLE_VALIDATION_TIMEOUT_MS", "60000")),
+ help="Oracle timeout for each translated query validation call. Use 0 to disable.",
+ )
+ parser.add_argument(
+ "--oracle-validation-fetch-limit",
+ type=int,
+ default=int(os.environ.get("ORACLE_VALIDATION_FETCH_LIMIT", "1")),
+ help="Rows fetched for each translated query validation call. Use 0 to fetch all.",
+ )
+ parser.add_argument("--resume", action="store_true")
+ parser.add_argument("--fail-fast", action="store_true")
+ return parser.parse_args()
+
+
+def process_unit(
+ unit: DatabaseUnit,
+ output_root: Path,
+ args: argparse.Namespace,
+) -> tuple[Dict[str, Any], List[Dict[str, Any]]]:
+ graph_name = graph_name_for(unit, args.graph_prefix)
+ out_dir = output_root / unit.split / unit.database
+ out_dir.mkdir(parents=True, exist_ok=True)
+ output_path = out_dir / "oracle_sqlpgq_enriched.jsonl"
+ all_records = load_records(unit.query_path)
+ query_offset = max(args.query_offset, 0)
+ records = all_records[query_offset:]
+ if args.limit_queries:
+ records = records[: args.limit_queries]
+
+ client = None
+ loader = None
+ load_counts: Dict[str, int] = {}
+ node_label_map: Dict[str, List[str]] = {}
+ edge_label_map: Dict[str, List[str]] = {}
+ property_type_map: Dict[str, Dict[str, str]] = {}
+ node_primary_key_map: Dict[str, str] = {}
+ edge_primary_key_map: Dict[str, str] = {}
+ source_schema = CypherSchema.from_path(unit.import_config_path)
+
+ enriched: List[Dict[str, Any]] = []
+ unsupported_samples: List[Dict[str, Any]] = []
+ try:
+ if not args.skip_live_validation:
+ client = OracleDBClient(
+ {
+ "dsn": os.environ["ORACLE_DSN"],
+ "user": os.environ["ORACLE_USER"],
+ "password": os.environ["ORACLE_PASSWORD"],
+ }
+ )
+ loader = DatasetOracleLoader(client, unit.import_config_path, unit.csv_root, graph_name)
+ load_counts = loader.setup()
+ node_label_map = loader.node_label_map()
+ edge_label_map = loader.edge_label_map()
+ property_type_map = loader.property_type_map()
+ node_primary_key_map = loader.node_primary_key_map()
+ edge_primary_key_map = loader.edge_primary_key_map()
+
+ for selected_index, record in enumerate(records):
+ source_index = query_offset + selected_index
+ if args.progress_every and selected_index % args.progress_every == 0:
+ print(
+ "[query] "
+ f"{unit.split}/{unit.database} "
+ f"selected_index={selected_index} "
+ f"record_index={source_index} "
+ f"id={record.get('id')}",
+ flush=True,
+ )
+ enriched_record = translate_record(
+ record,
+ graph_name,
+ client,
+ node_label_map=node_label_map,
+ edge_label_map=edge_label_map,
+ property_type_map=property_type_map,
+ node_primary_key_map=node_primary_key_map,
+ edge_primary_key_map=edge_primary_key_map,
+ source_schema=source_schema,
+ validation_timeout_ms=args.oracle_validation_timeout_ms,
+ validation_fetch_limit=args.oracle_validation_fetch_limit,
+ )
+ enriched_record["oracle_dataset_meta"] = {
+ "split": unit.split,
+ "database": unit.database,
+ "query_file": str(unit.query_path),
+ "import_config": str(unit.import_config_path),
+ "graph_name": graph_name,
+ "record_index": source_index,
+ "selected_index": selected_index,
+ }
+ enriched.append(enriched_record)
+ if enriched_record["oracle_validation_status"] == "unsupported":
+ unsupported_samples.append(
+ {
+ "split": unit.split,
+ "database": unit.database,
+ "id": record.get("id"),
+ "unsupported_features": enriched_record["oracle_unsupported_features"],
+ "query": enriched_record.get("oracle_source_query"),
+ }
+ )
+ write_jsonl(output_path, enriched)
+ summary = summarize_records(enriched)
+ summary.update(
+ {
+ "split": unit.split,
+ "database": unit.database,
+ "graph_name": graph_name,
+ "query_file": str(unit.query_path),
+ "import_config": str(unit.import_config_path),
+ "query_offset": query_offset,
+ "load_counts": load_counts,
+ }
+ )
+ write_json(out_dir / "summary.json", summary)
+ return summary, unsupported_samples
+ except Exception:
+ if not args.keep_db_on_failure and loader is not None:
+ loader.cleanup(ignore_errors=True)
+ raise
+ finally:
+ if loader is not None and not args.keep_db_on_failure:
+ loader.cleanup(ignore_errors=True)
+ if client is not None:
+ client.close()
+
+
+def load_records(query_path: Path) -> List[Dict[str, Any]]:
+ data = json.loads(query_path.read_text(encoding="utf-8"))
+ if isinstance(data, dict):
+ return [
+ dict(value, id=key) if isinstance(value, dict) else {"id": key, "value": value}
+ for key, value in data.items()
+ ]
+ return list(data)
+
+
+def translate_record(
+ record: Dict[str, Any],
+ graph_name: str,
+ client: OracleDBClient | None,
+ node_label_map: Dict[str, List[str]] | None = None,
+ edge_label_map: Dict[str, List[str]] | None = None,
+ property_type_map: Dict[str, Dict[str, str]] | None = None,
+ node_primary_key_map: Dict[str, str] | None = None,
+ edge_primary_key_map: Dict[str, str] | None = None,
+ source_schema: CypherSchema | None = None,
+ validation_timeout_ms: int = 0,
+ validation_fetch_limit: int = 0,
+) -> Dict[str, Any]:
+ output = dict(record)
+ query_field, query = source_query(record)
+ output["oracle_source_query_field"] = query_field
+ output["oracle_source_query"] = query
+ output["oracle_unsupported_features"] = detect_unsupported_features(
+ query,
+ source_schema=source_schema,
+ )
+ if not query:
+ output.update(
+ _status(None, "missing_source_query", "unsupported", "No source query found.")
+ )
+ return output
+ if output["oracle_unsupported_features"]:
+ output.update(
+ _status(
+ None,
+ "Graph-IL Not Support",
+ "unsupported",
+ "Query uses constructs intentionally not emitted for Oracle SQL/PGQ.",
+ )
+ )
+ return output
+
+ translated, category = cypher2oracle_sqlpgq(
+ query,
+ graph_name=graph_name,
+ node_label_map=node_label_map,
+ edge_label_map=edge_label_map,
+ property_type_map=property_type_map,
+ node_primary_key_map=node_primary_key_map,
+ edge_primary_key_map=edge_primary_key_map,
+ strict_property_validation=bool(property_type_map),
+ )
+ if category != "Graph-IL Translatable":
+ output.update(_status(None, category, "unsupported", translated))
+ return output
+
+ validation_status = "syntax_ok"
+ error = ""
+ if client is not None:
+ result = client.execute_query(
+ translated,
+ fetch_limit=validation_fetch_limit,
+ call_timeout_ms=validation_timeout_ms,
+ )
+ if result.status_code == QueryStatus.SUCCESS:
+ validation_status = "success"
+ elif result.status_code == QueryStatus.NO_RECORD:
+ validation_status = "no_record"
+ elif result.status_code == QueryStatus.CLIENT_ERROR:
+ validation_status = "syntax_error"
+ error = result.error or ""
+ else:
+ validation_status = "runtime_error"
+ error = result.error or ""
+
+ output.update(_status(translated, category, validation_status, error))
+ return output
+
+
+def detect_unsupported_features(
+ query: str,
+ source_schema: CypherSchema | None = None,
+) -> List[str]:
+ searchable_query = mask_string_literals(query or "")
+ features = [
+ name
+ for name, pattern in UNSUPPORTED_PATTERNS.items()
+ if query and pattern.search(searchable_query)
+ ]
+ if query and has_quantified_relationship_property_map(searchable_query):
+ features.append("quantified_relationship_property_map")
+ if query and has_expensive_undirected_variable_length_path(searchable_query):
+ features.append("expensive_variable_length_path")
+ if query and len(re.findall(r"\bWITH\b", searchable_query, flags=re.IGNORECASE)) > 1:
+ features.append("multiple_with")
+ if source_schema is not None:
+ features.extend(issue.signature for issue in source_schema.validation_issues(query))
+ if "optional_match" in features and is_supported_correlated_optional_match(query):
+ features.remove("optional_match")
+ if "optional_match" in features and is_supported_standalone_optional_match(query):
+ features.remove("optional_match")
+ if "optional_match" in features and is_supported_optional_after_with_match(query):
+ features.remove("optional_match")
+ if "optional_match" in features and is_supported_match_optional_with(query):
+ features.remove("optional_match")
+ if "optional_match" in features and is_supported_optional_null_antijoin(query):
+ features.remove("optional_match")
+ return list(dict.fromkeys(features))
+
+
+def has_quantified_relationship_property_map(query: str) -> bool:
+ return bool(
+ re.search(
+ r"\[[^\]]*\*\s*(?:\d+\s*)?(?:\.\.\s*\d*)?[^\]]*\{[^}]+\}[^\]]*\]|"
+ r"\[[^\]]*\{[^}]+\}[^\]]*\*\s*(?:\d+\s*)?(?:\.\.\s*\d*)?[^\]]*\]",
+ query,
+ flags=re.IGNORECASE,
+ )
+ )
+
+
+def has_expensive_undirected_variable_length_path(query: str) -> bool:
+ for match in re.finditer(
+ r"(?\d+)\s*\.\.\s*(?P\d+)\s*\]\s*-(?!>)",
+ query,
+ flags=re.IGNORECASE,
+ ):
+ upper = int(match.group("upper"))
+ if upper > 3:
+ return True
+ return False
+
+
+def mask_string_literals(query: str) -> str:
+ result: List[str] = []
+ index = 0
+ while index < len(query):
+ quote = query[index]
+ if quote not in {"'", '"'}:
+ result.append(quote)
+ index += 1
+ continue
+ result.append(quote)
+ index += 1
+ while index < len(query):
+ char = query[index]
+ if char == "\\" and index + 1 < len(query):
+ result.append(" ")
+ result.append(" ")
+ index += 2
+ continue
+ if char == quote:
+ if quote == "'" and index + 1 < len(query) and query[index + 1] == "'":
+ result.append(" ")
+ result.append(" ")
+ index += 2
+ continue
+ result.append(quote)
+ index += 1
+ break
+ result.append(" " if not char.isspace() else char)
+ index += 1
+ return "".join(result)
+
+
+def is_supported_correlated_optional_match(query: str) -> bool:
+ normalized = " ".join(str(query or "").split())
+ if re.match(r"^OPTIONAL\s+MATCH\b", normalized, flags=re.IGNORECASE):
+ return False
+ if len(re.findall(r"\bOPTIONAL\s+MATCH\b", normalized, flags=re.IGNORECASE)) != 1:
+ return False
+ if len(re.findall(r"\bWITH\b", normalized, flags=re.IGNORECASE)) > 1:
+ return False
+ optional_match = re.search(r"\bOPTIONAL\s+MATCH\b", normalized, flags=re.IGNORECASE)
+ if not optional_match:
+ return False
+ base_fragment = normalized[: optional_match.start()]
+ optional_tail = normalized[optional_match.end() :]
+ optional_end = len(optional_tail)
+ for keyword_match in re.finditer(r"\b(?:WITH|RETURN)\b", optional_tail, flags=re.IGNORECASE):
+ optional_end = keyword_match.start()
+ break
+ optional_fragment = optional_tail[:optional_end]
+ base_variables = set(_declared_cypher_variables(base_fragment))
+ optional_variables = set(_declared_cypher_variables(optional_fragment))
+ return bool(base_variables & optional_variables)
+
+
+def _declared_cypher_variables(fragment: str) -> List[str]:
+ variables: List[str] = []
+ for pattern in (
+ r"\(\s*(?P[A-Za-z_][A-Za-z0-9_]*)\s*(?::|[){])",
+ r"\[\s*(?P[A-Za-z_][A-Za-z0-9_]*)\s*(?::|[\]{])",
+ ):
+ for match in re.finditer(pattern, fragment or ""):
+ variable = match.group("var")
+ if variable not in variables:
+ variables.append(variable)
+ return variables
+
+
+def is_supported_standalone_optional_match(query: str) -> bool:
+ normalized = " ".join(str(query or "").split())
+ if not re.match(r"^OPTIONAL\s+MATCH\b", normalized, flags=re.IGNORECASE):
+ return False
+ if len(re.findall(r"\bOPTIONAL\s+MATCH\b", normalized, flags=re.IGNORECASE)) != 1:
+ return False
+ if re.search(r"\bWITH\b", normalized, flags=re.IGNORECASE):
+ return False
+ if re.search(r"\bRETURN\s+count\s*\(\s*\*\s*\)", normalized, flags=re.IGNORECASE):
+ return False
+ return True
+
+
+def is_supported_optional_after_with_match(query: str) -> bool:
+ normalized = " ".join(str(query or "").split())
+ if re.match(r"^OPTIONAL\s+MATCH\b", normalized, flags=re.IGNORECASE):
+ return False
+ if len(re.findall(r"\bOPTIONAL\s+MATCH\b", normalized, flags=re.IGNORECASE)) != 1:
+ return False
+ if len(re.findall(r"\bWITH\b", normalized, flags=re.IGNORECASE)) != 1:
+ return False
+ if re.search(r"\bRETURN\s+count\s*\(\s*\*\s*\)", normalized, flags=re.IGNORECASE):
+ return False
+ return bool(
+ re.search(
+ r"\bMATCH\b.+\bWITH\b.+\bOPTIONAL\s+MATCH\b.+\bRETURN\b",
+ normalized,
+ flags=re.IGNORECASE,
+ )
+ )
+
+
+def is_supported_match_optional_with(query: str) -> bool:
+ normalized = " ".join(str(query or "").split())
+ if re.match(r"^OPTIONAL\s+MATCH\b", normalized, flags=re.IGNORECASE):
+ return False
+ if len(re.findall(r"\bOPTIONAL\s+MATCH\b", normalized, flags=re.IGNORECASE)) != 1:
+ return False
+ if len(re.findall(r"\bWITH\b", normalized, flags=re.IGNORECASE)) != 1:
+ return False
+ return bool(
+ re.search(
+ r"^\s*MATCH\b.+\bOPTIONAL\s+MATCH\b.+\bWITH\b.+\bRETURN\b",
+ normalized,
+ flags=re.IGNORECASE,
+ )
+ )
+
+
+def is_supported_optional_null_antijoin(query: str) -> bool:
+ normalized = " ".join(str(query or "").split())
+ if re.match(r"^OPTIONAL\s+MATCH\b", normalized, flags=re.IGNORECASE):
+ return False
+ if len(re.findall(r"\bOPTIONAL\s+MATCH\b", normalized, flags=re.IGNORECASE)) != 1:
+ return False
+ if re.search(r"\bWITH\b", normalized, flags=re.IGNORECASE):
+ return False
+ return bool(
+ re.search(
+ r"^\s*MATCH\b.+\bOPTIONAL\s+MATCH\b.+\bWHERE\s+"
+ r"[A-Za-z_][A-Za-z0-9_]*\s+IS\s+NULL\b.+\bRETURN\b",
+ normalized,
+ flags=re.IGNORECASE,
+ )
+ )
+
+
+def _status(
+ sqlpgq: str | None,
+ category: str,
+ validation_status: str,
+ error: str,
+) -> Dict[str, Any]:
+ return {
+ "oracle_sqlpgq": sqlpgq,
+ "oracle_translation_category": category,
+ "oracle_validation_status": validation_status,
+ "oracle_validation_error": error,
+ }
+
+
+def graph_name_for(unit: DatabaseUnit, prefix: str) -> str:
+ return OracleNameSanitizer.clean(f"{prefix}_{unit.split}_{unit.database}", fallback="GRAPH")
+
+
+def write_jsonl(path: Path, records: List[Dict[str, Any]]) -> None:
+ path.parent.mkdir(parents=True, exist_ok=True)
+ with open(path, "w", encoding="utf-8") as file:
+ for record in records:
+ file.write(json.dumps(record, ensure_ascii=False) + "\n")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/doc/en-us/development/oracle_sqlpgq_data_generation_workflow.md b/doc/en-us/development/oracle_sqlpgq_data_generation_workflow.md
new file mode 100644
index 0000000..b8fa6d6
--- /dev/null
+++ b/doc/en-us/development/oracle_sqlpgq_data_generation_workflow.md
@@ -0,0 +1,554 @@
+# Oracle SQL/PGQ Data Generation Workflow
+
+This guide explains the complete Oracle SQL/PGQ workflow currently implemented in the repository:
+
+1. Convert a framework/TuGraph-style schema into Oracle SQL/PGQ artifacts.
+2. Create and populate local Oracle property graphs.
+3. Generate deterministic Oracle SQL/PGQ query/question pairs from the Oracle manifest.
+4. Translate Cypher query templates into Oracle SQL/PGQ where the current Graph-IL subset supports them.
+5. Generate LLM-based Oracle SQL/PGQ corpus pairs and validate/repair them against Oracle.
+6. Combine generated corpora into a normalized dataset with metadata and train/dev/test splits.
+
+Oracle property graphs are represented by relational table DDL, `CREATE PROPERTY GRAPH` DDL, JSON manifests, and executable `GRAPH_TABLE` queries.
+
+Run commands from the repository root.
+
+## Prerequisites
+
+Install the Python environment:
+
+```bash
+poetry env use python3.10
+poetry install
+```
+
+If `poetry` is not on `PATH`, use the absolute executable:
+
+```bash
+/opt/homebrew/bin/poetry --version
+```
+
+If Python hits a local `pyexpat` or `libexpat` issue on macOS, export:
+
+```bash
+export DYLD_LIBRARY_PATH="/opt/homebrew/opt/expat/lib:${DYLD_LIBRARY_PATH:-}"
+```
+
+## Environment
+
+Set Oracle connection variables:
+
+```bash
+export ORACLE_DSN="localhost:1521/FREEPDB1"
+export ORACLE_USER="SYSTEM"
+export ORACLE_PASSWORD="tiger"
+```
+
+Set OpenAI-compatible LLM variables. For OCI/LiteLLM-style endpoints, use:
+
+```bash
+export OPENAI_API_KEY=""
+export OPENAI_BASE_URL=""
+export OPENAI_EXTRA_HEADERS='{"client":"codex-cli","client-version":"0"}'
+export LLM_PLATFORM="openai"
+export LLM_MODEL=""
+```
+
+For local Ollama/OpenAI-compatible testing, use:
+
+```bash
+export NO_PROXY="localhost,127.0.0.1,::1,${NO_PROXY:-}"
+export no_proxy="localhost,127.0.0.1,::1,${no_proxy:-}"
+export OPENAI_BASE_URL="http://127.0.0.1:11434/v1"
+export OPENAI_API_KEY="ollama"
+export LLM_PLATFORM="openai"
+export LLM_MODEL="gemma4:latest"
+```
+
+Check local services:
+
+```bash
+nc -vz 127.0.0.1 1521
+poetry run pytest test/test_oracle_sqlpgq_live.py
+```
+
+The live test runs `SELECT 1 AS VALUE FROM dual` through the Oracle DB client.
+
+## Core Files
+
+Oracle implementation:
+
+```text
+app/impl/oracle_sqlpgq/schema/schema_parser.py
+app/impl/oracle_sqlpgq/translator/oracle_sqlpgq_query_translator.py
+app/impl/oracle_sqlpgq/ast_visitor/oracle_sqlpgq_ast_visitor.py
+app/impl/oracle_sqlpgq/db_client/oracle_db_client.py
+app/impl/oracle_sqlpgq/generator/template_instantiator.py
+app/impl/oracle_sqlpgq/generator/query_generalizer.py
+app/impl/oracle_sqlpgq/generator/corpus_combiner.py
+```
+
+Examples:
+
+```text
+examples/tugraph_to_oracle_sqlpgq.py
+examples/setup_oracle_sqlpgq_example_db.py
+examples/setup_oracle_sqlpgq_fraud_db.py
+examples/generate_oracle_sqlpgq_template_corpus.py
+examples/cypher2oracle_sqlpgq.py
+examples/generalize_oracle_sqlpgq_from_cypher.py
+examples/generate_corpus_oracle_sqlpgq.py
+examples/generate_corpus_oracle_sqlpgq_fraud.py
+examples/combine_oracle_sqlpgq_corpus.py
+```
+
+Generated outputs:
+
+```text
+examples/Oracle_SQLPGQ_Instance/
+examples/generated_corpus/
+```
+
+## Step 1: Generate Oracle SQL/PGQ Artifacts
+
+Convert a framework/TuGraph-style schema JSON into Oracle SQL/PGQ artifacts:
+
+```bash
+poetry run python examples/tugraph_to_oracle_sqlpgq.py
+```
+
+This reads:
+
+```text
+examples/generated_schemas/example_schema.json
+```
+
+And writes:
+
+```text
+examples/Oracle_SQLPGQ_Instance/TEXT2GQL_GRAPH_oracle_schema.json
+examples/Oracle_SQLPGQ_Instance/TEXT2GQL_GRAPH_oracle_tables.sql
+examples/Oracle_SQLPGQ_Instance/TEXT2GQL_GRAPH_oracle_property_graph.sql
+examples/Oracle_SQLPGQ_Instance/TEXT2GQL_GRAPH_oracle_loader.py
+```
+
+How it works:
+
+- `OracleSqlPgqSchemaParser` converts `SchemaGraph` nodes into relational vertex tables.
+- It converts graph edges into relational edge tables with `SRC_ID`, `DST_ID`, primary keys, and foreign keys.
+- It emits `CREATE OR REPLACE PROPERTY GRAPH ... VERTEX TABLES ... EDGE TABLES`.
+- It writes a JSON manifest used by deterministic generation and validation workflows.
+
+Why it matters:
+
+- Oracle SQL/PGQ property graphs are metadata over relational tables.
+- The manifest becomes the source of truth for schema-aware corpus generation.
+
+## Step 2: Create and Populate the Movie Graph
+
+Create the default local movie-style Oracle graph:
+
+```bash
+poetry run python examples/setup_oracle_sqlpgq_example_db.py
+```
+
+This script:
+
+1. Drops `TEXT2GQL_GRAPH` if it exists.
+2. Drops/recreates the example graph tables.
+3. Inserts deterministic generated rows.
+4. Creates the Oracle SQL property graph.
+
+The script is destructive only for the example graph objects in the configured Oracle schema.
+
+Use a larger dataset when generating bigger corpora:
+
+```bash
+poetry run python examples/setup_oracle_sqlpgq_example_db.py \
+ --users 100 \
+ --movies 500 \
+ --genres 20 \
+ --tags 200 \
+ --ratings 2000 \
+ --genre-edges 800 \
+ --tag-edges 1000 \
+ --friend-edges 500 \
+ --similarity-edges 1000 \
+ --seed 42 \
+ --batch-size 1000
+```
+
+Equivalent environment variable form:
+
+```bash
+export ORACLE_SQLPGQ_USERS=100
+export ORACLE_SQLPGQ_MOVIES=500
+export ORACLE_SQLPGQ_GENRES=20
+export ORACLE_SQLPGQ_TAGS=200
+export ORACLE_SQLPGQ_RATINGS=2000
+export ORACLE_SQLPGQ_GENRE_EDGES=800
+export ORACLE_SQLPGQ_TAG_EDGES=1000
+export ORACLE_SQLPGQ_FRIEND_EDGES=500
+export ORACLE_SQLPGQ_SIMILARITY_EDGES=1000
+export ORACLE_SQLPGQ_DATA_SEED=42
+export ORACLE_SQLPGQ_INSERT_BATCH_SIZE=1000
+
+poetry run python examples/setup_oracle_sqlpgq_example_db.py
+```
+
+Minimum sizes are enforced because the seed validation queries depend on known values:
+
+```text
+users >= 4
+movies >= 4
+genres >= 4
+tags >= 6
+ratings >= 4
+genre_edges >= 5
+tag_edges >= 6
+friend_edges >= 3
+similarity_edges >= 3
+```
+
+Sanity check the graph:
+
+```bash
+poetry run python -c 'from app.impl.oracle_sqlpgq.db_client.oracle_db_client import OracleDBClient; import os; c=OracleDBClient({"dsn":os.environ["ORACLE_DSN"],"user":os.environ["ORACLE_USER"],"password":os.environ["ORACLE_PASSWORD"]}); q="SELECT * FROM GRAPH_TABLE (\"TEXT2GQL_GRAPH\" MATCH (m IS \"MOVIE\") COLUMNS (m.\"title\" AS title)) gt FETCH FIRST 3 ROWS ONLY"; r=c.execute_query(q); print(r.status_code); print(r.data or r.error); c.close()'
+```
+
+Expected status:
+
+```text
+200
+```
+
+## Step 3: Generate Deterministic Template Corpus
+
+Generate schema-derived Oracle SQL/PGQ pairs without using an LLM:
+
+```bash
+poetry run python examples/generate_oracle_sqlpgq_template_corpus.py \
+ --manifest examples/Oracle_SQLPGQ_Instance/TEXT2GQL_GRAPH_oracle_schema.json \
+ --target-size 50 \
+ --output examples/generated_corpus/oracle_sqlpgq_template_corpus.json
+```
+
+How it works:
+
+- `OracleSqlPgqTemplateInstantiator` reads the Oracle manifest.
+- It creates query/question pairs for:
+ - one-hop traversals
+ - graph element identifiers with `VERTEX_ID` and `EDGE_ID`
+ - aggregate counts
+ - two-hop traversals
+ - bounded path queries for recursive/self edges
+- It avoids literal predicates by default, so generated queries are schema-driven and portable across instances of the same graph.
+
+Why it matters:
+
+- This path is deterministic and reproducible.
+- It gives high-validity Oracle SQL/PGQ coverage without depending on an LLM.
+- It is useful as seed/context data for later LLM generation.
+
+Live-validate generated template queries:
+
+```bash
+poetry run python -c 'import json, os; from pathlib import Path; from app.core.validator.db_client import QueryStatus; from app.impl.oracle_sqlpgq.db_client.oracle_db_client import OracleDBClient; path=Path("examples/generated_corpus/oracle_sqlpgq_template_corpus.json"); data=json.loads(path.read_text()); c=OracleDBClient({"dsn":os.environ["ORACLE_DSN"],"user":os.environ["ORACLE_USER"],"password":os.environ["ORACLE_PASSWORD"]}); ok=0; bad=[];
+for i,item in enumerate(data):
+ r=c.execute_query(item["query"])
+ ok += r.status_code == QueryStatus.SUCCESS
+ bad.append((i, r.error)) if r.status_code != QueryStatus.SUCCESS else None
+print(f"{ok}/{len(data)} passed")
+print(bad[:5])
+c.close()'
+```
+
+## Step 4: Translate Cypher Queries to Oracle SQL/PGQ
+
+Translate one supported Cypher query:
+
+```bash
+poetry run python examples/cypher2oracle_sqlpgq.py \
+ --query "MATCH (p:PERSON)-[a:ACTED_IN]->(m:MOVIE) RETURN m.title AS movie_title" \
+ --graph-name TEXT2GQL_GRAPH
+```
+
+Translate the Cypher template file:
+
+```bash
+poetry run python examples/cypher2oracle_sqlpgq.py \
+ --input examples/corpus_templates/corpus_templates.json \
+ --output examples/generated_corpus/cypher_templates_to_oracle_sqlpgq.json \
+ --graph-name TEXT2GQL_GRAPH
+```
+
+How it works:
+
+```text
+Cypher query -> Cypher grammar check -> Cypher AST visitor -> Graph-IL -> Oracle SQL/PGQ translator
+```
+
+The output category explains what happened:
+
+- `Graph-IL Translatable`: translated successfully.
+- `Not Comply with OpenCypher`: the source query failed the current Cypher grammar check.
+- `Graph-IL Not Support`: the query parsed as Cypher but is outside the current visitor/IR subset.
+- `No Related Oracle SQL/PGQ Standard`: translation completed but failed the current SQL/PGQ structural validator.
+
+Current supported subset:
+
+- basic `MATCH` paths
+- labels and relationship labels
+- simple property comparisons
+- `RETURN` items and aliases
+- aggregates supported by current Graph-IL extraction
+- `ORDER BY`, `SKIP`, `LIMIT`
+- `DISTINCT`
+- relationship hop ranges
+
+Unsupported or partially supported examples include `OPTIONAL MATCH`, `shortestPath`, unions, complex boolean logic, `IN`, `STARTS WITH`, `CONTAINS`, and some path-return forms.
+
+Why it matters:
+
+- Existing Cypher corpora can be triaged into Oracle-translatable and unsupported categories.
+- The translated subset can be validated against Oracle and included in Oracle SQL/PGQ datasets.
+
+## Step 5: Generalize Oracle SQL/PGQ Queries From a Seed Shape
+
+Generate Oracle queries by taking a seed Cypher path shape and matching it over the Oracle manifest:
+
+```bash
+poetry run python examples/generalize_oracle_sqlpgq_from_cypher.py \
+ --query "MATCH (a)-[e]->(b) RETURN b" \
+ --manifest examples/Oracle_SQLPGQ_Instance/TEXT2GQL_GRAPH_oracle_schema.json \
+ --target-size 25 \
+ --output examples/generated_corpus/oracle_sqlpgq_generalized_queries.json
+```
+
+How it works:
+
+- The seed Cypher query is parsed into Graph-IL.
+- `OracleSqlPgqQueryGeneralizer` reads the Oracle manifest.
+- It finds schema paths with the same path length/hop shape.
+- It emits Oracle `GRAPH_TABLE` queries through `OracleSqlPgqQueryTranslator`.
+
+Why it matters:
+
+- A small number of seed query shapes can produce many Oracle-specific query variants.
+- This increases query diversity while staying schema-aware.
+
+Live-validate generated generalized queries:
+
+```bash
+poetry run python -c 'import json, os; from pathlib import Path; from app.core.validator.db_client import QueryStatus; from app.impl.oracle_sqlpgq.db_client.oracle_db_client import OracleDBClient; path=Path("examples/generated_corpus/oracle_sqlpgq_generalized_queries.json"); data=json.loads(path.read_text()); c=OracleDBClient({"dsn":os.environ["ORACLE_DSN"],"user":os.environ["ORACLE_USER"],"password":os.environ["ORACLE_PASSWORD"]}); ok=0; bad=[];
+for i,item in enumerate(data):
+ r=c.execute_query(item["query"])
+ ok += r.status_code == QueryStatus.SUCCESS
+ bad.append((i, r.error)) if r.status_code != QueryStatus.SUCCESS else None
+print(f"{ok}/{len(data)} passed")
+print(bad[:5])
+c.close()'
+```
+
+## Step 6: Generate LLM Corpus and Validate/Repair
+
+Run a small test generation:
+
+```bash
+ORACLE_SQLPGQ_CORPUS_SIZE=3 \
+ORACLE_SQLPGQ_NUM_PER_ITERATION=1 \
+ORACLE_SQLPGQ_REPAIR_ATTEMPTS=1 \
+poetry run python examples/generate_corpus_oracle_sqlpgq.py
+```
+
+Run a larger generation:
+
+```bash
+ORACLE_SQLPGQ_CORPUS_SIZE=50 \
+ORACLE_SQLPGQ_NUM_PER_ITERATION=5 \
+ORACLE_SQLPGQ_REPAIR_ATTEMPTS=2 \
+poetry run python examples/generate_corpus_oracle_sqlpgq.py
+```
+
+How it works:
+
+1. Hardcoded seed queries in `examples/generate_corpus_oracle_sqlpgq.py` are executed against Oracle.
+2. Successful seed pairs are saved with result summaries.
+3. Seed pairs and result summaries are passed to the LLM as examples.
+4. The LLM generates new natural-language questions and Oracle SQL/PGQ queries.
+5. Raw candidates are saved before validation.
+6. Each generated query is executed against Oracle.
+7. Failed queries are sent back to the LLM with the Oracle error for repair.
+8. Valid and rejected pairs are saved separately.
+
+Outputs:
+
+```text
+examples/generated_corpus/oracle_sqlpgq_example_corpus_raw.json
+examples/generated_corpus/oracle_sqlpgq_example_corpus.json
+examples/generated_corpus/oracle_sqlpgq_example_corpus_validated_with_results.json
+examples/generated_corpus/oracle_sqlpgq_example_corpus_seed_validated_with_results.json
+examples/generated_corpus/oracle_sqlpgq_example_corpus_rejected.json
+```
+
+Why it matters:
+
+- The model sees real Oracle SQL/PGQ examples and real result summaries.
+- Live validation prevents invalid SQL/PGQ from entering the final corpus.
+- Repair loops improve final yield.
+
+## Step 7: Combine Corpus Files
+
+Combine deterministic, generalized, and LLM-generated corpora:
+
+```bash
+poetry run python examples/combine_oracle_sqlpgq_corpus.py \
+ --input \
+ examples/generated_corpus/oracle_sqlpgq_template_corpus.json \
+ examples/generated_corpus/oracle_sqlpgq_generalized_queries.json \
+ examples/generated_corpus/oracle_sqlpgq_example_corpus.json \
+ --output examples/generated_corpus/oracle_sqlpgq_combined_corpus.json \
+ --split
+```
+
+Combine and live-validate every final query against Oracle:
+
+```bash
+poetry run python examples/combine_oracle_sqlpgq_corpus.py \
+ --input \
+ examples/generated_corpus/oracle_sqlpgq_template_corpus.json \
+ examples/generated_corpus/oracle_sqlpgq_generalized_queries.json \
+ examples/generated_corpus/oracle_sqlpgq_example_corpus.json \
+ --output examples/generated_corpus/oracle_sqlpgq_combined_corpus_validated.json \
+ --split \
+ --validate-live
+```
+
+How it works:
+
+- `OracleSqlPgqCorpusCombiner` reads multiple JSON files.
+- It normalizes fields into a common format.
+- It deduplicates repeated question/query pairs.
+- It preserves source file, source index, category, labels, result summaries, and validation metadata.
+- With `--split`, it assigns train/dev/test splits.
+- With `--validate-live`, it executes each final query against Oracle and writes:
+ - `validation`: `passed` or `failed`
+ - `validation_status_code`
+ - `result` preview for passed queries
+ - `validation_error` preview for failed queries
+
+Why it matters:
+
+- The final file is benchmark/training friendly.
+- Records keep provenance, so you can trace whether a pair came from templates, generalization, or LLM generation.
+- Live validation on the combined output is the strongest final check because it runs after deduplication and normalization.
+
+## Step 8: Create the Fraud/Payments Graph
+
+The repo includes a second scalable Oracle SQL/PGQ domain:
+
+```text
+TEXT2GQL_FRAUD_GRAPH
+```
+
+Default size:
+
+```bash
+poetry run python examples/setup_oracle_sqlpgq_fraud_db.py
+```
+
+Scaled size:
+
+```bash
+poetry run python examples/setup_oracle_sqlpgq_fraud_db.py \
+ --customers 100 \
+ --accounts 200 \
+ --merchants 80 \
+ --transactions 5000 \
+ --devices 300 \
+ --cities 25 \
+ --seed 19 \
+ --batch-size 1000
+```
+
+Environment variable form:
+
+```bash
+export FRAUD_GRAPH_CUSTOMERS=100
+export FRAUD_GRAPH_ACCOUNTS=200
+export FRAUD_GRAPH_MERCHANTS=80
+export FRAUD_GRAPH_TRANSACTIONS=5000
+export FRAUD_GRAPH_DEVICES=300
+export FRAUD_GRAPH_CITIES=25
+export FRAUD_GRAPH_DATA_SEED=19
+export FRAUD_GRAPH_INSERT_BATCH_SIZE=1000
+
+poetry run python examples/setup_oracle_sqlpgq_fraud_db.py
+```
+
+Generated artifacts:
+
+```text
+examples/Oracle_SQLPGQ_Instance/TEXT2GQL_FRAUD_GRAPH_oracle_schema.json
+examples/Oracle_SQLPGQ_Instance/TEXT2GQL_FRAUD_GRAPH_oracle_tables.sql
+examples/Oracle_SQLPGQ_Instance/TEXT2GQL_FRAUD_GRAPH_oracle_property_graph.sql
+```
+
+Generate and validate fraud corpus:
+
+```bash
+FRAUD_ORACLE_SQLPGQ_CORPUS_SIZE=50 \
+FRAUD_ORACLE_SQLPGQ_NUM_PER_ITERATION=5 \
+FRAUD_ORACLE_SQLPGQ_REPAIR_ATTEMPTS=2 \
+poetry run python examples/generate_corpus_oracle_sqlpgq_fraud.py
+```
+
+Outputs:
+
+```text
+examples/generated_corpus/oracle_sqlpgq_fraud_corpus_raw.json
+examples/generated_corpus/oracle_sqlpgq_fraud_corpus.json
+examples/generated_corpus/oracle_sqlpgq_fraud_corpus_validated_with_results.json
+examples/generated_corpus/oracle_sqlpgq_fraud_corpus_seed_validated_with_results.json
+examples/generated_corpus/oracle_sqlpgq_fraud_corpus_rejected.json
+```
+
+## Step 9: Run Tests
+
+Offline Oracle unit tests:
+
+```bash
+poetry run pytest \
+ test/test_oracle_sqlpgq.py \
+ test/test_cypher2oracle_sqlpgq.py \
+ test/test_oracle_sqlpgq_template_instantiator.py \
+ test/test_oracle_sqlpgq_query_generalizer.py \
+ test/test_oracle_sqlpgq_corpus_combiner.py
+```
+
+Live Oracle smoke test:
+
+```bash
+poetry run pytest test/test_oracle_sqlpgq_live.py
+```
+
+## Current Capabilities
+
+Implemented:
+
+- Oracle SQL/PGQ schema artifact generation.
+- Movie graph setup and data population.
+- Fraud graph setup and data population.
+- Oracle DB client and live validation.
+- Graph-IL to Oracle SQL/PGQ translation.
+- Cypher to Oracle SQL/PGQ translation for the supported subset.
+- Deterministic Oracle SQL/PGQ template corpus generation.
+- Oracle query generalization from seed Graph-IL/Cypher shapes.
+- LLM corpus generation with live validation and repair.
+- Corpus combination with metadata, deduplication, and splits.
+
+Still limited:
+
+- Oracle SQL/PGQ parsing is subset-based, not a full ANTLR-backed Oracle SQL/PGQ grammar.
+- Deterministic templates currently avoid literal filters. Value-aware template instantiation can be added by sampling real values from Oracle.
+- Generic synthetic row generation for arbitrary new schemas is not complete. New domains still need domain-specific setup scripts, CSV loading, or a future generic data synthesizer.
diff --git a/examples/combine_oracle_sqlpgq_corpus.py b/examples/combine_oracle_sqlpgq_corpus.py
new file mode 100644
index 0000000..6ad0162
--- /dev/null
+++ b/examples/combine_oracle_sqlpgq_corpus.py
@@ -0,0 +1,69 @@
+import argparse
+import os
+from pathlib import Path
+
+from app.core.validator.validator import CorpusValidator
+from app.impl.oracle_sqlpgq.generator.corpus_combiner import OracleSqlPgqCorpusCombiner
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(
+ description="Combine Oracle SQL/PGQ corpus files with metadata and deduplication."
+ )
+ parser.add_argument(
+ "--input",
+ type=Path,
+ nargs="+",
+ default=[
+ Path("examples/generated_corpus/oracle_sqlpgq_template_corpus.json"),
+ Path("examples/generated_corpus/oracle_sqlpgq_generalized_queries.json"),
+ Path("examples/generated_corpus/oracle_sqlpgq_example_corpus.json"),
+ ],
+ help="Input JSON corpus files.",
+ )
+ parser.add_argument(
+ "--output",
+ type=Path,
+ default=Path("examples/generated_corpus/oracle_sqlpgq_combined_corpus.json"),
+ )
+ parser.add_argument("--split", action="store_true", help="Assign train/dev/test splits.")
+ parser.add_argument(
+ "--validate-live",
+ action="store_true",
+ help="Execute each combined query against Oracle and mark validation status.",
+ )
+ parser.add_argument(
+ "--result-preview-chars",
+ type=int,
+ default=500,
+ help="Maximum number of characters to keep from result/error previews.",
+ )
+ args = parser.parse_args()
+
+ validator = None
+ if args.validate_live:
+ validator = CorpusValidator(
+ backend="oracle_sqlpgq",
+ db_client_params={
+ "dsn": os.environ["ORACLE_DSN"],
+ "user": os.environ["ORACLE_USER"],
+ "password": os.environ["ORACLE_PASSWORD"],
+ },
+ )
+
+ records = OracleSqlPgqCorpusCombiner().write_combined(
+ args.input,
+ args.output,
+ split=args.split,
+ validator=validator,
+ result_preview_chars=args.result_preview_chars,
+ )
+ print(f"Saved {len(records)} combined Oracle SQL/PGQ records to {args.output}")
+ if args.validate_live:
+ passed = sum(1 for record in records if record.get("validation") == "passed")
+ failed = sum(1 for record in records if record.get("validation") == "failed")
+ print(f"Live validation: {passed} passed, {failed} failed")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/cypher2oracle_sqlpgq.py b/examples/cypher2oracle_sqlpgq.py
new file mode 100644
index 0000000..45aeed5
--- /dev/null
+++ b/examples/cypher2oracle_sqlpgq.py
@@ -0,0 +1,796 @@
+import argparse
+import contextlib
+import io
+import json
+from pathlib import Path
+import re
+
+from app.impl.oracle_sqlpgq.translator.oracle_sqlpgq_query_translator import (
+ OracleSqlPgqQueryTranslator,
+)
+from app.impl.tugraph_cypher.ast_visitor.tugraph_cypher_ast_visitor import (
+ TugraphCypherAstVisitor,
+)
+from app.impl.tugraph_cypher.translator.tugraph_cypher_query_translator import (
+ TugraphCypherQueryTranslator as CypherTranslator,
+)
+
+SUPPORTED_SCOPE = (
+ "Graph-IL subset: basic MATCH paths, simple property comparisons, RETURN items, "
+ "aliases, aggregates, ORDER BY, SKIP, LIMIT, DISTINCT, and relationship hop ranges."
+)
+
+
+def cypher2oracle_sqlpgq(
+ query: str,
+ graph_name: str = "GRAPH",
+ node_label_map: dict[str, list[str]] | None = None,
+ edge_label_map: dict[str, list[str]] | None = None,
+ property_type_map: dict[str, dict[str, str]] | None = None,
+ node_primary_key_map: dict[str, str] | None = None,
+ edge_primary_key_map: dict[str, str] | None = None,
+ strict_property_validation: bool = False,
+) -> tuple[str, str]:
+ """Translate a supported Cypher query into Oracle SQL/PGQ GRAPH_TABLE syntax."""
+
+ if _starts_with_optional_match(query):
+ optional_query = _translate_standalone_optional_match(
+ query,
+ graph_name=graph_name,
+ node_label_map=node_label_map,
+ edge_label_map=edge_label_map,
+ property_type_map=property_type_map,
+ node_primary_key_map=node_primary_key_map,
+ edge_primary_key_map=edge_primary_key_map,
+ strict_property_validation=strict_property_validation,
+ )
+ if optional_query is not None:
+ return optional_query
+ return "Unable to Translate to Oracle SQL/PGQ", "Graph-IL Not Support"
+
+ optional_null_query = _translate_optional_null_antijoin(
+ query,
+ graph_name=graph_name,
+ node_label_map=node_label_map,
+ edge_label_map=edge_label_map,
+ property_type_map=property_type_map,
+ node_primary_key_map=node_primary_key_map,
+ edge_primary_key_map=edge_primary_key_map,
+ strict_property_validation=strict_property_validation,
+ )
+ if optional_null_query is not None:
+ return optional_null_query
+
+ match_optional_with_query = _translate_match_optional_with(
+ query,
+ graph_name=graph_name,
+ node_label_map=node_label_map,
+ edge_label_map=edge_label_map,
+ property_type_map=property_type_map,
+ node_primary_key_map=node_primary_key_map,
+ edge_primary_key_map=edge_primary_key_map,
+ strict_property_validation=strict_property_validation,
+ )
+ if match_optional_with_query is not None:
+ return match_optional_with_query
+
+ union_query = _translate_union_query(
+ query,
+ graph_name=graph_name,
+ node_label_map=node_label_map,
+ edge_label_map=edge_label_map,
+ property_type_map=property_type_map,
+ node_primary_key_map=node_primary_key_map,
+ edge_primary_key_map=edge_primary_key_map,
+ strict_property_validation=strict_property_validation,
+ )
+ if union_query is not None:
+ return union_query
+
+ query_visitor = TugraphCypherAstVisitor()
+ cypher_translator = CypherTranslator()
+ oracle_translator = OracleSqlPgqQueryTranslator(
+ graph_name=graph_name,
+ node_label_map=node_label_map,
+ edge_label_map=edge_label_map,
+ property_type_map=property_type_map,
+ node_primary_key_map=node_primary_key_map,
+ edge_primary_key_map=edge_primary_key_map,
+ strict_property_validation=strict_property_validation,
+ )
+
+ if not cypher_translator.grammar_check(query):
+ return "Unable to Translate to Oracle SQL/PGQ", "Not Comply with OpenCypher"
+
+ with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(io.StringIO()):
+ success, query_pattern = query_visitor.get_query_pattern(query)
+ if not success:
+ return "Unable to Translate to Oracle SQL/PGQ", "Graph-IL Not Support"
+
+ try:
+ sqlpgq_query = oracle_translator.translate(query_pattern)
+ except Exception:
+ return "Unable to Translate to Oracle SQL/PGQ", "Graph-IL Not Support"
+
+ if oracle_translator.grammar_check(sqlpgq_query):
+ return sqlpgq_query, "Graph-IL Translatable"
+
+ return "Unable to Translate to Oracle SQL/PGQ", "No Related Oracle SQL/PGQ Standard"
+
+
+def _translate_optional_null_antijoin(
+ query: str,
+ graph_name: str,
+ node_label_map: dict[str, list[str]] | None,
+ edge_label_map: dict[str, list[str]] | None,
+ property_type_map: dict[str, dict[str, str]] | None,
+ node_primary_key_map: dict[str, str] | None,
+ edge_primary_key_map: dict[str, str] | None,
+ strict_property_validation: bool,
+) -> tuple[str, str] | None:
+ match = re.fullmatch(
+ r"\s*(?P MATCH\b.+?)\s+OPTIONAL\s+MATCH\s+"
+ r"(?P.+?)\s+WHERE\s+"
+ r"(?P[A-Za-z_][A-Za-z0-9_]*)\s+IS\s+NULL\s+"
+ r"(?PRETURN\b.+)\s*",
+ query,
+ flags=re.IGNORECASE | re.DOTALL,
+ )
+ if not match:
+ return None
+ optional_pattern = match.group("optional").strip()
+ transformed = (
+ f"{match.group('base').strip()} "
+ f"WHERE NOT EXISTS({optional_pattern}) "
+ f"{match.group('tail').strip()}"
+ )
+ return cypher2oracle_sqlpgq(
+ transformed,
+ graph_name=graph_name,
+ node_label_map=node_label_map,
+ edge_label_map=edge_label_map,
+ property_type_map=property_type_map,
+ node_primary_key_map=node_primary_key_map,
+ edge_primary_key_map=edge_primary_key_map,
+ strict_property_validation=strict_property_validation,
+ )
+
+
+def _translate_match_optional_with(
+ query: str,
+ graph_name: str,
+ node_label_map: dict[str, list[str]] | None,
+ edge_label_map: dict[str, list[str]] | None,
+ property_type_map: dict[str, dict[str, str]] | None,
+ node_primary_key_map: dict[str, str] | None,
+ edge_primary_key_map: dict[str, str] | None,
+ strict_property_validation: bool,
+) -> tuple[str, str] | None:
+ if not _is_supported_match_optional_with(query):
+ return None
+ optional_match = re.search(r"\bOPTIONAL\s+MATCH\b", query, flags=re.IGNORECASE)
+ if not optional_match:
+ return None
+ base_part = query[: optional_match.start()].strip()
+ optional_part = query[optional_match.start() :].strip()
+ with_match = re.search(r"\bWITH\b", optional_part, flags=re.IGNORECASE)
+ if not with_match:
+ return None
+ optional_pattern = optional_part[: with_match.start()]
+ base_variables = _declared_cypher_variables(base_part)
+ optional_variables = _declared_cypher_variables(optional_pattern)
+ join_variables = [var for var in base_variables if var in optional_variables]
+ if not join_variables:
+ return "Unable to Translate to Oracle SQL/PGQ", "Graph-IL Not Support"
+
+ transformed = f"{base_part} WITH {', '.join(base_variables)} {optional_part}"
+ match_query = re.sub(
+ r"\bOPTIONAL\s+MATCH\b",
+ "MATCH",
+ transformed,
+ count=1,
+ flags=re.IGNORECASE,
+ )
+ translated, category = cypher2oracle_sqlpgq(
+ match_query,
+ graph_name=graph_name,
+ node_label_map=node_label_map,
+ edge_label_map=edge_label_map,
+ property_type_map=property_type_map,
+ node_primary_key_map=node_primary_key_map,
+ edge_primary_key_map=edge_primary_key_map,
+ strict_property_validation=strict_property_validation,
+ )
+ if category != "Graph-IL Translatable":
+ return "Unable to Translate to Oracle SQL/PGQ", category
+ rewritten = _rewrite_optional_with_left_join(translated, join_variables)
+ if rewritten is None:
+ return "Unable to Translate to Oracle SQL/PGQ", "Graph-IL Not Support"
+ if OracleSqlPgqQueryTranslator(graph_name=graph_name).grammar_check(rewritten):
+ return rewritten, "Graph-IL Translatable"
+ return "Unable to Translate to Oracle SQL/PGQ", "No Related Oracle SQL/PGQ Standard"
+
+
+def _is_supported_match_optional_with(query: str) -> bool:
+ normalized = " ".join(str(query or "").split())
+ if _starts_with_optional_match(normalized):
+ return False
+ if len(re.findall(r"\bOPTIONAL\s+MATCH\b", normalized, flags=re.IGNORECASE)) != 1:
+ return False
+ if len(re.findall(r"\bWITH\b", normalized, flags=re.IGNORECASE)) != 1:
+ return False
+ return bool(
+ re.search(
+ r"^\s*MATCH\b.+\bOPTIONAL\s+MATCH\b.+\bWITH\b.+\bRETURN\b",
+ normalized,
+ flags=re.IGNORECASE,
+ )
+ )
+
+
+def _declared_cypher_variables(fragment: str) -> list[str]:
+ variables: list[str] = []
+ for pattern in (
+ r"\(\s*(?P[A-Za-z_][A-Za-z0-9_]*)\s*(?::|[){])",
+ r"\[\s*(?P[A-Za-z_][A-Za-z0-9_]*)\s*(?::|[\]{])",
+ ):
+ for match in re.finditer(pattern, fragment or ""):
+ variable = match.group("var")
+ if variable not in variables:
+ variables.append(variable)
+ return variables
+
+
+def _rewrite_optional_with_left_join(
+ translated: str,
+ carried_variables: list[str],
+) -> str | None:
+ rewritten = re.sub(
+ r"(?P[ \t]*)FROM stage_2\n(?P=indent)JOIN stage_1 ON (?P[^\n]+)",
+ r"\gFROM stage_1\n\gLEFT JOIN stage_2 ON \g",
+ translated,
+ count=1,
+ flags=re.IGNORECASE,
+ )
+ if rewritten == translated:
+ return None
+ for variable in carried_variables:
+ stage_2_alias = f"{variable}_VALUE"
+ stage_1_alias = f"stage_1_{variable}_VALUE"
+ rewritten = re.sub(
+ rf"(?P\bSELECT\s+){re.escape(stage_2_alias)}(?P\s*,)",
+ rf"\gstage_1.{stage_1_alias} AS {stage_2_alias}\g",
+ rewritten,
+ count=1,
+ flags=re.IGNORECASE,
+ )
+ rewritten = _replace_group_by_token(
+ rewritten,
+ stage_2_alias,
+ f"stage_1.{stage_1_alias}",
+ )
+ return rewritten
+
+
+def _replace_group_by_token(sql: str, token: str, replacement: str) -> str:
+ def replace_line(match: re.Match) -> str:
+ body = match.group("body")
+ parts = _split_top_level_commas(body)
+ parts = [replacement if part.strip() == token else part.strip() for part in parts]
+ return f"GROUP BY {', '.join(parts)}"
+
+ return re.sub(
+ r"GROUP BY (?P[^\n]+)",
+ replace_line,
+ sql,
+ flags=re.IGNORECASE,
+ )
+
+
+def _translate_optional_after_with_match(
+ query: str,
+ graph_name: str,
+ node_label_map: dict[str, list[str]] | None,
+ edge_label_map: dict[str, list[str]] | None,
+ property_type_map: dict[str, dict[str, str]] | None,
+ node_primary_key_map: dict[str, str] | None,
+ edge_primary_key_map: dict[str, str] | None,
+ strict_property_validation: bool,
+) -> tuple[str, str] | None:
+ if not _is_supported_optional_after_with_match(query):
+ return None
+
+ match_query = re.sub(
+ r"\bOPTIONAL\s+MATCH\b",
+ "MATCH",
+ query,
+ count=1,
+ flags=re.IGNORECASE,
+ )
+ translated, category = cypher2oracle_sqlpgq(
+ match_query,
+ graph_name=graph_name,
+ node_label_map=node_label_map,
+ edge_label_map=edge_label_map,
+ property_type_map=property_type_map,
+ node_primary_key_map=node_primary_key_map,
+ edge_primary_key_map=edge_primary_key_map,
+ strict_property_validation=strict_property_validation,
+ )
+ if category != "Graph-IL Translatable":
+ return "Unable to Translate to Oracle SQL/PGQ", category
+
+ rewritten = re.sub(
+ r"\nFROM stage_2\nJOIN stage_1 ON (?P[^\n]+)",
+ r"\nFROM stage_1\nLEFT JOIN stage_2 ON \g",
+ translated,
+ count=1,
+ flags=re.IGNORECASE,
+ )
+ if rewritten == translated:
+ return "Unable to Translate to Oracle SQL/PGQ", "Graph-IL Not Support"
+ if OracleSqlPgqQueryTranslator(graph_name=graph_name).grammar_check(rewritten):
+ return rewritten, "Graph-IL Translatable"
+ return "Unable to Translate to Oracle SQL/PGQ", "No Related Oracle SQL/PGQ Standard"
+
+
+def _is_supported_optional_after_with_match(query: str) -> bool:
+ normalized = " ".join(str(query or "").split())
+ if _starts_with_optional_match(normalized):
+ return False
+ if len(re.findall(r"\bOPTIONAL\s+MATCH\b", normalized, flags=re.IGNORECASE)) != 1:
+ return False
+ if len(re.findall(r"\bWITH\b", normalized, flags=re.IGNORECASE)) != 1:
+ return False
+ if re.search(r"\bRETURN\s+count\s*\(\s*\*\s*\)", normalized, flags=re.IGNORECASE):
+ return False
+ return bool(
+ re.search(
+ r"\bMATCH\b.+\bWITH\b.+\bOPTIONAL\s+MATCH\b.+\bRETURN\b",
+ normalized,
+ flags=re.IGNORECASE,
+ )
+ )
+
+
+def _translate_standalone_optional_match(
+ query: str,
+ graph_name: str,
+ node_label_map: dict[str, list[str]] | None,
+ edge_label_map: dict[str, list[str]] | None,
+ property_type_map: dict[str, dict[str, str]] | None,
+ node_primary_key_map: dict[str, str] | None,
+ edge_primary_key_map: dict[str, str] | None,
+ strict_property_validation: bool,
+) -> tuple[str, str] | None:
+ if not _is_supported_standalone_optional_match(query):
+ return None
+
+ match_query = re.sub(
+ r"^\s*OPTIONAL\s+MATCH\b",
+ "MATCH",
+ query,
+ count=1,
+ flags=re.IGNORECASE,
+ )
+ translated, category = cypher2oracle_sqlpgq(
+ match_query,
+ graph_name=graph_name,
+ node_label_map=node_label_map,
+ edge_label_map=edge_label_map,
+ property_type_map=property_type_map,
+ node_primary_key_map=node_primary_key_map,
+ edge_primary_key_map=edge_primary_key_map,
+ strict_property_validation=strict_property_validation,
+ )
+ if category != "Graph-IL Translatable":
+ return "Unable to Translate to Oracle SQL/PGQ", category
+
+ wrapped = _wrap_standalone_optional_sql(translated)
+ if wrapped is None:
+ return "Unable to Translate to Oracle SQL/PGQ", "Graph-IL Not Support"
+ if OracleSqlPgqQueryTranslator(graph_name=graph_name).grammar_check(wrapped):
+ return wrapped, "Graph-IL Translatable"
+ return "Unable to Translate to Oracle SQL/PGQ", "No Related Oracle SQL/PGQ Standard"
+
+
+def _is_supported_standalone_optional_match(query: str) -> bool:
+ normalized = " ".join(str(query or "").split())
+ if not _starts_with_optional_match(normalized):
+ return False
+ if len(re.findall(r"\bOPTIONAL\s+MATCH\b", normalized, flags=re.IGNORECASE)) != 1:
+ return False
+ if re.search(r"\bWITH\b", normalized, flags=re.IGNORECASE):
+ return False
+ if re.search(r"\bRETURN\s+count\s*\(\s*\*\s*\)", normalized, flags=re.IGNORECASE):
+ return False
+ return True
+
+
+def _starts_with_optional_match(query: str) -> bool:
+ return bool(
+ re.match(
+ r"^\s*OPTIONAL\s+MATCH\b",
+ str(query or ""),
+ flags=re.IGNORECASE,
+ )
+ )
+
+
+def _contains_optional_match(query: str) -> bool:
+ return bool(
+ re.search(
+ r"\bOPTIONAL\s+MATCH\b",
+ str(query or ""),
+ flags=re.IGNORECASE,
+ )
+ )
+
+
+def _wrap_standalone_optional_sql(translated: str) -> str | None:
+ body, tail = _split_trailing_order_and_paging(translated)
+ select_items = _top_level_select_items(body)
+ if not select_items:
+ return None
+
+ output_aliases: list[str] = []
+ fallback_items: list[str] = []
+ if len(select_items) == 1 and select_items[0] == "*":
+ output_aliases = _graph_table_column_aliases(body)
+ fallback_items = [f"NULL AS {alias}" for alias in output_aliases]
+ if not output_aliases:
+ return None
+ else:
+ for item in select_items:
+ alias = _select_item_alias(item)
+ if alias is None:
+ return None
+ expression = _select_item_expression(item)
+ if re.fullmatch(r"COUNT\s*\(\s*\*\s*\)", expression, flags=re.IGNORECASE):
+ return None
+ fallback_value = (
+ "0" if re.match(r"COUNT\s*\(", expression, flags=re.IGNORECASE) else "NULL"
+ )
+ output_aliases.append(alias)
+ fallback_items.append(f"{fallback_value} AS {alias}")
+
+ final_projection = ", ".join(output_aliases)
+ fallback_projection = ", ".join(fallback_items)
+ return (
+ "WITH optional_rows AS (\n"
+ f"{_indent_sql(body)}\n"
+ "),\n"
+ "optional_result AS (\n"
+ f" SELECT {final_projection}\n"
+ " FROM optional_rows\n"
+ " UNION ALL\n"
+ f" SELECT {fallback_projection}\n"
+ " FROM DUAL\n"
+ " WHERE NOT EXISTS (SELECT 1 FROM optional_rows)\n"
+ ")\n"
+ f"SELECT {final_projection}\n"
+ "FROM optional_result"
+ f"{tail}"
+ )
+
+
+def _split_trailing_order_and_paging(sql: str) -> tuple[str, str]:
+ indexes = [
+ match.start()
+ for pattern in (r"\nORDER\s+BY\b", r"\nOFFSET\b", r"\nFETCH\b")
+ if (match := re.search(pattern, sql, flags=re.IGNORECASE))
+ ]
+ if not indexes:
+ return sql.rstrip(), ""
+ index = min(indexes)
+ return sql[:index].rstrip(), sql[index:]
+
+
+def _top_level_select_items(sql: str) -> list[str]:
+ match = re.match(r"\s*SELECT\s+(?:DISTINCT\s+)?", sql, flags=re.IGNORECASE)
+ if not match:
+ return []
+ from_match = re.search(r"\nFROM\b", sql[match.end() :], flags=re.IGNORECASE)
+ if not from_match:
+ return []
+ select_body = sql[match.end() : match.end() + from_match.start()]
+ return _split_top_level_commas(select_body)
+
+
+def _select_item_alias(item: str) -> str | None:
+ match = re.search(
+ r"\bAS\s+([A-Za-z_][A-Za-z0-9_$#]*)\s*$",
+ item,
+ flags=re.IGNORECASE,
+ )
+ if match:
+ return match.group(1)
+ if re.fullmatch(r"[A-Za-z_][A-Za-z0-9_$#]*", item.strip()):
+ return item.strip()
+ return None
+
+
+def _select_item_expression(item: str) -> str:
+ match = re.search(r"\bAS\b", item, flags=re.IGNORECASE)
+ if not match:
+ return item.strip()
+ return item[: match.start()].strip()
+
+
+def _translate_union_query(
+ query: str,
+ graph_name: str,
+ node_label_map: dict[str, list[str]] | None,
+ edge_label_map: dict[str, list[str]] | None,
+ property_type_map: dict[str, dict[str, str]] | None,
+ node_primary_key_map: dict[str, str] | None,
+ edge_primary_key_map: dict[str, str] | None,
+ strict_property_validation: bool,
+) -> tuple[str, str] | None:
+ branches, operators = _split_top_level_unions(query)
+ if len(branches) == 1:
+ return None
+
+ translated_branches: list[str] = []
+ branch_aliases: list[list[str]] = []
+ for branch in branches:
+ translated, category = cypher2oracle_sqlpgq(
+ branch,
+ graph_name=graph_name,
+ node_label_map=node_label_map,
+ edge_label_map=edge_label_map,
+ property_type_map=property_type_map,
+ node_primary_key_map=node_primary_key_map,
+ edge_primary_key_map=edge_primary_key_map,
+ strict_property_validation=strict_property_validation,
+ )
+ if category != "Graph-IL Translatable":
+ return "Unable to Translate to Oracle SQL/PGQ", category
+ aliases = _graph_table_column_aliases(translated)
+ if not aliases:
+ return "Unable to Translate to Oracle SQL/PGQ", "Graph-IL Not Support"
+ translated_branches.append(translated)
+ branch_aliases.append(aliases)
+
+ output_aliases: list[str] = []
+ for aliases in branch_aliases:
+ for alias in aliases:
+ if alias not in output_aliases:
+ output_aliases.append(alias)
+
+ sql_branches = [
+ _wrap_union_branch(translated, aliases, output_aliases, index)
+ for index, (translated, aliases) in enumerate(
+ zip(translated_branches, branch_aliases, strict=True),
+ start=1,
+ )
+ ]
+ union_sql = sql_branches[0]
+ for operator, branch_sql in zip(operators, sql_branches[1:], strict=True):
+ union_sql += f"\n{operator}\n{branch_sql}"
+
+ if OracleSqlPgqQueryTranslator(graph_name=graph_name).grammar_check(union_sql):
+ return union_sql, "Graph-IL Translatable"
+ return "Unable to Translate to Oracle SQL/PGQ", "No Related Oracle SQL/PGQ Standard"
+
+
+def _wrap_union_branch(
+ translated: str,
+ branch_aliases: list[str],
+ output_aliases: list[str],
+ index: int,
+) -> str:
+ select_items = [
+ alias if alias in branch_aliases else f"NULL AS {alias}" for alias in output_aliases
+ ]
+ return (
+ "SELECT "
+ + ", ".join(select_items)
+ + f"\nFROM (\n{_indent_sql(translated)}\n) union_branch_{index}"
+ )
+
+
+def _indent_sql(sql: str) -> str:
+ return "\n".join(" " + line for line in sql.splitlines())
+
+
+def _graph_table_column_aliases(sql: str) -> list[str]:
+ start = re.search(r"\bCOLUMNS\s*\(", sql, flags=re.IGNORECASE)
+ if not start:
+ return []
+ index = start.end()
+ depth = 1
+ in_single = False
+ in_double = False
+ body_start = index
+ while index < len(sql):
+ char = sql[index]
+ if char == "'" and not in_double:
+ if index + 1 < len(sql) and sql[index + 1] == "'":
+ index += 2
+ continue
+ in_single = not in_single
+ elif char == '"' and not in_single:
+ in_double = not in_double
+ elif not in_single and not in_double:
+ if char == "(":
+ depth += 1
+ elif char == ")":
+ depth -= 1
+ if depth == 0:
+ return _projection_aliases(sql[body_start:index])
+ index += 1
+ return []
+
+
+def _projection_aliases(columns_body: str) -> list[str]:
+ projections = _split_top_level_commas(columns_body)
+ aliases = []
+ for projection in projections:
+ match = re.search(
+ r"\bAS\s+([A-Za-z_][A-Za-z0-9_$#]*)\s*$",
+ projection.strip(),
+ flags=re.IGNORECASE,
+ )
+ if not match:
+ return []
+ aliases.append(match.group(1))
+ return aliases
+
+
+def _split_top_level_commas(text: str) -> list[str]:
+ parts: list[str] = []
+ start = 0
+ depth = 0
+ in_single = False
+ in_double = False
+ index = 0
+ while index < len(text):
+ char = text[index]
+ if char == "'" and not in_double:
+ if index + 1 < len(text) and text[index + 1] == "'":
+ index += 2
+ continue
+ in_single = not in_single
+ elif char == '"' and not in_single:
+ in_double = not in_double
+ elif not in_single and not in_double:
+ if char == "(":
+ depth += 1
+ elif char == ")":
+ depth -= 1
+ elif char == "," and depth == 0:
+ parts.append(text[start:index].strip())
+ start = index + 1
+ index += 1
+ tail = text[start:].strip()
+ if tail:
+ parts.append(tail)
+ return parts
+
+
+def _split_top_level_unions(query: str) -> tuple[list[str], list[str]]:
+ branches: list[str] = []
+ operators: list[str] = []
+ start = 0
+ index = 0
+ depth = 0
+ in_single = False
+ in_double = False
+ in_backtick = False
+ while index < len(query):
+ char = query[index]
+ if char == "`" and not in_single and not in_double:
+ in_backtick = not in_backtick
+ index += 1
+ continue
+ if char == "'" and not in_double and not in_backtick:
+ if index + 1 < len(query) and query[index + 1] == "'":
+ index += 2
+ continue
+ in_single = not in_single
+ index += 1
+ continue
+ if char == '"' and not in_single and not in_backtick:
+ in_double = not in_double
+ index += 1
+ continue
+ if in_single or in_double or in_backtick:
+ index += 1
+ continue
+ if char in "([{":
+ depth += 1
+ index += 1
+ continue
+ if char in ")]}":
+ depth -= 1
+ index += 1
+ continue
+ if depth == 0 and _keyword_at(query, index, "UNION"):
+ operator, end = _union_operator_at(query, index)
+ branches.append(query[start:index].strip())
+ operators.append(operator)
+ start = end
+ index = end
+ continue
+ index += 1
+ if not operators:
+ return [query], []
+ branches.append(query[start:].strip())
+ if any(not branch for branch in branches):
+ return [query], []
+ return branches, operators
+
+
+def _union_operator_at(query: str, index: int) -> tuple[str, int]:
+ end = index + len("UNION")
+ while end < len(query) and query[end].isspace():
+ end += 1
+ if _keyword_at(query, end, "ALL"):
+ return "UNION ALL", end + len("ALL")
+ return "UNION", end
+
+
+def _keyword_at(text: str, index: int, keyword: str) -> bool:
+ end = index + len(keyword)
+ if text[index:end].upper() != keyword:
+ return False
+ before = text[index - 1] if index > 0 else ""
+ after = text[end] if end < len(text) else ""
+ return not (_is_identifier_char(before) or _is_identifier_char(after))
+
+
+def _is_identifier_char(char: str) -> bool:
+ return bool(char and (char.isalnum() or char in "_$#"))
+
+
+def _translate_queries(queries: list[str], graph_name: str) -> list[dict[str, str]]:
+ output_query_list = []
+ for query in queries:
+ translated_query, category = cypher2oracle_sqlpgq(query, graph_name=graph_name)
+ output_query_list.append(
+ {
+ "cypher": query,
+ "oracle_sqlpgq": translated_query,
+ "category": category,
+ }
+ )
+ return output_query_list
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(
+ description=(
+ "Translate supported Cypher queries into Oracle SQL/PGQ GRAPH_TABLE syntax. "
+ f"Supported scope: {SUPPORTED_SCOPE}"
+ )
+ )
+ parser.add_argument("--query", help="Single Cypher query to translate.")
+ parser.add_argument("--input", type=Path, help="JSON file containing a list of Cypher queries.")
+ parser.add_argument(
+ "--output",
+ type=Path,
+ default=Path("test_oracle_sqlpgq_query.json"),
+ help="Output JSON file when --input is used.",
+ )
+ parser.add_argument("--graph-name", default="GRAPH", help="Oracle property graph name.")
+ args = parser.parse_args()
+
+ if args.query:
+ translated_query, category = cypher2oracle_sqlpgq(args.query, graph_name=args.graph_name)
+ print(json.dumps({"query": translated_query, "category": category}, indent=2))
+ return
+
+ if not args.input:
+ parser.error("Provide either --query or --input.")
+
+ with open(args.input, encoding="utf-8") as file:
+ query_list = json.load(file)
+
+ output_query_list = _translate_queries(query_list, graph_name=args.graph_name)
+ with open(args.output, "w", encoding="utf-8") as file:
+ json.dump(output_query_list, file, ensure_ascii=False, indent=4)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/generalize_oracle_sqlpgq_from_cypher.py b/examples/generalize_oracle_sqlpgq_from_cypher.py
new file mode 100644
index 0000000..19cfa87
--- /dev/null
+++ b/examples/generalize_oracle_sqlpgq_from_cypher.py
@@ -0,0 +1,60 @@
+import argparse
+import json
+from pathlib import Path
+
+from app.impl.oracle_sqlpgq.generator.query_generalizer import (
+ OracleSqlPgqQueryGeneralizer,
+)
+from app.impl.tugraph_cypher.ast_visitor.tugraph_cypher_ast_visitor import (
+ TugraphCypherAstVisitor,
+)
+from app.impl.tugraph_cypher.translator.tugraph_cypher_query_translator import (
+ TugraphCypherQueryTranslator,
+)
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(
+ description=(
+ "Generalize a Cypher path shape over an Oracle SQL/PGQ manifest and "
+ "emit Oracle GRAPH_TABLE queries."
+ )
+ )
+ parser.add_argument(
+ "--query",
+ default="MATCH (a)-[e]->(b) RETURN b",
+ help="Seed Cypher query whose path length/hop ranges define the pattern.",
+ )
+ parser.add_argument(
+ "--manifest",
+ type=Path,
+ default=Path("examples/Oracle_SQLPGQ_Instance/TEXT2GQL_GRAPH_oracle_schema.json"),
+ )
+ parser.add_argument(
+ "--output",
+ type=Path,
+ default=Path("examples/generated_corpus/oracle_sqlpgq_generalized_queries.json"),
+ )
+ parser.add_argument("--target-size", type=int, default=25)
+ args = parser.parse_args()
+
+ cypher_translator = TugraphCypherQueryTranslator()
+ if not cypher_translator.grammar_check(args.query):
+ raise ValueError("Seed query does not pass the current Cypher grammar check.")
+
+ success, query_pattern = TugraphCypherAstVisitor().get_query_pattern(args.query)
+ if not success:
+ raise ValueError("Seed query is outside the current Graph-IL visitor subset.")
+
+ generalizer = OracleSqlPgqQueryGeneralizer.from_file(args.manifest)
+ generated = generalizer.generalize_dicts(query_pattern, target_size=args.target_size)
+
+ args.output.parent.mkdir(parents=True, exist_ok=True)
+ with open(args.output, "w", encoding="utf-8") as file:
+ json.dump(generated, file, indent=2, ensure_ascii=False)
+ print(f"Saved {len(generated)} generalized Oracle SQL/PGQ queries to {args.output}")
+
+
+if __name__ == "__main__":
+ main()
+
diff --git a/examples/generate_corpus_oracle_sqlpgq.py b/examples/generate_corpus_oracle_sqlpgq.py
new file mode 100644
index 0000000..7ba253f
--- /dev/null
+++ b/examples/generate_corpus_oracle_sqlpgq.py
@@ -0,0 +1,285 @@
+"""
+Generate an Oracle SQL/PGQ corpus with live Oracle validation.
+
+Required environment variables:
+- ORACLE_DSN
+- ORACLE_USER
+- ORACLE_PASSWORD
+"""
+
+import json
+import logging
+import os
+from pathlib import Path
+from typing import Any
+
+from app.core.generator.corpus_generator import CorpusGenerator
+from app.core.llm.llm_client import LlmClient
+from app.core.validator.db_client import QueryResult, QueryStatus
+from app.core.validator.validator import CorpusValidator
+
+logging.basicConfig(
+ level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
+)
+logger = logging.getLogger("OracleSqlPgqCorpusGenerator")
+
+
+def summarize_result(result: QueryResult) -> str:
+ if len(str(result.data)) > 500:
+ return str(result.data)[:500] + "..."
+ return str(result.data)
+
+
+def save_json(data: list[dict[str, Any]], file_path: Path) -> None:
+ with open(file_path, "w", encoding="utf-8") as file:
+ json.dump(data, file, indent=2, ensure_ascii=False)
+ logger.info("Saved %s records to %s", len(data), file_path)
+
+
+def validate_and_repair_pairs(
+ *,
+ generator: CorpusGenerator,
+ validator: CorpusValidator,
+ schema_json: str,
+ raw_pairs: list[dict[str, str]],
+ max_repair_attempts: int,
+) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
+ client = validator._get_client()
+ valid_pairs = []
+ rejected_pairs = []
+
+ for pair in raw_pairs:
+ candidate = pair
+ last_failure: dict[str, Any] | None = None
+ for attempt in range(max_repair_attempts + 1):
+ result = client.execute_query(candidate.get("query", ""))
+ if result.status_code == QueryStatus.SUCCESS:
+ valid_pairs.append(
+ {
+ "question": candidate["question"],
+ "query": candidate["query"],
+ "result": summarize_result(result),
+ }
+ )
+ break
+
+ last_failure = {
+ "question": candidate.get("question", ""),
+ "query": candidate.get("query", ""),
+ "status_code": result.status_code,
+ "error": result.error,
+ "attempt": attempt,
+ }
+ if attempt >= max_repair_attempts:
+ rejected_pairs.append(last_failure)
+ break
+
+ error_context = (
+ "Previous query failed Oracle validation. Correct it.\n"
+ f"Previous query:\n{candidate.get('query', '')}\n"
+ f"Oracle status code: {result.status_code}\n"
+ f"Oracle error:\n{result.error}\n"
+ )
+ repaired = generator.generate_translation_batch(
+ schema_json=schema_json,
+ questions=[candidate.get("question", "")],
+ error_context=error_context,
+ )
+ if not repaired:
+ rejected_pairs.append(last_failure)
+ break
+ candidate = repaired[0]
+
+ return valid_pairs, rejected_pairs
+
+
+def main():
+ graph_name = "TEXT2GQL_GRAPH"
+ schema_file = Path("examples/Oracle_SQLPGQ_Instance/TEXT2GQL_GRAPH_oracle_schema.json")
+ output_file = Path("examples/generated_corpus/oracle_sqlpgq_example_corpus.json")
+ raw_output_file = output_file.with_name(f"{output_file.stem}_raw{output_file.suffix}")
+ rejected_output_file = output_file.with_name(
+ f"{output_file.stem}_rejected{output_file.suffix}"
+ )
+ validated_output_file = output_file.with_name(
+ f"{output_file.stem}_validated_with_results{output_file.suffix}"
+ )
+ seed_validation_output_file = output_file.with_name(
+ f"{output_file.stem}_seed_validated_with_results{output_file.suffix}"
+ )
+ output_file.parent.mkdir(parents=True, exist_ok=True)
+
+ with open(schema_file, encoding="utf-8") as file:
+ schema_json = json.dumps(json.load(file), ensure_ascii=False)
+
+ validator = CorpusValidator(
+ backend="oracle_sqlpgq",
+ db_client_params={
+ "dsn": os.environ["ORACLE_DSN"],
+ "user": os.environ["ORACLE_USER"],
+ "password": os.environ["ORACLE_PASSWORD"],
+ },
+ )
+ generator = CorpusGenerator(
+ llm_client=LlmClient(
+ model=os.getenv("LLM_MODEL", "qwen3-coder-plus-2025-07-22"),
+ platform=os.getenv("LLM_PLATFORM", ""),
+ ),
+ query_language="oracle_sqlpgq",
+ graph_name=graph_name,
+ )
+
+ seed_queries = [
+ {
+ "question": "Show a small sample of connected graph elements.",
+ "query": (
+ 'SELECT * FROM GRAPH_TABLE ("TEXT2GQL_GRAPH" '
+ 'MATCH (a)-[e]->(b) COLUMNS (VERTEX_ID(a) AS A_ID, VERTEX_ID(b) AS B_ID)) gt '
+ "FETCH FIRST 5 ROWS ONLY"
+ ),
+ },
+ {
+ "question": "Which movies belong to the Science Fiction genre?",
+ "query": (
+ 'SELECT * FROM GRAPH_TABLE ("TEXT2GQL_GRAPH" '
+ 'MATCH (m IS "MOVIE")-[b IS "BELONGS_TO"]->(g IS "GENRE") '
+ 'WHERE g."name" = \'Science Fiction\' '
+ 'COLUMNS (m."title" AS movie_title, g."name" AS genre_name)) gt'
+ ),
+ },
+ {
+ "question": "Which users gave ratings above 4.5, and what movies were rated?",
+ "query": (
+ 'SELECT * FROM GRAPH_TABLE ("TEXT2GQL_GRAPH" '
+ 'MATCH (u IS "USER")-[ur IS "RATES"]->(r IS "RATING")-[rm IS "RATES"]->(m IS "MOVIE") '
+ 'WHERE r."score" > 4.5 '
+ 'COLUMNS (u."USER_id" AS user_id, m."title" AS movie_title, r."score" AS score)) gt'
+ ),
+ },
+ {
+ "question": "Which movie was tagged as classic?",
+ "query": (
+ 'SELECT * FROM GRAPH_TABLE ("TEXT2GQL_GRAPH" '
+ 'MATCH (u IS "USER")-[ut IS "TAGS"]->(t IS "TAG")-[tm IS "TAGS"]->(m IS "MOVIE") '
+ 'WHERE t."text_content" = \'classic\' '
+ 'COLUMNS (u."USER_id" AS user_id, m."title" AS movie_title, t."text_content" AS tag_text)) gt'
+ ),
+ },
+ {
+ "question": "Which users are friends with user 1?",
+ "query": (
+ 'SELECT * FROM GRAPH_TABLE ("TEXT2GQL_GRAPH" '
+ 'MATCH (u IS "USER")-[f IS "FRIENDS_WITH"]->(friend IS "USER") '
+ 'WHERE u."USER_id" = 1 '
+ 'COLUMNS (friend."USER_id" AS friend_id, f."connection_strength" AS connection_strength)) gt'
+ ),
+ },
+ {
+ "question": "Which movies are similar to Inception?",
+ "query": (
+ 'SELECT * FROM GRAPH_TABLE ("TEXT2GQL_GRAPH" '
+ 'MATCH (m1 IS "MOVIE")-[s IS "SIMILAR_TO"]->(m2 IS "MOVIE") '
+ 'WHERE m1."title" = \'Inception\' '
+ 'COLUMNS (m2."title" AS similar_movie, s."similarity_score" AS similarity_score)) gt'
+ ),
+ },
+ {
+ "question": "Which Science Fiction movies found through the graph can be joined back to the movie table with their release year?",
+ "query": (
+ "WITH graph_movies AS ("
+ 'SELECT gt.movie_id, gt.movie_title FROM GRAPH_TABLE ("TEXT2GQL_GRAPH" '
+ 'MATCH (m IS "MOVIE")-[b IS "BELONGS_TO"]->(g IS "GENRE") '
+ 'WHERE g."name" = \'Science Fiction\' '
+ 'COLUMNS (m."MOVIE_id" AS movie_id, m."title" AS movie_title)) gt'
+ ") "
+ 'SELECT gm.movie_title, m."release_year" AS release_year '
+ 'FROM graph_movies gm JOIN "MOVIE" m ON m."MOVIE_id" = gm.movie_id '
+ 'ORDER BY m."release_year" DESC'
+ ),
+ },
+ {
+ "question": "Show movies that come either from the base movie table after 2010 or from the graph as Action movies.",
+ "query": (
+ 'SELECT \'TABLE\' AS source_type, m."title" AS movie_title FROM "MOVIE" m '
+ 'WHERE m."release_year" >= 2010 '
+ "UNION ALL "
+ 'SELECT \'GRAPH\' AS source_type, gt.movie_title FROM GRAPH_TABLE ("TEXT2GQL_GRAPH" '
+ 'MATCH (m IS "MOVIE")-[b IS "BELONGS_TO"]->(g IS "GENRE") '
+ 'WHERE g."name" = \'Action\' '
+ 'COLUMNS (m."title" AS movie_title)) gt'
+ ),
+ },
+ {
+ "question": "Rank the highest graph ratings and show the user, movie, score, and rank.",
+ "query": (
+ 'SELECT gt.user_id, gt.movie_title, gt.score, '
+ "ROW_NUMBER() OVER (ORDER BY gt.score DESC) AS score_rank "
+ 'FROM GRAPH_TABLE ("TEXT2GQL_GRAPH" '
+ 'MATCH (u IS "USER")-[ur IS "RATES"]->(r IS "RATING")-[rm IS "RATES"]->(m IS "MOVIE") '
+ 'COLUMNS (u."USER_id" AS user_id, m."title" AS movie_title, r."score" AS score)) gt '
+ "ORDER BY score_rank FETCH FIRST 5 ROWS ONLY"
+ ),
+ },
+ {
+ "question": "Which movies have at least one rating, and what is their average score?",
+ "query": (
+ 'SELECT gt.movie_title, AVG(gt.score) AS average_score, COUNT(*) AS rating_count '
+ 'FROM GRAPH_TABLE ("TEXT2GQL_GRAPH" '
+ 'MATCH (u IS "USER")-[ur IS "RATES"]->(r IS "RATING")-[rm IS "RATES"]->(m IS "MOVIE") '
+ 'COLUMNS (m."title" AS movie_title, r."score" AS score)) gt '
+ "GROUP BY gt.movie_title HAVING COUNT(*) >= 1 ORDER BY average_score DESC"
+ ),
+ },
+ {
+ "question": "Show one row per friendship step for paths up to three hops between users.",
+ "query": (
+ 'SELECT * FROM GRAPH_TABLE ("TEXT2GQL_GRAPH" '
+ 'MATCH (start_user IS "USER")-[friend_path IS "FRIENDS_WITH"]->{1,3}(end_user IS "USER") '
+ "ONE ROW PER STEP (step_source, step_edge, step_target) "
+ 'COLUMNS (MATCHNUM() AS match_number, ELEMENT_NUMBER(step_edge) AS step_number, '
+ 'step_source."USER_id" AS source_user_id, EDGE_ID(step_edge) AS edge_id, '
+ 'step_target."USER_id" AS target_user_id)) gt FETCH FIRST 10 ROWS ONLY'
+ ),
+ },
+ ]
+ seed_context = validator.execute_with_results(seed_queries)
+ if not seed_context:
+ raise RuntimeError("No seed queries validated successfully; cannot generate corpus.")
+ save_json(seed_context, seed_validation_output_file)
+
+ target_size = int(os.getenv("ORACLE_SQLPGQ_CORPUS_SIZE", "10"))
+ if target_size <= 0:
+ logger.info("Seed validation only: %s seed queries validated successfully.", len(seed_context))
+ return
+
+ num_per_iteration = int(os.getenv("ORACLE_SQLPGQ_NUM_PER_ITERATION", "3"))
+ max_repair_attempts = int(os.getenv("ORACLE_SQLPGQ_REPAIR_ATTEMPTS", "1"))
+ raw_pairs = generator.run_generation_loop(
+ schema_json=schema_json,
+ seeds_corpus_with_context=seed_context,
+ num_per_iteration=num_per_iteration,
+ complexity_corpus_size=target_size,
+ )
+ save_json(raw_pairs, raw_output_file)
+ valid_pairs, rejected_pairs = validate_and_repair_pairs(
+ generator=generator,
+ validator=validator,
+ schema_json=schema_json,
+ raw_pairs=raw_pairs,
+ max_repair_attempts=max_repair_attempts,
+ )
+ save_json(rejected_pairs, rejected_output_file)
+ save_json(valid_pairs, validated_output_file)
+ save_json(
+ [{"question": item["question"], "query": item["query"]} for item in valid_pairs],
+ output_file,
+ )
+ logger.info(
+ "Validated Oracle SQL/PGQ corpus: %s valid, %s rejected",
+ len(valid_pairs),
+ len(rejected_pairs),
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/generate_corpus_oracle_sqlpgq_fraud.py b/examples/generate_corpus_oracle_sqlpgq_fraud.py
new file mode 100644
index 0000000..1142a15
--- /dev/null
+++ b/examples/generate_corpus_oracle_sqlpgq_fraud.py
@@ -0,0 +1,227 @@
+"""
+Generate an Oracle SQL/PGQ corpus for the fraud/payments graph with live validation.
+
+Required environment variables:
+- ORACLE_DSN
+- ORACLE_USER
+- ORACLE_PASSWORD
+"""
+
+import json
+import logging
+import os
+from pathlib import Path
+
+from app.core.generator.corpus_generator import CorpusGenerator
+from app.core.llm.llm_client import LlmClient
+from app.core.validator.validator import CorpusValidator
+from examples.generate_corpus_oracle_sqlpgq import save_json, validate_and_repair_pairs
+from examples.setup_oracle_sqlpgq_fraud_db import GRAPH_NAME, build_manifest
+
+
+logging.basicConfig(
+ level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
+)
+logger = logging.getLogger("OracleSqlPgqFraudCorpusGenerator")
+
+
+def env_int(name: str, default: int) -> int:
+ return int(os.getenv(name, os.getenv(name.replace("FRAUD_", ""), str(default))))
+
+
+def main() -> None:
+ output_file = Path("examples/generated_corpus/oracle_sqlpgq_fraud_corpus.json")
+ raw_output_file = output_file.with_name(f"{output_file.stem}_raw{output_file.suffix}")
+ rejected_output_file = output_file.with_name(
+ f"{output_file.stem}_rejected{output_file.suffix}"
+ )
+ validated_output_file = output_file.with_name(
+ f"{output_file.stem}_validated_with_results{output_file.suffix}"
+ )
+ seed_validation_output_file = output_file.with_name(
+ f"{output_file.stem}_seed_validated_with_results{output_file.suffix}"
+ )
+ output_file.parent.mkdir(parents=True, exist_ok=True)
+
+ schema_json = json.dumps(build_manifest()["schema"], ensure_ascii=False)
+ validator = CorpusValidator(
+ backend="oracle_sqlpgq",
+ db_client_params={
+ "dsn": os.environ["ORACLE_DSN"],
+ "user": os.environ["ORACLE_USER"],
+ "password": os.environ["ORACLE_PASSWORD"],
+ },
+ )
+ generator = CorpusGenerator(
+ llm_client=LlmClient(
+ model=os.getenv("LLM_MODEL", "qwen3-coder-plus-2025-07-22"),
+ platform=os.getenv("LLM_PLATFORM", ""),
+ ),
+ query_language="oracle_sqlpgq",
+ graph_name=GRAPH_NAME,
+ )
+
+ seed_queries = [
+ {
+ "question": "Show a small sample of connected fraud graph elements.",
+ "query": (
+ f'SELECT * FROM GRAPH_TABLE ("{GRAPH_NAME}" '
+ 'MATCH (a)-[e]->(b) COLUMNS (VERTEX_ID(a) AS source_id, EDGE_ID(e) AS edge_id, VERTEX_ID(b) AS destination_id)) gt '
+ "FETCH FIRST 5 ROWS ONLY"
+ ),
+ },
+ {
+ "question": "Which high-risk customers initiated fraudulent transactions above 1000?",
+ "query": (
+ f'SELECT * FROM GRAPH_TABLE ("{GRAPH_NAME}" '
+ 'MATCH (c IS "CUSTOMER")-[o IS "OWNS"]->(a IS "ACCOUNT")-[i IS "INITIATED"]->(t IS "TRANSACTION") '
+ 'WHERE c."RISK_SCORE" > 0.7 AND t."AMOUNT" > 1000 AND t."FRAUD_FLAG" = 1 '
+ 'COLUMNS (c."NAME" AS customer_name, t."TRANSACTION_ID" AS transaction_id, t."AMOUNT" AS amount)) gt'
+ ),
+ },
+ {
+ "question": "Which merchants received online transactions over 500?",
+ "query": (
+ f'SELECT * FROM GRAPH_TABLE ("{GRAPH_NAME}" '
+ 'MATCH (a IS "ACCOUNT")-[i IS "INITIATED"]->(t IS "TRANSACTION")-[p IS "PAID"]->(m IS "MERCHANT") '
+ 'WHERE t."CHANNEL" = \'online\' AND t."AMOUNT" > 500 '
+ 'COLUMNS (m."NAME" AS merchant_name, m."CATEGORY" AS category, t."AMOUNT" AS amount)) gt'
+ ),
+ },
+ {
+ "question": "Which transactions used an untrusted device?",
+ "query": (
+ f'SELECT * FROM GRAPH_TABLE ("{GRAPH_NAME}" '
+ 'MATCH (t IS "TRANSACTION")-[u IS "USED_DEVICE"]->(d IS "DEVICE") '
+ 'WHERE d."TRUSTED" = 0 '
+ 'COLUMNS (t."TRANSACTION_ID" AS transaction_id, t."AMOUNT" AS amount, d."DEVICE_TYPE" AS device_type)) gt'
+ ),
+ },
+ {
+ "question": "Which customers live in Casablanca?",
+ "query": (
+ f'SELECT * FROM GRAPH_TABLE ("{GRAPH_NAME}" '
+ 'MATCH (c IS "CUSTOMER")-[l IS "LIVES_IN"]->(city IS "CITY") '
+ 'WHERE city."NAME" = \'Casablanca\' '
+ 'COLUMNS (c."NAME" AS customer_name, c."SEGMENT" AS segment)) gt'
+ ),
+ },
+ {
+ "question": "What merchant categories have fraudulent transactions?",
+ "query": (
+ f'SELECT gt.category, COUNT(*) AS fraud_transaction_count FROM GRAPH_TABLE ("{GRAPH_NAME}" '
+ 'MATCH (t IS "TRANSACTION")-[p IS "PAID"]->(m IS "MERCHANT") '
+ 'WHERE t."FRAUD_FLAG" = 1 '
+ 'COLUMNS (m."CATEGORY" AS category, t."TRANSACTION_ID" AS transaction_id)) gt '
+ "GROUP BY gt.category"
+ ),
+ },
+ {
+ "question": "Which fraudulent graph transactions belong to customers whose base table risk score is above 0.7?",
+ "query": (
+ "WITH graph_tx AS ("
+ f'SELECT gt.customer_id, gt.customer_name, gt.transaction_id, gt.amount FROM GRAPH_TABLE ("{GRAPH_NAME}" '
+ 'MATCH (c IS "CUSTOMER")-[o IS "OWNS"]->(a IS "ACCOUNT")-[i IS "INITIATED"]->(t IS "TRANSACTION") '
+ 'WHERE t."FRAUD_FLAG" = 1 '
+ 'COLUMNS (c."CUSTOMER_ID" AS customer_id, c."NAME" AS customer_name, '
+ 't."TRANSACTION_ID" AS transaction_id, t."AMOUNT" AS amount)) gt'
+ ") "
+ 'SELECT graph_tx.customer_name, c."RISK_SCORE" AS risk_score, graph_tx.transaction_id, graph_tx.amount '
+ 'FROM graph_tx JOIN "CUSTOMER" c ON c."CUSTOMER_ID" = graph_tx.customer_id '
+ 'WHERE c."RISK_SCORE" > 0.7'
+ ),
+ },
+ {
+ "question": "Combine high-risk customers from the customer table with customers found through fraudulent graph transactions.",
+ "query": (
+ 'SELECT \'HIGH_RISK_TABLE\' AS source_type, c."NAME" AS customer_name FROM "CUSTOMER" c '
+ 'WHERE c."RISK_SCORE" > 0.8 '
+ "UNION ALL "
+ f'SELECT \'FRAUD_GRAPH\' AS source_type, gt.customer_name FROM GRAPH_TABLE ("{GRAPH_NAME}" '
+ 'MATCH (c IS "CUSTOMER")-[o IS "OWNS"]->(a IS "ACCOUNT")-[i IS "INITIATED"]->(t IS "TRANSACTION") '
+ 'WHERE t."FRAUD_FLAG" = 1 '
+ 'COLUMNS (c."NAME" AS customer_name)) gt'
+ ),
+ },
+ {
+ "question": "Rank fraudulent graph transactions by amount within each merchant category.",
+ "query": (
+ "SELECT gt.category, gt.transaction_id, gt.amount, "
+ "ROW_NUMBER() OVER (PARTITION BY gt.category ORDER BY gt.amount DESC) AS category_amount_rank "
+ f'FROM GRAPH_TABLE ("{GRAPH_NAME}" '
+ 'MATCH (t IS "TRANSACTION")-[p IS "PAID"]->(m IS "MERCHANT") '
+ 'WHERE t."FRAUD_FLAG" = 1 '
+ 'COLUMNS (m."CATEGORY" AS category, t."TRANSACTION_ID" AS transaction_id, t."AMOUNT" AS amount)) gt '
+ "ORDER BY gt.category, category_amount_rank"
+ ),
+ },
+ {
+ "question": "Which merchants have at least one online graph transaction and also exist in the merchant table outside the graph?",
+ "query": (
+ f'SELECT m."NAME" AS merchant_name, m."CATEGORY" AS category '
+ 'FROM "MERCHANT" m '
+ "WHERE EXISTS ("
+ f'SELECT 1 FROM GRAPH_TABLE ("{GRAPH_NAME}" '
+ 'MATCH (t IS "TRANSACTION")-[p IS "PAID"]->(merchant IS "MERCHANT") '
+ 'WHERE t."CHANNEL" = \'online\' '
+ 'COLUMNS (merchant."MERCHANT_ID" AS merchant_id)) gt '
+ 'WHERE gt.merchant_id = m."MERCHANT_ID")'
+ ),
+ },
+ {
+ "question": "Show one row per step along the customer-account-transaction-merchant payment path.",
+ "query": (
+ f'SELECT * FROM GRAPH_TABLE ("{GRAPH_NAME}" '
+ 'MATCH (c IS "CUSTOMER")-[o IS "OWNS"]->(a IS "ACCOUNT")-[i IS "INITIATED"]->'
+ '(t IS "TRANSACTION")-[p IS "PAID"]->(m IS "MERCHANT") '
+ "ONE ROW PER STEP (step_source, step_edge, step_target) "
+ "COLUMNS (MATCHNUM() AS match_number, ELEMENT_NUMBER(step_edge) AS step_number, "
+ "VERTEX_ID(step_source) AS source_id, EDGE_ID(step_edge) AS edge_id, "
+ "VERTEX_ID(step_target) AS target_id)) gt FETCH FIRST 10 ROWS ONLY"
+ ),
+ },
+ ]
+ seed_context = validator.execute_with_results(seed_queries)
+ if not seed_context:
+ raise RuntimeError("No fraud seed queries validated successfully; cannot generate corpus.")
+ save_json(seed_context, seed_validation_output_file)
+
+ target_size = env_int("FRAUD_ORACLE_SQLPGQ_CORPUS_SIZE", 10)
+ if target_size <= 0:
+ logger.info(
+ "Fraud seed validation only: %s seed queries validated successfully.",
+ len(seed_context),
+ )
+ return
+
+ num_per_iteration = env_int("FRAUD_ORACLE_SQLPGQ_NUM_PER_ITERATION", 3)
+ max_repair_attempts = env_int("FRAUD_ORACLE_SQLPGQ_REPAIR_ATTEMPTS", 1)
+ raw_pairs = generator.run_generation_loop(
+ schema_json=schema_json,
+ seeds_corpus_with_context=seed_context,
+ num_per_iteration=num_per_iteration,
+ complexity_corpus_size=target_size,
+ )
+ save_json(raw_pairs, raw_output_file)
+ valid_pairs, rejected_pairs = validate_and_repair_pairs(
+ generator=generator,
+ validator=validator,
+ schema_json=schema_json,
+ raw_pairs=raw_pairs,
+ max_repair_attempts=max_repair_attempts,
+ )
+ save_json(rejected_pairs, rejected_output_file)
+ save_json(valid_pairs, validated_output_file)
+ save_json(
+ [{"question": item["question"], "query": item["query"]} for item in valid_pairs],
+ output_file,
+ )
+ logger.info(
+ "Validated fraud Oracle SQL/PGQ corpus: %s valid, %s rejected",
+ len(valid_pairs),
+ len(rejected_pairs),
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/generate_oracle_sqlpgq_template_corpus.py b/examples/generate_oracle_sqlpgq_template_corpus.py
new file mode 100644
index 0000000..d249c52
--- /dev/null
+++ b/examples/generate_oracle_sqlpgq_template_corpus.py
@@ -0,0 +1,48 @@
+import argparse
+import json
+from pathlib import Path
+
+from app.impl.oracle_sqlpgq.generator.template_instantiator import (
+ OracleSqlPgqTemplateInstantiator,
+)
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(
+ description="Generate deterministic Oracle SQL/PGQ corpus pairs from a schema manifest."
+ )
+ parser.add_argument(
+ "--manifest",
+ type=Path,
+ default=Path("examples/Oracle_SQLPGQ_Instance/TEXT2GQL_GRAPH_oracle_schema.json"),
+ help="Oracle SQL/PGQ schema manifest JSON.",
+ )
+ parser.add_argument(
+ "--output",
+ type=Path,
+ default=Path("examples/generated_corpus/oracle_sqlpgq_template_corpus.json"),
+ help="Output JSON file.",
+ )
+ parser.add_argument("--target-size", type=int, default=50)
+ parser.add_argument(
+ "--without-metadata",
+ action="store_true",
+ help="Write only question/query fields.",
+ )
+ args = parser.parse_args()
+
+ instantiator = OracleSqlPgqTemplateInstantiator.from_file(args.manifest)
+ pairs = instantiator.generate_dicts(
+ target_size=args.target_size,
+ include_metadata=not args.without_metadata,
+ )
+
+ args.output.parent.mkdir(parents=True, exist_ok=True)
+ with open(args.output, "w", encoding="utf-8") as file:
+ json.dump(pairs, file, indent=2, ensure_ascii=False)
+ print(f"Saved {len(pairs)} Oracle SQL/PGQ template pairs to {args.output}")
+
+
+if __name__ == "__main__":
+ main()
+
diff --git a/examples/generated_corpus/cypher_templates_to_oracle_sqlpgq.json b/examples/generated_corpus/cypher_templates_to_oracle_sqlpgq.json
new file mode 100644
index 0000000..518102c
--- /dev/null
+++ b/examples/generated_corpus/cypher_templates_to_oracle_sqlpgq.json
@@ -0,0 +1,202 @@
+[
+ {
+ "cypher": "MATCH (n:label_1) WHERE n.prop_1 = \"value_1\" RETURN n.prop_1 LIMIT 10",
+ "oracle_sqlpgq": "SELECT *\nFROM GRAPH_TABLE (\n \"TEXT2GQL_GRAPH\" MATCH (n IS \"label_1\") WHERE n.\"prop_1\" = \"value_1\" COLUMNS (n.\"prop_1\" AS prop_1)\n) gt\nFETCH FIRST 10 ROWS ONLY",
+ "category": "Graph-IL Translatable"
+ },
+ {
+ "cypher": "MATCH (n:label_1) WHERE n.prop_1 = \"value_1\" RETURN n.id, n.prop_1 LIMIT 10",
+ "oracle_sqlpgq": "SELECT *\nFROM GRAPH_TABLE (\n \"TEXT2GQL_GRAPH\" MATCH (n IS \"label_1\") WHERE n.\"prop_1\" = \"value_1\" COLUMNS (n.\"id\" AS id, n.\"prop_1\" AS prop_1)\n) gt\nFETCH FIRST 10 ROWS ONLY",
+ "category": "Graph-IL Translatable"
+ },
+ {
+ "cypher": "MATCH (n:label_1) WHERE n.age > 20 RETURN n.name, n.age, n.email",
+ "oracle_sqlpgq": "SELECT *\nFROM GRAPH_TABLE (\n \"TEXT2GQL_GRAPH\" MATCH (n IS \"label_1\") WHERE n.\"age\" > 20 COLUMNS (n.\"name\" AS name, n.\"age\" AS age, n.\"email\" AS email)\n) gt",
+ "category": "Graph-IL Translatable"
+ },
+ {
+ "cypher": "MATCH (n:label_1) WHERE n.price >= 100 AND n.price <= 500 RETURN n.product_id, n.price",
+ "oracle_sqlpgq": "Unable to Translate to Oracle SQL/PGQ",
+ "category": "Graph-IL Not Support"
+ },
+ {
+ "cypher": "MATCH (n:label_1) WHERE n.name STARTS WITH \"A\" RETURN n.id, n.name",
+ "oracle_sqlpgq": "Unable to Translate to Oracle SQL/PGQ",
+ "category": "Graph-IL Not Support"
+ },
+ {
+ "cypher": "MATCH (n:label_1) WHERE n.description CONTAINS \"tech\" RETURN n.description LIMIT 5",
+ "oracle_sqlpgq": "Unable to Translate to Oracle SQL/PGQ",
+ "category": "Graph-IL Not Support"
+ },
+ {
+ "cypher": "MATCH (n:label_1) WHERE n.email IS NOT NULL RETURN n.id, n.email",
+ "oracle_sqlpgq": "Unable to Translate to Oracle SQL/PGQ",
+ "category": "Graph-IL Not Support"
+ },
+ {
+ "cypher": "MATCH (n:label_1) WHERE n.status IN [\"active\", \"pending\"] RETURN n.id, n.status",
+ "oracle_sqlpgq": "Unable to Translate to Oracle SQL/PGQ",
+ "category": "Graph-IL Not Support"
+ },
+ {
+ "cypher": "MATCH (a:label_1)-[r:edge_1]->(b:label_2) RETURN a.name, b.title LIMIT 20",
+ "oracle_sqlpgq": "SELECT *\nFROM GRAPH_TABLE (\n \"TEXT2GQL_GRAPH\" MATCH (a IS \"label_1\")-[r IS \"edge_1\"]->(b IS \"label_2\") COLUMNS (a.\"name\" AS name, b.\"title\" AS title)\n) gt\nFETCH FIRST 20 ROWS ONLY",
+ "category": "Graph-IL Translatable"
+ },
+ {
+ "cypher": "MATCH (a:label_1)-[:edge_1]->(b:label_2) WHERE b.status = \"active\" RETURN a.id, a.name",
+ "oracle_sqlpgq": "SELECT *\nFROM GRAPH_TABLE (\n \"TEXT2GQL_GRAPH\" MATCH (a IS \"label_1\")-[e1 IS \"edge_1\"]->(b IS \"label_2\") WHERE b.\"status\" = \"active\" COLUMNS (a.\"id\" AS id, a.\"name\" AS name)\n) gt",
+ "category": "Graph-IL Translatable"
+ },
+ {
+ "cypher": "MATCH (a:label_1)-[r:edge_1]->(b:label_2) WHERE r.weight > 0.8 RETURN a.id, r.weight, b.id",
+ "oracle_sqlpgq": "SELECT *\nFROM GRAPH_TABLE (\n \"TEXT2GQL_GRAPH\" MATCH (a IS \"label_1\")-[r IS \"edge_1\"]->(b IS \"label_2\") WHERE r.\"weight\" > 0.8 COLUMNS (a.\"id\" AS id, r.\"weight\" AS weight, b.\"id\" AS id)\n) gt",
+ "category": "Graph-IL Translatable"
+ },
+ {
+ "cypher": "MATCH (a:label_1)--(b:label_2) WHERE a.name = \"value_1\" RETURN b.id, b.name LIMIT 5",
+ "oracle_sqlpgq": "SELECT *\nFROM GRAPH_TABLE (\n \"TEXT2GQL_GRAPH\" MATCH (a IS \"label_1\")-[e1]-(b IS \"label_2\") WHERE a.\"name\" = \"value_1\" COLUMNS (b.\"id\" AS id, b.\"name\" AS name)\n) gt\nFETCH FIRST 5 ROWS ONLY",
+ "category": "Graph-IL Translatable"
+ },
+ {
+ "cypher": "MATCH (n:label_1) WHERE n.type = \"A\" OR n.type = \"B\" RETURN n.id, n.type",
+ "oracle_sqlpgq": "Unable to Translate to Oracle SQL/PGQ",
+ "category": "Graph-IL Not Support"
+ },
+ {
+ "cypher": "MATCH (n:label_1) WHERE NOT n.status = \"banned\" RETURN n.id, n.status LIMIT 10",
+ "oracle_sqlpgq": "SELECT *\nFROM GRAPH_TABLE (\n \"TEXT2GQL_GRAPH\" MATCH (n IS \"label_1\") WHERE n.\"status\" = \"banned\" COLUMNS (n.\"id\" AS id, n.\"status\" AS status)\n) gt\nFETCH FIRST 10 ROWS ONLY",
+ "category": "Graph-IL Translatable"
+ },
+ {
+ "cypher": "MATCH (n:label_1) WHERE n.prop_1 = \"value_1\" RETURN count(n.id)",
+ "oracle_sqlpgq": "SELECT *\nFROM GRAPH_TABLE (\n \"TEXT2GQL_GRAPH\" MATCH (n IS \"label_1\") WHERE n.\"prop_1\" = \"value_1\" COLUMNS (COUNT(n.\"id\") AS id)\n) gt",
+ "category": "Graph-IL Translatable"
+ },
+ {
+ "cypher": "MATCH (n:label_1) WHERE n.score > 60 RETURN n.name, n.score ORDER BY n.score DESC LIMIT 5",
+ "oracle_sqlpgq": "SELECT *\nFROM GRAPH_TABLE (\n \"TEXT2GQL_GRAPH\" MATCH (n IS \"label_1\") WHERE n.\"score\" > 60 COLUMNS (n.\"name\" AS name, n.\"score\" AS score)\n) gt\nORDER BY score DESC\nFETCH FIRST 5 ROWS ONLY",
+ "category": "Graph-IL Translatable"
+ },
+ {
+ "cypher": "MATCH (n:label_1) RETURN DISTINCT n.category",
+ "oracle_sqlpgq": "SELECT DISTINCT *\nFROM GRAPH_TABLE (\n \"TEXT2GQL_GRAPH\" MATCH (n IS \"label_1\") COLUMNS (n.\"category\" AS category)\n) gt",
+ "category": "Graph-IL Translatable"
+ },
+ {
+ "cypher": "MATCH (n:label_1) WHERE n.prop_1 = \"value_1\" RETURN n.prop_1 LIMIT 1",
+ "oracle_sqlpgq": "SELECT *\nFROM GRAPH_TABLE (\n \"TEXT2GQL_GRAPH\" MATCH (n IS \"label_1\") WHERE n.\"prop_1\" = \"value_1\" COLUMNS (n.\"prop_1\" AS prop_1)\n) gt\nFETCH FIRST 1 ROWS ONLY",
+ "category": "Graph-IL Translatable"
+ },
+ {
+ "cypher": "MATCH p = (n1:label_1)-[e1]-(x)-[e2]-(n2:label_1) WHERE n1.prop_1 = \"value_1\" AND n2.prop_1 <> \"value_1\" RETURN p LIMIT 1",
+ "oracle_sqlpgq": "Unable to Translate to Oracle SQL/PGQ",
+ "category": "Graph-IL Not Support"
+ },
+ {
+ "cypher": "MATCH (n1:label_1)-[e*1..3]-(n2:label_2) WHERE n1.prop_1 = \"value_1\" RETURN n1, n2 LIMIT 5",
+ "oracle_sqlpgq": "SELECT *\nFROM GRAPH_TABLE (\n \"TEXT2GQL_GRAPH\" MATCH (n1 IS \"label_1\")-[e]-{1,3}(n2 IS \"label_2\") WHERE n1.\"prop_1\" = \"value_1\" COLUMNS (VERTEX_ID(n1) AS n1_VALUE, VERTEX_ID(n2) AS n2_VALUE)\n) gt\nFETCH FIRST 5 ROWS ONLY",
+ "category": "Graph-IL Translatable"
+ },
+ {
+ "cypher": "MATCH (n1:label_1)-[e*1..3]-(n2:label_2) WHERE n1.prop_1 = \"value_1\" RETURN n1.prop_1, n2.prop_1 LIMIT 5",
+ "oracle_sqlpgq": "SELECT *\nFROM GRAPH_TABLE (\n \"TEXT2GQL_GRAPH\" MATCH (n1 IS \"label_1\")-[e]-{1,3}(n2 IS \"label_2\") WHERE n1.\"prop_1\" = \"value_1\" COLUMNS (n1.\"prop_1\" AS prop_1, n2.\"prop_1\" AS prop_1)\n) gt\nFETCH FIRST 5 ROWS ONLY",
+ "category": "Graph-IL Translatable"
+ },
+ {
+ "cypher": "MATCH (n1:label_1)-[*1..3]->(n2:label_2) WHERE n1.prop_1 = \"value_1\" RETURN n2.id, n2.name LIMIT 5",
+ "oracle_sqlpgq": "SELECT *\nFROM GRAPH_TABLE (\n \"TEXT2GQL_GRAPH\" MATCH (n1 IS \"label_1\")-[e1]->{1,3}(n2 IS \"label_2\") WHERE n1.\"prop_1\" = \"value_1\" COLUMNS (n2.\"id\" AS id, n2.\"name\" AS name)\n) gt\nFETCH FIRST 5 ROWS ONLY",
+ "category": "Graph-IL Translatable"
+ },
+ {
+ "cypher": "MATCH (n1:label_1)-[*2..5]->(n2:label_2) WHERE n1.id = \"start_node\" RETURN n2.name LIMIT 5",
+ "oracle_sqlpgq": "SELECT *\nFROM GRAPH_TABLE (\n \"TEXT2GQL_GRAPH\" MATCH (n1 IS \"label_1\")-[e1]->{2,5}(n2 IS \"label_2\") WHERE n1.\"id\" = \"start_node\" COLUMNS (n2.\"name\" AS name)\n) gt\nFETCH FIRST 5 ROWS ONLY",
+ "category": "Graph-IL Translatable"
+ },
+ {
+ "cypher": "MATCH (n1:label_1)-[*3]->(n2:label_2) WHERE n1.prop_1 = \"value_1\" RETURN n2.id, n2.prop_2 LIMIT 5",
+ "oracle_sqlpgq": "SELECT *\nFROM GRAPH_TABLE (\n \"TEXT2GQL_GRAPH\" MATCH (n1 IS \"label_1\")-[e1]->{3}(n2 IS \"label_2\") WHERE n1.\"prop_1\" = \"value_1\" COLUMNS (n2.\"id\" AS id, n2.\"prop_2\" AS prop_2)\n) gt\nFETCH FIRST 5 ROWS ONLY",
+ "category": "Graph-IL Translatable"
+ },
+ {
+ "cypher": "MATCH (n1:label_1)-[:edge_1*1..3]->(n2:label_2) WHERE n1.name = \"Alice\" RETURN n2.name LIMIT 5",
+ "oracle_sqlpgq": "SELECT *\nFROM GRAPH_TABLE (\n \"TEXT2GQL_GRAPH\" MATCH (n1 IS \"label_1\")-[e1 IS \"edge_1\"]->{1,3}(n2 IS \"label_2\") WHERE n1.\"name\" = \"Alice\" COLUMNS (n2.\"name\" AS name)\n) gt\nFETCH FIRST 5 ROWS ONLY",
+ "category": "Graph-IL Translatable"
+ },
+ {
+ "cypher": "MATCH (n1:label_1)-[:edge_1|edge_2*1..4]->(n2:label_2) RETURN n1.id, n2.id LIMIT 10",
+ "oracle_sqlpgq": "SELECT *\nFROM GRAPH_TABLE (\n \"TEXT2GQL_GRAPH\" MATCH (n1 IS \"label_1\")-[e1 IS \"edge_1\"]->{1,4}(n2 IS \"label_2\") COLUMNS (n1.\"id\" AS id, n2.\"id\" AS id)\n) gt\nFETCH FIRST 10 ROWS ONLY",
+ "category": "Graph-IL Translatable"
+ },
+ {
+ "cypher": "MATCH (n1:label_1)-[*1..3]->(n2:label_2) WHERE n1.status = \"active\" AND n2.age > 30 RETURN n1.id, n2.id LIMIT 5",
+ "oracle_sqlpgq": "Unable to Translate to Oracle SQL/PGQ",
+ "category": "Graph-IL Not Support"
+ },
+ {
+ "cypher": "MATCH (n1:label_1)-[*1..2]-(n2:label_2) WHERE n1.id = \"user_123\" RETURN n2.id, n2.name LIMIT 5",
+ "oracle_sqlpgq": "SELECT *\nFROM GRAPH_TABLE (\n \"TEXT2GQL_GRAPH\" MATCH (n1 IS \"label_1\")-[e1]-{1,2}(n2 IS \"label_2\") WHERE n1.\"id\" = \"user_123\" COLUMNS (n2.\"id\" AS id, n2.\"name\" AS name)\n) gt\nFETCH FIRST 5 ROWS ONLY",
+ "category": "Graph-IL Translatable"
+ },
+ {
+ "cypher": "MATCH (n1:label_1)-[*1..3]->(n2:label_2) WHERE n1.prop_1 = \"root\" RETURN count(n2.id) LIMIT 5",
+ "oracle_sqlpgq": "SELECT *\nFROM GRAPH_TABLE (\n \"TEXT2GQL_GRAPH\" MATCH (n1 IS \"label_1\")-[e1]->{1,3}(n2 IS \"label_2\") WHERE n1.\"prop_1\" = \"root\" COLUMNS (COUNT(n2.\"id\") AS id)\n) gt\nFETCH FIRST 5 ROWS ONLY",
+ "category": "Graph-IL Translatable"
+ },
+ {
+ "cypher": "MATCH (n1:label_1)-[*1..3]->(n2:label_2) WHERE n1.prop_1 = \"root\" RETURN count(DISTINCT n2.id) LIMIT 5",
+ "oracle_sqlpgq": "SELECT *\nFROM GRAPH_TABLE (\n \"TEXT2GQL_GRAPH\" MATCH (n1 IS \"label_1\")-[e1]->{1,3}(n2 IS \"label_2\") WHERE n1.\"prop_1\" = \"root\" COLUMNS (COUNT(n2.\"id\") AS id)\n) gt\nFETCH FIRST 5 ROWS ONLY",
+ "category": "Graph-IL Translatable"
+ },
+ {
+ "cypher": "MATCH (n1:label_1)-[e:label_2]->(n2:label_3) WHERE n1.prop_1 = \"value_1\" RETURN n1, e, n2 LIMIT 10",
+ "oracle_sqlpgq": "SELECT *\nFROM GRAPH_TABLE (\n \"TEXT2GQL_GRAPH\" MATCH (n1 IS \"label_1\")-[e IS \"label_2\"]->(n2 IS \"label_3\") WHERE n1.\"prop_1\" = \"value_1\" COLUMNS (VERTEX_ID(n1) AS n1_VALUE, EDGE_ID(e) AS e_VALUE, VERTEX_ID(n2) AS n2_VALUE)\n) gt\nFETCH FIRST 10 ROWS ONLY",
+ "category": "Graph-IL Translatable"
+ },
+ {
+ "cypher": "MATCH p=(n1:label_1)-[e:label_2*1..3]->(n2:label_1) WHERE n1.prop_1 = \"value_1\" AND n2.prop_1 = \"value_3\" RETURN n2.prop_1 LIMIT 5",
+ "oracle_sqlpgq": "Unable to Translate to Oracle SQL/PGQ",
+ "category": "Graph-IL Not Support"
+ },
+ {
+ "cypher": "MATCH p = (n1:label_1)-[e1]-(x)-[e2]-(n2:label_1) WHERE n1.prop_1 = \"value_1\" AND n2.prop_1 <> \"value_1\" RETURN p LIMIT 5 UNION ALL MATCH p = (n1:label_1)-[e1]-(x)-[e2]-(y)-[e3]-(n2:label_1) WHERE n1.prop_1 = \"value_1\" AND n2.prop_1 <> \"value_1\" RETURN p LIMIT 5",
+ "oracle_sqlpgq": "Unable to Translate to Oracle SQL/PGQ",
+ "category": "Graph-IL Not Support"
+ },
+ {
+ "cypher": "MATCH (n1:label_1)-[e:label_2|label_3]->(n2:label_4) RETURN e, n2 LIMIT 5",
+ "oracle_sqlpgq": "SELECT *\nFROM GRAPH_TABLE (\n \"TEXT2GQL_GRAPH\" MATCH (n1 IS \"label_1\")-[e IS \"label_2\"]->(n2 IS \"label_4\") COLUMNS (EDGE_ID(e) AS e_VALUE, VERTEX_ID(n2) AS n2_VALUE)\n) gt\nFETCH FIRST 5 ROWS ONLY",
+ "category": "Graph-IL Translatable"
+ },
+ {
+ "cypher": "MATCH (n2:label_4)<-[e4:type_edge]-(n4:type_node) WHERE n4.name = \"Alice\" RETURN COUNT(e4), COUNT(n4), e4.type ORDER BY COUNT(e4) DESC SKIP 0 LIMIT 5",
+ "oracle_sqlpgq": "SELECT *\nFROM GRAPH_TABLE (\n \"TEXT2GQL_GRAPH\" MATCH (n2 IS \"label_4\")<-[e4 IS \"type_edge\"]-(n4 IS \"type_node\") WHERE n4.\"name\" = \"Alice\" COLUMNS (COUNT(EDGE_ID(e4)) AS e4_VALUE, COUNT(VERTEX_ID(n4)) AS n4_VALUE, e4.\"type\" AS type)\n) gt\nORDER BY e4_VALUE DESC\nOFFSET 0 ROWS\nFETCH FIRST 5 ROWS ONLY",
+ "category": "Graph-IL Translatable"
+ },
+ {
+ "cypher": "MATCH (n1:label_1)-[]->(n2:label_2) WHERE n1.prop_1 = \"value_1\" RETURN COUNT(DISTINCT n2) as count_alias",
+ "oracle_sqlpgq": "SELECT *\nFROM GRAPH_TABLE (\n \"TEXT2GQL_GRAPH\" MATCH (n1 IS \"label_1\")-[e1]->(n2 IS \"label_2\") WHERE n1.\"prop_1\" = \"value_1\" COLUMNS (COUNT(VERTEX_ID(n2)) AS count_alias)\n) gt",
+ "category": "Graph-IL Translatable"
+ },
+ {
+ "cypher": "MATCH p = shortestPath((a:label_1)-[:label_2*1..10]->(b:label_3)) WHERE a.prop_1 = \"value_1\" RETURN p, length(p) AS depth ORDER BY depth ASC LIMIT 5",
+ "oracle_sqlpgq": "Unable to Translate to Oracle SQL/PGQ",
+ "category": "Not Comply with OpenCypher"
+ },
+ {
+ "cypher": "MATCH (a:label_1)-[:label_2]->(b:label_3) WHERE a.prop_1 = \"value_1\" OPTIONAL MATCH (b)<-[:label_4]-(c:label_5) OPTIONAL MATCH (b)<-[:label_6]-(d:label_7) RETURN DISTINCT c.prop_2, d.prop_3 LIMIT 10",
+ "oracle_sqlpgq": "SELECT DISTINCT *\nFROM GRAPH_TABLE (\n \"TEXT2GQL_GRAPH\" MATCH (a IS \"label_1\")-[e1 IS \"label_2\"]->(b IS \"label_3\"), (b)<-[e2 IS \"label_4\"]-(c IS \"label_5\"), (b)<-[e3 IS \"label_6\"]-(d IS \"label_7\") WHERE a.\"prop_1\" = \"value_1\" COLUMNS (c.\"prop_2\" AS prop_2, d.\"prop_3\" AS prop_3)\n) gt\nFETCH FIRST 10 ROWS ONLY",
+ "category": "Graph-IL Translatable"
+ },
+ {
+ "cypher": "MATCH (nodes_1:label_1) WHERE nodes_1.prop_1 = \"value_1\" WITH nodes_1 MATCH (nodes_1)-[edges_1:label_2]->(nodes_2:label_3) RETURN nodes_2, nodes_1, edges_1 LIMIT 10",
+ "oracle_sqlpgq": "SELECT *\nFROM GRAPH_TABLE (\n \"TEXT2GQL_GRAPH\" MATCH (nodes_1)-[edges_1 IS \"label_2\"]->(nodes_2 IS \"label_3\") COLUMNS (VERTEX_ID(nodes_2) AS nodes_2_VALUE, VERTEX_ID(nodes_1) AS nodes_1_VALUE, EDGE_ID(edges_1) AS edges_1_VALUE)\n) gt\nFETCH FIRST 10 ROWS ONLY",
+ "category": "Graph-IL Translatable"
+ },
+ {
+ "cypher": "MATCH (a:label_1 {prop_1: \"value_1\"})-[e:label_2]->(n:label_3) WHERE e.prop_2 = \"value_2\" AND e.prop_4 > \"value_3\" RETURN n, e.prop_5 LIMIT 10",
+ "oracle_sqlpgq": "Unable to Translate to Oracle SQL/PGQ",
+ "category": "Graph-IL Not Support"
+ }
+]
\ No newline at end of file
diff --git a/examples/setup_oracle_sqlpgq_example_db.py b/examples/setup_oracle_sqlpgq_example_db.py
new file mode 100644
index 0000000..f0f2739
--- /dev/null
+++ b/examples/setup_oracle_sqlpgq_example_db.py
@@ -0,0 +1,544 @@
+"""
+Create and populate the example Oracle SQL/PGQ graph used for corpus generation.
+
+Required environment variables:
+- ORACLE_DSN
+- ORACLE_USER
+- ORACLE_PASSWORD
+
+Size can be controlled with CLI flags or matching environment variables, for example:
+
+ORACLE_SQLPGQ_USERS=100 ORACLE_SQLPGQ_MOVIES=500 ORACLE_SQLPGQ_RATINGS=2000 \
+ python examples/setup_oracle_sqlpgq_example_db.py
+"""
+
+import argparse
+from dataclasses import dataclass
+from datetime import datetime, timedelta
+import os
+from pathlib import Path
+import random
+from typing import Any
+
+from app.core.validator.db_client import QueryStatus
+from app.impl.oracle_sqlpgq.db_client.oracle_db_client import OracleDBClient
+
+
+GRAPH_NAME = "TEXT2GQL_GRAPH"
+ARTIFACT_DIR = Path("examples/Oracle_SQLPGQ_Instance")
+TABLE_DDL_PATH = ARTIFACT_DIR / "TEXT2GQL_GRAPH_oracle_tables.sql"
+GRAPH_DDL_PATH = ARTIFACT_DIR / "TEXT2GQL_GRAPH_oracle_property_graph.sql"
+Row = tuple[Any, ...]
+
+TABLE_DROP_ORDER = [
+ "USER_FRIENDS_WITH_USER",
+ "MOVIE_SIMILAR_TO_MOVIE",
+ "TAG_TAGS_MOVIE",
+ "USER_TAGS_TAG",
+ "MOVIE_BELONGS_TO_GENRE",
+ "RATING_RATES_MOVIE",
+ "USER_RATES_RATING",
+ "TAG",
+ "RATING",
+ "GENRE",
+ "MOVIE",
+ "USER",
+]
+
+BASE_MOVIES = [
+ ("Inception", 2010, 148, "A thief enters dreams to steal secrets.", "English"),
+ ("The Matrix", 1999, 136, "A hacker discovers a simulated reality.", "English"),
+ ("Arrival", 2016, 116, "A linguist communicates with alien visitors.", "English"),
+ ("Spirited Away", 2001, 125, "A girl enters a world of spirits.", "Japanese"),
+]
+BASE_GENRES = [
+ ("Science Fiction", "Speculative stories about science and technology"),
+ ("Action", "Fast-paced stories with intense conflict"),
+ ("Drama", "Character-driven serious stories"),
+ ("Animation", "Animated feature films"),
+ ("Comedy", "Humorous stories and situations"),
+ ("Thriller", "Suspenseful stories with tension"),
+ ("Fantasy", "Stories with magic or imaginary worlds"),
+ ("Documentary", "Non-fiction film storytelling"),
+ ("Romance", "Stories centered on relationships"),
+ ("Horror", "Stories intended to frighten or unsettle"),
+ ("Adventure", "Journeys, exploration, and discovery"),
+]
+BASE_TAGS = [
+ "mind-bending",
+ "classic",
+ "time travel",
+ "animated",
+ "Sci-Fi",
+ "action-packed",
+ "thoughtful",
+ "dystopian",
+ "alien contact",
+ "dream logic",
+ "cyberpunk",
+ "family",
+ "award-winning",
+ "cult favorite",
+ "visual effects",
+ "slow burn",
+]
+
+
+@dataclass(frozen=True)
+class DataConfig:
+ users: int
+ movies: int
+ genres: int
+ tags: int
+ ratings: int
+ genre_edges: int
+ tag_edges: int
+ friend_edges: int
+ similarity_edges: int
+ seed: int
+ batch_size: int
+
+
+def env_int(name: str, default: int) -> int:
+ return int(os.getenv(name, str(default)))
+
+
+def env_int_optional(name: str) -> int | None:
+ value = os.getenv(name)
+ return int(value) if value not in (None, "") else None
+
+
+def positive_int(value: str) -> int:
+ parsed = int(value)
+ if parsed <= 0:
+ raise argparse.ArgumentTypeError("value must be positive")
+ return parsed
+
+
+def parse_args() -> DataConfig:
+ default_users = env_int("ORACLE_SQLPGQ_USERS", 4)
+ default_movies = env_int("ORACLE_SQLPGQ_MOVIES", 4)
+ default_genres = env_int("ORACLE_SQLPGQ_GENRES", 4)
+ default_tags = env_int("ORACLE_SQLPGQ_TAGS", 6)
+ default_ratings = env_int("ORACLE_SQLPGQ_RATINGS", 4)
+ default_genre_edges = env_int_optional("ORACLE_SQLPGQ_GENRE_EDGES")
+ default_tag_edges = env_int_optional("ORACLE_SQLPGQ_TAG_EDGES")
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--users", type=positive_int, default=default_users)
+ parser.add_argument("--movies", type=positive_int, default=default_movies)
+ parser.add_argument("--genres", type=positive_int, default=default_genres)
+ parser.add_argument("--tags", type=positive_int, default=default_tags)
+ parser.add_argument("--ratings", type=positive_int, default=default_ratings)
+ parser.add_argument(
+ "--genre-edges",
+ type=positive_int,
+ default=default_genre_edges,
+ help="MOVIE->GENRE BELONGS_TO edges; default is max(movies + 1, 5)",
+ )
+ parser.add_argument(
+ "--tag-edges",
+ type=positive_int,
+ default=default_tag_edges,
+ help="USER->TAG and TAG->MOVIE edge pairs; default is max(tags, 6)",
+ )
+ parser.add_argument(
+ "--friend-edges",
+ type=positive_int,
+ default=env_int("ORACLE_SQLPGQ_FRIEND_EDGES", 3),
+ )
+ parser.add_argument(
+ "--similarity-edges",
+ type=positive_int,
+ default=env_int("ORACLE_SQLPGQ_SIMILARITY_EDGES", 3),
+ )
+ parser.add_argument("--seed", type=int, default=env_int("ORACLE_SQLPGQ_DATA_SEED", 7))
+ parser.add_argument(
+ "--batch-size",
+ type=positive_int,
+ default=env_int("ORACLE_SQLPGQ_INSERT_BATCH_SIZE", 1000),
+ )
+ args = parser.parse_args()
+
+ config = DataConfig(
+ users=args.users,
+ movies=args.movies,
+ genres=args.genres,
+ tags=args.tags,
+ ratings=args.ratings,
+ genre_edges=args.genre_edges or max(args.movies + 1, 5),
+ tag_edges=args.tag_edges or max(args.tags, 6),
+ friend_edges=args.friend_edges,
+ similarity_edges=args.similarity_edges,
+ seed=args.seed,
+ batch_size=args.batch_size,
+ )
+ validate_config(config)
+ return config
+
+
+def validate_config(config: DataConfig) -> None:
+ minimums = {
+ "users": 4,
+ "movies": 4,
+ "genres": 4,
+ "tags": 6,
+ "ratings": 4,
+ "genre_edges": 5,
+ "tag_edges": 6,
+ "friend_edges": 3,
+ "similarity_edges": 3,
+ }
+ values = config.__dict__
+ too_small = [
+ f"{name}>={minimum} (got {values[name]})"
+ for name, minimum in minimums.items()
+ if values[name] < minimum
+ ]
+ if too_small:
+ raise ValueError(
+ "The example corpus seed queries require these minimum sizes: "
+ + ", ".join(too_small)
+ )
+
+
+def timestamp(month: int, day_offset: int, hour: int = 9) -> datetime:
+ return datetime(2025, month, 1, hour, 0, 0) + timedelta(days=day_offset)
+
+
+def safe_execute(client: OracleDBClient, sql: str, ignore_errors: tuple[str, ...] = ()) -> None:
+ result = client.execute_query(sql)
+ if result.status_code == QueryStatus.SUCCESS:
+ return
+ if result.error and any(token in result.error for token in ignore_errors):
+ return
+ raise RuntimeError(f"Failed SQL:\n{sql}\n\nError:\n{result.error}")
+
+
+def reset_objects(client: OracleDBClient) -> None:
+ safe_execute(
+ client,
+ f'DROP PROPERTY GRAPH "{GRAPH_NAME}"',
+ ignore_errors=("ORA-42421", "ORA-04043", "ORA-00942"),
+ )
+ for table_name in TABLE_DROP_ORDER:
+ safe_execute(
+ client,
+ f'DROP TABLE "{table_name}" CASCADE CONSTRAINTS PURGE',
+ ignore_errors=("ORA-00942",),
+ )
+
+
+def run_script(client: OracleDBClient, script_path: Path) -> None:
+ for result in client.execute_script(script_path.read_text(encoding="utf-8")):
+ if result.status_code != QueryStatus.SUCCESS:
+ raise RuntimeError(f"Failed script {script_path}: {result.error}")
+
+
+def insert_rows(
+ client: OracleDBClient,
+ statement: str,
+ rows: list[Row],
+ label: str,
+ batch_size: int,
+) -> None:
+ for start in range(0, len(rows), batch_size):
+ result = client.executemany(statement, rows[start : start + batch_size])
+ if result.status_code != QueryStatus.SUCCESS:
+ raise RuntimeError(f"Failed loading {label}: {result.error}")
+ print(f"Loaded {len(rows)} rows into {label}")
+
+
+def generate_users(config: DataConfig) -> list[Row]:
+ base = [
+ (1, 34, "F", "data scientist", "10001", datetime(2024, 1, 10, 9, 0, 0)),
+ (2, 41, "M", "teacher", "94105", datetime(2024, 2, 12, 12, 30, 0)),
+ (3, 28, "F", "designer", "60601", datetime(2024, 3, 3, 16, 45, 0)),
+ (4, 52, "M", "Engineer", "02139", datetime(2024, 4, 22, 8, 15, 0)),
+ ]
+ occupations = ["analyst", "writer", "doctor", "artist", "developer", "researcher"]
+ rows = base[: config.users]
+ for user_id in range(len(rows) + 1, config.users + 1):
+ rows.append(
+ (
+ user_id,
+ 18 + (user_id * 7) % 55,
+ "F" if user_id % 2 else "M",
+ occupations[user_id % len(occupations)],
+ f"{10000 + user_id:05d}",
+ datetime(2024, 1, 1, 8, 0, 0) + timedelta(days=user_id),
+ )
+ )
+ return rows
+
+
+def generate_movies(config: DataConfig) -> list[Row]:
+ adjectives = ["Hidden", "Silent", "Neon", "Lost", "Quantum", "Midnight", "Parallel"]
+ nouns = ["Signal", "Journey", "Archive", "Horizon", "City", "Memory", "Protocol"]
+ rows = [
+ (10 + index, title, year, duration, summary, language)
+ for index, (title, year, duration, summary, language) in enumerate(BASE_MOVIES)
+ ][: config.movies]
+ for index in range(len(rows), config.movies):
+ title = f"{adjectives[index % len(adjectives)]} {nouns[index % len(nouns)]} {index + 1}"
+ rows.append(
+ (
+ 10 + index,
+ title,
+ 1980 + (index * 3) % 45,
+ 85 + (index * 11) % 70,
+ f"Synthetic plot summary for {title}.",
+ "English" if index % 5 else "Spanish",
+ )
+ )
+ return rows
+
+
+def generate_genres(config: DataConfig) -> list[Row]:
+ rows = [
+ (20 + index, name, description)
+ for index, (name, description) in enumerate(BASE_GENRES)
+ ][: config.genres]
+ for index in range(len(rows), config.genres):
+ rows.append((20 + index, f"Genre {index + 1}", "Synthetic generated genre"))
+ return rows
+
+
+def generate_tags(config: DataConfig) -> list[Row]:
+ rows = [
+ (200 + index, tag, timestamp(2, index), round(0.96 - index * 0.01, 3))
+ for index, tag in enumerate(BASE_TAGS)
+ ][: config.tags]
+ for index in range(len(rows), config.tags):
+ rows.append((200 + index, f"tag-{index + 1}", timestamp(2, index), 0.5))
+ return rows
+
+
+def generate_ratings(config: DataConfig) -> tuple[list[Row], list[Row], list[Row]]:
+ score_cycle = [4.8, 4.6, 4.4, 4.9, 3.9, 4.2, 4.7, 3.5, 4.1, 4.0]
+ review_cycle = [
+ "Inventive and visually impressive",
+ "A defining cyberpunk classic",
+ "Quiet and thoughtful science fiction",
+ "Beautiful and imaginative",
+ "Good pacing and strong performances",
+ ]
+ ratings = []
+ user_rating_edges = []
+ rating_movie_edges = []
+ for index in range(config.ratings):
+ rating_id = 100 + index
+ user_id = 1 + index % config.users
+ movie_id = 10 + index % config.movies
+ score = score_cycle[index % len(score_cycle)]
+ rated_at = timestamp(1, index, hour=10 + index % 10)
+ ratings.append((rating_id, score, rated_at, review_cycle[index % len(review_cycle)]))
+ user_rating_edges.append(
+ (user_id, rating_id, rated_at, score, round(0.75 + (index % 20) / 100, 3))
+ )
+ rating_movie_edges.append(
+ (rating_id, movie_id, rated_at, score, round(0.75 + (index % 20) / 100, 3))
+ )
+ return ratings, user_rating_edges, rating_movie_edges
+
+
+def generate_genre_edges(config: DataConfig, rng: random.Random) -> list[Row]:
+ rows = [
+ (10, 20, 1, 0.99),
+ (11, 20, 1, 0.94),
+ (11, 21, 0, 0.88),
+ (12, 22, 1, 0.91),
+ (13, 23, 1, 0.99),
+ ]
+ for _ in range(len(rows), config.genre_edges):
+ rows.append(
+ (
+ 10 + rng.randrange(config.movies),
+ 20 + rng.randrange(config.genres),
+ 1 if rng.random() > 0.35 else 0,
+ round(0.55 + rng.random() * 0.44, 3),
+ )
+ )
+ return rows
+
+
+def generate_tag_edges(config: DataConfig, rng: random.Random) -> tuple[list[Row], list[Row]]:
+ base = [
+ (1, 200, 10),
+ (2, 201, 11),
+ (3, 202, 12),
+ (4, 203, 13),
+ (1, 204, 11),
+ (2, 205, 11),
+ ]
+ user_tag_edges = []
+ tag_movie_edges = []
+ for index in range(config.tag_edges):
+ if index < len(base):
+ user_id, tag_id, movie_id = base[index]
+ else:
+ user_id = 1 + rng.randrange(config.users)
+ tag_id = 200 + rng.randrange(config.tags)
+ movie_id = 10 + rng.randrange(config.movies)
+ tagged_at = timestamp(2, index, hour=9 + index % 8)
+ confidence = round(0.6 + rng.random() * 0.39, 3)
+ visibility = 1 if rng.random() > 0.2 else 0
+ user_tag_edges.append((user_id, tag_id, tagged_at, confidence, visibility))
+ tag_movie_edges.append((tag_id, movie_id, tagged_at, confidence, visibility))
+ return user_tag_edges, tag_movie_edges
+
+
+def generate_similarity_edges(config: DataConfig, rng: random.Random) -> list[Row]:
+ rows = [
+ (10, 11, 0.87, "v1", timestamp(3, 0, hour=0)),
+ (10, 12, 0.76, "v1", timestamp(3, 0, hour=0)),
+ (11, 10, 0.87, "v1", timestamp(3, 0, hour=0)),
+ ]
+ for index in range(len(rows), config.similarity_edges):
+ src = 10 + rng.randrange(config.movies)
+ dst = 10 + rng.randrange(config.movies)
+ if config.movies > 1:
+ while dst == src:
+ dst = 10 + rng.randrange(config.movies)
+ rows.append((src, dst, round(0.45 + rng.random() * 0.5, 3), "v1", timestamp(3, index)))
+ return rows
+
+
+def generate_friend_edges(config: DataConfig, rng: random.Random) -> list[Row]:
+ rows = [
+ (1, 2, 0.72, 3, timestamp(3, 9, hour=18)),
+ (2, 3, 0.64, 2, timestamp(3, 10, hour=18)),
+ (3, 4, 0.68, 4, timestamp(3, 11, hour=18)),
+ ]
+ for index in range(len(rows), config.friend_edges):
+ src = 1 + rng.randrange(config.users)
+ dst = 1 + rng.randrange(config.users)
+ if config.users > 1:
+ while dst == src:
+ dst = 1 + rng.randrange(config.users)
+ rows.append(
+ (
+ src,
+ dst,
+ round(0.3 + rng.random() * 0.69, 3),
+ 1 + rng.randrange(8),
+ timestamp(3, 12 + index, hour=18),
+ )
+ )
+ return rows
+
+
+def load_data(client: OracleDBClient, config: DataConfig) -> None:
+ rng = random.Random(config.seed)
+ users = generate_users(config)
+ movies = generate_movies(config)
+ genres = generate_genres(config)
+ tags = generate_tags(config)
+ ratings, user_rating_edges, rating_movie_edges = generate_ratings(config)
+ genre_edges = generate_genre_edges(config, rng)
+ user_tag_edges, tag_movie_edges = generate_tag_edges(config, rng)
+ similarity_edges = generate_similarity_edges(config, rng)
+ friend_edges = generate_friend_edges(config, rng)
+
+ inserts: list[tuple[str, str, list[Row]]] = [
+ (
+ "USER",
+ 'INSERT INTO "USER" ("USER_id", "age", "gender", "occupation", "zip_code", "registration_date") VALUES (:1, :2, :3, :4, :5, :6)',
+ users,
+ ),
+ (
+ "MOVIE",
+ 'INSERT INTO "MOVIE" ("MOVIE_id", "title", "release_year", "duration", "plot_summary", "language") VALUES (:1, :2, :3, :4, :5, :6)',
+ movies,
+ ),
+ (
+ "GENRE",
+ 'INSERT INTO "GENRE" ("GENRE_id", "name", "description") VALUES (:1, :2, :3)',
+ genres,
+ ),
+ (
+ "RATING",
+ 'INSERT INTO "RATING" ("RATING_id", "score", "timestamp", "review_text") VALUES (:1, :2, :3, :4)',
+ ratings,
+ ),
+ (
+ "TAG",
+ 'INSERT INTO "TAG" ("TAG_id", "text_content", "timestamp", "relevance_score") VALUES (:1, :2, :3, :4)',
+ tags,
+ ),
+ (
+ "USER_RATES_RATING",
+ 'INSERT INTO "USER_RATES_RATING" ("SRC_ID", "DST_ID", "rating_timestamp", "rating_value", "confidence_weight") VALUES (:1, :2, :3, :4, :5)',
+ user_rating_edges,
+ ),
+ (
+ "RATING_RATES_MOVIE",
+ 'INSERT INTO "RATING_RATES_MOVIE" ("SRC_ID", "DST_ID", "rating_timestamp", "rating_value", "confidence_weight") VALUES (:1, :2, :3, :4, :5)',
+ rating_movie_edges,
+ ),
+ (
+ "MOVIE_BELONGS_TO_GENRE",
+ 'INSERT INTO "MOVIE_BELONGS_TO_GENRE" ("SRC_ID", "DST_ID", "primary_classification", "strength_score") VALUES (:1, :2, :3, :4)',
+ genre_edges,
+ ),
+ (
+ "USER_TAGS_TAG",
+ 'INSERT INTO "USER_TAGS_TAG" ("SRC_ID", "DST_ID", "tag_timestamp", "user_confidence", "public_visibility") VALUES (:1, :2, :3, :4, :5)',
+ user_tag_edges,
+ ),
+ (
+ "TAG_TAGS_MOVIE",
+ 'INSERT INTO "TAG_TAGS_MOVIE" ("SRC_ID", "DST_ID", "tag_timestamp", "user_confidence", "public_visibility") VALUES (:1, :2, :3, :4, :5)',
+ tag_movie_edges,
+ ),
+ (
+ "MOVIE_SIMILAR_TO_MOVIE",
+ 'INSERT INTO "MOVIE_SIMILAR_TO_MOVIE" ("SRC_ID", "DST_ID", "similarity_score", "algorithm_version", "update_timestamp") VALUES (:1, :2, :3, :4, :5)',
+ similarity_edges,
+ ),
+ (
+ "USER_FRIENDS_WITH_USER",
+ 'INSERT INTO "USER_FRIENDS_WITH_USER" ("SRC_ID", "DST_ID", "connection_strength", "mutual_interests_count", "last_interaction") VALUES (:1, :2, :3, :4, :5)',
+ friend_edges,
+ ),
+ ]
+ for label, statement, rows in inserts:
+ insert_rows(client, statement, rows, label, config.batch_size)
+ safe_execute(client, "COMMIT")
+
+
+def print_config(config: DataConfig) -> None:
+ print("Data size configuration:")
+ for name, value in config.__dict__.items():
+ print(f" {name}: {value}")
+
+
+def main() -> None:
+ config = parse_args()
+ client = OracleDBClient(
+ {
+ "dsn": os.environ["ORACLE_DSN"],
+ "user": os.environ["ORACLE_USER"],
+ "password": os.environ["ORACLE_PASSWORD"],
+ }
+ )
+ if not client.connection:
+ raise RuntimeError("Oracle connection failed.")
+
+ try:
+ print_config(config)
+ print(f"Resetting Oracle objects for {GRAPH_NAME}")
+ reset_objects(client)
+ print("Creating Oracle tables")
+ run_script(client, TABLE_DDL_PATH)
+ print("Inserting generated sample rows")
+ load_data(client, config)
+ print("Creating Oracle SQL property graph")
+ run_script(client, GRAPH_DDL_PATH)
+ print(f"Ready: {GRAPH_NAME}")
+ finally:
+ client.close()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/setup_oracle_sqlpgq_fraud_db.py b/examples/setup_oracle_sqlpgq_fraud_db.py
new file mode 100644
index 0000000..00655a2
--- /dev/null
+++ b/examples/setup_oracle_sqlpgq_fraud_db.py
@@ -0,0 +1,494 @@
+"""
+Create and populate a scalable Oracle SQL/PGQ fraud/payments graph.
+
+Required environment variables:
+- ORACLE_DSN
+- ORACLE_USER
+- ORACLE_PASSWORD
+
+This example is intentionally separate from the movie graph so SQL/PGQ corpus
+generation can be tested across different graph domains.
+"""
+
+import argparse
+from dataclasses import dataclass
+from datetime import datetime, timedelta
+import json
+import os
+from pathlib import Path
+import random
+from typing import Any
+
+from app.core.schema.edge import Edge
+from app.core.schema.node import Node
+from app.core.schema.schema_graph import SchemaGraph
+from app.core.validator.db_client import QueryStatus
+from app.impl.oracle_sqlpgq.db_client.oracle_db_client import OracleDBClient
+from app.impl.oracle_sqlpgq.schema.schema_parser import OracleSqlPgqSchemaParser
+
+
+GRAPH_NAME = "TEXT2GQL_FRAUD_GRAPH"
+ARTIFACT_DIR = Path("examples/Oracle_SQLPGQ_Instance")
+BASE_NAME = GRAPH_NAME
+Row = tuple[Any, ...]
+
+
+@dataclass(frozen=True)
+class FraudConfig:
+ customers: int
+ accounts: int
+ merchants: int
+ transactions: int
+ devices: int
+ cities: int
+ seed: int
+ batch_size: int
+
+
+def env_int(name: str, default: int) -> int:
+ return int(os.getenv(name, str(default)))
+
+
+def positive_int(value: str) -> int:
+ parsed = int(value)
+ if parsed <= 0:
+ raise argparse.ArgumentTypeError("value must be positive")
+ return parsed
+
+
+def parse_args() -> FraudConfig:
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--customers", type=positive_int, default=env_int("FRAUD_GRAPH_CUSTOMERS", 6))
+ parser.add_argument("--accounts", type=positive_int, default=env_int("FRAUD_GRAPH_ACCOUNTS", 8))
+ parser.add_argument("--merchants", type=positive_int, default=env_int("FRAUD_GRAPH_MERCHANTS", 6))
+ parser.add_argument(
+ "--transactions", type=positive_int, default=env_int("FRAUD_GRAPH_TRANSACTIONS", 24)
+ )
+ parser.add_argument("--devices", type=positive_int, default=env_int("FRAUD_GRAPH_DEVICES", 6))
+ parser.add_argument("--cities", type=positive_int, default=env_int("FRAUD_GRAPH_CITIES", 4))
+ parser.add_argument("--seed", type=int, default=env_int("FRAUD_GRAPH_DATA_SEED", 19))
+ parser.add_argument(
+ "--batch-size",
+ type=positive_int,
+ default=env_int("FRAUD_GRAPH_INSERT_BATCH_SIZE", 1000),
+ )
+ config = FraudConfig(**vars(parser.parse_args()))
+ validate_config(config)
+ return config
+
+
+def validate_config(config: FraudConfig) -> None:
+ minimums = {
+ "customers": 4,
+ "accounts": 4,
+ "merchants": 4,
+ "transactions": 8,
+ "devices": 4,
+ "cities": 4,
+ }
+ values = config.__dict__
+ too_small = [
+ f"{name}>={minimum} (got {values[name]})"
+ for name, minimum in minimums.items()
+ if values[name] < minimum
+ ]
+ if too_small:
+ raise ValueError(
+ "The fraud graph seed queries require these minimum sizes: "
+ + ", ".join(too_small)
+ )
+
+
+def build_schema_graph() -> SchemaGraph:
+ graph = SchemaGraph(GRAPH_NAME)
+ graph.add_node(
+ Node(
+ label="CUSTOMER",
+ primary="CUSTOMER_ID",
+ properties=[
+ {"name": "CUSTOMER_ID", "type": "INT64"},
+ {"name": "NAME", "type": "STRING"},
+ {"name": "RISK_SCORE", "type": "FLOAT"},
+ {"name": "SEGMENT", "type": "STRING"},
+ ],
+ )
+ )
+ graph.add_node(
+ Node(
+ label="ACCOUNT",
+ primary="ACCOUNT_ID",
+ properties=[
+ {"name": "ACCOUNT_ID", "type": "INT64"},
+ {"name": "ACCOUNT_TYPE", "type": "STRING"},
+ {"name": "BALANCE", "type": "FLOAT"},
+ {"name": "OPENED_AT", "type": "DATETIME"},
+ ],
+ )
+ )
+ graph.add_node(
+ Node(
+ label="MERCHANT",
+ primary="MERCHANT_ID",
+ properties=[
+ {"name": "MERCHANT_ID", "type": "INT64"},
+ {"name": "NAME", "type": "STRING"},
+ {"name": "CATEGORY", "type": "STRING"},
+ {"name": "COUNTRY", "type": "STRING"},
+ ],
+ )
+ )
+ graph.add_node(
+ Node(
+ label="TRANSACTION",
+ primary="TRANSACTION_ID",
+ properties=[
+ {"name": "TRANSACTION_ID", "type": "INT64"},
+ {"name": "AMOUNT", "type": "FLOAT"},
+ {"name": "STATUS", "type": "STRING"},
+ {"name": "EVENT_TIME", "type": "DATETIME"},
+ {"name": "CHANNEL", "type": "STRING"},
+ {"name": "FRAUD_FLAG", "type": "BOOL"},
+ ],
+ )
+ )
+ graph.add_node(
+ Node(
+ label="DEVICE",
+ primary="DEVICE_ID",
+ properties=[
+ {"name": "DEVICE_ID", "type": "INT64"},
+ {"name": "DEVICE_TYPE", "type": "STRING"},
+ {"name": "TRUSTED", "type": "BOOL"},
+ ],
+ )
+ )
+ graph.add_node(
+ Node(
+ label="CITY",
+ primary="CITY_ID",
+ properties=[
+ {"name": "CITY_ID", "type": "INT64"},
+ {"name": "NAME", "type": "STRING"},
+ {"name": "COUNTRY", "type": "STRING"},
+ ],
+ )
+ )
+ graph.add_edge(
+ Edge(
+ label="OWNS",
+ src_dst_list=[["CUSTOMER", "ACCOUNT"]],
+ properties=[
+ {"name": "OWNERSHIP_TYPE", "type": "STRING"},
+ {"name": "SINCE", "type": "DATETIME"},
+ ],
+ )
+ )
+ graph.add_edge(
+ Edge(
+ label="INITIATED",
+ src_dst_list=[["ACCOUNT", "TRANSACTION"]],
+ properties=[
+ {"name": "IP_ADDRESS", "type": "STRING"},
+ {"name": "AUTH_METHOD", "type": "STRING"},
+ ],
+ )
+ )
+ graph.add_edge(
+ Edge(
+ label="PAID",
+ src_dst_list=[["TRANSACTION", "MERCHANT"]],
+ properties=[
+ {"name": "PAYMENT_METHOD", "type": "STRING"},
+ {"name": "APPROVAL_CODE", "type": "STRING"},
+ ],
+ )
+ )
+ graph.add_edge(
+ Edge(
+ label="USED_DEVICE",
+ src_dst_list=[["TRANSACTION", "DEVICE"]],
+ properties=[{"name": "DEVICE_CONFIDENCE", "type": "FLOAT"}],
+ )
+ )
+ graph.add_edge(
+ Edge(
+ label="LIVES_IN",
+ src_dst_list=[["CUSTOMER", "CITY"]],
+ properties=[{"name": "SINCE", "type": "DATETIME"}],
+ )
+ )
+ graph.add_edge(
+ Edge(
+ label="LOCATED_IN",
+ src_dst_list=[["MERCHANT", "CITY"]],
+ properties=[{"name": "ACTIVE", "type": "BOOL"}],
+ )
+ )
+ return graph
+
+
+def build_manifest() -> dict[str, Any]:
+ return OracleSqlPgqSchemaParser(db_id=GRAPH_NAME).build_manifest(
+ build_schema_graph(),
+ graph_name=GRAPH_NAME,
+ )
+
+
+def write_artifacts(manifest: dict[str, Any]) -> None:
+ ARTIFACT_DIR.mkdir(parents=True, exist_ok=True)
+ (ARTIFACT_DIR / f"{BASE_NAME}_oracle_schema.json").write_text(
+ json.dumps(manifest, indent=2, ensure_ascii=False),
+ encoding="utf-8",
+ )
+ (ARTIFACT_DIR / f"{BASE_NAME}_oracle_tables.sql").write_text(
+ manifest["table_ddl"],
+ encoding="utf-8",
+ )
+ (ARTIFACT_DIR / f"{BASE_NAME}_oracle_property_graph.sql").write_text(
+ manifest["property_graph_ddl"],
+ encoding="utf-8",
+ )
+
+
+def safe_execute(client: OracleDBClient, sql: str, ignore_errors: tuple[str, ...] = ()) -> None:
+ result = client.execute_query(sql)
+ if result.status_code == QueryStatus.SUCCESS:
+ return
+ if result.error and any(token in result.error for token in ignore_errors):
+ return
+ raise RuntimeError(f"Failed SQL:\n{sql}\n\nError:\n{result.error}")
+
+
+def run_script(client: OracleDBClient, script: str) -> None:
+ for result in client.execute_script(script):
+ if result.status_code != QueryStatus.SUCCESS:
+ raise RuntimeError(result.error)
+
+
+def reset_objects(client: OracleDBClient, manifest: dict[str, Any]) -> None:
+ safe_execute(
+ client,
+ f'DROP PROPERTY GRAPH "{GRAPH_NAME}"',
+ ignore_errors=("ORA-42421", "ORA-04043", "ORA-00942"),
+ )
+ for table_name in reversed(manifest["load_order"]):
+ safe_execute(
+ client,
+ f'DROP TABLE "{table_name}" CASCADE CONSTRAINTS PURGE',
+ ignore_errors=("ORA-00942",),
+ )
+
+
+def insert_rows(
+ client: OracleDBClient,
+ statement: str,
+ rows: list[Row],
+ label: str,
+ batch_size: int,
+) -> None:
+ for start in range(0, len(rows), batch_size):
+ result = client.executemany(statement, rows[start : start + batch_size])
+ if result.status_code != QueryStatus.SUCCESS:
+ raise RuntimeError(f"Failed loading {label}: {result.error}")
+ print(f"Loaded {len(rows)} rows into {label}")
+
+
+def dt(days: int, hour: int = 9) -> datetime:
+ return datetime(2025, 1, 1, hour, 0, 0) + timedelta(days=days)
+
+
+def generate_rows(config: FraudConfig) -> dict[str, list[Row]]:
+ rng = random.Random(config.seed)
+ customer_names = ["Alice", "Bob", "Carol", "Daniel", "Eve", "Fatima", "Omar", "Lina"]
+ segments = ["retail", "premium", "small_business", "student"]
+ customers = []
+ for index in range(config.customers):
+ risk = [0.82, 0.35, 0.64, 0.91][index] if index < 4 else round(rng.random(), 3)
+ name = customer_names[index] if index < len(customer_names) else f"Customer {index + 1}"
+ customers.append((1 + index, name, risk, segments[index % len(segments)]))
+
+ account_types = ["checking", "savings", "business", "wallet"]
+ accounts = []
+ owns = []
+ for index in range(config.accounts):
+ account_id = 1000 + index
+ customer_id = 1 + index % config.customers
+ accounts.append(
+ (
+ account_id,
+ account_types[index % len(account_types)],
+ round(250 + rng.random() * 25000, 2),
+ dt(index, hour=8),
+ )
+ )
+ owns.append((customer_id, account_id, "primary" if index % 3 else "joint", dt(index)))
+
+ base_merchants = [
+ ("ElectroHub", "Electronics", "US"),
+ ("GroceryMart", "Grocery", "US"),
+ ("TravelNow", "Travel", "FR"),
+ ("CryptoX", "Crypto", "MA"),
+ ("BookBarn", "Books", "UK"),
+ ("HealthPlus", "Healthcare", "US"),
+ ]
+ merchants = []
+ for index in range(config.merchants):
+ if index < len(base_merchants):
+ name, category, country = base_merchants[index]
+ else:
+ name, category, country = f"Merchant {index + 1}", "General", "US"
+ merchants.append((2000 + index, name, category, country))
+
+ device_types = ["mobile", "desktop", "tablet", "pos_terminal"]
+ devices = []
+ for index in range(config.devices):
+ trusted = 0 if index in {1, 4} else 1
+ devices.append((3000 + index, device_types[index % len(device_types)], trusted))
+
+ base_cities = [
+ ("Casablanca", "MA"),
+ ("New York", "US"),
+ ("Paris", "FR"),
+ ("London", "UK"),
+ ]
+ cities = []
+ for index in range(config.cities):
+ if index < len(base_cities):
+ name, country = base_cities[index]
+ else:
+ name, country = f"City {index + 1}", "US"
+ cities.append((4000 + index, name, country))
+
+ customer_cities = [
+ (1 + index, 4000 + index % config.cities, dt(index + 30))
+ for index in range(config.customers)
+ ]
+ merchant_cities = [
+ (2000 + index, 4000 + index % config.cities, 1)
+ for index in range(config.merchants)
+ ]
+
+ amount_cycle = [1299.99, 42.50, 860.00, 2200.00, 120.25, 510.75, 75.00, 1450.30]
+ status_cycle = ["APPROVED", "APPROVED", "DECLINED", "APPROVED"]
+ channel_cycle = ["online", "card", "mobile", "branch"]
+ transactions = []
+ initiated = []
+ paid = []
+ used_device = []
+ for index in range(config.transactions):
+ transaction_id = 5000 + index
+ amount = amount_cycle[index % len(amount_cycle)]
+ fraud_flag = 1 if index % 7 in {0, 3} else 0
+ account_id = 1000 + index % config.accounts
+ merchant_id = 2000 + index % config.merchants
+ device_id = 3000 + index % config.devices
+ event_time = dt(index, hour=10 + index % 10)
+ channel = channel_cycle[index % len(channel_cycle)]
+ transactions.append(
+ (
+ transaction_id,
+ amount,
+ status_cycle[index % len(status_cycle)],
+ event_time,
+ channel,
+ fraud_flag,
+ )
+ )
+ initiated.append(
+ (
+ account_id,
+ transaction_id,
+ f"10.0.{index % 255}.{(index * 17) % 255}",
+ "otp" if index % 3 else "password",
+ )
+ )
+ paid.append(
+ (
+ transaction_id,
+ merchant_id,
+ "card" if index % 2 else "wallet",
+ f"APP{transaction_id}",
+ )
+ )
+ used_device.append(
+ (
+ transaction_id,
+ device_id,
+ round(0.35 + rng.random() * 0.64, 3),
+ )
+ )
+
+ return {
+ "CUSTOMER": customers,
+ "ACCOUNT": accounts,
+ "MERCHANT": merchants,
+ "TRANSACTION": transactions,
+ "DEVICE": devices,
+ "CITY": cities,
+ "CUSTOMER_OWNS_ACCOUNT": owns,
+ "ACCOUNT_INITIATED_TRANSACTION": initiated,
+ "TRANSACTION_PAID_MERCHANT": paid,
+ "TRANSACTION_USED_DEVICE_DEVICE": used_device,
+ "CUSTOMER_LIVES_IN_CITY": customer_cities,
+ "MERCHANT_LOCATED_IN_CITY": merchant_cities,
+ }
+
+
+def load_data(client: OracleDBClient, config: FraudConfig) -> None:
+ rows = generate_rows(config)
+ statements = {
+ "CUSTOMER": 'INSERT INTO "CUSTOMER" ("CUSTOMER_ID", "NAME", "RISK_SCORE", "SEGMENT") VALUES (:1, :2, :3, :4)',
+ "ACCOUNT": 'INSERT INTO "ACCOUNT" ("ACCOUNT_ID", "ACCOUNT_TYPE", "BALANCE", "OPENED_AT") VALUES (:1, :2, :3, :4)',
+ "MERCHANT": 'INSERT INTO "MERCHANT" ("MERCHANT_ID", "NAME", "CATEGORY", "COUNTRY") VALUES (:1, :2, :3, :4)',
+ "TRANSACTION": 'INSERT INTO "TRANSACTION" ("TRANSACTION_ID", "AMOUNT", "STATUS", "EVENT_TIME", "CHANNEL", "FRAUD_FLAG") VALUES (:1, :2, :3, :4, :5, :6)',
+ "DEVICE": 'INSERT INTO "DEVICE" ("DEVICE_ID", "DEVICE_TYPE", "TRUSTED") VALUES (:1, :2, :3)',
+ "CITY": 'INSERT INTO "CITY" ("CITY_ID", "NAME", "COUNTRY") VALUES (:1, :2, :3)',
+ "CUSTOMER_OWNS_ACCOUNT": 'INSERT INTO "CUSTOMER_OWNS_ACCOUNT" ("SRC_ID", "DST_ID", "OWNERSHIP_TYPE", "SINCE") VALUES (:1, :2, :3, :4)',
+ "ACCOUNT_INITIATED_TRANSACTION": 'INSERT INTO "ACCOUNT_INITIATED_TRANSACTION" ("SRC_ID", "DST_ID", "IP_ADDRESS", "AUTH_METHOD") VALUES (:1, :2, :3, :4)',
+ "TRANSACTION_PAID_MERCHANT": 'INSERT INTO "TRANSACTION_PAID_MERCHANT" ("SRC_ID", "DST_ID", "PAYMENT_METHOD", "APPROVAL_CODE") VALUES (:1, :2, :3, :4)',
+ "TRANSACTION_USED_DEVICE_DEVICE": 'INSERT INTO "TRANSACTION_USED_DEVICE_DEVICE" ("SRC_ID", "DST_ID", "DEVICE_CONFIDENCE") VALUES (:1, :2, :3)',
+ "CUSTOMER_LIVES_IN_CITY": 'INSERT INTO "CUSTOMER_LIVES_IN_CITY" ("SRC_ID", "DST_ID", "SINCE") VALUES (:1, :2, :3)',
+ "MERCHANT_LOCATED_IN_CITY": 'INSERT INTO "MERCHANT_LOCATED_IN_CITY" ("SRC_ID", "DST_ID", "ACTIVE") VALUES (:1, :2, :3)',
+ }
+ for table_name, statement in statements.items():
+ insert_rows(client, statement, rows[table_name], table_name, config.batch_size)
+ safe_execute(client, "COMMIT")
+
+
+def print_config(config: FraudConfig) -> None:
+ print("Fraud graph data size configuration:")
+ for name, value in config.__dict__.items():
+ print(f" {name}: {value}")
+
+
+def main() -> None:
+ config = parse_args()
+ manifest = build_manifest()
+ write_artifacts(manifest)
+ client = OracleDBClient(
+ {
+ "dsn": os.environ["ORACLE_DSN"],
+ "user": os.environ["ORACLE_USER"],
+ "password": os.environ["ORACLE_PASSWORD"],
+ }
+ )
+ if not client.connection:
+ raise RuntimeError("Oracle connection failed.")
+
+ try:
+ print_config(config)
+ print(f"Resetting Oracle objects for {GRAPH_NAME}")
+ reset_objects(client, manifest)
+ print("Creating Oracle tables")
+ run_script(client, manifest["table_ddl"])
+ print("Inserting generated fraud sample rows")
+ load_data(client, config)
+ print("Creating Oracle SQL property graph")
+ run_script(client, manifest["property_graph_ddl"])
+ print(f"Ready: {GRAPH_NAME}")
+ finally:
+ client.close()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/tugraph_to_oracle_sqlpgq.py b/examples/tugraph_to_oracle_sqlpgq.py
new file mode 100644
index 0000000..5129138
--- /dev/null
+++ b/examples/tugraph_to_oracle_sqlpgq.py
@@ -0,0 +1,51 @@
+"""
+Convert a framework/TuGraph-style schema JSON into Oracle SQL/PGQ artifacts.
+
+Generated files:
+- *_oracle_tables.sql
+- *_oracle_property_graph.sql
+- *_oracle_schema.json
+- *_oracle_loader.py
+"""
+
+import logging
+from pathlib import Path
+import sys
+
+sys.path.insert(0, str(Path(__file__).parent.parent))
+
+from app.impl.oracle_sqlpgq.schema.schema_parser import OracleSqlPgqSchemaParser
+
+logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
+logger = logging.getLogger(__name__)
+
+
+def convert_schema_to_oracle_sqlpgq(
+ input_json_path: str,
+ output_dir: str,
+ graph_name: str = "TEXT2GQL_GRAPH",
+) -> str:
+ input_path = Path(input_json_path)
+ if not input_path.exists():
+ raise FileNotFoundError(f"Input schema file not found: {input_path}")
+
+ parser = OracleSqlPgqSchemaParser(db_id=graph_name, instance_path=str(input_path))
+ schema_graph = parser.get_schema_graph()
+ output_path = Path(output_dir)
+ saved_manifest = parser.save_schema_to_file(
+ output_path,
+ schema_graph,
+ domain=graph_name,
+ subdomain="",
+ )
+ logger.info("Generated Oracle SQL/PGQ artifacts under %s", output_path)
+ logger.info("Manifest: %s", saved_manifest)
+ return saved_manifest
+
+
+if __name__ == "__main__":
+ convert_schema_to_oracle_sqlpgq(
+ input_json_path="examples/generated_schemas/example_schema.json",
+ output_dir="examples/Oracle_SQLPGQ_Instance",
+ graph_name="TEXT2GQL_GRAPH",
+ )
diff --git a/pyproject.toml b/pyproject.toml
index fb5c788..85e1bcd 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -21,6 +21,7 @@ tugraphclient = "^2.0"
openai = "==2.9.0"
numpy = "2.2.6"
neo4j = "6.0.3"
+oracledb = "^3.1.0"
[[tool.poetry.source]]
name = "PyPI"
@@ -83,5 +84,6 @@ python_files = ["test_*.py"]
addopts = "-v"
asyncio_mode = "auto" # Enable asyncio mode
markers = [
- "asyncio: mark test as async"
-]
\ No newline at end of file
+ "asyncio: mark test as async",
+ "oracle: tests that require a live Oracle Database"
+]
diff --git a/test/test_cypher2oracle_sqlpgq.py b/test/test_cypher2oracle_sqlpgq.py
new file mode 100644
index 0000000..3c0117f
--- /dev/null
+++ b/test/test_cypher2oracle_sqlpgq.py
@@ -0,0 +1,2609 @@
+from app.core.clauses.match_clause import MatchClause
+from app.impl.tugraph_cypher.ast_visitor.tugraph_cypher_ast_visitor import (
+ TugraphCypherAstVisitor,
+)
+from examples.cypher2oracle_sqlpgq import cypher2oracle_sqlpgq
+
+
+def _translate(cypher: str) -> str:
+ query, category = cypher2oracle_sqlpgq(cypher, graph_name="MOVIE_GRAPH")
+
+ assert category == "Graph-IL Translatable"
+ assert query.startswith("SELECT")
+ assert "FROM GRAPH_TABLE" in query
+ assert '"MOVIE_GRAPH"' in query
+ return query
+
+
+def _translate_with_types(
+ cypher: str,
+ property_type_map: dict[str, dict[str, str]],
+) -> str:
+ query, category = cypher2oracle_sqlpgq(
+ cypher,
+ graph_name="MOVIE_GRAPH",
+ property_type_map=property_type_map,
+ )
+
+ assert category == "Graph-IL Translatable"
+ return query
+
+
+def _translate_sql(cypher: str) -> str:
+ query, category = cypher2oracle_sqlpgq(cypher, graph_name="MOVIE_GRAPH")
+
+ assert category == "Graph-IL Translatable"
+ assert "FROM GRAPH_TABLE" in query
+ assert '"MOVIE_GRAPH"' in query
+ return query
+
+
+def test_cypher2oracle_sqlpgq_translates_simple_node_return_property():
+ query = _translate("MATCH (p:PERSON) RETURN p.name AS person_name")
+
+ assert 'MATCH (p IS "PERSON")' in query
+ assert 'COLUMNS (p."name" AS person_name)' in query
+
+
+def test_cypher_ast_marks_optional_match_clause():
+ visitor = TugraphCypherAstVisitor()
+
+ success, optional_pattern = visitor.get_query_pattern("OPTIONAL MATCH (a)-->(b) RETURN b")
+ assert success
+ optional_match = next(clause for clause in optional_pattern if isinstance(clause, MatchClause))
+ assert optional_match.optional
+
+ success, regular_pattern = visitor.get_query_pattern("MATCH (a)-->(b) RETURN b")
+ assert success
+ regular_match = next(clause for clause in regular_pattern if isinstance(clause, MatchClause))
+ assert not regular_match.optional
+
+
+def test_cypher2oracle_sqlpgq_translates_directed_edge_and_where():
+ query = _translate(
+ "MATCH (p:PERSON)-[a:ACTED_IN]->(m:MOVIE) "
+ "WHERE p.name = 'Tom Hanks' "
+ "RETURN m.title AS movie_title"
+ )
+
+ assert 'MATCH (p IS "PERSON")-[a IS "ACTED_IN"]->(m IS "MOVIE")' in query
+ assert "WHERE p.\"name\" = 'Tom Hanks'" in query
+ assert 'COLUMNS (m."title" AS movie_title)' in query
+
+
+def test_cypher2oracle_sqlpgq_translates_order_skip_limit():
+ query = _translate(
+ "MATCH (p:PERSON) RETURN p.name AS person_name ORDER BY p.name ASC SKIP 5 LIMIT 10"
+ )
+
+ assert "ORDER BY person_name ASC" in query
+ assert "OFFSET 5 ROWS" in query
+ assert "FETCH FIRST 10 ROWS ONLY" in query
+
+
+def test_cypher2oracle_sqlpgq_wraps_distinct_hidden_sort_columns():
+ query = _translate_sql(
+ "MATCH (r:RULE)<-[mr:MATCHED_RULE]-(t:TRANSACTION) "
+ "RETURN DISTINCT r ORDER BY r.created_at ASC"
+ )
+
+ assert "SELECT DISTINCT r_VALUE, created_at" in query
+ assert "SELECT r_VALUE\nFROM (" in query
+ assert "ORDER BY created_at ASC" in query
+
+
+def test_cypher2oracle_sqlpgq_translates_variable_length_relationship():
+ query = _translate(
+ "MATCH (person:PERSON)-[:KNOWS*..3]->(friend:PERSON) RETURN friend.name AS friend_name"
+ )
+
+ assert '[e1 IS "KNOWS"]->{1,3}' in query
+ assert 'COLUMNS (friend."name" AS friend_name)' in query
+
+
+def test_cypher2oracle_sqlpgq_translates_broad_bounded_variable_length_relationship():
+ query = _translate(
+ "MATCH (a:ACCOUNT {account_id: 'A000000'})-[*..10]-(t:TRANSACTION)"
+ "-[:GovernedBy]->(c:COMPLIANCE_RULE {regulation_standard: 'GDPR'}) "
+ "RETURN a.account_id, t.transaction_id, c.rule_id LIMIT 1"
+ )
+
+ assert '(a IS "ACCOUNT")-[e1]-{1,10}(t IS "TRANSACTION")' in query
+ assert '[e2 IS "GovernedBy"]->(c IS "COMPLIANCE_RULE")' in query
+ assert 'a."account_id" AS account_id' in query
+ assert 't."transaction_id" AS transaction_id' in query
+ assert 'c."rule_id" AS rule_id' in query
+ assert "FETCH FIRST 1 ROWS ONLY" in query
+
+
+def test_cypher2oracle_sqlpgq_translates_whole_node_return_to_vertex_id():
+ query = _translate("MATCH (p:PERSON) RETURN p")
+
+ assert "COLUMNS (VERTEX_ID(p) AS p_VALUE)" in query
+
+
+def test_cypher2oracle_sqlpgq_expands_named_path_return_to_element_ids():
+ query = _translate(
+ "MATCH p = (n1:ACCOUNT)-[e1]-(x)-[e2]-(n2:ACCOUNT) "
+ "WHERE n1.account_id = 'A000000' AND n2.account_id <> 'A000000' "
+ "RETURN p LIMIT 1"
+ )
+
+ assert 'MATCH (n1 IS "ACCOUNT")-[e1]-(x)-[e2]-(n2 IS "ACCOUNT")' in query
+ assert "VERTEX_ID(n1) AS p_n1_ID" in query
+ assert "EDGE_ID(e1) AS p_e1_ID" in query
+ assert "VERTEX_ID(x) AS p_x_ID" in query
+ assert "EDGE_ID(e2) AS p_e2_ID" in query
+ assert "VERTEX_ID(n2) AS p_n2_ID" in query
+ assert "FETCH FIRST 1 ROWS ONLY" in query
+
+
+def test_cypher2oracle_sqlpgq_expands_variable_length_named_path_return():
+ query = _translate("MATCH p = (a:ACCOUNT)-[e*1..3]->(b:ACCOUNT) RETURN p LIMIT 5")
+
+ assert 'MATCH (a IS "ACCOUNT")-[e]->{1,3}(b IS "ACCOUNT")' in query
+ assert "VERTEX_ID(a) AS p_a_ID" in query
+ assert "JSON_ARRAYAGG(EDGE_ID(e)) AS p_e_IDS" in query
+ assert "VERTEX_ID(b) AS p_b_ID" in query
+ assert "FETCH FIRST 5 ROWS ONLY" in query
+
+
+def test_cypher2oracle_sqlpgq_rejects_open_ended_variable_length_path():
+ query, category = cypher2oracle_sqlpgq(
+ "MATCH (person:PERSON)-[:KNOWS*1..]->(friend:PERSON) RETURN friend",
+ graph_name="MOVIE_GRAPH",
+ )
+
+ assert query == "Unable to Translate to Oracle SQL/PGQ"
+ assert category == "Graph-IL Not Support"
+
+
+def test_cypher2oracle_sqlpgq_rejects_quantified_relationship_property_map():
+ query, category = cypher2oracle_sqlpgq(
+ "MATCH (a)-[:CONNECTS_TO*1..2 {connectionType:'WiFi'}]->(b) RETURN count(b)"
+ )
+
+ assert query == "Unable to Translate to Oracle SQL/PGQ"
+ assert category == "Graph-IL Not Support"
+
+
+def test_cypher2oracle_sqlpgq_translates_union_all_path_branches():
+ query = _translate(
+ "MATCH p = (n1:ACCOUNT)-[e1:BelongsTo]-(x:FINANCIAL_PERIOD) "
+ 'WHERE n1.account_id = "A000000" RETURN p LIMIT 5 '
+ "UNION ALL "
+ "MATCH p = (n1:ACCOUNT)-[e1:BelongsTo]-(x:FINANCIAL_PERIOD)-[e2]-(y) "
+ 'WHERE n1.account_id = "A000000" RETURN p LIMIT 5'
+ )
+
+ assert "\nUNION ALL\n" in query
+ assert query.count("FROM GRAPH_TABLE") == 2
+ assert "NULL AS p_e2_ID" in query
+ assert "NULL AS p_y_ID" in query
+ assert query.count("FETCH FIRST 5 ROWS ONLY") == 2
+
+
+def test_cypher2oracle_sqlpgq_rejects_invalid_cypher():
+ translated_query, category = cypher2oracle_sqlpgq("MATCH (p:PERSON RETURN p")
+
+ assert translated_query == "Unable to Translate to Oracle SQL/PGQ"
+ assert category != "Graph-IL Translatable"
+
+
+def test_cypher2oracle_sqlpgq_translates_multi_condition_where():
+ query = _translate(
+ "MATCH (m:Movie) "
+ "WHERE m.released >= 1990 AND m.released <= 2000 AND m.votes > 5000 "
+ "RETURN m.title, m.released, m.votes"
+ )
+
+ assert 'm."released" >= 1990 AND m."released" <= 2000' in query
+ assert 'm."votes" > 5000' in query
+
+
+def test_cypher2oracle_sqlpgq_translates_null_predicate():
+ query = _translate(
+ "MATCH (c:Character) WHERE c.song IS NOT NULL RETURN c.name AS character_name"
+ )
+
+ assert 'WHERE c."song" IS NOT NULL' in query
+
+
+def test_cypher2oracle_sqlpgq_translates_count_star():
+ query = _translate("MATCH (m:Movie) RETURN count(*)")
+
+ assert query.startswith("SELECT COUNT(*) AS COUNT_VALUE")
+ assert "COLUMNS (1 AS dummy_value)" in query
+
+
+def test_cypher2oracle_sqlpgq_coalesces_sum_but_not_other_aggregates():
+ query = _translate(
+ "MATCH (m:Movie) "
+ "RETURN sum(m.budget) AS total_budget, count(m) AS movie_count, "
+ "avg(m.budget) AS average_budget, min(m.budget) AS min_budget, "
+ "max(m.budget) AS max_budget"
+ )
+
+ assert "COALESCE(SUM(budget), 0) AS total_budget" in query
+ assert "COUNT(m_VALUE) AS movie_count" in query
+ assert "AVG(budget) AS average_budget" in query
+ assert "MIN(budget) AS min_budget" in query
+ assert "MAX(budget) AS max_budget" in query
+
+
+def test_cypher2oracle_sqlpgq_translates_backtick_identifiers_and_column_compare():
+ query = _translate(
+ "MATCH (a:`voice-actors`)-[:MOVIE]->(c:characters) "
+ "WHERE a.movie = c.movie_title "
+ "RETURN a.`voice-actor` AS actor"
+ )
+
+ assert 'a."movie" = c."movie_title"' in query
+ assert 'a."voice_actor" AS actor' in query
+
+
+def test_cypher2oracle_sqlpgq_translates_double_quoted_property_map_strings():
+ query = _translate('MATCH (u1:User {label: "inchristbl.bsky.social"}) RETURN u1.label')
+
+ assert "u1.\"label\" = 'inchristbl.bsky.social'" in query
+ assert '"inchristbl.bsky.social"' not in query
+
+
+def test_cypher2oracle_sqlpgq_translates_escaped_single_quote_literals():
+ query = _translate_sql(
+ "MATCH (p:Product {productName: 'Chef Anton\\'s Cajun Seasoning'})"
+ "-[:PART_OF]->(c:Category) "
+ "MATCH (c)<-[:PART_OF]-(otherProducts:Product) "
+ "MATCH (otherProducts)<-[:SUPPLIES]-(suppliers:Supplier) "
+ "RETURN DISTINCT suppliers.companyName"
+ )
+
+ assert "Chef Anton''s Cajun Seasoning" in query
+ assert "Anton\\'s" not in query
+
+
+def test_cypher2oracle_sqlpgq_uses_out_of_line_where_for_property_maps():
+ query = _translate(
+ "MATCH (target:User {label: 'dwither.bsky.social'})"
+ "<-[:INTERACTED]-(user:User) "
+ "RETURN user.x, user.y LIMIT 3"
+ )
+
+ assert 'MATCH (user_var IS "User")-[e1 IS "INTERACTED"]->(target IS "User")' in query
+ assert 'COLUMNS (user_var."x" AS x, user_var."y" AS y)' in query
+ assert "WHERE target.\"label\" = 'dwither.bsky.social'" in query
+
+
+def test_cypher2oracle_sqlpgq_projects_hidden_order_by_property():
+ query = _translate("MATCH (u:User) WHERE u.size < 2.0 RETURN u ORDER BY u.size DESC LIMIT 5")
+
+ assert query.startswith("SELECT u_VALUE")
+ assert 'COLUMNS (VERTEX_ID(u) AS u_VALUE, u."size" AS size_VALUE)' in query
+ assert "ORDER BY size_VALUE DESC" in query
+
+
+def test_cypher2oracle_sqlpgq_projects_hidden_order_by_scalar_function():
+ query = _translate("MATCH (u:User) WHERE u.x IS NOT NULL RETURN u ORDER BY abs(u.x) LIMIT 3")
+
+ assert query.startswith("SELECT u_VALUE")
+ assert 'abs(u."x") AS abs_x' in query
+ assert "ORDER BY abs_x" in query
+
+
+def test_cypher2oracle_sqlpgq_translates_cypher_date_function_to_oracle_literal():
+ query = _translate(
+ "MATCH (a:ACCOUNT) WHERE a.opening_date < date('2020-01-01') RETURN a.status"
+ )
+
+ assert "a.\"opening_date\" < DATE '2020-01-01'" in query
+
+
+def test_cypher2oracle_sqlpgq_translates_cypher_current_date_function():
+ query = _translate(
+ "MATCH (cr:COMPLIANCE_RULE) WHERE cr.expiry_date >= date() RETURN cr.rule_id"
+ )
+
+ assert 'cr."expiry_date" >= TRUNC(CURRENT_DATE)' in query
+ assert "date()" not in query
+
+
+def test_cypher2oracle_sqlpgq_translates_cypher_date_weekday_extractors():
+ weekday = _translate("MATCH (m:Movie) WHERE date(m.released).weekday = 5 RETURN m.title")
+ day_of_week = _translate("MATCH (m:Movie) WHERE date(m.released).dayOfWeek = 5 RETURN m.title")
+
+ assert '(TRUNC(m."released") - TRUNC(m."released", \'IW\')) = 5' in weekday
+ assert '(TRUNC(m."released") - TRUNC(m."released", \'IW\') + 1) = 5' in day_of_week
+ assert ".weekday" not in weekday
+ assert ".dayOfWeek" not in day_of_week
+
+
+def test_cypher2oracle_sqlpgq_coerces_string_backed_date_literals():
+ query = _translate_with_types(
+ "MATCH (m:Movie) WHERE m.release_date >= date('1990-01-01') RETURN m.title",
+ {"Movie": {"release_date": "VARCHAR2(4000)"}},
+ )
+
+ assert "m.\"release_date\" >= '1990-01-01'" in query
+
+
+def test_cypher2oracle_sqlpgq_coerces_string_property_numeric_literals():
+ query = _translate_with_types(
+ "MATCH (u:User {id: 1}) RETURN u",
+ {"User": {"id": "VARCHAR2(4000)"}},
+ )
+
+ assert "u.\"id\" = '1'" in query
+
+
+def test_cypher2oracle_sqlpgq_coerces_string_property_boolean_literals():
+ query = _translate_with_types(
+ "MATCH (s:Supplier)-[:SUPPLIES]->(p:Product) WHERE p.discontinued = true RETURN s",
+ {"Product": {"discontinued": "VARCHAR2(4000)"}},
+ )
+
+ assert "p.\"discontinued\" = 'true'" in query
+ assert 'p."discontinued" = 1' not in query
+
+
+def test_cypher2oracle_sqlpgq_coerces_string_property_map_boolean_literals():
+ query = _translate_with_types(
+ "MATCH (p:Product {discontinued: false}) RETURN p",
+ {"Product": {"discontinued": "VARCHAR2(4000)"}},
+ )
+
+ assert "p.\"discontinued\" = 'false'" in query
+ assert "p.\"discontinued\" = '0'" not in query
+
+
+def test_cypher2oracle_sqlpgq_translates_date_property_extractors():
+ query = _translate_with_types(
+ "MATCH (m:Movie) WHERE date(m.release_date).year = 1995 RETURN m.title",
+ {"Movie": {"release_date": "VARCHAR2(4000)"}},
+ )
+
+ assert (
+ 'EXTRACT(YEAR FROM TO_DATE(m."release_date" DEFAULT NULL ON CONVERSION ERROR, '
+ "'YYYY-MM-DD')) = 1995"
+ ) in query
+
+
+def test_cypher2oracle_sqlpgq_coerces_string_property_date_function_calls():
+ query = _translate_with_types(
+ "MATCH (a:Actor)-[:ACTED_IN]->(m:Movie) "
+ "WHERE date(m.released) < a.born RETURN DISTINCT a.name",
+ {
+ "Actor": {"name": "VARCHAR2(4000)", "born": "VARCHAR2(4000)"},
+ "Movie": {"released": "VARCHAR2(4000)"},
+ },
+ )
+
+ assert (
+ "TO_DATE(m.\"released\" DEFAULT NULL ON CONVERSION ERROR, 'YYYY-MM-DD') "
+ "< TO_DATE(a.\"born\" DEFAULT NULL ON CONVERSION ERROR, 'YYYY-MM-DD')"
+ ) in query
+ assert 'CAST(m."released" AS DATE)' not in query
+
+
+def test_cypher2oracle_sqlpgq_translates_property_date_extractors_and_modulo():
+ query = _translate_with_types(
+ "MATCH (m:Movie) WHERE m.release_date.year % 4 = 0 RETURN m.title",
+ {"Movie": {"release_date": "VARCHAR2(4000)"}},
+ )
+
+ assert (
+ 'MOD(EXTRACT(YEAR FROM TO_DATE(m."release_date" DEFAULT NULL ON CONVERSION ERROR, '
+ "'YYYY-MM-DD')), 4) = 0"
+ ) in query
+
+
+def test_cypher2oracle_sqlpgq_disambiguates_duplicate_projection_aliases():
+ query = _translate(
+ "MATCH (dc:DataConsumer)-[:Consumes]->(da:DataAsset) RETURN dc.name, da.name"
+ )
+
+ assert 'dc."name" AS dc' in query
+ assert 'da."name" AS da' in query
+ assert " AS name" not in query
+
+
+def test_cypher2oracle_sqlpgq_disambiguates_complex_aggregate_property_aliases():
+ query = _translate(
+ "MATCH (dc:DataConsumer)-[c:Consumes {critical_dependency: true}]->(da:DataAsset) "
+ "RETURN dc.name, COUNT(c) AS critical_dependencies, "
+ "AVG(TOFLOAT(SIZE(da.name))) AS avg_usage_frequency"
+ )
+
+ assert 'dc."name" AS name' in query
+ assert 'da."name" AS da_name' in query
+ assert "AVG(TO_NUMBER(LENGTH(da_name))) AS avg_usage_frequency" in query
+ assert "AVG(TO_NUMBER(LENGTH(name)))" not in query
+
+
+def test_cypher2oracle_sqlpgq_orders_by_aggregate_alias_without_inner_projection():
+ query = _translate(
+ "MATCH (u:USER)-[:Initiates]->(t:TRANSACTION) "
+ "RETURN u.user_id, SUM(t.amount) AS total_amount "
+ "ORDER BY total_amount DESC"
+ )
+
+ assert "SELECT user_id, COALESCE(SUM(amount), 0) AS total_amount" in query
+ assert "total_amount AS total_amount" not in query
+ assert "ORDER BY total_amount DESC" in query
+
+
+def test_cypher2oracle_sqlpgq_translates_scalar_function_return_expressions():
+ query = _translate(
+ "MATCH (m:Movie) "
+ "RETURN abs(m.revenue - m.budget) AS difference "
+ "ORDER BY difference DESC LIMIT 3"
+ )
+
+ assert 'abs(m."revenue" - m."budget") AS difference' in query
+
+
+def test_cypher2oracle_sqlpgq_translates_tofloat_to_oracle_number_cast():
+ query = _translate("MATCH (m:Movie) RETURN avg(toFloat(m.budget)) AS average_budget")
+
+ assert 'COLUMNS (m."budget" AS m_budget)' in query
+ assert "SELECT avg(TO_NUMBER(m_budget)) AS average_budget" in query
+
+
+def test_cypher2oracle_sqlpgq_translates_tointeger_to_oracle_cast():
+ query = _translate(
+ "MATCH (se:SystemEnvironment)<-[:RunsIn]-(pj:ProcessingJob)"
+ "-[:Transforms]->(da:DataAsset) "
+ "WHERE da.sensitivity_level = 'PII' "
+ "RETURN se.name AS environment_name, COUNT(pj) AS pii_processing_job_count, "
+ "AVG(toInteger(pj.sla_requirements)) AS avg_sla_hours"
+ )
+
+ assert 'pj."sla_requirements" AS pj_sla_requirements' in query
+ assert (
+ "SELECT environment_name, COUNT(pj_VALUE) AS pii_processing_job_count, "
+ "AVG(CAST(pj_sla_requirements AS INTEGER)) AS avg_sla_hours" in query
+ )
+ assert "toInteger" not in query
+
+
+def test_cypher2oracle_sqlpgq_translates_substring_to_oracle_substr():
+ query = _translate(
+ "MATCH (o:Order) WHERE substring(o.orderDate, 0, 4) = '1997' RETURN count(o)"
+ )
+
+ assert "SUBSTR(o.\"orderDate\", 1, 4) = '1997'" in query
+ assert "substring" not in query.lower()
+
+
+def test_cypher2oracle_sqlpgq_translates_modulo_after_to_integer():
+ query = _translate(
+ "MATCH (p:Product) WHERE toInteger(p.unitPrice) % 10 = 0 RETURN p.productName"
+ )
+
+ assert 'MOD(CAST(p."unitPrice" AS INTEGER), 10) = 0' in query
+ assert "%" not in query
+
+
+def test_cypher2oracle_sqlpgq_redirects_unambiguous_node_property_to_adjacent_edge():
+ query, category = cypher2oracle_sqlpgq(
+ "MATCH (c:Category {categoryName: 'Condiments'})"
+ "<-[:PART_OF]-(p:Product)<-[:ORDERS]-(o:Order) "
+ "RETURN p.productName, SUM(o.quantity) AS totalQuantity",
+ graph_name="G",
+ node_label_map={
+ "Category": ["Category"],
+ "Product": ["Product"],
+ "Order": ["Order"],
+ },
+ edge_label_map={
+ "PART_OF": ["Product_PART_OF_Category"],
+ "ORDERS": ["Order_ORDERS_Product"],
+ },
+ property_type_map={
+ "Category": {"categoryName": "VARCHAR2(4000)"},
+ "Product": {"productName": "VARCHAR2(4000)"},
+ "Order": {"orderID": "VARCHAR2(4000)"},
+ "Order_ORDERS_Product": {"quantity": "NUMBER"},
+ "Product_PART_OF_Category": {},
+ },
+ strict_property_validation=True,
+ )
+
+ assert category == "Graph-IL Translatable"
+ assert 'e1."quantity" AS quantity' in query
+ assert 'o."quantity"' not in query
+
+
+def test_cypher2oracle_sqlpgq_rejects_unmapped_edge_property_on_node():
+ query, category = cypher2oracle_sqlpgq(
+ "MATCH (c:Customer)-[:PURCHASED]->(o:Order) "
+ "WITH o.discount AS discounts RETURN avg(toFloat(discounts))",
+ graph_name="G",
+ node_label_map={"Customer": ["Customer"], "Order": ["Order"]},
+ edge_label_map={
+ "PURCHASED": ["Customer_PURCHASED_Order"],
+ "ORDERS": ["Order_ORDERS_Product"],
+ },
+ property_type_map={
+ "Customer": {},
+ "Order": {"orderID": "VARCHAR2(4000)"},
+ "Customer_PURCHASED_Order": {},
+ "Order_ORDERS_Product": {"discount": "VARCHAR2(4000)"},
+ },
+ strict_property_validation=True,
+ )
+
+ assert query == "Unable to Translate to Oracle SQL/PGQ"
+ assert category == "Graph-IL Not Support"
+
+
+def test_cypher2oracle_sqlpgq_rejects_raw_edge_property_leaking_to_node():
+ query, category = cypher2oracle_sqlpgq(
+ "MATCH (c:Customer {country: 'Argentina'})-[:PURCHASED]->(o:Order) "
+ "WITH o.discount AS discounts RETURN avg(toFloat(discounts)) AS average_discount",
+ graph_name="G",
+ node_label_map={"Customer": ["Customer"], "Order": ["Order"]},
+ edge_label_map={
+ "PURCHASED": ["Customer_PURCHASED_Order"],
+ "ORDERS": ["Order_ORDERS_Product"],
+ },
+ property_type_map={
+ "Customer": {"country": "VARCHAR2(4000)"},
+ "Order": {"orderID": "VARCHAR2(4000)"},
+ "PURCHASED": {},
+ "Customer_PURCHASED_Order": {},
+ "ORDERS": {"discount": "VARCHAR2(4000)"},
+ "Order_ORDERS_Product": {"discount": "VARCHAR2(4000)"},
+ },
+ strict_property_validation=True,
+ )
+
+ assert query == "Unable to Translate to Oracle SQL/PGQ"
+ assert category == "Graph-IL Not Support"
+
+
+def test_cypher2oracle_sqlpgq_infers_endpoint_label_for_strict_node_property():
+ query, category = cypher2oracle_sqlpgq(
+ "MATCH (c:Customer {country: 'Argentina'})-[:PURCHASED]->(o:Order) RETURN o.orderID",
+ graph_name="G",
+ node_label_map={"Customer": ["Customer"], "Order": ["Order"]},
+ edge_label_map={"PURCHASED": ["Customer_PURCHASED_Order"]},
+ property_type_map={
+ "Customer": {"country": "VARCHAR2(4000)"},
+ "Order": {"orderID": "VARCHAR2(4000)"},
+ "PURCHASED": {},
+ "Customer_PURCHASED_Order": {},
+ },
+ strict_property_validation=True,
+ )
+
+ assert category == "Graph-IL Translatable"
+ assert 'o."orderID" AS orderID' in query
+
+
+def test_cypher2oracle_sqlpgq_translates_size_split_word_count():
+ query = _translate('MATCH (m:Movie) RETURN size(split(m.overview, " ")) AS word_count')
+
+ assert "REGEXP_COUNT(m.\"overview\", '\\S+') AS word_count" in query
+
+
+def test_cypher2oracle_sqlpgq_translates_with_size_split_word_count_aggregate():
+ query = _translate_sql(
+ 'MATCH (q:Question)-[:TAGGED]->(t:Tag {name: "graphql"}) '
+ 'WITH size(split(q.text, " ")) AS wordsInQuestion '
+ "RETURN avg(wordsInQuestion) AS averageWordCount"
+ )
+
+ assert "WITH stage_1 AS" in query
+ assert "REGEXP_COUNT(q.\"text\", '\\S+') AS wordsInQuestion" in query
+ assert "SELECT AVG(wordsInQuestion) AS averageWordCount" in query
+
+
+def test_cypher2oracle_sqlpgq_translates_vertex_comparisons():
+ query = _translate("MATCH (m1:Movie) MATCH (m2:Movie) WHERE m1 <> m2 RETURN m2")
+
+ assert "NOT VERTEX_EQUAL(m1, m2)" in query
+
+
+def test_cypher2oracle_sqlpgq_translates_final_with_projection_aggregate():
+ query = _translate(
+ "MATCH (u1:User)-[:INTERACTED]->(u2:User) "
+ "WHERE u2.size <> 1.5 "
+ "WITH u1.size AS interactingUserSize "
+ "RETURN sum(interactingUserSize) AS totalSize"
+ )
+
+ assert query.startswith("SELECT COALESCE(SUM(interactingUserSize), 0) AS totalSize")
+ assert 'u1."size" AS interactingUserSize' in query
+ assert 'WHERE u2."size" <> 1.5' in query
+
+
+def test_cypher2oracle_sqlpgq_uses_oracle_edge_label_map():
+ query, category = cypher2oracle_sqlpgq(
+ "MATCH (a)-[:book]->(b) RETURN b",
+ graph_name="G",
+ edge_label_map={"book": ["book_language_book_publisher", "book_author_book"]},
+ )
+
+ assert category == "Graph-IL Translatable"
+ assert 'IS "book_language_book_publisher" | "book_author_book"' in query
+
+
+def test_cypher2oracle_sqlpgq_translates_relationship_label_alternatives():
+ query = _translate(
+ "MATCH (fp:FINANCIAL_PERIOD {status: 'Open'})"
+ "<-[:BelongsTo]-(t:TRANSACTION)<-[:Initiates|Approves]-(u:USER) "
+ "RETURN fp.period_id, COUNT(DISTINCT u) AS unique_users, "
+ "SUM(t.amount) AS total_transaction_amount"
+ )
+
+ assert 'IS "Initiates" | "Approves"' in query
+ assert 'fp."period_id" AS period_id' in query
+ assert "COUNT(DISTINCT u_VALUE) AS unique_users" in query
+ assert "COALESCE(SUM(amount), 0) AS total_transaction_amount" in query
+
+
+def test_cypher2oracle_sqlpgq_aggregates_outside_graph_table():
+ query = _translate(
+ "MATCH (book:book) "
+ "WHERE (book.publisher_id = 1929 AND book.num_pages > 500) "
+ "RETURN count(*)"
+ )
+
+ assert query.startswith("SELECT COUNT(*) AS COUNT_VALUE")
+ assert 'MATCH (book IS "book")' in query
+ assert 'WHERE (book."publisher_id" = 1929 AND book."num_pages" > 500)' in query
+ assert "COLUMNS (1 AS dummy_value)" in query
+ assert "COUNT(*) AS COUNT_VALUE)" not in query
+
+
+def test_cypher2oracle_sqlpgq_average_aggregates_projected_property_outside_graph_table():
+ query = _translate(
+ "MATCH (:Person)-[r:REVIEWED]->(m:Movie) "
+ "WHERE r.rating > 80 "
+ "RETURN AVG(m.released) AS averageReleaseYear"
+ )
+
+ assert query.startswith("SELECT AVG(released) AS averageReleaseYear")
+ assert 'COLUMNS (m."released" AS released)' in query
+
+
+def test_cypher2oracle_sqlpgq_grouped_aggregate_uses_outer_group_by():
+ query = _translate(
+ "MATCH (director:director) "
+ "RETURN director, count(director.name) ORDER BY count(director.name) DESC LIMIT 1"
+ )
+
+ assert query.startswith("SELECT director_VALUE, COUNT(name) AS name")
+ assert 'COLUMNS (VERTEX_ID(director) AS director_VALUE, director."name" AS name)' in query
+ assert "GROUP BY director_VALUE" in query
+ assert "ORDER BY name DESC" in query
+
+
+def test_cypher2oracle_sqlpgq_translates_aggregate_case_expression():
+ query = _translate(
+ "MATCH (t1:area_code)<-[zip_code:ZIP_CODE]-(t2:zip_data) "
+ "WHERE t1.area_code = 787 "
+ "RETURN (count(CASE WHEN t2.type = 'P.O. Box Only' THEN 1 ELSE NULL END) "
+ "- count(CASE WHEN t2.type = 'Post Office' THEN 1 ELSE NULL END)) AS DIFFERENCE"
+ )
+
+ assert query.startswith(
+ "SELECT (count(CASE WHEN t2_type = 'P.O. Box Only' THEN 1 ELSE NULL END) "
+ "- count(CASE WHEN t2_type = 'Post Office' THEN 1 ELSE NULL END)) AS DIFFERENCE"
+ )
+ assert 'COLUMNS (t2."type" AS t2_type)' in query
+ assert "GROUP BY" not in query
+
+
+def test_cypher2oracle_sqlpgq_projects_multi_property_case_aggregate_outside_graph_table():
+ query = _translate(
+ "MATCH (t1:person)-->(t2:disabled) "
+ "RETURN count(CASE WHEN t2.name IS NULL THEN t1.name END) AS number"
+ )
+
+ assert query.startswith("SELECT count(CASE WHEN t2_name IS NULL THEN t1_name END) AS number")
+ assert 't2."name" AS t2_name' in query
+ assert 't1."name" AS t1_name' in query
+ assert "COLUMNS (CASE WHEN" not in query
+
+
+def test_cypher2oracle_sqlpgq_uses_case_insensitive_label_maps():
+ query, category = cypher2oracle_sqlpgq(
+ "MATCH (t1:state)<-[state:STATE]-(t2:country) RETURN count(t2.county)",
+ graph_name="G",
+ node_label_map={"state": ["state"], "country": ["country"]},
+ edge_label_map={"state": ["country_STATE_state"]},
+ )
+
+ assert category == "Graph-IL Translatable"
+ assert '[state IS "country_STATE_state"]->' in query
+ assert '[state IS "STATE"]' not in query
+
+
+def test_cypher2oracle_sqlpgq_maps_sanitized_labels_and_properties_strictly():
+ query, category = cypher2oracle_sqlpgq(
+ "MATCH (t1:characters)-[hero:HERO]->(t2:`voice-actors`) "
+ "WHERE t2.movie = t1.movie_title "
+ "RETURN t2.`voice-actor`",
+ graph_name="G",
+ node_label_map={
+ "characters": ["characters"],
+ "voice_actors": ["voice_actors"],
+ },
+ edge_label_map={"HERO": ["characters_HERO_voice_actors"]},
+ property_type_map={
+ "characters": {"movie_title": "VARCHAR2(4000)"},
+ "voice_actors": {
+ "movie": "VARCHAR2(4000)",
+ "voice_actor": "VARCHAR2(4000)",
+ },
+ "characters_HERO_voice_actors": {},
+ },
+ strict_property_validation=True,
+ )
+
+ assert category == "Graph-IL Translatable"
+ assert 't2 IS "voice_actors"' in query
+ assert '[hero IS "characters_HERO_voice_actors"]->' in query
+ assert 't2."voice_actor" AS voice_actor' in query
+
+
+def test_cypher2oracle_sqlpgq_translates_string_predicates():
+ starts = _translate("MATCH (a:author) WHERE a.name STARTS WITH 'George' RETURN a.name")
+ ends = _translate("MATCH (c:customer) WHERE c.email ENDS WITH '@x.test' RETURN c.email")
+ contains = _translate("MATCH (p:publisher) WHERE p.name CONTAINS 'book' RETURN count(*)")
+
+ assert "a.\"name\" LIKE 'George' || '%'" in starts
+ assert "c.\"email\" LIKE '%' || '@x.test'" in ends
+ assert "INSTR(p.\"name\", 'book') > 0" in contains
+
+
+def test_cypher2oracle_sqlpgq_translates_label_predicates():
+ query, category = cypher2oracle_sqlpgq(
+ "MATCH (entity)-[:ClassifiedUnder]->(d:Domain) "
+ "WHERE entity:Concept OR entity:Assertion "
+ "RETURN d.name, entity:Concept AS is_concept",
+ graph_name="G",
+ node_label_map={"Concept": ["Concept"], "Assertion": ["Assertion"]},
+ )
+
+ assert category == "Graph-IL Translatable"
+ assert 'entity IS LABELED "Concept"' in query
+ assert 'entity IS LABELED "Assertion"' in query
+ assert "entity:Concept" not in query
+
+
+def test_cypher2oracle_sqlpgq_folds_impossible_label_predicates():
+ query, category = cypher2oracle_sqlpgq(
+ "MATCH (p:Person)-[:DIRECTED]->(m:Movie) "
+ "WHERE NOT (p:Director) "
+ "RETURN m.title AS MovieTitle",
+ graph_name="G",
+ node_label_map={"Person": ["Person"], "Director": ["Director"], "Movie": ["Movie"]},
+ )
+
+ assert category == "Graph-IL Translatable"
+ assert "WHERE NOT ((1 = 0))" in query
+ assert 'p IS LABELED "Director"' not in query
+
+
+def test_cypher2oracle_sqlpgq_maps_identity_to_primary_key():
+ query, category = cypher2oracle_sqlpgq(
+ "MATCH (n:PaymentTransaction) RETURN count(n.identity)",
+ graph_name="G",
+ property_type_map={
+ "PaymentTransaction": {
+ "transaction_id": "VARCHAR2(4000)",
+ "status": "VARCHAR2(4000)",
+ }
+ },
+ node_primary_key_map={"PaymentTransaction": "transaction_id"},
+ strict_property_validation=True,
+ )
+
+ assert category == "Graph-IL Translatable"
+ assert 'n."transaction_id" AS identity' in query
+ assert 'n."identity"' not in query
+
+
+def test_cypher2oracle_sqlpgq_preserves_real_id_property_over_pseudo_identity():
+ query, category = cypher2oracle_sqlpgq(
+ "MATCH (q:Question) RETURN q.id, q.title",
+ graph_name="G",
+ property_type_map={
+ "Question": {
+ "vid": "VARCHAR2(4000)",
+ "id": "NUMBER",
+ "title": "VARCHAR2(4000)",
+ }
+ },
+ node_primary_key_map={"Question": "vid"},
+ strict_property_validation=True,
+ )
+
+ assert category == "Graph-IL Translatable"
+ assert 'q."id" AS id' in query
+ assert 'q."vid" AS id' not in query
+
+
+def test_cypher2oracle_sqlpgq_rejects_unresolved_identity_pseudo_property():
+ query, category = cypher2oracle_sqlpgq(
+ "MATCH (n1:USER)-[*1..3]->(n2) RETURN n2.identity",
+ graph_name="G",
+ property_type_map={
+ "USER": {"user_id": "VARCHAR2(4000)"},
+ "TWEET": {"tweet_id": "VARCHAR2(4000)"},
+ },
+ node_primary_key_map={"USER": "user_id", "TWEET": "tweet_id"},
+ strict_property_validation=True,
+ )
+
+ assert query == "Unable to Translate to Oracle SQL/PGQ"
+ assert category == "Graph-IL Not Support"
+
+
+def test_cypher2oracle_sqlpgq_maps_edge_identity_to_edge_id_function():
+ query, category = cypher2oracle_sqlpgq(
+ "MATCH (u:USER)-[r:POSTS]->(t:TWEET) RETURN u.user_id, r.identity, t.tweet_id",
+ graph_name="G",
+ node_label_map={"USER": ["USER"], "TWEET": ["TWEET"]},
+ edge_label_map={"POSTS": ["USER_POSTS_TWEET"]},
+ property_type_map={
+ "USER": {"user_id": "VARCHAR2(4000)"},
+ "TWEET": {"tweet_id": "VARCHAR2(4000)"},
+ "POSTS": {"EDGE_ID": "NUMBER"},
+ "USER_POSTS_TWEET": {"EDGE_ID": "NUMBER"},
+ },
+ edge_primary_key_map={"POSTS": "EDGE_ID", "USER_POSTS_TWEET": "EDGE_ID"},
+ strict_property_validation=True,
+ )
+
+ assert category == "Graph-IL Translatable"
+ assert "EDGE_ID(r) AS identity" in query
+ assert 'r."EDGE_ID"' not in query
+
+
+def test_cypher2oracle_sqlpgq_rejects_missing_properties_when_strict():
+ query, category = cypher2oracle_sqlpgq(
+ "MATCH (n:PaymentTransaction) RETURN n.missing_property",
+ graph_name="G",
+ property_type_map={"PaymentTransaction": {"transaction_id": "VARCHAR2(4000)"}},
+ strict_property_validation=True,
+ )
+
+ assert query == "Unable to Translate to Oracle SQL/PGQ"
+ assert category == "Graph-IL Not Support"
+
+
+def test_cypher2oracle_sqlpgq_rejects_string_numeric_aggregate_without_cast():
+ query, category = cypher2oracle_sqlpgq(
+ "MATCH (t:TWEET) RETURN AVG(t.tweet_id) AS avg_tweet_id",
+ graph_name="G",
+ property_type_map={"TWEET": {"tweet_id": "VARCHAR2(4000)"}},
+ strict_property_validation=True,
+ )
+
+ assert query == "Unable to Translate to Oracle SQL/PGQ"
+ assert category == "Graph-IL Not Support"
+
+
+def test_cypher2oracle_sqlpgq_maps_camel_case_property_to_snake_case():
+ query, category = cypher2oracle_sqlpgq(
+ "MATCH (u:User)-[:ASKED]->(q:Question) "
+ "WITH u, count(q) AS questionCount ORDER BY questionCount DESC LIMIT 3 "
+ "RETURN u.displayName, questionCount",
+ graph_name="G",
+ property_type_map={
+ "User": {"display_name": "VARCHAR2(4000)"},
+ "Question": {},
+ },
+ strict_property_validation=True,
+ )
+
+ assert category == "Graph-IL Translatable"
+ assert 'u."display_name" AS u_displayName' in query
+ assert 'u."displayName"' not in query
+
+
+def test_cypher2oracle_sqlpgq_translates_aggregate_arithmetic_over_elements():
+ query = _translate(
+ "MATCH (g:Group)<-[:BelongsTo]-(u:User)-[:HasRole]->(r:Role) "
+ "RETURN g.Group_id, g.name, COUNT(DISTINCT u) AS user_count, "
+ "COUNT(r) / COUNT(DISTINCT u) AS avg_roles_per_user"
+ )
+
+ assert "COUNT(r_VALUE) / COUNT(DISTINCT u_VALUE) AS avg_roles_per_user" in query
+ assert "VERTEX_ID(r) AS r_VALUE" in query
+ assert "VERTEX_ID(u) AS u_VALUE" in query
+ assert "r) / COUNT(u AS" not in query
+
+
+def test_cypher2oracle_sqlpgq_translates_lowercase_distinct_element_aggregate():
+ query = _translate("MATCH (p:Person) RETURN count(distinct p) AS person_count")
+
+ assert "COUNT(DISTINCT p_VALUE) AS person_count" in query
+ assert "VERTEX_ID(p) AS p_VALUE" in query
+ assert "distinct p AS" not in query
+
+
+def test_cypher2oracle_sqlpgq_translates_aggregate_arithmetic_over_properties():
+ query = _translate(
+ "MATCH (fp:FINANCIAL_PERIOD)<-[:BelongsTo]-(t:TRANSACTION), "
+ "(b:BUDGET)-[:AllocatedTo]->(a:ACCOUNT)-[:BelongsTo]->(fp) "
+ "RETURN fp.period_id, SUM(b.amount) - SUM(t.amount) AS budget_variance"
+ )
+
+ assert "COALESCE(SUM(b_amount), 0) - COALESCE(SUM(t_amount), 0) AS budget_variance" in query
+ assert 'b."amount" AS b_amount' in query
+ assert 't."amount" AS t_amount' in query
+
+
+def test_cypher2oracle_sqlpgq_translates_aggregate_with_to_cte():
+ query = _translate_sql(
+ "MATCH (u:USER)-[:POSTS]->(t:TWEET) "
+ "WITH u.username AS username, COUNT(t) AS tweet_count "
+ "WHERE tweet_count > 10 "
+ "RETURN username, tweet_count ORDER BY tweet_count DESC LIMIT 5"
+ )
+
+ assert query.startswith("WITH stage_1 AS")
+ assert "SELECT username, COUNT(t_VALUE) AS tweet_count" in query
+ assert "GROUP BY username" in query
+ assert "WHERE tweet_count > 10" in query
+ assert "ORDER BY tweet_count DESC" in query
+
+
+def test_cypher2oracle_sqlpgq_resolves_duplicate_final_aliases_after_with_stage():
+ query = _translate_sql(
+ "MATCH (a:ACTOR)-[:ACTED_IN]->(m1:MOVIE)-[s:SIMILAR_TO]->(m2:MOVIE) "
+ "WHERE s.similarity_score > 0.8 "
+ "WITH a, m1, m2, COUNT(s) AS sim_count "
+ "WHERE sim_count >= 3 "
+ "RETURN DISTINCT a.name, m1.title, m2.title"
+ )
+
+ assert "SELECT DISTINCT a_name AS name, m1_title AS m1, m2_title AS m2" in query
+ assert "m1_title AS title, m2_title AS title" not in query
+
+
+def test_cypher2oracle_sqlpgq_translates_correlated_optional_match_to_left_join():
+ query = _translate_sql(
+ "MATCH (a:Person) "
+ "OPTIONAL MATCH (a)-[:KNOWS]->(b:Person) "
+ "RETURN a.name AS person_name, b.name AS friend_name"
+ )
+
+ assert query.startswith("WITH stage_1 AS")
+ assert "stage_2 AS" in query
+ assert "LEFT JOIN stage_2 ON stage_2.a_VALUE = stage_1.stage_1_a_VALUE" in query
+ assert "stage_1.person_name AS person_name" in query
+ assert "stage_2.friend_name AS friend_name" in query
+
+
+def test_cypher2oracle_sqlpgq_translates_optional_match_after_with_to_left_join():
+ query = _translate_sql(
+ "MATCH (a:Person) WITH a "
+ "OPTIONAL MATCH (a)-[:KNOWS]->(b:Person) "
+ "RETURN a.name AS person_name, b.name AS friend_name"
+ )
+
+ assert "LEFT JOIN stage_2 ON stage_2.a_VALUE = stage_1.stage_1_a_VALUE" in query
+ assert "stage_1.person_name AS person_name" in query
+ assert "stage_2.friend_name AS friend_name" in query
+
+
+def test_cypher2oracle_sqlpgq_aggregates_optional_match_rows():
+ query = _translate_sql(
+ "MATCH (a:Person) "
+ "OPTIONAL MATCH (a)-[:KNOWS]->(b:Person) "
+ "WITH a, count(b) AS friend_count "
+ "RETURN a.name AS person_name, friend_count"
+ )
+
+ assert "LEFT JOIN stage_2 ON stage_2.a_VALUE = stage_1.stage_1_a_VALUE" in query
+ assert "COUNT(stage_2.b_VALUE) AS friend_count" in query
+ assert "GROUP BY stage_1.stage_1_a_VALUE" in query
+ assert "stage_1.a_name AS a_name" in query
+
+
+def test_cypher2oracle_sqlpgq_keeps_optional_where_inside_optional_stage():
+ query = _translate_sql(
+ "MATCH (a:Person) "
+ "OPTIONAL MATCH (a)-[:KNOWS]->(b:Person) "
+ "WHERE b.age > 30 "
+ "RETURN a.name AS person_name, b.name AS friend_name"
+ )
+
+ assert "FROM stage_1\nLEFT JOIN stage_2 ON stage_2.a_VALUE = stage_1.stage_1_a_VALUE" in query
+ assert query.index('WHERE b."age" > 30') < query.index('COLUMNS (b."name" AS friend_name')
+
+
+def test_cypher2oracle_sqlpgq_rejects_unsupported_optional_match_shapes():
+ query, category = cypher2oracle_sqlpgq(
+ "MATCH (a:Person) OPTIONAL MATCH (a)-[:KNOWS]->(b:Person) "
+ "OPTIONAL MATCH (b)-[:KNOWS]->(c:Person) RETURN c.name",
+ graph_name="MOVIE_GRAPH",
+ )
+
+ assert query == "Unable to Translate to Oracle SQL/PGQ"
+ assert category == "Graph-IL Not Support"
+
+ query, category = cypher2oracle_sqlpgq(
+ "MATCH (a:Person) OPTIONAL MATCH (b:Person) RETURN b.name",
+ graph_name="MOVIE_GRAPH",
+ )
+
+ assert query == "Unable to Translate to Oracle SQL/PGQ"
+ assert category == "Graph-IL Not Support"
+
+
+def test_cypher2oracle_sqlpgq_carries_with_variable_properties_to_cte():
+ query = _translate_sql(
+ "MATCH (u:USER)-[:Initiates]->(t:TRANSACTION) "
+ "WITH u, COUNT(t) AS transaction_count, SUM(t.amount) AS total_amount "
+ "WHERE transaction_count > 5 "
+ "RETURN u.user_id, total_amount"
+ )
+
+ assert 'u."user_id" AS u_user_id' in query
+ assert "GROUP BY u_VALUE, u_user_id" in query
+ assert "SELECT u_user_id AS user_id, total_amount AS total_amount" in query
+ assert "SELECT u AS user_id" not in query
+
+
+def test_cypher2oracle_sqlpgq_carries_with_variable_properties_for_final_aggregate():
+ query = _translate_sql(
+ "MATCH (a:ACCOUNT)-[:GovernedBy]->(cr:COMPLIANCE_RULE) "
+ "WITH a, COUNT(cr) AS rule_count "
+ "WHERE rule_count > 1 "
+ "RETURN AVG(a.balance) AS average_balance"
+ )
+
+ assert 'a."balance" AS a_balance' in query
+ assert "GROUP BY a_VALUE, a_balance" in query
+ assert "SELECT AVG(a_balance) AS average_balance" in query
+ assert "AVG(a)" not in query
+
+
+def test_cypher2oracle_sqlpgq_groups_final_with_stage_aggregate_projection():
+ query = _translate_sql(
+ "MATCH (u:USER)-[:USES]->(r:RESOURCE) "
+ "WITH u, COUNT(DISTINCT r) AS distinct_resources "
+ "RETURN u.name, AVG(distinct_resources) AS avg_resources"
+ )
+
+ assert "SELECT u_name AS name, AVG(distinct_resources) AS avg_resources" in query
+ assert "GROUP BY u_name" in query
+ assert 'u."name" AS u_name' in query
+
+
+def test_cypher2oracle_sqlpgq_aggregates_carried_with_edge_variable():
+ query = _translate_sql(
+ "MATCH (p:PERSON)-[r:HAS]->(res:RESOURCE) "
+ "WITH p, r, COUNT(res) AS resource_count "
+ "RETURN p.name, COUNT(r) AS relationship_count"
+ )
+
+ assert "COUNT(r_VALUE) AS relationship_count" in query
+ assert "EDGE_ID(r) AS r_VALUE" in query
+ assert "GROUP BY p_name" in query
+ assert "COUNT(r) AS" not in query
+
+
+def test_cypher2oracle_sqlpgq_filters_with_stage_on_carried_property_alias():
+ query = _translate_sql(
+ "MATCH (g:GROUP)<-[:MEMBER_OF]-(u:USER) "
+ "WITH g, COUNT(u) AS member_count "
+ "WHERE g.member_count > 10 "
+ "RETURN g.name, member_count"
+ )
+
+ assert 'g."member_count" AS g_member_count' in query
+ assert "WHERE g_member_count > 10" in query
+ assert "WHERE g.member_count" not in query
+
+
+def test_cypher2oracle_sqlpgq_translates_ordered_with_to_cte():
+ query = _translate_sql(
+ "MATCH (c:Character) "
+ "WITH c.name AS name, c.rank AS rank ORDER BY rank DESC LIMIT 5 "
+ "RETURN name, rank"
+ )
+
+ assert query.startswith("WITH stage_1 AS")
+ assert 'c."name" AS name' in query
+ assert 'c."rank" AS rank' in query
+ assert "ORDER BY rank DESC" in query
+ assert "FETCH FIRST 5 ROWS ONLY" in query
+ assert query.strip().endswith("FROM stage_1")
+
+
+def test_cypher2oracle_sqlpgq_normalizes_descending_order_keyword():
+ query = _translate_sql(
+ "MATCH (p:Product) "
+ "WITH p, p.reorderLevel AS reorderLevel ORDER BY reorderLevel DESCENDING "
+ "RETURN p.productName LIMIT 1"
+ )
+
+ assert "ORDER BY reorderLevel DESC" in query
+ assert "DESCENDING" not in query
+
+
+def test_cypher2oracle_sqlpgq_orders_with_stage_by_hidden_property_alias():
+ query = _translate_sql(
+ "MATCH (m:Movie)<-[:RATED]-(u:User) "
+ "WITH m, count(u) AS userCount "
+ "WHERE userCount > 1000 "
+ "RETURN m ORDER BY m.imdbRating DESC LIMIT 3"
+ )
+
+ assert 'm."imdbRating" AS m_imdbRating' in query
+ assert "ORDER BY m_imdbRating DESC" in query
+ assert "ORDER BY imdbRating DESC" not in query
+
+
+def test_cypher2oracle_sqlpgq_translates_chained_numeric_comparison():
+ query = _translate(
+ "MATCH (q:Question) "
+ "WHERE 100 <= q.view_count <= 500 "
+ "RETURN q.uuid, q.title, q.view_count ORDER BY q.view_count ASC"
+ )
+
+ assert '(100 <= q."view_count" AND q."view_count" <= 500)' in query
+ assert '100 <= q."view_count" <= 500' not in query
+
+
+def test_cypher2oracle_sqlpgq_omits_unknown_strict_edge_label():
+ query, category = cypher2oracle_sqlpgq(
+ "MATCH (z:ZIP)-[state:STATE]->(s:STATE) RETURN s.name",
+ graph_name="G",
+ node_label_map={"ZIP": ["zip"], "STATE": ["state"]},
+ edge_label_map={"zip_to_state": ["zip_data_country_state"]},
+ strict_property_validation=True,
+ )
+
+ assert category == "Graph-IL Translatable"
+ assert 'MATCH (z IS "zip")-[state]->(s IS "state")' in query
+ assert 'state IS "STATE"' not in query
+
+
+def test_cypher2oracle_sqlpgq_rejects_property_on_unknown_strict_edge_label():
+ query, category = cypher2oracle_sqlpgq(
+ "MATCH (a:A)-[bad:UnknownEdge]->(b:B) RETURN bad.some_property",
+ graph_name="G",
+ node_label_map={"A": ["A"], "B": ["B"]},
+ edge_label_map={"KnownEdge": ["A_KnownEdge_B"]},
+ property_type_map={"UnknownEdge": {"some_property": "VARCHAR2(4000)"}},
+ strict_property_validation=True,
+ )
+
+ assert query == "Unable to Translate to Oracle SQL/PGQ"
+ assert category == "Graph-IL Not Support"
+
+
+def test_cypher2oracle_sqlpgq_maps_source_label_alias_to_graph_label():
+ query, category = cypher2oracle_sqlpgq(
+ "MATCH (s:Source)<-[:SourcedFrom]-(a:Assertion) RETURN COUNT(DISTINCT s)",
+ graph_name="G",
+ node_label_map={"Source": ["InfoSource"], "Assertion": ["Assertion"]},
+ edge_label_map={"SourcedFrom": ["Assertion_SourcedFrom_InfoSource"]},
+ strict_property_validation=True,
+ )
+
+ assert category == "Graph-IL Translatable"
+ assert 's IS "InfoSource"' in query
+ assert 's IS "Source"' not in query
+
+
+def test_cypher2oracle_sqlpgq_translates_id_function_vertex_comparison():
+ query = _translate(
+ "MATCH (a:Assertion)<-[:Supports]-(s1:Assertion), "
+ "(a)<-[:Supports]-(s2:Assertion) "
+ "WHERE id(s1) <> id(s2) "
+ "RETURN avg(a.confidence_score) AS average_confidence"
+ )
+
+ assert "WHERE NOT VERTEX_EQUAL(s1, s2)" in query
+ assert "id(s1)" not in query
+
+
+def test_cypher2oracle_sqlpgq_uses_resolved_aliases_for_duplicate_aggregate_properties():
+ query = _translate(
+ "MATCH (p:Policy)-[:Enforces]->(r:Role)-[:GrantsAccessTo]->(res:Resource) "
+ "RETURN r.name, p.name, COUNT(res) AS api_endpoint_count"
+ )
+
+ assert "SELECT r, p, COUNT(res_VALUE) AS api_endpoint_count" in query
+ assert 'COLUMNS (r."name" AS r, p."name" AS p, VERTEX_ID(res) AS res_VALUE)' in query
+ assert "GROUP BY r, p" in query
+ assert "SELECT name, name" not in query
+
+
+def test_cypher2oracle_sqlpgq_translates_with_stage_tofloat_aggregate_argument():
+ query = _translate_sql(
+ "MATCH (role:Role {is_compliant: true})-[:GrantsAccessTo]->(res:Resource) "
+ "WITH role, COUNT(res) AS resource_count "
+ "WHERE resource_count >= 1 "
+ "RETURN role.name AS role_name, AVG(toFloat(resource_count)) AS avg_resources "
+ "ORDER BY avg_resources DESC"
+ )
+
+ assert "AVG(TO_NUMBER(resource_count)) AS avg_resources" in query
+ assert "toFloat" not in query
+
+
+def test_cypher2oracle_sqlpgq_translates_final_with_arithmetic_expression():
+ query = _translate_sql(
+ "MATCH (u:User)-[:AttemptsAccess]->(ae:AccessEvent)-[:Targets]->(r:Resource) "
+ "WITH u, COUNT(ae) AS attempt_count, MIN(ae.timestamp) AS first_attempt, "
+ "MAX(ae.timestamp) AS last_attempt "
+ "WHERE attempt_count > 1 "
+ "RETURN u.name, attempt_count, "
+ "(last_attempt - first_attempt) / attempt_count AS avg_time_between_attempts"
+ )
+
+ assert "(last_attempt - first_attempt) / attempt_count AS avg_time_between_attempts" in query
+ assert "_last_attempt_first_attempt_attempt_count AS avg_time_between_attempts" not in query
+
+
+def test_cypher2oracle_sqlpgq_casts_date_function_over_properties():
+ query = _translate(
+ "MATCH (p:Project)<-[:AppliedIn]-(c:Concept)-[:TaggedWith]->(t:Tag) "
+ "RETURN COUNT(p) AS project_count, "
+ "AVG(DATE(p.end_date) - DATE(p.start_date)) AS avg_duration_days"
+ )
+
+ assert "CAST(p_end_date AS DATE) - CAST(p_start_date AS DATE)" in query
+ assert 'p."end_date" AS p_end_date' in query
+ assert 'p."start_date" AS p_start_date' in query
+ assert 'DATE(p."end_date")' not in query
+
+
+def test_cypher2oracle_sqlpgq_rejects_numeric_aggregate_over_temporal_property():
+ query, category = cypher2oracle_sqlpgq(
+ "MATCH (p:Policy) RETURN AVG(p.effective_date) AS avg_effective_date",
+ graph_name="G",
+ property_type_map={"Policy": {"effective_date": "DATE"}},
+ )
+
+ assert query == "Unable to Translate to Oracle SQL/PGQ"
+ assert category == "Graph-IL Not Support"
+
+
+def test_cypher2oracle_sqlpgq_suffixes_resource_reserved_identifier():
+ query = _translate_sql(
+ "MATCH (role:Role)-[:GrantsAccessTo]->(resource:Resource) "
+ "RETURN role, resource "
+ "UNION "
+ "MATCH (role:Role)-[:GrantsAccessTo]->(:TemporaryAccess)"
+ "-[:AssignedTempAccess]->(resource:Resource) "
+ "RETURN role, resource"
+ )
+
+ assert '(resource_VALUE IS "Resource")' in query
+ assert "VERTEX_ID(resource_VALUE) AS resource_VALUE" in query
+ assert '(resource IS "Resource")' not in query
+
+
+def test_cypher2oracle_sqlpgq_uses_safe_variable_for_user_identifier():
+ query = _translate_sql(
+ "MATCH (user:User)-[:COMMENTED]->(comment:Comment)"
+ "-[:COMMENTED_ON]->(question:Question) "
+ "WITH question, count(DISTINCT user) AS distinct_commenters "
+ "WHERE distinct_commenters > 1 RETURN question.title"
+ )
+
+ assert 'MATCH (user_var IS "User")' in query
+ assert '(comment_VALUE IS "Comment")' in query
+ assert "VERTEX_ID(user_var) AS user_VALUE" in query
+ assert "COUNT(DISTINCT user_VALUE) AS distinct_commenters" in query
+
+
+def test_cypher2oracle_sqlpgq_translates_with_match_join_on_carried_variable():
+ query = _translate_sql(
+ "MATCH (u:User)-[:AttemptsAccess]->(ae:AccessEvent)-[:Targets]->(r:Resource) "
+ "WHERE r.sensitivity_level = 'Confidential' "
+ "WITH u "
+ "MATCH (u)-[:AssignedTempAccess]->(ta:TemporaryAccess) "
+ "RETURN DISTINCT u.name, u.department"
+ )
+
+ assert query.startswith("WITH stage_1 AS")
+ assert "stage_2 AS" in query
+ assert "VERTEX_ID(u) AS stage_1_u_VALUE" in query
+ assert "VERTEX_ID(u) AS u_VALUE" in query
+ assert "JOIN stage_1 ON stage_2.u_VALUE = stage_1.stage_1_u_VALUE" in query
+ assert "SELECT DISTINCT stage_1.name AS name, stage_1.department AS department" in query
+
+
+def test_cypher2oracle_sqlpgq_translates_with_match_join_on_multiple_variables():
+ query = _translate_sql(
+ "MATCH (u:User)-[:AttemptsAccess]->(ae:AccessEvent)-[:Targets]->(r:Resource) "
+ "WHERE r.sensitivity_level = 'Confidential' "
+ "WITH u, r "
+ "MATCH (u)-[:AssignedTempAccess]->(ta:TemporaryAccess)-[:AssignedTempAccess]->(r) "
+ "RETURN DISTINCT u.name, r.name"
+ )
+
+ assert "VERTEX_ID(r) AS stage_1_r_VALUE" in query
+ assert "VERTEX_ID(u) AS stage_1_u_VALUE" in query
+ assert "VERTEX_ID(r) AS r_VALUE" in query
+ assert "VERTEX_ID(u) AS u_VALUE" in query
+ assert "stage_2.r_VALUE = stage_1.stage_1_r_VALUE" in query
+ assert "stage_2.u_VALUE = stage_1.stage_1_u_VALUE" in query
+ assert "SELECT DISTINCT stage_1.u AS u, stage_1.r AS r" in query
+
+
+def test_cypher2oracle_sqlpgq_uses_second_stage_duplicate_property_aliases_after_with_match():
+ query = _translate_sql(
+ "MATCH (creator:User {department:'IT'})-[:BelongsTo]->(g:Group) "
+ "WITH DISTINCT g "
+ "MATCH (g)<-[:BelongsTo]-(u:User)-[:AttemptsAccess]->"
+ "(ae:AccessEvent {outcome:'FAILURE'})-[:Targets]->(r:Resource) "
+ "RETURN u.name, g.name, r.name"
+ )
+
+ assert "SELECT stage_2.u AS u, stage_1.g AS g, stage_2.r AS r" in query
+ assert "stage_2.name AS u" not in query
+ assert "stage_2.name AS r" not in query
+
+
+def test_cypher2oracle_sqlpgq_qualifies_second_stage_whole_element_after_with_match():
+ query = _translate_sql(
+ "MATCH (v:VENDOR)-[:PROVIDES]->(c:CONTRACT) "
+ "WITH v, SUM(c.value) AS total_value "
+ "WHERE total_value > 100000 "
+ "MATCH (v)-[:PROVIDES]->(c:CONTRACT) "
+ "RETURN v, c ORDER BY c.value DESC LIMIT 10"
+ )
+
+ assert "SELECT stage_1.v_VALUE AS v_VALUE, stage_2.c_VALUE AS c_VALUE" in query
+ assert "SELECT stage_1.v_VALUE AS v_VALUE, c_VALUE" not in query
+ assert 'c."value" AS c_value_PROP' in query
+ assert "ORDER BY stage_2.c_value_PROP DESC" in query
+
+
+def test_cypher2oracle_sqlpgq_projects_carried_return_properties_from_first_with_stage():
+ query = _translate_sql(
+ "MATCH (p:Policy)-[:Enforces]->(r:Role)-[:GrantsAccessTo]->(res1:Resource) "
+ "WHERE res1.type = 'Database' "
+ "WITH p, r "
+ "MATCH (r)-[:GrantsAccessTo]->(res2:Resource) "
+ "WHERE res2.type = 'File Server' "
+ "RETURN DISTINCT p.name AS PolicyName, p.description AS PolicyDescription"
+ )
+
+ assert 'p."name" AS PolicyName' in query
+ assert 'p."description" AS PolicyDescription' in query
+ assert "COLUMNS (VERTEX_ID(r) AS r_VALUE)" in query
+ assert "p.name AS PolicyName" not in query
+
+
+def test_cypher2oracle_sqlpgq_maps_label_derived_id_for_carried_with_match_property():
+ query, category = cypher2oracle_sqlpgq(
+ "MATCH (s:Source)<-[:SourcedFrom]-(a1:Assertion) "
+ "WHERE a1.confidence_score < 0.5 "
+ "WITH s "
+ "MATCH (s)<-[:SourcedFrom]-(a2:Assertion) "
+ "WHERE a2.confidence_score > 0.9 "
+ "RETURN DISTINCT s.source_id, s.title",
+ graph_name="G",
+ node_label_map={"Source": ["InfoSource"], "Assertion": ["Assertion"]},
+ edge_label_map={"SourcedFrom": ["Assertion_SourcedFrom_InfoSource"]},
+ property_type_map={
+ "InfoSource": {
+ "infosource_id": "VARCHAR2(4000)",
+ "title": "VARCHAR2(4000)",
+ },
+ "Assertion": {"confidence_score": "NUMBER"},
+ },
+ node_primary_key_map={"InfoSource": "infosource_id"},
+ strict_property_validation=True,
+ )
+
+ assert category == "Graph-IL Translatable"
+ assert 's."infosource_id" AS source_id' in query
+ assert 's."source_id"' not in query
+ assert "SELECT DISTINCT stage_1.source_id AS source_id, stage_1.title AS title" in query
+
+
+def test_cypher2oracle_sqlpgq_does_not_duplicate_redeclared_carried_vertex_return():
+ query = _translate_sql(
+ "MATCH (nodes_1:Policy) WHERE nodes_1.name = 'Compliance Policy 9' "
+ "WITH nodes_1 "
+ "MATCH (nodes_1)-[edges_1:GrantsAccessTo]->(nodes_2:Resource) "
+ "RETURN nodes_2, nodes_1, edges_1 LIMIT 10"
+ )
+
+ stage_1 = query.split("stage_2 AS", 1)[0]
+ assert "VERTEX_ID(nodes_1) AS stage_1_nodes_1_VALUE" in stage_1
+ assert "VERTEX_ID(nodes_1) AS nodes_1_VALUE" not in stage_1
+ assert (
+ "SELECT stage_2.nodes_2_VALUE AS nodes_2_VALUE, "
+ "stage_1.stage_1_nodes_1_VALUE AS nodes_1_VALUE, "
+ "stage_2.edges_1_VALUE AS edges_1_VALUE"
+ ) in query
+
+
+def test_cypher2oracle_sqlpgq_translates_aggregate_with_match_stage_alias_aggregate():
+ query = _translate_sql(
+ "MATCH (u:User)-[:AttemptsAccess]->(:AccessEvent)-[:Targets]->(res:Resource) "
+ "WITH u, COUNT(res) AS resource_count "
+ "WHERE resource_count > 3 "
+ "MATCH (u)-[:HasRole]->(r:Role) "
+ "RETURN r.role_type, AVG(resource_count) AS avg_resources_accessed"
+ )
+
+ assert "COUNT(res_VALUE) AS resource_count" in query
+ assert "WHERE resource_count > 3" in query
+ assert "AVG(stage_1.resource_count) AS avg_resources_accessed" in query
+ assert "JOIN stage_1 ON stage_2.u_VALUE = stage_1.u_VALUE" in query
+
+
+def test_cypher2oracle_sqlpgq_translates_aggregate_with_match_carried_property_projection():
+ query = _translate_sql(
+ "MATCH (u:User)-[:AttemptsAccess]->(ae:AccessEvent)-[:Targets]->(r:Resource) "
+ "WHERE r.sensitivity_level = 'Confidential' "
+ "WITH u, r, COUNT(ae) AS access_count "
+ "WHERE access_count > 1 "
+ "MATCH (u)-[:HasRole]->(ro:Role)-[:GrantsAccessTo]->(r) "
+ "RETURN DISTINCT u.name, r.name, access_count"
+ )
+
+ assert 'u."name" AS u_name' in query
+ assert 'r."name" AS r_name' in query
+ assert "SELECT DISTINCT stage_1.u_name AS u, stage_1.r_name AS r" in query
+ assert "stage_1.access_count AS access_count" in query
+
+
+def test_cypher2oracle_sqlpgq_translates_aggregate_with_match_distinct_second_stage_count():
+ query = _translate_sql(
+ "MATCH (u:User)-[:BelongsTo]->(g:Group) "
+ "WITH u, COUNT(g) AS group_count "
+ "WHERE group_count > 1 "
+ "MATCH (u)-[:AttemptsAccess]->(ae:AccessEvent)-[:Targets]->(r:Resource) "
+ "WHERE r.type = 'API Endpoint' "
+ "RETURN u.name, COUNT(DISTINCT r) AS api_endpoints_accessed"
+ )
+
+ assert "VERTEX_ID(r) AS r_VALUE" in query
+ assert "COUNT(DISTINCT stage_2.r_VALUE) AS api_endpoints_accessed" in query
+
+
+def test_cypher2oracle_sqlpgq_translates_aggregate_with_match_order_by_stage_alias():
+ query = _translate_sql(
+ "MATCH (u:User)-[:HasRole]->(r:Role)-[:GrantsAccessTo]->(res:Resource) "
+ "WITH u, r, COUNT(res) AS resource_count "
+ "WHERE resource_count > 5 "
+ "MATCH (u)-[:BelongsTo]->(g:Group) "
+ "WHERE g.member_count > 10 "
+ "RETURN u.name AS UserName, r.name AS RoleName, resource_count "
+ "ORDER BY resource_count DESC"
+ )
+
+ assert "stage_1.resource_count AS resource_count" in query
+ assert "ORDER BY resource_count DESC" in query
+ stage_2 = query.split("stage_2 AS", 1)[1].split(")\nSELECT", 1)[0]
+ assert "resource_count AS resource_count" not in stage_2
+
+
+def test_cypher2oracle_sqlpgq_translates_ordered_limited_aggregate_with_match():
+ query = _translate_sql(
+ "MATCH (q:Question)<-[:COMMENTED_ON]-(c:Comment) "
+ "WITH q, COUNT(c) AS commentCount ORDER BY commentCount DESC LIMIT 5 "
+ "MATCH (q)<-[:ANSWERED]-(a:Answer)<-[:PROVIDED]-(u:User) "
+ "RETURN u.display_name AS user, COUNT(a) AS answerCount "
+ "ORDER BY answerCount DESC LIMIT 5"
+ )
+
+ assert "COUNT(c_VALUE) AS commentCount" in query
+ assert "ORDER BY commentCount DESC" in query
+ assert "FETCH FIRST 5 ROWS ONLY" in query
+ assert "JOIN stage_1 ON stage_2.q_VALUE = stage_1.q_VALUE" in query
+ assert "COUNT(stage_2.a_VALUE) AS answerCount" in query
+ assert "ORDER BY answerCount DESC" in query
+
+
+def test_cypher2oracle_sqlpgq_keeps_final_aggregate_when_alias_matches_stage():
+ query = _translate_sql(
+ "MATCH (q:Question)<-[:COMMENTED_ON]-(c:Comment) "
+ "WITH q, COUNT(c) AS comment_count ORDER BY comment_count DESC LIMIT 3 "
+ "MATCH (q)<-[:COMMENTED_ON]-(c)<-[:COMMENTED]-(u:User) "
+ "RETURN q.title AS question_title, u.display_name AS commenter_name, "
+ "COUNT(c) AS comment_count ORDER BY comment_count DESC"
+ )
+
+ assert "VERTEX_ID(c) AS c_VALUE" in query
+ assert "COUNT(stage_2.c_VALUE) AS comment_count" in query
+ assert "stage_1.comment_count AS comment_count" not in query
+
+
+def test_cypher2oracle_sqlpgq_translates_ordered_limited_with_match():
+ query = _translate_sql(
+ "MATCH (g:Group) "
+ "WITH g ORDER BY g.member_count DESC LIMIT 3 "
+ "MATCH (u:User)-[:BelongsTo]->(g) "
+ "RETURN g.name AS group_name, u.name AS user_name, "
+ "u.last_login AS last_login_date"
+ )
+
+ assert 'g."member_count" AS member_count' in query
+ assert "ORDER BY member_count DESC" in query
+ assert "FETCH FIRST 3 ROWS ONLY" in query
+ assert "JOIN stage_1 ON stage_2.g_VALUE = stage_1.stage_1_g_VALUE" in query
+ assert (
+ "SELECT stage_1.group_name AS group_name, "
+ "stage_2.user_name AS user_name, "
+ "stage_2.last_login_date AS last_login_date" in query
+ )
+
+
+def test_cypher2oracle_sqlpgq_projects_scalar_alias_in_ordered_with_match_stage():
+ query = _translate_sql(
+ "MATCH (q:Question) "
+ "WITH q, q.view_count AS viewCount ORDER BY viewCount DESC LIMIT 3 "
+ "MATCH (u:User)-[:ASKED]->(q) "
+ "RETURN u.display_name AS userDisplayName"
+ )
+
+ stage_1 = query.split("stage_2 AS", 1)[0]
+ assert 'q."view_count" AS viewCount' in stage_1
+ assert "viewCount AS viewCount" not in stage_1
+ assert "ORDER BY viewCount DESC" in stage_1
+
+
+def test_cypher2oracle_sqlpgq_translates_with_match_property_correlation():
+ query = _translate_sql(
+ "MATCH (u:User)-[:HasRole]->(r:Role)-[:GrantsAccessTo]->(res:Resource) "
+ "WITH u, COUNT(res) AS resource_count "
+ "WHERE resource_count > 5 "
+ "MATCH (g:Group) "
+ "WHERE g.created_by = u.User_id "
+ "RETURN AVG(g.member_count) AS AverageMembersInGroups"
+ )
+
+ assert 'u."User_id" AS u_User_id' in query
+ assert 'g."created_by" AS g_created_by' in query
+ assert "JOIN stage_1 ON stage_2.g_created_by = stage_1.u_User_id" in query
+ assert 'WHERE g."created_by" = u."User_id"' not in query
+
+
+def test_cypher2oracle_sqlpgq_translates_with_match_correlation_and_stage_alias_return():
+ query = _translate_sql(
+ "MATCH (u:User)-[:AttemptsAccess]->(ae:AccessEvent)-[:Targets]->(r:Resource) "
+ "WITH u, COUNT(DISTINCT r) AS resource_count "
+ "WHERE resource_count > 3 "
+ "MATCH (al:AuditLog) "
+ "WHERE al.performed_by = u.User_id "
+ "RETURN al.timestamp, resource_count"
+ )
+
+ assert 'al."performed_by" AS al_performed_by' in query
+ assert "stage_1.resource_count AS resource_count" in query
+ assert "JOIN stage_1 ON stage_2.al_performed_by = stage_1.u_User_id" in query
+
+
+def test_cypher2oracle_sqlpgq_translates_with_match_element_correlation():
+ query = _translate_sql(
+ "MATCH (p1:Product {productName: 'Aniseed Syrup'})"
+ "<-[:ORDERS]-(:Order)<-[:PURCHASED]-(c1:Customer) "
+ "WITH c1 "
+ "MATCH (p2:Product {productName: 'Ipoh Coffee'})"
+ "<-[:ORDERS]-(:Order)<-[:PURCHASED]-(c2:Customer) "
+ "WHERE c1 = c2 "
+ "RETURN c1.companyName"
+ )
+
+ assert "VERTEX_ID(c1) AS stage_1_c1_VALUE" in query
+ assert 'c1."companyName" AS companyName' in query
+ assert "VERTEX_ID(c2) AS c2_VALUE" in query
+ assert "JOIN stage_1 ON stage_2.c2_VALUE = stage_1.stage_1_c1_VALUE" in query
+
+
+def test_cypher2oracle_sqlpgq_translates_scalar_with_match_comparison():
+ query = _translate_sql(
+ "MATCH (p:Policy {name: 'Compliance Policy 9'}) "
+ "WITH p.effective_date AS target_date "
+ "MATCH (other:Policy) "
+ "WHERE other.effective_date > target_date "
+ "RETURN other"
+ )
+
+ assert 'p."effective_date" AS target_date' in query
+ assert 'other."effective_date" AS other_effective_date' in query
+ assert "JOIN stage_1 ON stage_2.other_effective_date > stage_1.target_date" in query
+ assert "SELECT stage_2.other_VALUE AS other_VALUE" in query
+
+
+def test_cypher2oracle_sqlpgq_translates_scalar_with_match_comparison_with_path():
+ query = _translate_sql(
+ "MATCH (r:Resource {name: 'meeting Document'}) "
+ "WITH r.created_date AS target_date "
+ "MATCH (role:Role)-[:GrantsAccessTo]->(res:Resource) "
+ "WHERE res.created_date > target_date "
+ "RETURN DISTINCT role.name"
+ )
+
+ assert 'r."created_date" AS target_date' in query
+ assert 'res."created_date" AS res_created_date' in query
+ assert "JOIN stage_1 ON stage_2.res_created_date > stage_1.target_date" in query
+ assert "SELECT DISTINCT stage_2.name AS name" in query
+
+
+def test_cypher2oracle_sqlpgq_translates_cast_scalar_with_match_comparison():
+ query, category = cypher2oracle_sqlpgq(
+ "MATCH (o:Order)-[oi:ORDERS]->(p:Product) "
+ "WITH MAX(toFloat(oi.unitPrice)) AS maxUnitPrice "
+ "MATCH (o:Order)-[oi:ORDERS]->(p:Product) "
+ "WHERE toFloat(oi.unitPrice) = maxUnitPrice "
+ "MATCH (c:Customer)-[:PURCHASED]->(o) "
+ "RETURN c.companyName",
+ edge_label_map={
+ "ORDERS": ["Order_ORDERS_Product"],
+ "PURCHASED": ["Customer_PURCHASED_Order"],
+ },
+ property_type_map={
+ "Order": {},
+ "Product": {},
+ "Customer": {"companyName": "VARCHAR2(4000)"},
+ "ORDERS": {"unitPrice": "VARCHAR2(4000)"},
+ "Order_ORDERS_Product": {"unitPrice": "VARCHAR2(4000)"},
+ "PURCHASED": {},
+ "Customer_PURCHASED_Order": {},
+ },
+ strict_property_validation=True,
+ )
+
+ assert category == "Graph-IL Translatable"
+ assert 'oi."unitPrice" AS oi_unitPrice' in query
+ assert 'TO_NUMBER(oi."unitPrice") AS oi_unitPrice_toFloat' in query
+ assert "JOIN stage_1 ON stage_2.oi_unitPrice_toFloat = stage_1.maxUnitPrice" in query
+
+
+def test_cypher2oracle_sqlpgq_translates_size_scalar_with_match_comparison():
+ query = _translate_sql(
+ "MATCH (a:Answer) "
+ "WITH avg(size(a.body_markdown)) AS average_length "
+ "MATCH (a:Answer) "
+ "WHERE size(a.body_markdown) > average_length "
+ "RETURN count(a) AS answer_count"
+ )
+
+ assert 'a."body_markdown" AS a_body_markdown' in query
+ assert 'LENGTH(a."body_markdown") AS a_body_markdown_size' in query
+ assert "JOIN stage_1 ON stage_2.a_body_markdown_size > stage_1.average_length" in query
+ assert "SELECT COUNT(stage_2.a_VALUE) AS answer_count" in query
+
+
+def test_cypher2oracle_sqlpgq_translates_scalar_with_match_property_map_correlation():
+ query = _translate_sql(
+ 'MATCH (m:Movie {title: "Open Season"}) '
+ "WITH m.year AS releaseYear "
+ "MATCH (movies:Movie {year: releaseYear}) "
+ "RETURN avg(movies.imdbRating) AS averageRating"
+ )
+
+ assert 'm."year" AS releaseYear' in query
+ assert 'movies."year" AS movies_year' in query
+ assert "JOIN stage_1 ON stage_2.movies_year = stage_1.releaseYear" in query
+ assert "year = releaseYear" not in query
+
+
+def test_cypher2oracle_sqlpgq_translates_with_match_scalar_expression_correlation():
+ query = _translate_sql(
+ "MATCH (t:TRANSACTION)-[:TRIGGERED_ALERT]->(a:ALERT) "
+ "WITH avg(t.amount) AS avg_amount "
+ "MATCH (t2:TRANSACTION)-[:TRIGGERED_ALERT]->(a2:ALERT) "
+ "WHERE t2.amount > 2 * avg_amount "
+ "RETURN a2.alert_id, a2.severity_level, a2.fraud_probability_score"
+ )
+
+ assert "AVG(amount) AS avg_amount" in query
+ assert 't2."amount" AS t2_amount' in query
+ assert "JOIN stage_1 ON stage_2.t2_amount > 2 * stage_1.avg_amount" in query
+ assert "WHERE t2.amount > 2 * avg_amount" not in query
+
+
+def test_cypher2oracle_sqlpgq_translates_with_match_aggregate_expression_alias():
+ query = _translate_sql(
+ "MATCH (t:TRANSACTION) "
+ "WITH avg(t.amount) * 2 AS threshold "
+ "MATCH (i:INVESTIGATION)-[:INVESTIGATES]->(al:ALERT)"
+ "<-[:TRIGGERED_ALERT]-(t:TRANSACTION) "
+ "WHERE t.amount > threshold "
+ "RETURN i.investigation_id, i.start_time, i.findings_summary"
+ )
+
+ assert "avg(t_amount) * 2 AS threshold" in query
+ assert 't."amount" AS t_amount' in query
+ assert "JOIN stage_1 ON stage_2.t_amount > stage_1.threshold" in query
+ assert "WHERE t.amount > threshold" not in query
+
+
+def test_cypher2oracle_sqlpgq_filters_with_match_on_stage_scalar_alias():
+ query = _translate_sql(
+ "MATCH (a:ASSET)-[:LOCATED_AT]->(l:LOCATION) "
+ "WITH l, SUM(l.capacity) AS total_capacity "
+ "MATCH (a:ASSET)-[:LOCATED_AT]->(l:LOCATION) "
+ "WHERE total_capacity > 1000 "
+ "RETURN a.asset_id, a.model, l.name"
+ )
+
+ assert "WHERE stage_1.total_capacity > 1000" in query
+ assert "WHERE total_capacity > 1000" not in query
+
+
+def test_cypher2oracle_sqlpgq_correlates_anonymous_property_map_scalar_alias():
+ query = _translate_sql(
+ "MATCH (u:User)-[:WROTE]->(r:Review)-[:REVIEWS]->(b:Business) "
+ "WITH u.name AS userName, b.city AS businessCity "
+ "MATCH (user:User {name: userName})-[:WROTE]->(:Review)"
+ "-[:REVIEWS]->(:Business {city: businessCity}) "
+ "RETURN DISTINCT userName"
+ )
+
+ assert 'with_corr_n1."city" AS with_corr_n1_city' in query
+ assert "stage_2.with_corr_n1_city = stage_1.businessCity" in query
+ assert 'city" = businessCity' not in query
+
+
+def test_cypher2oracle_sqlpgq_orders_with_aggregate_alias_expression():
+ query = _translate_sql(
+ "MATCH (m:Movie)-[:IN_COLLECTION]->(c:Collection) "
+ "WITH c, max(m.release_date) AS maxDate, min(m.release_date) AS minDate "
+ "ORDER BY maxDate - minDate DESC LIMIT 5 "
+ "RETURN c.name AS CollectionName, minDate, maxDate, maxDate - minDate AS DateRange"
+ )
+
+ assert "ORDER BY maxDate - minDate DESC" in query
+ assert "maxDate - minDate AS maxDate_minDate" not in query
+ assert "GROUP BY c_VALUE, c_name, maxDate_minDate" not in query
+
+
+def test_cypher2oracle_sqlpgq_groups_by_hidden_aggregate_sort_property():
+ query = _translate_sql(
+ "MATCH (fp:FINANCIAL_PERIOD)<-[:BelongsTo]-(t:TRANSACTION {status: 'Completed'}) "
+ "WHERE fp.start_date > date('2020-01-01') "
+ "RETURN fp.period_id, sum(t.amount) AS total_amount "
+ "ORDER BY fp.start_date"
+ )
+
+ assert 'fp."start_date" AS start_date' in query
+ assert "GROUP BY period_id, start_date" in query
+ assert "ORDER BY start_date" in query
+
+
+def test_cypher2oracle_sqlpgq_counts_carried_with_vertex_after_match():
+ query = _translate_sql(
+ "MATCH (u:USER)-[:Initiates]->(t:TRANSACTION) "
+ "WITH u, t "
+ "MATCH (u)-[:Approves]->(:REPORT) "
+ "RETURN COUNT(t) AS total_transactions"
+ )
+
+ assert "VERTEX_ID(t) AS stage_1_t_VALUE" in query
+ assert "COUNT(stage_1.stage_1_t_VALUE) AS total_transactions" in query
+ assert "COUNT(t)" not in query
+
+
+def test_cypher2oracle_sqlpgq_aggregates_carried_with_vertex_property_after_match():
+ query = _translate_sql(
+ "MATCH (c:CUSTOMER)-[:BELONGS_TO]->(a:ACCOUNT)-[:INITIATED_BY]->(t:TRANSACTION) "
+ "WITH c, a, t "
+ "MATCH (a)-[:USED_DEVICE]->(d:DEVICE) "
+ "WHERE d.risk_score > 0.8 "
+ "RETURN AVG(t.amount) AS average_transaction_amount"
+ )
+
+ assert 't."amount" AS t_amount' in query
+ assert "AVG(stage_1.t_amount) AS average_transaction_amount" in query
+ assert "AVG(amount)" not in query
+
+
+def test_cypher2oracle_sqlpgq_aggregates_carried_optional_vertex_property():
+ query = _translate_sql(
+ "MATCH (a:ACCOUNT)-[:USED_DEVICE]->(d:DEVICE) "
+ "WITH a, d "
+ "OPTIONAL MATCH (t:TRANSACTION)-[:INITIATED_BY]->(a) "
+ "RETURN SUM(t.amount) AS total_transaction_amount, "
+ "AVG(d.risk_score) AS average_device_risk_score"
+ )
+
+ assert 'd."risk_score" AS d_risk_score' in query
+ assert "AVG(stage_1.d_risk_score) AS average_device_risk_score" in query
+ assert "AVG(risk_score)" not in query
+
+
+def test_cypher2oracle_sqlpgq_cross_joins_uncorrelated_second_match_after_with():
+ query = _translate_sql(
+ "MATCH (ou:ORGANIZATION_UNIT)<-[:AssignedTo]-(u:USER)-[:Approves]->(b:BUDGET) "
+ "WHERE b.variance_threshold > 0.1 "
+ "WITH ou, AVG(b.amount) AS avg_budget_amount "
+ "MATCH (u)-[:Initiates]->(t:TRANSACTION)-[:GovernedBy]->(cr:COMPLIANCE_RULE) "
+ "WHERE cr.regulation_standard = 'SOX' "
+ "RETURN ou.name, avg_budget_amount"
+ )
+
+ assert "CROSS JOIN stage_1" in query
+ assert "1 AS dummy_value" in query
+ assert "stage_1.ou_name AS name" in query
+ assert "stage_1.avg_budget_amount AS avg_budget_amount" in query
+
+
+def test_cypher2oracle_sqlpgq_groups_by_resolved_alias_for_duplicate_property_names():
+ query = _translate_sql(
+ "MATCH (follower:USER)-[:FOLLOWS]->(followee:USER)"
+ "-[:POSTS]->(t:TWEET {is_sensitive: true}) "
+ "WITH follower, followee, t "
+ "MATCH (follower)-[e:ENGAGES_WITH]->(t) "
+ "RETURN follower.display_name, followee.display_name, COUNT(e) AS engagement_count"
+ )
+
+ assert 'follower."display_name" AS follower' in query
+ assert 'followee."display_name" AS followee' in query
+ assert "SELECT stage_1.follower AS follower, stage_1.followee AS followee" in query
+ assert "GROUP BY stage_1.follower, stage_1.followee" in query
+ assert "GROUP BY display_name" not in query
+ assert 'follower."stage_1.display_name"' not in query
+ assert "stage_1.followee.display_name" not in query
+
+
+def test_cypher2oracle_sqlpgq_selects_second_stage_properties_after_with_match():
+ query = _translate_sql(
+ "MATCH (u:USER)-[:FOLLOWS]->(followed:USER) "
+ "WITH u, COUNT(followed) AS follows_count "
+ "WHERE follows_count >= 5 "
+ "MATCH (u)-[:POSTS]->(t:TWEET)-[:ATTACHES_MEDIA]->(:MEDIA_ATTACHMENT) "
+ "RETURN t.content, t.view_count ORDER BY t.view_count DESC LIMIT 20"
+ )
+
+ assert "SELECT stage_2.content AS content, stage_2.view_count AS view_count" in query
+ assert "ORDER BY stage_2.t_view_count DESC" in query
+ assert 't."content"' not in query.split("SELECT stage_2.content AS content", 1)[1]
+ assert 't."view_count"' not in query.split("SELECT stage_2.content AS content", 1)[1]
+
+
+def test_cypher2oracle_sqlpgq_groups_second_stage_property_after_distinct_with():
+ query = _translate_sql(
+ "MATCH (u:USER {verified_status: true})-[:MEMBER_OF_LIST]->(l:LIST {is_private: true}) "
+ "WITH DISTINCT u "
+ "MATCH (u)-[:POSTS]->(t:TWEET) "
+ "RETURN t.language, AVG(t.view_count) AS avg_view_count"
+ )
+
+ assert (
+ "SELECT stage_2.language AS language, AVG(stage_2.t_view_count) AS avg_view_count"
+ in query
+ )
+ assert "GROUP BY stage_2.language" in query
+ assert 't."language"' not in query.split("SELECT stage_2.language AS language", 1)[1]
+
+
+def test_cypher2oracle_sqlpgq_computes_with_scalar_expression_in_outer_select():
+ query = _translate_sql(
+ "MATCH (u:USER)-[:POSTS]->(t:TWEET) "
+ "WITH u, SUM(t.like_count + t.retweet_count + t.reply_count) AS engagement_sum "
+ "MATCH (u) "
+ "RETURN u.display_name, u.followers_count + engagement_sum AS influence_score "
+ "ORDER BY influence_score DESC LIMIT 10"
+ )
+
+ assert 'u."followers_count" + engagement_sum' not in query
+ assert "stage_1.u_followers_count + stage_1.engagement_sum AS influence_score" in query
+ assert "ORDER BY influence_score DESC" in query
+
+
+def test_cypher2oracle_sqlpgq_rejects_unprojected_aggregate_in_with_where():
+ query, category = cypher2oracle_sqlpgq(
+ "MATCH (u:USER)-[:POSTS]->(t:TWEET) "
+ "WITH u WHERE u.followers_count > 1000 AND AVG(t.view_count) > 100 "
+ "MATCH (u)-[:POSTS]->(t:TWEET) "
+ "RETURN AVG(t.view_count) AS avg_view_count"
+ )
+
+ assert query == "Unable to Translate to Oracle SQL/PGQ"
+ assert category == "Graph-IL Not Support"
+
+
+def test_cypher2oracle_sqlpgq_rejects_aggregate_in_match_where():
+ query, category = cypher2oracle_sqlpgq(
+ "MATCH (t:TRANSACTION)-[:INITIATED_BY]->(a:ACCOUNT) "
+ "WHERE a.account_id = 'ACC000000' AND t.amount > avg(t.amount) "
+ "RETURN t"
+ )
+
+ assert query == "Unable to Translate to Oracle SQL/PGQ"
+ assert category == "Graph-IL Not Support"
+
+
+def test_cypher2oracle_sqlpgq_rejects_duration_between_aggregate():
+ query, category = cypher2oracle_sqlpgq(
+ "MATCH (a:ACCOUNT)-[ud:USED_DEVICE]->(d:DEVICE) "
+ "RETURN AVG(duration.between(ud.session_start, ud.session_end)) "
+ "AS avg_session_duration"
+ )
+
+ assert query == "Unable to Translate to Oracle SQL/PGQ"
+ assert category == "Graph-IL Not Support"
+
+
+def test_cypher2oracle_sqlpgq_rejects_temporal_arithmetic_aggregate_without_cast():
+ query, category = cypher2oracle_sqlpgq(
+ "MATCH (a:ACCOUNT)-[u:USED_DEVICE]->(d:DEVICE) "
+ "RETURN AVG(u.session_end - u.session_start) AS avg_session_duration",
+ property_type_map={
+ "USED_DEVICE": {
+ "session_start": "TIMESTAMP",
+ "session_end": "TIMESTAMP",
+ }
+ },
+ )
+
+ assert query == "Unable to Translate to Oracle SQL/PGQ"
+ assert category == "Graph-IL Not Support"
+
+
+def test_cypher2oracle_sqlpgq_orders_with_match_by_staged_property_alias():
+ query = _translate_sql(
+ "MATCH (p:Product) WITH avg(p.unitPrice) AS avgPrice "
+ "MATCH (p2:Product) WHERE p2.unitPrice < avgPrice "
+ "RETURN p2 ORDER BY p2.unitPrice LIMIT 5"
+ )
+
+ assert 'p2."unitPrice" AS p2_unitPrice' in query
+ assert "ORDER BY stage_2.p2_unitPrice" in query
+ assert "ORDER BY unitPrice" not in query
+
+
+def test_cypher2oracle_sqlpgq_translates_with_match_cross_stage_vertex_inequality():
+ query = _translate_sql(
+ 'MATCH (toyStory:Movie {title: "Toy Story"})-[:IN_GENRE]->(genre:Genre) '
+ "WITH toyStory, genre "
+ "MATCH (genre)<-[:IN_GENRE]-(otherMovie:Movie) "
+ "WHERE toyStory <> otherMovie "
+ "RETURN DISTINCT otherMovie.title"
+ )
+
+ assert "VERTEX_ID(otherMovie) AS otherMovie_VALUE" in query
+ assert "stage_2.otherMovie_VALUE <> stage_1.stage_1_toyStory_VALUE" in query
+ assert "WHERE toyStory <> otherMovie" not in query
+
+
+def test_cypher2oracle_sqlpgq_translates_scalar_with_split_count_projection():
+ query = _translate_sql(
+ "MATCH (p:Policy {review_status: 'Approved'})"
+ "-[:Enforces]->(r:Role {is_compliant: true}) "
+ "WITH r, SIZE(SPLIT(r.permissions_list, ',')) AS permission_count "
+ "RETURN r.name AS role_name, r.description AS role_description, permission_count"
+ )
+
+ assert 'CASE WHEN r."permissions_list" IS NULL' in query
+ assert "REGEXP_COUNT(r.\"permissions_list\", ',') + 1 END AS permission_count" in query
+ assert "SELECT r_name AS role_name, r_description AS role_description" in query
+
+
+def test_cypher2oracle_sqlpgq_translates_scalar_function_with_final_aggregate():
+ query = _translate_sql(
+ "MATCH (m:Movie)-[:IN_GENRE]->(g:Genre) "
+ "WITH g.name AS genre, size(m.languages) AS languageCount "
+ "RETURN genre, sum(languageCount) AS totalLanguages "
+ "ORDER BY totalLanguages DESC LIMIT 5"
+ )
+
+ assert 'LENGTH(m."languages") AS languageCount' in query
+ assert "SELECT genre AS genre, COALESCE(SUM(languageCount), 0) AS totalLanguages" in query
+ assert "GROUP BY genre" in query
+
+
+def test_cypher2oracle_sqlpgq_translates_string_size_to_length():
+ query = _translate_sql("MATCH (q:Question) RETURN size(q.text) AS textLength")
+
+ assert 'LENGTH(q."text") AS textLength' in query
+
+
+def test_cypher2oracle_sqlpgq_translates_id_order_comparison():
+ query = _translate_sql(
+ "MATCH (m1:Movie)<-[:ACTED_IN]-(a:Actor)-[:ACTED_IN]->(m2:Movie) "
+ "WHERE id(m1) < id(m2) "
+ "RETURN m1.title, m2.title"
+ )
+
+ assert "WHERE VERTEX_ID(m1) < VERTEX_ID(m2)" in query
+ assert "WHERE id(" not in query.lower()
+
+
+def test_cypher2oracle_sqlpgq_rewrites_with_filter_element_equality_to_projected_ids():
+ query = _translate_sql(
+ "MATCH (neo4j:User {screen_name: 'neo4j'}) "
+ "MATCH (neo4j)-[:FOLLOWS]->(followed:User) "
+ "MATCH (mentioned:User)-[:POSTS]->(:Tweet)-[:MENTIONS]->(neo4j) "
+ "WITH DISTINCT followed, mentioned "
+ "WHERE followed = mentioned "
+ "RETURN followed.screen_name AS users"
+ )
+
+ assert "WHERE followed_VALUE = mentioned_VALUE" in query
+ assert "VERTEX_EQUAL(followed, mentioned)" not in query
+
+
+def test_cypher2oracle_sqlpgq_uses_second_stage_property_alias_when_final_alias_differs():
+ query = _translate_sql(
+ "MATCH (aa1:AdministrativeArea)-[:Borders]->(aa2:AdministrativeArea) "
+ "WITH aa1, COUNT(aa2) AS border_count "
+ "WHERE border_count >= 3 "
+ "MATCH (aa1)-[:ContainsPoi]->(poi:PointOfInterest) "
+ "RETURN poi.name, poi.category, aa1.name"
+ )
+
+ assert 'poi."name" AS name' in query
+ assert "SELECT stage_2.name AS poi" in query
+ assert "stage_2.poi" not in query
+
+
+def test_cypher2oracle_sqlpgq_uses_second_stage_aliases_for_renamed_properties():
+ query = _translate_sql(
+ "MATCH (u:USER)-[:Approves]->(r1:REPORT)-[:ConsolidatesInto]->(r2:REPORT) "
+ "WITH u, r1, r2 "
+ "MATCH (u2:USER)-[:Approves]->(r2) "
+ "WHERE u.department <> u2.department "
+ "RETURN u.user_id, u.department, u2.user_id AS consolidated_approver_id, "
+ "u2.department AS consolidated_department"
+ )
+
+ assert 'u2."user_id" AS u2_user_id' in query
+ assert 'u2."department" AS u2_department' in query
+ assert "AS consolidated_approver_id" in query
+ assert "AS consolidated_department" in query
+ assert "stage_2.u2_department <> stage_1.u_department" in query
+ assert "stage_2.u2_department = stage_1.u_department" not in query
+
+
+def test_cypher2oracle_sqlpgq_uses_distinct_second_stage_aliases_for_with_match_aggregates():
+ query = _translate_sql(
+ "MATCH (ou:ORGANIZATION_UNIT)<-[:AssignedTo]-(u:USER)-[:Initiates]->"
+ "(t:TRANSACTION)-[:BelongsTo]->(fp:FINANCIAL_PERIOD) "
+ "WITH ou, fp, MAX(fp.start_date) AS latest_start_date "
+ "MATCH (ou)<-[:AssignedTo]-(u:USER)-[:Initiates]->(t:TRANSACTION)-[:BelongsTo]->"
+ "(fp:FINANCIAL_PERIOD {start_date: latest_start_date})<-[:BelongsTo]-"
+ "(acc:ACCOUNT)<-[:AllocatedTo]-(b:BUDGET) "
+ "RETURN ou.name, SUM(b.amount) AS total_budget, "
+ "SUM(t.amount) AS total_transactions"
+ )
+
+ assert 'b."amount" AS b_amount' in query
+ assert 't."amount" AS t_amount' in query
+ assert "COALESCE(SUM(stage_2.b_amount), 0) AS total_budget" in query
+ assert "COALESCE(SUM(stage_2.t_amount), 0) AS total_transactions" in query
+ assert "SUM(amount)" not in query
+
+
+def test_cypher2oracle_sqlpgq_translates_left_tostring_and_safe_numeric_division():
+ left_query = _translate_sql(
+ "MATCH (p1:Person)-[:FOLLOWS]->(p2:Person) "
+ "WHERE left(p1.name, 1) = left(p2.name, 1) "
+ "RETURN p1.name LIMIT 3"
+ )
+ assert 'SUBSTR(p1."name", 1, 1) = SUBSTR(p2."name", 1, 1)' in left_query
+ assert "left(" not in left_query.lower()
+
+ tostring_query = _translate_sql(
+ "MATCH (dd:DataDomain)<-[:BelongsTo]-(da:DataAsset) "
+ "WITH dd, AVG(TOFLOAT(REPLACE(da.schema_version, 'v', ''))) AS avg_schema_version "
+ "WHERE avg_schema_version > 2.0 "
+ "RETURN dd.name AS DomainName, 'v' + TOSTRING(avg_schema_version) AS AverageSchemaVersion"
+ )
+ assert "'v' || TO_CHAR(avg_schema_version) AS AverageSchemaVersion" in tostring_query
+ assert "TOSTRING" not in tostring_query
+
+ division_query = _translate_sql(
+ "MATCH (u:User) "
+ "WITH u, u.following / toFloat(u.followers) AS ratio "
+ "WHERE ratio > 2 "
+ "RETURN u.screen_name, ratio"
+ )
+ assert 'u."following" / NULLIF(TO_NUMBER(u."followers"), 0) AS ratio' in division_query
+
+
+def test_cypher2oracle_sqlpgq_translates_map_comparison_property_literals():
+ query = _translate_sql("MATCH ()-[:INTERACTS45 {weight: {lt: 10}}]->() RETURN count(*)")
+
+ assert 'e1."weight" < 10' in query
+ assert "{lt:" not in query
+
+
+def test_cypher2oracle_sqlpgq_translates_apoc_toset_collect_size_as_count_distinct():
+ query = _translate_sql(
+ "MATCH (p:Person)-[r:ACTED_IN]->(m:Movie) "
+ "WITH p, size(apoc.coll.toSet(collect(r.roles))) AS roleDiversity "
+ "RETURN p.name AS actor, roleDiversity ORDER BY roleDiversity DESC LIMIT 3"
+ )
+
+ assert "COUNT(DISTINCT r_roles) AS roleDiversity" in query
+ assert "apoc.coll.toSet" not in query
+
+
+def test_cypher2oracle_sqlpgq_auto_edge_names_avoid_future_node_names():
+ query = _translate_sql(
+ "MATCH (e1:Employee)-[:REQUESTS]->(a:Asset)<-[:REQUESTS]-(e2:Employee) "
+ "WHERE e1.name = 'Donald Schultz' AND e2.name <> 'Donald Schultz' "
+ "RETURN DISTINCT e2.name"
+ )
+
+ assert (
+ '(e1 IS "Employee")-[e3 IS "REQUESTS"]->'
+ '(a IS "Asset")<-[e4 IS "REQUESTS"]-(e2 IS "Employee")' in query
+ )
+
+
+def test_cypher2oracle_sqlpgq_qualifies_carried_with_match_return_variables():
+ query = _translate_sql(
+ "MATCH (c:Customer)-[:Initiates]->(pt:PaymentTransaction)-[:ProcessedFor]->(m:Merchant) "
+ "WHERE m.category_code = 'Retail' "
+ "WITH c, pt "
+ "MATCH (pt)-[:HasRiskAssessment]->(ra:RiskAssessment) "
+ "WHERE ra.score > 70 "
+ "RETURN DISTINCT c LIMIT 10"
+ )
+
+ assert "SELECT DISTINCT stage_1.c_VALUE AS c" in query
+ assert "\nSELECT DISTINCT c\n" not in query
+
+
+def test_cypher2oracle_sqlpgq_orders_with_stage_by_expression_aliases():
+ query = _translate_sql(
+ "MATCH (p:Person)-[:CAST_FOR]->(m:Movie) "
+ "MATCH (p)-[:CAST_FOR]->(v:Video) "
+ "WITH p, COUNT(DISTINCT m) AS movie_count, COUNT(DISTINCT v) AS video_count "
+ "WHERE movie_count > 0 AND video_count > 0 "
+ "RETURN p.name AS actor, movie_count, video_count "
+ "ORDER BY movie_count + video_count DESC LIMIT 3"
+ )
+
+ assert "ORDER BY movie_count + video_count DESC" in query
+ assert "movie_count_video_count" not in query
+
+
+def test_cypher2oracle_sqlpgq_groups_direct_with_projection_final_aggregate():
+ query = _translate_sql(
+ "MATCH (u:User)-[:WROTE]->(r:Review)-[:REVIEWS]->(b:Business) "
+ "WITH b.city AS city, r.stars AS stars "
+ "RETURN city, avg(stars) AS averageRating"
+ )
+
+ assert "SELECT city AS city, AVG(stars) AS averageRating" in query
+ assert "GROUP BY city" in query
+
+
+def test_cypher2oracle_sqlpgq_translates_complex_optional_aggregate_expression():
+ query = _translate_sql(
+ "MATCH (t:Tweet) "
+ "OPTIONAL MATCH (t)-[:RETWEETS]->(r:Tweet) "
+ "RETURN t, t.favorites + count(r) AS score ORDER BY score DESC LIMIT 3"
+ )
+
+ assert "stage_1.t_favorites + count(stage_2.r_VALUE) AS score" in query
+ assert 't."favorites" + count(r)' not in query
+ assert "GROUP BY stage_1.stage_1_t_VALUE, stage_1.t_favorites" in query
+
+
+def test_cypher2oracle_sqlpgq_translates_complex_with_match_count_expression():
+ query = _translate_sql(
+ "MATCH (pc:ProductionCompany)<-[:PRODUCED_BY]-(m1:Movie)"
+ "-[:DIRECTED_BY]->(d:Director {name: 'Director 15'}) "
+ "WITH pc, m1 "
+ "MATCH (pc)<-[:PRODUCED_BY]-(m2:Movie)-[:BELONGS_TO]->(g:Genre {name: 'Horror'}) "
+ "RETURN pc.name AS ProductionCompany, COUNT(DISTINCT m1) + "
+ "COUNT(DISTINCT m2) AS TotalDistinctMovies"
+ )
+
+ assert "VERTEX_ID(m2) AS m2_VALUE" in query
+ assert (
+ "COUNT(DISTINCT stage_1.stage_1_m1_VALUE) + "
+ "COUNT(DISTINCT stage_2.m2_VALUE) AS TotalDistinctMovies" in query
+ )
+
+
+def test_cypher2oracle_sqlpgq_projects_properties_from_aliased_with_vertex():
+ query = _translate_sql(
+ "MATCH (c1:Character) WHERE c1.community = 735 "
+ "MATCH (c1)--(c2:Character) "
+ "WITH DISTINCT c2 AS character, c2.book1PageRank AS pageRank "
+ "ORDER BY pageRank DESC LIMIT 10 "
+ "RETURN character.name AS characterName, pageRank"
+ )
+
+ assert 'c2."name" AS character_name' in query
+ assert "SELECT character_name AS characterName, pageRank AS pageRank" in query
+ assert "character.name" not in query
+
+
+def test_cypher2oracle_sqlpgq_translates_passthrough_with_final_aggregate():
+ query = _translate_sql(
+ "MATCH (g:Group)<-[:BelongsTo]-(u:User {department: 'Finance'})"
+ "-[:AttemptsAccess]->(:AccessEvent)-[:Targets]->(:Resource {type: 'Database'}) "
+ "WITH g, u "
+ "RETURN g.name AS GroupName, COUNT(DISTINCT u) AS DistinctUsersCount"
+ )
+
+ assert "VERTEX_ID(g) AS g_VALUE" in query
+ assert "VERTEX_ID(u) AS u_VALUE" in query
+ assert "SELECT g_name AS GroupName, COUNT(DISTINCT u_VALUE) AS DistinctUsersCount" in query
+ assert "GROUP BY g_name" in query
+
+
+def test_cypher2oracle_sqlpgq_translates_two_stage_aggregate_with():
+ query = _translate_sql(
+ "MATCH (u:User)-[:AttemptsAccess]->(ae:AccessEvent)-[:Targets]->(r:Resource) "
+ "WITH u.department AS dept, u.User_id AS userId, "
+ "COUNT(DISTINCT r.Resource_id) AS resourceCount "
+ "WITH dept, AVG(resourceCount) AS avgResourcesPerUser "
+ "ORDER BY avgResourcesPerUser DESC "
+ "RETURN dept, avgResourcesPerUser"
+ )
+
+ assert "WITH stage_1 AS" in query
+ assert 'u."department" AS dept' in query
+ assert 'r."Resource_id" AS r_Resource_id' in query
+ assert "COUNT(DISTINCT r_Resource_id) AS resourceCount" in query
+ assert "GROUP BY dept, userId" in query
+ assert "stage_2 AS" in query
+ assert "AVG(resourceCount) AS avgResourcesPerUser" in query
+ assert "GROUP BY dept" in query
+ assert "ORDER BY avgResourcesPerUser DESC" in query
+
+
+def test_cypher2oracle_sqlpgq_translates_two_stage_with_filter_before_second_aggregate():
+ query = _translate_sql(
+ "MATCH (u:User)-[:AssignedTempAccess]->(ta:TemporaryAccess) "
+ "WITH u.department AS Department, u.User_id AS UserID, "
+ "COUNT(ta) AS TempAccessCount "
+ "WHERE Department IS NOT NULL "
+ "WITH Department, AVG(TempAccessCount) AS AvgTempAccesses "
+ "RETURN Department, AvgTempAccesses"
+ )
+
+ assert "WHERE Department IS NOT NULL" in query
+ assert "COUNT(ta_VALUE) AS TempAccessCount" in query
+ assert "GROUP BY Department, UserID" in query
+ assert "AVG(TempAccessCount) AS AvgTempAccesses" in query
+ assert "GROUP BY Department" in query
+ assert "SELECT Department AS Department, AvgTempAccesses AS AvgTempAccesses" in query
+
+
+def test_cypher2oracle_sqlpgq_groups_two_stage_final_aggregate_return():
+ query = _translate_sql(
+ "MATCH (m:Movie)-[:IN_GENRE]->(g:Genre) "
+ "WITH g, m.languages AS languages "
+ "WITH g, size(languages) AS languageCount "
+ "RETURN g.name AS genre, avg(languageCount) AS avgLanguages "
+ "ORDER BY avgLanguages DESC LIMIT 5"
+ )
+
+ assert "SELECT g_name AS genre, AVG(languageCount) AS avgLanguages" in query
+ assert "GROUP BY g_name" in query
+ assert "ORDER BY avgLanguages DESC" in query
+
+
+def test_cypher2oracle_sqlpgq_keeps_two_stage_aggregate_function_argument_expression():
+ query = _translate_sql(
+ "MATCH (s:Supplier)-[:SUPPLIES]->(p:Product)<-[:ORDERS]-(o:Order) "
+ "WITH s, o.freight AS freight WHERE freight IS NOT NULL "
+ "WITH s, avg(toFloat(freight)) AS avgFreight "
+ "ORDER BY avgFreight DESC LIMIT 3 "
+ "RETURN s.companyName AS supplierName, avgFreight"
+ )
+
+ assert "avg(TO_NUMBER(freight)) AS avgFreight" in query
+ assert query.count("GROUP BY s_VALUE, s_companyName") == 1
+ assert "TO_NUMBER_freight_" not in query
+
+
+def test_cypher2oracle_sqlpgq_translates_size_collect_distinct_to_count():
+ query = _translate_sql(
+ "MATCH (d:Director)-[:DIRECTED]->(m:Movie) "
+ "WITH d, size(collect(distinct m.countries)) AS numCountries "
+ "WHERE numCountries > 3 "
+ "RETURN d.name AS director, numCountries "
+ "ORDER BY numCountries DESC LIMIT 5"
+ )
+
+ assert "COUNT(DISTINCT m_countries) AS numCountries" in query
+ assert 'm."countries" AS m_countries' in query
+ assert "LENGTH(collect" not in query
+
+
+def test_cypher2oracle_sqlpgq_translates_size_collect_element_to_count():
+ query = _translate_sql(
+ "MATCH (d:Director)-[:DIRECTED]->(m:Movie) "
+ "WHERE m.imdbRating > 8.0 "
+ "WITH d, size(collect(m)) AS moviesDirected "
+ "ORDER BY moviesDirected DESC "
+ "RETURN d.name, moviesDirected LIMIT 1"
+ )
+
+ assert "COUNT(m_VALUE) AS moviesDirected" in query
+ assert "VERTEX_ID(m) AS m_VALUE" in query
+ assert "LENGTH(collect" not in query
+
+
+def test_cypher2oracle_sqlpgq_carries_two_stage_with_vertex_properties_to_final_return():
+ query = _translate_sql(
+ "MATCH (p:Policy)-[:Enforces]->(r:Role)-[:GrantsAccessTo]->(res:Resource) "
+ "WITH p, r, COUNT(res) AS resource_count "
+ "WHERE resource_count > 3 "
+ "WITH p, COUNT(r) AS role_count "
+ "WHERE role_count >= 2 "
+ "RETURN DISTINCT p.name"
+ )
+
+ assert 'p."name" AS p_name' in query
+ assert "COUNT(r_VALUE) AS role_count" in query
+ assert "GROUP BY p_VALUE, p_name" in query
+ assert "SELECT DISTINCT p_name AS name" in query
+ assert 'p."name" AS name' not in query
+
+
+def test_cypher2oracle_sqlpgq_carries_two_stage_with_aliased_final_properties():
+ query = _translate_sql(
+ "MATCH (g:Group)<-[:BelongsTo]-(u:User)-[:AttemptsAccess]->(ae:AccessEvent) "
+ "WITH g, u, COUNT(ae) AS user_access_count "
+ "WITH g, AVG(user_access_count) AS avg_access_per_user "
+ "RETURN g.name AS GroupName, avg_access_per_user AS AverageAccessPerUser"
+ )
+
+ assert 'g."name" AS g_name' in query
+ assert "AVG(user_access_count) AS avg_access_per_user" in query
+ assert "GROUP BY g_VALUE, g_name" in query
+ assert "SELECT g_name AS GroupName" in query
+ assert 'g."name" AS GroupName' not in query
+
+
+def test_cypher2oracle_sqlpgq_translates_not_exists_pattern_predicate():
+ query = _translate_sql(
+ "MATCH (u:User) "
+ "WHERE NOT EXISTS((u)-[:AttemptsAccess]->(:AccessEvent)) "
+ "RETURN u.name, u.email"
+ )
+
+ assert "WITH base AS" in query
+ assert 'MATCH (u IS "User")' in query
+ assert "SELECT name, email" in query
+ assert "WHERE NOT EXISTS" in query
+ assert 'MATCH (u)-[e1 IS "AttemptsAccess"]->(n1 IS "AccessEvent")' in query
+ assert "pp.u_VALUE = base.u_VALUE" in query
+
+
+def test_cypher2oracle_sqlpgq_translates_pattern_predicate_inline_map_as_filter():
+ query = _translate_sql(
+ "MATCH (d:DIRECTOR) "
+ "WHERE NOT EXISTS((d)<-[:DIRECTED_BY]-(m:MOVIE)"
+ "-[:BELONGS_TO]->(:GENRE {name: 'Horror'})) "
+ "RETURN d.name"
+ )
+
+ assert 'IS "GENRE"' in query
+ assert "n1.\"name\" = 'Horror'" in query
+ assert "GENRE_name_Horror" not in query
+
+
+def test_cypher2oracle_sqlpgq_translates_exists_pattern_predicate_before_aggregate():
+ query = _translate_sql(
+ "MATCH (a1:Assertion)-[:Contradicts]->(a2:Assertion) "
+ "WHERE EXISTS((a1)-[:ClassifiedUnder]->(:Domain)<-[:ClassifiedUnder]-(a2)) "
+ "RETURN COUNT(*) AS contradiction_count"
+ )
+
+ assert "SELECT COUNT(*) AS contradiction_count" in query
+ assert "WHERE EXISTS" in query
+ assert 'MATCH (a1)-[e1 IS "ClassifiedUnder"]->(n1 IS "Domain")' in query
+ assert '<-[e2 IS "ClassifiedUnder"]-(a2)' in query
+ assert "pp.a1_VALUE = base.a1_VALUE" in query
+ assert "pp.a2_VALUE = base.a2_VALUE" in query
+
+
+def test_cypher2oracle_sqlpgq_translates_incoming_exists_pattern_predicate_in_and():
+ query = _translate_sql(
+ "MATCH (q:Question)-[:TAGGED]->(t:Tag {name: 'neo4j'}) "
+ "WHERE EXISTS((q)<-[:ANSWERED]-(:Answer)) "
+ "AND EXISTS((q)<-[:COMMENTED_ON]-(:Comment)) "
+ "RETURN q.title"
+ )
+
+ assert "WITH base AS" in query
+ assert "WHERE EXISTS" in query
+ assert 'MATCH (q IS "Question")-[e1 IS "TAGGED"]->(t IS "Tag")' in query
+ assert 'MATCH (n1 IS "Answer")-[e1 IS "ANSWERED"]->(q)' in query
+ assert 'MATCH (n1 IS "Comment")-[e1 IS "COMMENTED_ON"]->(q)' in query
+ assert "pp.q_VALUE = base.q_VALUE" in query
+ assert "EXISTS((" not in query
+
+
+def test_cypher2oracle_sqlpgq_translates_lowercase_exists_pattern_predicate_in_and():
+ query = _translate_sql(
+ "MATCH (m:Movie) "
+ "WHERE m.released < '2000-01-01' AND exists((m)-[:IN_GENRE]->(:Genre)) "
+ "RETURN m.title, m.imdbRating ORDER BY m.imdbRating DESC LIMIT 3"
+ )
+
+ assert "WITH base AS" in query
+ assert "m.\"released\" < '2000-01-01'" in query
+ assert "WHERE EXISTS" in query
+ assert 'MATCH (m)-[e1 IS "IN_GENRE"]->(n1 IS "Genre")' in query
+ assert "exists((" not in query
+
+
+def test_cypher2oracle_sqlpgq_translates_raw_path_predicates_before_with_aggregate():
+ query = _translate_sql(
+ "MATCH (p)--(m:Movie) "
+ "WHERE (p)-[:DIRECTED]->(m) AND (p)-[:ACTED_IN]->(m) "
+ "WITH avg(m.budget) AS average_budget "
+ "RETURN average_budget"
+ )
+
+ assert "WITH base AS" in query
+ assert "stage_1 AS" in query
+ assert "predicate_1 AS" in query
+ assert "predicate_2 AS" in query
+ assert "JOIN predicate_1 ON predicate_1.m_VALUE = base.m_VALUE" in query
+ assert "JOIN predicate_2 ON predicate_2.m_VALUE = base.m_VALUE" in query
+ assert 'MATCH (p)-[e1 IS "DIRECTED"]->(m)' in query
+ assert 'MATCH (p)-[e1 IS "ACTED_IN"]->(m)' in query
+ assert "WHERE EXISTS" not in query
+ assert "[:DIRECTED]" not in query
+
+
+def test_cypher2oracle_sqlpgq_translates_return_exists_pattern_predicate():
+ query = _translate_sql(
+ "MATCH (p:Policy)-[:Enforces]->(r:Role)-[:GrantsAccessTo]->(res:Resource) "
+ "WHERE res.sensitivity_level = 'Confidential' "
+ "RETURN EXISTS((p)-[:Enforces]->(r)-[:GrantsAccessTo]->(res)) AS policyExists"
+ )
+
+ assert "SELECT CASE WHEN EXISTS" in query
+ assert "THEN 1 ELSE 0 END AS policyExists" in query
+ assert "pp.p_VALUE = base.p_VALUE" in query
+ assert "pp.r_VALUE = base.r_VALUE" in query
+ assert "pp.res_VALUE = base.res_VALUE" in query
+
+
+def test_cypher2oracle_sqlpgq_avoids_auto_edge_variable_collision():
+ query = _translate_sql(
+ "MATCH p = (n1:USER)-[e1:POSTS]-(x)-[]-(n2:USER) "
+ "WHERE n1.username = 'robert00' "
+ "RETURN p LIMIT 1"
+ )
+
+ assert "EDGE_ID(e1) AS p_e1_ID" in query
+ assert "EDGE_ID(e2) AS p_e2_ID" in query
+ assert query.count("AS p_e1_ID") == 1
+
+
+def test_cypher2oracle_sqlpgq_translates_with_match_then_aggregate_with():
+ query = _translate_sql(
+ "MATCH (u:User)-[:BelongsTo]->(g:Group) "
+ "WHERE g.member_count > 10 "
+ "WITH u, g "
+ "MATCH (u)-[:HasRole]->(r:Role)-[:GrantsAccessTo]->(res:Resource) "
+ "WHERE res.sensitivity_level = 'Confidential' "
+ "WITH u, g, COUNT(DISTINCT res) AS confidential_resource_count "
+ "WHERE confidential_resource_count >= 1 "
+ "RETURN u.name, g.name, confidential_resource_count"
+ )
+
+ assert "stage_3 AS" in query
+ assert "JOIN stage_1 ON stage_2.u_VALUE = stage_1.stage_1_u_VALUE" in query
+ assert "COUNT(DISTINCT stage_2.res_VALUE) AS confidential_resource_count" in query
+ assert "WHERE confidential_resource_count >= 1" in query
+ assert "SELECT u_name AS name, g_name AS name" in query
+
+
+def test_cypher2oracle_sqlpgq_translates_uncorrelated_with_match_aggregate_cross_join():
+ query = _translate_sql(
+ "MATCH (u:User)-[:BelongsTo]->(g:Group) "
+ "WITH g, COUNT(u) AS member_count "
+ "MATCH (u:User)-[:AttemptsAccess]->(:AccessEvent)-[:Targets]->(r:Resource) "
+ "WHERE r.sensitivity_level = 'high' "
+ "RETURN AVG(member_count) AS average_member_count"
+ )
+
+ assert "CROSS JOIN stage_1" in query
+ assert "COLUMNS (1 AS dummy_value)" in query
+ assert "AVG(stage_1.member_count) AS average_member_count" in query
+
+
+def test_cypher2oracle_sqlpgq_translates_standalone_optional_match_null_fallback():
+ query = _translate_sql(
+ "OPTIONAL MATCH (p:Policy)-[:Enforces]->(r:Role) "
+ "RETURN p.name AS policy_name, r.name AS role_name"
+ )
+
+ assert query.startswith("WITH optional_rows AS")
+ assert 'MATCH (p IS "Policy")-[e1 IS "Enforces"]->(r IS "Role")' in query
+ assert "UNION ALL" in query
+ assert "SELECT NULL AS policy_name, NULL AS role_name" in query
+ assert "WHERE NOT EXISTS (SELECT 1 FROM optional_rows)" in query
+
+
+def test_cypher2oracle_sqlpgq_translates_standalone_optional_match_count_fallback():
+ query = _translate_sql(
+ "OPTIONAL MATCH (u:User)-[:AttemptsAccess]->(ae:AccessEvent) "
+ "RETURN u.name AS UserName, COUNT(ae) AS AccessEventCount "
+ "ORDER BY AccessEventCount DESC"
+ )
+
+ assert query.startswith("WITH optional_rows AS")
+ assert "COUNT(ae_VALUE) AS AccessEventCount" in query
+ assert "GROUP BY UserName" in query
+ assert "SELECT NULL AS UserName, 0 AS AccessEventCount" in query
+ assert query.strip().endswith("ORDER BY AccessEventCount DESC")
+
+
+def test_cypher2oracle_sqlpgq_rejects_standalone_optional_match_count_star():
+ query, category = cypher2oracle_sqlpgq(
+ "OPTIONAL MATCH (u:User)-[:AttemptsAccess]->(ae:AccessEvent) RETURN count(*)",
+ graph_name="G",
+ )
+
+ assert query == "Unable to Translate to Oracle SQL/PGQ"
+ assert category == "Graph-IL Not Support"
+
+
+def test_cypher2oracle_sqlpgq_translates_with_optional_match_as_left_join():
+ query = _translate_sql(
+ "MATCH (g:Group) "
+ "WITH g ORDER BY g.created_date DESC LIMIT 1 "
+ "OPTIONAL MATCH (g)<-[:BelongsTo]-(u:User) "
+ "RETURN g.name AS GroupName, g.created_by AS CreatorName, "
+ "COUNT(u) AS MemberCount"
+ )
+
+ assert "WITH stage_1 AS" in query
+ assert "stage_2 AS" in query
+ assert "FROM stage_1\nLEFT JOIN stage_2 ON stage_2.g_VALUE = stage_1.stage_1_g_VALUE" in query
+ assert "COUNT(stage_2.u_VALUE) AS MemberCount" in query
+ assert "JOIN stage_1" not in query
+
+
+def test_cypher2oracle_sqlpgq_translates_match_optional_with_count_left_join():
+ query = _translate_sql(
+ "MATCH (q:Question)-[:TAGGED]->(t:Tag {name: 'neo4j'}) "
+ "OPTIONAL MATCH (q)<-[:COMMENTED_ON]-(c:Comment) "
+ "WITH q, count(c) AS commentCount "
+ "RETURN q.title, commentCount"
+ )
+
+ assert "LEFT JOIN stage_2 ON stage_2.q_VALUE = stage_1.stage_1_q_VALUE" in query
+ assert "stage_1.stage_1_q_VALUE AS q_VALUE" in query
+ assert "COUNT(stage_2.c_VALUE) AS commentCount" in query
+ assert "GROUP BY stage_1.stage_1_q_VALUE, stage_1.q_title" in query
+
+
+def test_cypher2oracle_sqlpgq_carries_base_only_variable_in_match_optional_with():
+ query = _translate_sql(
+ "MATCH (p:Product)-[:PART_OF]->(category:Category) "
+ "OPTIONAL MATCH (p)<-[o:ORDERS]-(:Order) "
+ "WITH category.categoryName AS categoryName, "
+ "SUM(toInteger(o.quantity)) AS totalQuantityOrdered "
+ "RETURN categoryName, totalQuantityOrdered"
+ )
+
+ assert 'category."categoryName" AS categoryName' in query
+ assert "LEFT JOIN stage_2 ON stage_2.p_VALUE = stage_1.stage_1_p_VALUE" in query
+ assert "stage_1.categoryName AS categoryName" in query
+ assert "GROUP BY stage_1.categoryName" in query
+
+
+def test_cypher2oracle_sqlpgq_translates_optional_null_where_to_antijoin():
+ query = _translate_sql(
+ "MATCH (q:Question) "
+ "OPTIONAL MATCH (q)<-[:COMMENTED_ON]-(c:Comment) "
+ "WHERE c IS NULL "
+ "RETURN q.title"
+ )
+
+ assert query.startswith("WITH base AS")
+ assert "WHERE NOT EXISTS" in query
+ assert 'MATCH (c IS "Comment")-[e1 IS "COMMENTED_ON"]->(q)' in query
+ assert "pp.q_VALUE = base.q_VALUE" in query
+ assert "WHERE c IS NULL" not in query
+ assert "OPTIONAL MATCH" not in query
diff --git a/test/test_dataset_prep.py b/test/test_dataset_prep.py
new file mode 100644
index 0000000..8c420e0
--- /dev/null
+++ b/test/test_dataset_prep.py
@@ -0,0 +1,1593 @@
+import json
+from pathlib import Path
+from types import SimpleNamespace
+
+from app.core.validator.db_client import QueryResult, QueryStatus
+from dataset_prep.analyze_failures import failure_signature, unsupported_query_signature
+from dataset_prep.compare_oracle_neo4j_results import (
+ DatasetNeo4jLoader,
+ compare_record,
+ comparison_result,
+ is_nondeterministic_limit_without_order,
+ is_nondeterministic_with_limit_without_order,
+ is_order_by_limit_query,
+ is_with_order_by_limit_query,
+ normalize_rows,
+ result_diagnostics,
+ select_records_for_range,
+ stable_execution_queries,
+)
+from dataset_prep.cypher_schema import CypherSchema
+from dataset_prep.discover import DatabaseUnit, discover_database_units, source_query
+from dataset_prep.export_validated_dataset import (
+ copy_dataset_assets,
+ insert_sql_pgq_field,
+ merge_export_summaries,
+ project_export_record,
+ write_records_like_source,
+)
+from dataset_prep.oracle_loader import DatasetOracleLoader
+from dataset_prep.translate_validate import detect_unsupported_features, graph_name_for
+
+
+def _schema(config: dict) -> CypherSchema:
+ return CypherSchema(config)
+
+
+def test_discover_train_and_dev_layouts(tmp_path: Path):
+ train_db = tmp_path / "train" / "movies"
+ train_config = train_db / "cypher" / "movies_tugraph_new"
+ train_config.mkdir(parents=True)
+ (train_db / "4_level_results_ek_results.json").write_text("[]", encoding="utf-8")
+ (train_config / "import_config.json").write_text(
+ json.dumps({"schema": [], "files": []}),
+ encoding="utf-8",
+ )
+
+ dev_db = tmp_path / "dev" / "Disney" / "Cypher"
+ dev_config = dev_db / "disney__tugraph2"
+ dev_config.mkdir(parents=True)
+ (dev_db / "disney_cypher.json").write_text("[]", encoding="utf-8")
+ (dev_db / "4_level_results_ek_results_refined.json").write_text("[]", encoding="utf-8")
+ (dev_config / "import_config.json").write_text(
+ json.dumps({"schema": [], "files": []}),
+ encoding="utf-8",
+ )
+
+ units = discover_database_units(tmp_path, ["train", "dev"])
+
+ assert [(unit.split, unit.database) for unit in units] == [
+ ("dev", "Disney"),
+ ("train", "movies"),
+ ]
+ assert units[0].query_path.name == "disney_cypher.json"
+ assert units[1].csv_root == train_config
+
+
+def test_source_query_prefers_cypher_then_gql():
+ assert source_query({"initial_cypher": "MATCH (n) RETURN n", "initial_gql": "gql"}) == (
+ "initial_cypher",
+ "MATCH (n) RETURN n",
+ )
+ assert source_query({"initial_gql": "MATCH (n) RETURN n"}) == (
+ "initial_gql",
+ "MATCH (n) RETURN n",
+ )
+
+
+def test_graph_name_is_oracle_safe():
+ unit = DatabaseUnit(
+ split="test",
+ database="Manufacturing_BOM(Bill_Of_Materials)",
+ root=Path("."),
+ query_path=Path("query.json"),
+ import_config_path=Path("import_config.json"),
+ csv_root=Path("."),
+ )
+ name = graph_name_for(unit, "T2GQL")
+ assert "-" not in name
+ assert "(" not in name
+ assert len(name) <= 128
+
+
+def test_export_projection_adds_sql_pgq_and_strips_oracle_metadata():
+ record = {
+ "id": "q1",
+ "initial_question": "Question?",
+ "initial_cypher": "MATCH (n) RETURN n",
+ "initial_gql": "MATCH (n) RETURN n",
+ "source": "synthetic",
+ "oracle_sqlpgq": "SELECT * FROM GRAPH_TABLE (...)",
+ "oracle_validation_status": "success",
+ "oracle_dataset_meta": {"split": "train"},
+ }
+ args = SimpleNamespace(sql_pgq_field="initial_sql_pgq", include_oracle_metadata=False)
+
+ exported = project_export_record(record, args)
+
+ assert exported == {
+ "id": "q1",
+ "initial_question": "Question?",
+ "initial_cypher": "MATCH (n) RETURN n",
+ "initial_gql": "MATCH (n) RETURN n",
+ "initial_sql_pgq": "SELECT * FROM GRAPH_TABLE (...)",
+ "source": "synthetic",
+ }
+ assert list(exported) == [
+ "id",
+ "initial_question",
+ "initial_cypher",
+ "initial_gql",
+ "initial_sql_pgq",
+ "source",
+ ]
+
+
+def test_export_projection_can_keep_oracle_metadata_with_custom_field():
+ record = {
+ "id": "q1",
+ "oracle_sqlpgq": "SELECT 1",
+ "oracle_validation_status": "success",
+ }
+ args = SimpleNamespace(sql_pgq_field="sql_pgq_query_oracle", include_oracle_metadata=True)
+
+ exported = project_export_record(record, args)
+
+ assert exported["oracle_validation_status"] == "success"
+ assert exported["sql_pgq_query_oracle"] == "SELECT 1"
+
+
+def test_insert_sql_pgq_field_places_query_next_to_existing_query_fields():
+ exported = insert_sql_pgq_field(
+ {
+ "id": "q1",
+ "initial_cypher": "MATCH (n) RETURN n",
+ "initial_gql": "MATCH (n) RETURN n",
+ "difficulty": "easy",
+ },
+ "initial_sql_pgq",
+ "SELECT 1",
+ )
+
+ assert list(exported) == [
+ "id",
+ "initial_cypher",
+ "initial_gql",
+ "initial_sql_pgq",
+ "difficulty",
+ ]
+
+
+def test_write_records_like_source_preserves_list_shape(tmp_path: Path):
+ source = tmp_path / "source.json"
+ output = tmp_path / "out" / "source.json"
+ source.write_text(json.dumps([{"id": "a"}, {"id": "b"}]), encoding="utf-8")
+
+ write_records_like_source(source, output, [{"id": "a", "initial_sql_pgq": "SELECT 1"}])
+
+ assert json.loads(output.read_text(encoding="utf-8")) == [
+ {"id": "a", "initial_sql_pgq": "SELECT 1"}
+ ]
+
+
+def test_write_records_like_source_preserves_dict_shape(tmp_path: Path):
+ source = tmp_path / "source.json"
+ output = tmp_path / "out" / "source.json"
+ source.write_text(
+ json.dumps({"q1": {"initial_cypher": "MATCH (n) RETURN n"}}),
+ encoding="utf-8",
+ )
+
+ write_records_like_source(source, output, [{"id": "q1", "initial_sql_pgq": "SELECT 1"}])
+
+ assert json.loads(output.read_text(encoding="utf-8")) == {
+ "q1": {"id": "q1", "initial_sql_pgq": "SELECT 1"}
+ }
+
+
+def test_copy_dataset_assets_handles_resolved_dataset_root(tmp_path: Path, monkeypatch):
+ dataset_root = tmp_path / "dataset"
+ db_root = dataset_root / "train" / "movies"
+ config_root = db_root / "cypher" / "movies_tugraph"
+ config_root.mkdir(parents=True)
+ query_path = db_root / "4_level_results_ek_results.json"
+ query_path.write_text("[]", encoding="utf-8")
+ config_path = config_root / "import_config.json"
+ config_path.write_text("{}", encoding="utf-8")
+ asset_path = config_root / "Movie.csv"
+ asset_path.write_text("id,title\n1,Heat\n", encoding="utf-8")
+ output_root = tmp_path / "exported"
+ unit = DatabaseUnit(
+ split="train",
+ database="movies",
+ root=Path("dataset/train/movies"),
+ query_path=Path("dataset/train/movies/4_level_results_ek_results.json"),
+ import_config_path=Path("dataset/train/movies/cypher/movies_tugraph/import_config.json"),
+ csv_root=Path("dataset/train/movies/cypher/movies_tugraph"),
+ )
+ monkeypatch.chdir(tmp_path)
+
+ copy_dataset_assets(
+ [unit],
+ Path("dataset").resolve(),
+ output_root,
+ {query_path.resolve()},
+ )
+
+ assert not (output_root / "train" / "movies" / "4_level_results_ek_results.json").exists()
+ assert (
+ output_root / "train" / "movies" / "cypher" / "movies_tugraph" / "Movie.csv"
+ ).read_text(encoding="utf-8") == "id,title\n1,Heat\n"
+
+
+def test_merge_export_summaries_accumulates_counts():
+ merged = merge_export_summaries(
+ [
+ {
+ "total_records": 3,
+ "selected_records": 3,
+ "considered": 2,
+ "exported": 1,
+ "failed": 1,
+ "skipped": 1,
+ "skip_reasons": {"not_translatable": 1},
+ "failure_reasons": {"result_mismatch": 1},
+ },
+ {
+ "total_records": 2,
+ "selected_records": 2,
+ "considered": 2,
+ "exported": 2,
+ "failed": 0,
+ "skipped": 0,
+ "skip_reasons": {},
+ "failure_reasons": {},
+ },
+ ]
+ )
+
+ assert merged["databases"] == 2
+ assert merged["total_records"] == 5
+ assert merged["exported"] == 3
+ assert merged["skip_reasons"] == {"not_translatable": 1}
+ assert merged["failure_reasons"] == {"result_mismatch": 1}
+
+
+def test_detect_unsupported_oracle_sqlpgq_features():
+ assert not detect_unsupported_features("MATCH p = (a)-[e]->(b) RETURN p")
+ assert not detect_unsupported_features("MATCH (a)-[:A|B]->(b) RETURN b")
+ assert not detect_unsupported_features(
+ "MATCH (a) RETURN count(CASE WHEN a.type = 'x' THEN 1 ELSE NULL END)"
+ )
+ assert "case_label_predicate" in detect_unsupported_features(
+ "MATCH (a) RETURN CASE WHEN a:ACCOUNT THEN 1 ELSE 0 END"
+ )
+ assert not detect_unsupported_features("OPTIONAL MATCH (a)-->(b) RETURN a.name, count(b)")
+ assert not detect_unsupported_features(
+ "MATCH (g:Group) WITH g ORDER BY g.created_date DESC LIMIT 1 "
+ "OPTIONAL MATCH (g)<-[:BelongsTo]-(u:User) RETURN g.name, COUNT(u)"
+ )
+ assert not detect_unsupported_features(
+ "MATCH (q:Question) OPTIONAL MATCH (q)<-[:COMMENTED_ON]-(c:Comment) "
+ "WITH q, count(c) AS commentCount RETURN q.title, commentCount"
+ )
+ assert not detect_unsupported_features(
+ "MATCH (q:Question) OPTIONAL MATCH (q)<-[:COMMENTED_ON]-(c:Comment) "
+ "WHERE c IS NULL RETURN q.title"
+ )
+ assert not detect_unsupported_features("MATCH (a) OPTIONAL MATCH (a)--(b) RETURN b")
+ assert "optional_match" in detect_unsupported_features(
+ "OPTIONAL MATCH (a)-->(b) OPTIONAL MATCH (b)-->(c) RETURN c"
+ )
+ assert "optional_match" in detect_unsupported_features("MATCH (a) OPTIONAL MATCH (b) RETURN b")
+ assert "optional_match" in detect_unsupported_features(
+ "OPTIONAL MATCH (a)-->(b) RETURN count(*)"
+ )
+ assert "multiple_with" in detect_unsupported_features(
+ "MATCH (a) WITH a MATCH (a)-->(b) WITH b RETURN b"
+ )
+ assert "unwind" not in detect_unsupported_features(
+ "MATCH (q:Question {title: 'use UNWIND and FOREACH safely'}) RETURN q.title"
+ )
+ assert "multiple_with" not in detect_unsupported_features(
+ 'MATCH (q:Question {title: "WITH examples in a title"}) RETURN q.title'
+ )
+ assert "open_ended_variable_length_path" in detect_unsupported_features(
+ "MATCH (a:ACCOUNT)-[*..10]-(t:TRANSACTION) RETURN t LIMIT 1"
+ )
+ assert "expensive_variable_length_path" not in detect_unsupported_features(
+ "MATCH (person:PERSON)-[:KNOWS*..3]->(friend:PERSON) RETURN friend"
+ )
+ assert "expensive_variable_length_path" in detect_unsupported_features(
+ 'MATCH (target:Character {name: "Stevron-Frey"}) '
+ "MATCH (target)-[:INTERACTS*1..5]-(other) RETURN other.name LIMIT 10"
+ )
+ assert "expensive_variable_length_path" not in detect_unsupported_features(
+ 'MATCH (target:Character {name: "Stevron-Frey"}) '
+ "MATCH (target)-[:INTERACTS*1..3]-(other) RETURN other.name LIMIT 10"
+ )
+ assert "cost" not in detect_unsupported_features("MATCH (p:Product) RETURN p.cost")
+ assert "cost" in detect_unsupported_features(
+ "MATCH p = ANY CHEAPEST (a)-[:ROUTE]->(b) RETURN p"
+ )
+ assert "open_ended_variable_length_path" in detect_unsupported_features(
+ "MATCH (person:PERSON)-[:KNOWS*1..]->(friend:PERSON) RETURN friend"
+ )
+ assert "open_ended_variable_length_path" in detect_unsupported_features(
+ "MATCH (person:PERSON)-[*..]->(friend:PERSON) RETURN friend"
+ )
+ assert "open_ended_variable_length_path" in detect_unsupported_features(
+ "MATCH (m:Material {Material_id: 'M000123'})-[*..10]-(p:Product) RETURN p"
+ )
+ assert "open_ended_variable_length_path" in detect_unsupported_features(
+ "MATCH (person:PERSON)-[*]->(friend:PERSON) RETURN friend"
+ )
+ assert "open_ended_variable_length_path" not in detect_unsupported_features(
+ "MATCH (person:PERSON)-[:KNOWS*1..3]->(friend:PERSON) RETURN friend"
+ )
+ assert "expensive_variable_length_path" not in detect_unsupported_features(
+ 'MATCH (u:USER)-[*2..5]->(n) WHERE u.user_id = "U000001" RETURN n.user_id'
+ )
+ assert "expensive_variable_length_path" not in detect_unsupported_features(
+ 'MATCH (u:USER)-[*1..2]->(n) WHERE u.user_id = "U000001" RETURN n.user_id'
+ )
+ assert "quantified_relationship_property_map" in detect_unsupported_features(
+ "MATCH (d:DEVICE)-[:CONNECTS_TO*1..2 {connectionType: 'WiFi'}]->(:GATEWAY) RETURN d"
+ )
+ assert "relative_duration" in detect_unsupported_features(
+ "MATCH (m:Movie) WHERE m.release_date >= date() - duration('P5Y') RETURN m"
+ )
+
+
+def test_detect_allows_broad_bounded_variable_length_paths_with_schema():
+ config_path = Path(
+ "dataset/dev/FInancial_Financial_Management/Cypher/TuGraph-DB_Instance/"
+ "import_config.json"
+ )
+ schema = CypherSchema(json.loads(config_path.read_text(encoding="utf-8")))
+
+ assert "open_ended_variable_length_path" in detect_unsupported_features(
+ "MATCH (a:ACCOUNT {account_id: 'A000000'})-[*..10]-(t:TRANSACTION)"
+ "-[:GovernedBy]->(c:COMPLIANCE_RULE {regulation_standard: 'GDPR'}) "
+ "RETURN a.account_id, t.transaction_id, c.rule_id LIMIT 1",
+ schema,
+ )
+ assert "open_ended_variable_length_path" in detect_unsupported_features(
+ "MATCH (a:ACCOUNT {account_id: 'A000000'})-[*]-(t:TRANSACTION) RETURN t",
+ schema,
+ )
+
+
+def test_detect_schema_direction_and_numeric_source_issues():
+ schema = _schema(
+ {
+ "schema": [
+ {
+ "label": "DataConsumer",
+ "type": "VERTEX",
+ "primary": "DataConsumer_id",
+ "properties": [{"name": "DataConsumer_id", "type": "STRING"}],
+ },
+ {
+ "label": "DataAsset",
+ "type": "VERTEX",
+ "primary": "DataAsset_id",
+ "properties": [{"name": "DataAsset_id", "type": "STRING"}],
+ },
+ {
+ "label": "ProcessingJob",
+ "type": "VERTEX",
+ "primary": "ProcessingJob_id",
+ "properties": [
+ {"name": "success_rate", "type": "DOUBLE"},
+ {"name": "sla_requirements", "type": "STRING"},
+ {"name": "last_review_date", "type": "DATE"},
+ ],
+ },
+ {
+ "label": "Character",
+ "type": "VERTEX",
+ "primary": "Character_id",
+ "properties": [
+ {"name": "Character_id", "type": "STRING"},
+ {"name": "fastrf_embedding", "type": "STRING"},
+ ],
+ },
+ {
+ "label": "Review",
+ "type": "VERTEX",
+ "primary": "Review_id",
+ "properties": [{"name": "Review_id", "type": "STRING"}],
+ },
+ {
+ "label": "Transforms",
+ "type": "EDGE",
+ "constraints": [["ProcessingJob", "DataAsset"]],
+ "properties": [],
+ },
+ {
+ "label": "USED_DEVICE",
+ "type": "EDGE",
+ "constraints": [["ProcessingJob", "DataAsset"]],
+ "properties": [
+ {"name": "session_start", "type": "TIMESTAMP"},
+ {"name": "session_end", "type": "TIMESTAMP"},
+ ],
+ },
+ ]
+ }
+ )
+
+ assert "invalid_schema_direction" in detect_unsupported_features(
+ "MATCH (da:DataAsset)-[:Transforms]->(pj:ProcessingJob) RETURN da",
+ source_schema=schema,
+ )
+ assert "invalid_schema_direction" not in detect_unsupported_features(
+ "MATCH (pj:ProcessingJob)-[:Transforms]->(da:DataAsset) RETURN da",
+ source_schema=schema,
+ )
+ assert "unsafe_numeric_conversion" in detect_unsupported_features(
+ "MATCH (pj:ProcessingJob) WHERE toInteger(pj.sla_requirements) > 24 RETURN pj",
+ source_schema=schema,
+ )
+ assert "unsafe_numeric_conversion" in detect_unsupported_features(
+ "MATCH (c:Character) WHERE c.fastrf_embedding IS NOT NULL "
+ "RETURN max(c.fastrf_embedding) - min(c.fastrf_embedding) AS diversity",
+ source_schema=schema,
+ )
+ assert "unsafe_temporal_arithmetic" in detect_unsupported_features(
+ "MATCH (pj:ProcessingJob) "
+ "RETURN max(pj.last_review_date) - min(pj.last_review_date) AS dateRange",
+ source_schema=schema,
+ )
+ assert "unsafe_temporal_arithmetic" in detect_unsupported_features(
+ "MATCH (pj:ProcessingJob) "
+ "WITH max(pj.last_review_date) AS maxDate, min(pj.last_review_date) AS minDate "
+ "ORDER BY maxDate - minDate DESC "
+ "RETURN minDate, maxDate, maxDate - minDate AS DateRange",
+ source_schema=schema,
+ )
+ assert "invalid_schema_property" in detect_unsupported_features(
+ "MATCH (r:Review) RETURN size(collect(DISTINCT r.summary)) AS summaryCount",
+ source_schema=schema,
+ )
+ assert "invalid_schema_property" in detect_unsupported_features(
+ "MATCH (pj:ProcessingJob)<-[:REVIEWED]-(r) "
+ "RETURN size(collect(DISTINCT r.summary)) AS summaryCount",
+ source_schema=schema,
+ )
+ duration_features = detect_unsupported_features(
+ "MATCH (pj:ProcessingJob)-[ud:USED_DEVICE]->(da:DataAsset) "
+ "RETURN avg(duration.between(ud.session_start, ud.session_end))",
+ source_schema=schema,
+ )
+ assert "invalid_schema_property" not in duration_features
+
+
+def test_failure_analysis_groups_unsupported_query_shapes():
+ def record(query: str) -> dict:
+ return {
+ "oracle_validation_status": "unsupported",
+ "oracle_translation_category": "Graph-IL Not Support",
+ "oracle_source_query": query,
+ "oracle_unsupported_features": [],
+ }
+
+ assert (
+ failure_signature(record("MATCH (a) WITH a MATCH (a)-->(b) WITH a RETURN a"))
+ == "multiple_with_skipped"
+ )
+ multi_with_optional = record("MATCH (a) WITH a OPTIONAL MATCH (a)-->(b) WITH a RETURN a")
+ multi_with_optional["oracle_unsupported_features"] = ["optional_match"]
+ assert failure_signature(multi_with_optional) == "multiple_with_skipped"
+ standalone_optional = record("OPTIONAL MATCH (a)-->(b) RETURN a")
+ standalone_optional["oracle_unsupported_features"] = ["optional_match"]
+ assert failure_signature(standalone_optional) == "standalone_optional_match"
+ optional_after_binding = record("MATCH (a) OPTIONAL MATCH (a)-->(b) RETURN a")
+ optional_after_binding["oracle_unsupported_features"] = ["optional_match"]
+ assert failure_signature(optional_after_binding) == "optional_match_left_join_required"
+ assert (
+ failure_signature(record("MATCH p=(a)-[:KNOWS*1..3]->(b) RETURN p"))
+ == "path_variable_return"
+ )
+ assert (
+ failure_signature(record("MATCH (p:Policy) RETURN AVG(p.effective_date) AS value"))
+ == "temporal_numeric_aggregate"
+ )
+ assert (
+ failure_signature(record("MATCH (a)-[e]->(b) RETURN count(DISTINCT e.bad_alias)"))
+ == "invalid_schema_property"
+ )
+ assert (
+ failure_signature(record("MATCH (a:Assertion) WHERE a.contradiction_severity > 1 RETURN a"))
+ == "invalid_schema_property"
+ )
+ assert (
+ failure_signature(
+ record(
+ "MATCH (g:Group) WITH g ORDER BY g.created_date DESC LIMIT 1 "
+ "OPTIONAL MATCH (g)<-[:BelongsTo]-(u:User) RETURN g.name, COUNT(u)"
+ )
+ )
+ == "invalid_schema_property"
+ )
+ assert failure_signature(record("MATCH (s:Source) RETURN s.name")) == "invalid_schema_property"
+ assert (
+ failure_signature(
+ record(
+ "MATCH p=(n1:Resource)-[e]-(n2:Policy) "
+ "WHERE n2.sensitivity_level <> 'Internal' RETURN p"
+ )
+ )
+ == "invalid_schema_property"
+ )
+ assert (
+ unsupported_query_signature(
+ "MATCH (s:Supplier)-[:SUPPLIES]->(p:Product)-[:ORDERS]->(o:Order) "
+ "WITH s.supplierID AS supplierID, COUNT(DISTINCT o.shipCity) AS cityCount "
+ "WHERE cityCount > 3 RETURN supplierID"
+ )
+ != "multi_pattern_match"
+ )
+ assert (
+ unsupported_query_signature("MATCH (a)-[:KNOWS*1..]->(b) RETURN b")
+ == "open_ended_variable_length_path"
+ )
+
+
+def test_failure_analysis_uses_manifest_for_invalid_schema(tmp_path: Path):
+ import_config = tmp_path / "import_config.json"
+ import_config.write_text(
+ json.dumps(
+ {
+ "schema": [
+ {
+ "label": "Product",
+ "type": "VERTEX",
+ "properties": [{"name": "productName"}, {"name": "productID"}],
+ },
+ {
+ "label": "Order",
+ "type": "VERTEX",
+ "properties": [{"name": "freight"}, {"name": "orderDate"}],
+ },
+ {
+ "label": "ORDERS",
+ "type": "EDGE",
+ "properties": [{"name": "quantity"}],
+ "constraints": [["Order", "Product"]],
+ },
+ ],
+ }
+ ),
+ encoding="utf-8",
+ )
+
+ def record(query: str) -> dict:
+ return {
+ "oracle_validation_status": "unsupported",
+ "oracle_translation_category": "Graph-IL Not Support",
+ "oracle_source_query": query,
+ "oracle_unsupported_features": [],
+ "oracle_dataset_meta": {"import_config": str(import_config)},
+ }
+
+ assert (
+ failure_signature(record("MATCH (p:Product)-[:ORDERS]->(o:Order) RETURN o.freight"))
+ == "invalid_schema_direction"
+ )
+ assert (
+ failure_signature(record("MATCH (p:Product) RETURN p.missingProperty"))
+ == "invalid_schema_property"
+ )
+
+
+def test_loader_converts_oracle_date_values():
+ loader = DatasetOracleLoader.__new__(DatasetOracleLoader)
+
+ assert loader._convert_value("2015-12-29", "DATE").isoformat() == "2015-12-29"
+ assert (
+ loader._convert_value("2025-05-18 12:14:48", "TIMESTAMP").isoformat()
+ == "2025-05-18T12:14:48"
+ )
+ assert loader._convert_value("", "DATE") is None
+ assert loader._convert_value("True", "NUMBER(1)") == 1
+ assert loader._convert_value("False", "NUMBER(1)") == 0
+ assert loader._convert_value("abcde", "VARCHAR2(3)") == "abc"
+ assert loader._convert_value("ééé", "VARCHAR2(4)") == "éé"
+
+
+def test_compare_normalizes_temporal_strings_and_numeric_precision():
+ class FakeNeo4jDateTime:
+ def iso_format(self):
+ return "2025-01-01T12:00:00.000000000"
+
+ oracle_rows = [{"created_at": "2025-01-01T12:00:00", "score": 1}]
+ neo4j_rows = [{"created_at": "2025-01-01T12:00:00.000000000", "score": 1.0}]
+
+ assert normalize_rows(oracle_rows) == normalize_rows(neo4j_rows)
+ assert normalize_rows([{"created_at": "2025-01-01T12:00:00"}]) == normalize_rows(
+ [{"created_at": FakeNeo4jDateTime()}]
+ )
+ assert normalize_rows([{"created_at": "2025-01-01T12:00:00.1"}]) == normalize_rows(
+ [{"created_at": "2025-01-01T12:00:00.100000000"}]
+ )
+ assert normalize_rows([{"allocation": 0.764800012112}]) == normalize_rows(
+ [{"allocation": 0.7648}]
+ )
+ assert normalize_rows([{"epoch": 1613446786}]) == normalize_rows(
+ [{"epoch": 1613446786.0000002}]
+ )
+ assert normalize_rows([{"date_value": "2025-01-01T00:00:00"}]) == normalize_rows(
+ [{"date_value": "2025-01-01"}]
+ )
+
+
+def test_compare_normalizes_oracle_and_neo4j_node_identity():
+ class FakeNode:
+ labels = {"director"}
+
+ def items(self):
+ return {
+ "_id": 57,
+ "name": "Pinocchio",
+ "director": "Ben Sharpsteen",
+ }.items()
+
+ oracle_rows = [
+ {
+ "director": {
+ "ELEM_TABLE": "director",
+ "GRAPH_NAME": "G",
+ "GRAPH_OWNER": "SYSTEM",
+ "KEY_VALUE": {"_id": 57},
+ }
+ }
+ ]
+ neo4j_rows = [{"director": FakeNode()}]
+
+ assert normalize_rows(oracle_rows, {"director": "_id"}) == normalize_rows(
+ neo4j_rows,
+ {"director": "_id"},
+ )
+
+
+def test_compare_normalizes_single_neo4j_path_to_flat_element_sequence():
+ class FakeNode:
+ def __init__(self, label: str, key: str, value: str):
+ self.labels = {label}
+ self._props = {key: value}
+
+ def items(self):
+ return self._props.items()
+
+ class FakeRelationship:
+ type = "POSTS"
+
+ def items(self):
+ return {}.items()
+
+ class FakePath:
+ nodes = [FakeNode("USER", "user_id", "U000001"), FakeNode("POST", "post_id", "P000001")]
+ relationships = [FakeRelationship()]
+
+ oracle_rows = [
+ {
+ "p_n1_ID": {"ELEM_TABLE": "USER", "KEY_VALUE": {"user_id": "U000001"}},
+ "p_e1_ID": {"ELEM_TABLE": "POSTS", "KEY_VALUE": {}},
+ "p_x_ID": {"ELEM_TABLE": "POST", "KEY_VALUE": {"post_id": "P000001"}},
+ }
+ ]
+ neo4j_rows = [{"p": FakePath()}]
+
+ assert normalize_rows(oracle_rows, {"USER": "user_id", "POST": "post_id"}) == normalize_rows(
+ neo4j_rows,
+ {"USER": "user_id", "POST": "post_id"},
+ )
+
+
+def test_compare_normalizes_oracle_and_neo4j_edge_identity():
+ class FakeRelationship:
+ type = "AllocatedTo"
+
+ def items(self):
+ return {"EDGE_ID": 6, "priority": 1}.items()
+
+ oracle_rows = [
+ {
+ "r": {
+ "ELEM_TABLE": "BUDGET_AllocatedTo_ACCOUNT",
+ "KEY_VALUE": {"EDGE_ID": 6},
+ }
+ }
+ ]
+ neo4j_rows = [{"r": FakeRelationship()}]
+
+ aliases = {"BUDGET_AllocatedTo_ACCOUNT": "AllocatedTo"}
+
+ assert normalize_rows(oracle_rows, element_label_aliases=aliases) == normalize_rows(
+ neo4j_rows,
+ element_label_aliases=aliases,
+ )
+
+
+def test_compare_detects_nondeterministic_limit_without_order_by():
+ assert is_nondeterministic_limit_without_order("MATCH (n) RETURN n LIMIT 10")
+ assert not is_nondeterministic_limit_without_order(
+ "MATCH (n) RETURN n ORDER BY n.name LIMIT 10"
+ )
+ assert is_nondeterministic_limit_without_order(
+ "MATCH (n) WITH n ORDER BY n.created LIMIT 1 RETURN n LIMIT 10"
+ )
+ assert not is_nondeterministic_limit_without_order(
+ "MATCH (n {text: 'ORDER BY words LIMIT examples'}) RETURN n"
+ )
+ assert is_order_by_limit_query("MATCH (n) RETURN n ORDER BY n.name LIMIT 10")
+
+
+def test_compare_builds_stable_execution_queries_for_unordered_scalar_limit():
+ stable = stable_execution_queries(
+ "SELECT name, score FROM graph_table(...) FETCH FIRST 10 ROWS ONLY",
+ "MATCH (n) RETURN n.name AS name, n.score AS score LIMIT 10",
+ )
+
+ assert stable.applied
+ assert stable.reason == "unordered_paging"
+ assert stable.cypher == (
+ "MATCH (n) RETURN n.name AS name, n.score AS score ORDER BY name, score LIMIT 10"
+ )
+ assert stable.oracle_sqlpgq == (
+ "SELECT name, score FROM graph_table(...)\nORDER BY 1, 2\nFETCH FIRST 10 ROWS ONLY"
+ )
+
+
+def test_compare_adds_stable_tiebreakers_to_ordered_scalar_limit():
+ stable = stable_execution_queries(
+ "SELECT name, score FROM graph_table(...) ORDER BY score FETCH FIRST 1 ROWS ONLY",
+ "MATCH (n) RETURN n.name AS name, n.score AS score ORDER BY score LIMIT 1",
+ )
+
+ assert stable.applied
+ assert stable.reason == "ordered_paging_tiebreaker"
+ assert stable.cypher == (
+ "MATCH (n) RETURN n.name AS name, n.score AS score ORDER BY score, name LIMIT 1"
+ )
+ assert stable.oracle_sqlpgq == (
+ "SELECT name, score FROM graph_table(...) ORDER BY score, 1, 2\n"
+ "FETCH FIRST 1 ROWS ONLY"
+ )
+
+
+def test_compare_adds_stable_tiebreakers_to_ordered_with_limit():
+ stable = stable_execution_queries(
+ (
+ "WITH stage_1 AS (\n"
+ " SELECT fingerprint, COUNT(transaction_id) AS transaction_count\n"
+ " FROM graph_table(...)\n"
+ " GROUP BY fingerprint\n"
+ " ORDER BY transaction_count DESC\n"
+ " FETCH FIRST 3 ROWS ONLY\n"
+ ")\n"
+ "SELECT fingerprint, transaction_count FROM stage_1"
+ ),
+ (
+ "MATCH (d:DEVICE)<-[:USED_DEVICE]-(t:TRANSACTION) "
+ "WITH d.device_fingerprint AS fingerprint, COUNT(t) AS transaction_count "
+ "ORDER BY transaction_count DESC LIMIT 3 "
+ "RETURN fingerprint, transaction_count"
+ ),
+ )
+
+ assert stable.applied
+ assert stable.reason == "with_ordered_paging_tiebreaker"
+ assert (
+ "WITH d.device_fingerprint AS fingerprint, COUNT(t) AS transaction_count "
+ "ORDER BY transaction_count DESC, fingerprint LIMIT 3"
+ ) in stable.cypher
+ assert "ORDER BY transaction_count DESC, 1, 2" in stable.oracle_sqlpgq
+ assert "2FETCH" not in stable.oracle_sqlpgq
+
+
+def test_compare_adds_stable_order_to_unordered_with_limit():
+ stable = stable_execution_queries(
+ (
+ "WITH stage_1 AS (\n"
+ " SELECT name, score\n"
+ " FROM graph_table(...)\n"
+ " FETCH FIRST 2 ROWS ONLY\n"
+ ")\n"
+ "SELECT name, score FROM stage_1"
+ ),
+ "MATCH (n) WITH n.name AS name, n.score AS score LIMIT 2 RETURN name, score",
+ )
+
+ assert stable.applied
+ assert stable.reason == "with_unordered_paging"
+ assert "WITH n.name AS name, n.score AS score ORDER BY name, score LIMIT 2" in stable.cypher
+ assert "ORDER BY 1, 2\nFETCH FIRST 2 ROWS ONLY" in stable.oracle_sqlpgq
+
+
+def test_compare_uses_primary_key_tiebreaker_for_with_graph_variable():
+ stable = stable_execution_queries(
+ (
+ "WITH stage_1 AS (\n"
+ " SELECT dc_VALUE, asset_count\n"
+ " FROM graph_table(...)\n"
+ " ORDER BY asset_count DESC\n"
+ " FETCH FIRST 5 ROWS ONLY\n"
+ ")\n"
+ "SELECT dc_VALUE, asset_count FROM stage_1"
+ ),
+ (
+ "MATCH (dc:DataConsumer)-[:Consumes]->(da:DataAsset) "
+ "WITH dc, COUNT(da) AS asset_count ORDER BY asset_count DESC LIMIT 5 "
+ "RETURN dc.name, asset_count"
+ ),
+ {"DataConsumer": "DataConsumer_id"},
+ )
+
+ assert stable.applied
+ assert stable.reason == "with_ordered_paging_tiebreaker"
+ assert "ORDER BY asset_count DESC, dc.DataConsumer_id LIMIT 5" in stable.cypher
+ assert "ORDER BY asset_count DESC, 1, 2" in stable.oracle_sqlpgq
+ assert "2FETCH" not in stable.oracle_sqlpgq
+
+
+def test_compare_does_not_stabilize_bare_entity_or_path_limit():
+ node_return = stable_execution_queries(
+ "SELECT n_VALUE FROM graph_table(...) FETCH FIRST 10 ROWS ONLY",
+ "MATCH (n:User) RETURN n LIMIT 10",
+ )
+ path_return = stable_execution_queries(
+ "SELECT p_n1_ID FROM graph_table(...) FETCH FIRST 1 ROWS ONLY",
+ "MATCH p = (n)-[r]->(m) RETURN p LIMIT 1",
+ )
+
+ assert not node_return.applied
+ assert not path_return.applied
+
+
+def test_compare_detects_stage_limit_queries():
+ assert is_nondeterministic_with_limit_without_order(
+ "MATCH (n) WITH n.name AS name LIMIT 2 RETURN name"
+ )
+ assert is_with_order_by_limit_query(
+ "MATCH (n) WITH n.name AS name ORDER BY name LIMIT 2 RETURN name"
+ )
+
+
+def test_compare_executes_matching_nondeterministic_limit_query():
+ class FakeOracle:
+ def __init__(self):
+ self.query = ""
+
+ def execute_query(self, query: str, **kwargs):
+ self.query = query
+ return QueryResult(QueryStatus.SUCCESS, data=[{"name": "A"}])
+
+ class FakeNeo4j:
+ primary_by_label = {}
+
+ def __init__(self):
+ self.query = ""
+
+ def execute(self, query: str, timeout_s=None):
+ self.query = query
+ return "success", [{"name": "A"}], ""
+
+ args = type("Args", (), {"oracle_timeout_ms": 0, "neo4j_timeout_s": 0})()
+ oracle = FakeOracle()
+ neo4j = FakeNeo4j()
+ comparison = compare_record(
+ {
+ "oracle_sqlpgq": "SELECT 'A' AS name FROM dual FETCH FIRST 1 ROWS ONLY",
+ "oracle_source_query": "MATCH (n) RETURN n.name LIMIT 1",
+ },
+ oracle,
+ neo4j,
+ args,
+ )
+
+ assert comparison["matched"]
+ assert "ORDER BY 1" in oracle.query
+ assert "ORDER BY n.name LIMIT 1" in neo4j.query
+
+
+def test_compare_fails_stabilized_mismatched_limit_query():
+ class FakeOracle:
+ def execute_query(self, query: str, **kwargs):
+ return QueryResult(QueryStatus.SUCCESS, data=[{"name": "A"}])
+
+ class FakeNeo4j:
+ primary_by_label = {}
+
+ def execute(self, query: str, timeout_s=None):
+ return "success", [{"name": "B"}], ""
+
+ args = type("Args", (), {"oracle_timeout_ms": 0, "neo4j_timeout_s": 0})()
+ comparison = compare_record(
+ {
+ "oracle_sqlpgq": "SELECT 'A' AS name FROM dual FETCH FIRST 1 ROWS ONLY",
+ "oracle_source_query": "MATCH (n) RETURN n.name LIMIT 1",
+ },
+ FakeOracle(),
+ FakeNeo4j(),
+ args,
+ )
+
+ assert not comparison["matched"]
+ assert comparison["reason"] == "result_mismatch"
+ assert comparison["deterministic_ordering"]["reason"] == "unordered_paging"
+
+
+def test_compare_skips_unsafe_mismatched_nondeterministic_limit_query():
+ class FakeOracle:
+ def execute_query(self, query: str, **kwargs):
+ return QueryResult(QueryStatus.SUCCESS, data=[{"n": {"KEY_VALUE": {"id": 1}}}])
+
+ class FakeNeo4j:
+ primary_by_label = {}
+
+ def execute(self, query: str, timeout_s=None):
+ return "success", [{"n": {"KEY_VALUE": {"id": 2}}}], ""
+
+ args = type("Args", (), {"oracle_timeout_ms": 0, "neo4j_timeout_s": 0})()
+ comparison = compare_record(
+ {
+ "oracle_sqlpgq": "SELECT n_VALUE AS n FROM graph_table(...) FETCH FIRST 1 ROWS ONLY",
+ "oracle_source_query": "MATCH (n) RETURN n LIMIT 1",
+ },
+ FakeOracle(),
+ FakeNeo4j(),
+ args,
+ )
+
+ assert not comparison["matched"]
+ assert comparison["reason"] == "nondeterministic_limit_without_order"
+
+
+def test_compare_skips_schema_invalid_source_before_execution():
+ class FakeOracle:
+ def execute_query(self, query: str, **kwargs):
+ raise AssertionError("Oracle should not execute schema-invalid source queries")
+
+ class FakeNeo4j:
+ primary_by_label = {}
+ cypher_schema = _schema(
+ {
+ "schema": [
+ {
+ "label": "DataConsumer",
+ "type": "VERTEX",
+ "properties": [{"name": "DataConsumer_id"}],
+ },
+ {
+ "label": "DataAsset",
+ "type": "VERTEX",
+ "properties": [{"name": "DataAsset_id"}],
+ },
+ {
+ "label": "Consumes",
+ "type": "EDGE",
+ "constraints": [["DataConsumer", "DataAsset"]],
+ "properties": [],
+ },
+ ]
+ }
+ )
+
+ def source_validation_issues(self, query: str):
+ return self.cypher_schema.validation_issues(query)
+
+ def execute(self, query: str, timeout_s=None):
+ raise AssertionError("Neo4j should not execute schema-invalid source queries")
+
+ args = type("Args", (), {"oracle_timeout_ms": 0, "neo4j_timeout_s": 0})()
+ comparison = compare_record(
+ {
+ "oracle_sqlpgq": "SELECT 1 AS value FROM dual",
+ "oracle_source_query": (
+ "MATCH (da:DataAsset)-[:Consumes]->(dc:DataConsumer) RETURN da"
+ ),
+ },
+ FakeOracle(),
+ FakeNeo4j(),
+ args,
+ )
+
+ assert comparison["reason"] == "source_invalid"
+ assert comparison["oracle_status"] == "not_executed"
+ assert "invalid_schema_direction" in comparison["neo4j_error"]
+
+
+def test_compare_classifies_failed_neo4j_query_as_source_invalid():
+ class FakeOracle:
+ def execute_query(self, query: str, **kwargs):
+ return QueryResult(QueryStatus.SUCCESS, data=[{"name": "A"}])
+
+ class FakeNeo4j:
+ primary_by_label = {}
+
+ def execute(self, query: str, timeout_s=None):
+ return "syntax_error", [], "Invalid input"
+
+ args = type("Args", (), {"oracle_timeout_ms": 0, "neo4j_timeout_s": 0})()
+ comparison = compare_record(
+ {
+ "oracle_sqlpgq": "SELECT 'A' AS name FROM dual",
+ "oracle_source_query": "MATCH (n) RETURN n.invalid.source",
+ },
+ FakeOracle(),
+ FakeNeo4j(),
+ args,
+ )
+
+ assert not comparison["matched"]
+ assert comparison["reason"] == "source_invalid"
+ assert comparison["neo4j_status"] == "syntax_error"
+ assert "result_diagnostics" not in comparison
+
+
+def test_compare_result_mismatch_reports_full_result_diagnostics():
+ oracle_rows = [{"name": "A"}, {"name": "A"}, {"name": "B"}]
+ neo4j_rows = [{"name": "A"}, {"name": "C"}]
+
+ diagnostics = result_diagnostics(oracle_rows, neo4j_rows)
+
+ assert diagnostics["oracle_row_count"] == 3
+ assert diagnostics["neo4j_row_count"] == 2
+ assert diagnostics["missing_from_neo4j_count"] == 2
+ assert diagnostics["extra_in_neo4j_count"] == 1
+ assert diagnostics["missing_from_neo4j_sample"] == [["A"], ["B"]]
+ assert diagnostics["extra_in_neo4j_sample"] == [["C"]]
+
+ comparison = comparison_result(
+ False,
+ "result_mismatch",
+ "MATCH (n) RETURN n.name",
+ "SELECT name FROM graph_table(...)",
+ "success",
+ "success",
+ "",
+ "",
+ oracle_rows,
+ neo4j_rows,
+ )
+
+ assert comparison["result_diagnostics"] == diagnostics
+
+
+def test_compare_matches_tiny_numeric_aggregate_deltas():
+ class FakeOracle:
+ def execute_query(self, query: str, **kwargs):
+ return QueryResult(
+ QueryStatus.SUCCESS,
+ data=[
+ {"account": "A", "amount": 4037.349344},
+ {"account": "B", "amount": 1250.577610},
+ ],
+ )
+
+ class FakeNeo4j:
+ primary_by_label = {}
+
+ def execute(self, query: str, timeout_s=None):
+ return (
+ "success",
+ [
+ {"account": "B", "amount": 1250.577578},
+ {"account": "A", "amount": 4037.349339},
+ ],
+ "",
+ )
+
+ args = type("Args", (), {"oracle_timeout_ms": 0, "neo4j_timeout_s": 0})()
+ comparison = compare_record(
+ {
+ "oracle_sqlpgq": "SELECT account, amount FROM graph_table(...)",
+ "oracle_source_query": "MATCH (n) RETURN n.account, sum(n.amount)",
+ },
+ FakeOracle(),
+ FakeNeo4j(),
+ args,
+ )
+
+ assert comparison["matched"]
+ assert comparison["reason"] == ""
+ assert "result_diagnostics" not in comparison
+
+
+def test_compare_does_not_mask_semantic_numeric_deltas():
+ class FakeOracle:
+ def execute_query(self, query: str, **kwargs):
+ return QueryResult(QueryStatus.SUCCESS, data=[{"account": "A", "amount": 10.0}])
+
+ class FakeNeo4j:
+ primary_by_label = {}
+
+ def execute(self, query: str, timeout_s=None):
+ return "success", [{"account": "A", "amount": 10.5}], ""
+
+ args = type("Args", (), {"oracle_timeout_ms": 0, "neo4j_timeout_s": 0})()
+ comparison = compare_record(
+ {
+ "oracle_sqlpgq": "SELECT account, amount FROM graph_table(...)",
+ "oracle_source_query": "MATCH (n) RETURN n.account, sum(n.amount)",
+ },
+ FakeOracle(),
+ FakeNeo4j(),
+ args,
+ )
+
+ assert not comparison["matched"]
+ assert comparison["reason"] == "result_mismatch"
+
+
+def test_compare_fails_ordered_limit_mismatch_without_boundary_tie():
+ class FakeOracle:
+ def __init__(self):
+ self.calls = 0
+
+ def execute_query(self, query: str, **kwargs):
+ self.calls += 1
+ if self.calls == 1:
+ return QueryResult(QueryStatus.SUCCESS, data=[{"name": "A", "score": 1}])
+ return QueryResult(
+ QueryStatus.SUCCESS,
+ data=[{"name": "A", "score": 1}, {"name": "C", "score": 2}],
+ )
+
+ class FakeNeo4j:
+ primary_by_label = {}
+
+ def __init__(self):
+ self.calls = 0
+
+ def execute(self, query: str, timeout_s=None):
+ self.calls += 1
+ if self.calls == 1:
+ return "success", [{"name": "B", "score": 2}], ""
+ return "success", [{"name": "B", "score": 2}, {"name": "D", "score": 3}], ""
+
+ args = type("Args", (), {"oracle_timeout_ms": 0, "neo4j_timeout_s": 0})()
+ comparison = compare_record(
+ {
+ "oracle_sqlpgq": (
+ "SELECT 'A' AS name, 1 AS score FROM dual ORDER BY score FETCH FIRST 1 ROWS ONLY"
+ ),
+ "oracle_source_query": (
+ "MATCH (n) RETURN n.name AS name, n.score AS score ORDER BY score LIMIT 1"
+ ),
+ },
+ FakeOracle(),
+ FakeNeo4j(),
+ args,
+ )
+
+ assert not comparison["matched"]
+ assert comparison["reason"] == "result_mismatch"
+
+
+def test_compare_skips_ordered_limit_mismatch_with_boundary_tie():
+ class FakeOracle:
+ def __init__(self):
+ self.calls = 0
+
+ def execute_query(self, query: str, **kwargs):
+ self.calls += 1
+ if self.calls == 1:
+ return QueryResult(QueryStatus.SUCCESS, data=[{"n": "A", "score": 1}])
+ return QueryResult(
+ QueryStatus.SUCCESS,
+ data=[{"n": "A", "score": 1}, {"n": "C", "score": 1}],
+ )
+
+ class FakeNeo4j:
+ primary_by_label = {}
+
+ def execute(self, query: str, timeout_s=None):
+ return "success", [{"n": "B", "score": 1}, {"n": "D", "score": 1}], ""
+
+ args = type("Args", (), {"oracle_timeout_ms": 0, "neo4j_timeout_s": 0})()
+ comparison = compare_record(
+ {
+ "oracle_sqlpgq": (
+ "SELECT 'A' AS n, 1 AS score FROM dual ORDER BY score FETCH FIRST 1 ROWS ONLY"
+ ),
+ "oracle_source_query": ("MATCH (n) RETURN n ORDER BY score LIMIT 1"),
+ },
+ FakeOracle(),
+ FakeNeo4j(),
+ args,
+ )
+
+ assert not comparison["matched"]
+ assert comparison["reason"] == "suspected_order_by_limit_tie"
+
+
+def test_compare_selects_offset_query_ranges():
+ records = [{"id": index} for index in range(5)]
+
+ assert select_records_for_range(records, query_offset=2, limit_queries=2) == [
+ {"id": 2},
+ {"id": 3},
+ ]
+ assert select_records_for_range(records, query_offset=3) == [{"id": 3}, {"id": 4}]
+ assert select_records_for_range(records, query_offset=-10, limit_queries=1) == [{"id": 0}]
+
+
+def test_neo4j_compare_prepares_string_backed_boolean_literals():
+ loader = DatasetNeo4jLoader.__new__(DatasetNeo4jLoader)
+ loader.property_types_by_label = {
+ "Question": {"answered": "STRING"},
+ "Answer": {"is_accepted": "STRING"},
+ "Product": {"discontinued": "STRING"},
+ "Role": {"is_compliant": "BOOL"},
+ }
+
+ assert (
+ loader.prepare_query("MATCH (q:Question {answered: true}) RETURN q")
+ == "MATCH (q:Question {answered: 'true'}) RETURN q"
+ )
+ assert (
+ loader.prepare_query("MATCH (p:Product) WHERE p.discontinued = false RETURN p")
+ == "MATCH (p:Product) WHERE p.discontinued = 'false' RETURN p"
+ )
+ assert (
+ loader.prepare_query("MATCH (q:Question) WHERE NOT q.answered RETURN q")
+ == "MATCH (q:Question) WHERE q.answered = 'false' RETURN q"
+ )
+ assert (
+ loader.prepare_query("MATCH (r:Role) WHERE r.is_compliant = true RETURN r")
+ == "MATCH (r:Role) WHERE r.is_compliant = true RETURN r"
+ )
+
+
+def test_neo4j_compare_prepares_string_backed_date_comparisons():
+ loader = DatasetNeo4jLoader.__new__(DatasetNeo4jLoader)
+ loader.property_types_by_label = {
+ "Director": {"died": "STRING"},
+ "Movie": {"release_date": "STRING"},
+ "Question": {"createdAt": "STRING"},
+ "Event": {"event_date": "DATE"},
+ }
+
+ assert (
+ loader.prepare_query("MATCH (d:Director) WHERE d.died > date('2000-01-01') RETURN d")
+ == "MATCH (d:Director) WHERE date(d.died) > date('2000-01-01') RETURN d"
+ )
+ assert (
+ loader.prepare_query("MATCH (m:Movie) WHERE date('2000-01-01') > m.release_date RETURN m")
+ == "MATCH (m:Movie) WHERE date('2000-01-01') > date(m.release_date) RETURN m"
+ )
+ assert (
+ loader.prepare_query("MATCH (e:Event) WHERE e.event_date > date('2000-01-01') RETURN e")
+ == "MATCH (e:Event) WHERE e.event_date > date('2000-01-01') RETURN e"
+ )
+ assert (
+ loader.prepare_query("MATCH (q:Question) WHERE date(q.createdAt).day = 1 RETURN q")
+ == "MATCH (q:Question) WHERE datetime(q.createdAt).day = 1 RETURN q"
+ )
+ assert (
+ loader.prepare_query("MATCH (q:Question) WHERE q.createdAt.day = 1 RETURN q")
+ == "MATCH (q:Question) WHERE datetime(q.createdAt).day = 1 RETURN q"
+ )
+
+
+def test_neo4j_compare_prepares_string_backed_numeric_comparisons():
+ loader = DatasetNeo4jLoader.__new__(DatasetNeo4jLoader)
+ loader.property_types_by_label = {
+ "Answer": {"uuid": "STRING"},
+ "Product": {"unitsOnOrder": "STRING", "reorderLevel": "INT64"},
+ "Order": {"freight": "STRING"},
+ }
+
+ assert (
+ loader.prepare_query("MATCH (p:Product) WHERE p.unitsOnOrder > 50 RETURN p")
+ == "MATCH (p:Product) WHERE p.unitsOnOrder > '50' RETURN p"
+ )
+ assert (
+ loader.prepare_query("MATCH (o:Order) WHERE 100 < o.freight RETURN o")
+ == "MATCH (o:Order) WHERE '100' < o.freight RETURN o"
+ )
+ assert (
+ loader.prepare_query("MATCH (p:Product) WHERE p.reorderLevel > 50 RETURN p")
+ == "MATCH (p:Product) WHERE p.reorderLevel > 50 RETURN p"
+ )
+ assert (
+ loader.prepare_query("MATCH (a:Answer {uuid: 69273049}) RETURN a")
+ == "MATCH (a:Answer {uuid: '69273049'}) RETURN a"
+ )
+
+
+def test_neo4j_compare_rewrites_sanitized_schema_aliases():
+ loader = DatasetNeo4jLoader.__new__(DatasetNeo4jLoader)
+ loader.vertex_labels = {"characters", "voice_actors"}
+ loader.edge_labels = {"HERO"}
+ loader.property_types_by_label = {
+ "characters": {"movie_title": "STRING"},
+ "voice_actors": {"voice_actor": "STRING", "movie": "STRING"},
+ }
+ loader.node_label_aliases = loader._schema_name_aliases(loader.vertex_labels)
+ loader.edge_type_aliases = loader._schema_name_aliases(loader.edge_labels)
+ loader.property_aliases_by_label = {
+ label: loader._schema_name_aliases(properties)
+ for label, properties in loader.property_types_by_label.items()
+ }
+ loader.global_property_aliases = loader._global_property_aliases()
+
+ assert loader.prepare_query(
+ "MATCH (t1:characters)-[hero:HERO]->(t2:`voice-actors`) "
+ "WHERE t2.movie = t1.movie_title AND t2.movie <> 'voice-actor' "
+ "RETURN t2.`voice-actor`"
+ ) == (
+ "MATCH (t1:characters)-[hero:HERO]->(t2:voice_actors) "
+ "WHERE t2.movie = t1.movie_title AND t2.movie <> 'voice-actor' "
+ "RETURN t2.voice_actor"
+ )
+
+
+def test_neo4j_compare_rewrites_identity_and_adjacent_edge_properties():
+ loader = DatasetNeo4jLoader.__new__(DatasetNeo4jLoader)
+ config = {
+ "schema": [
+ {
+ "label": "PaymentTransaction",
+ "type": "VERTEX",
+ "primary": "transaction_id",
+ "properties": [{"name": "transaction_id", "type": "STRING"}],
+ },
+ {
+ "label": "USER",
+ "type": "VERTEX",
+ "primary": "user_id",
+ "properties": [{"name": "user_id", "type": "STRING"}],
+ },
+ {
+ "label": "REPORT",
+ "type": "VERTEX",
+ "primary": "report_id",
+ "properties": [{"name": "report_id", "type": "STRING"}],
+ },
+ {
+ "label": "Approves",
+ "type": "EDGE",
+ "constraints": [["USER", "REPORT"]],
+ "properties": [
+ {"name": "EDGE_ID", "type": "INT64"},
+ {"name": "approval_date", "type": "TIMESTAMP"},
+ ],
+ },
+ ]
+ }
+ loader.cypher_schema = CypherSchema(config)
+ loader.vertex_labels = {"PaymentTransaction", "USER", "REPORT"}
+ loader.edge_labels = {"Approves"}
+ loader.primary_by_label = {"PaymentTransaction": "transaction_id", "USER": "user_id"}
+ loader.property_types_by_label = loader.cypher_schema.property_types_by_label
+ loader.node_label_aliases = loader._schema_name_aliases(loader.vertex_labels)
+ loader.edge_type_aliases = loader._schema_name_aliases(loader.edge_labels)
+ loader.property_aliases_by_label = {
+ label: loader._schema_name_aliases(properties)
+ for label, properties in loader.property_types_by_label.items()
+ }
+ loader.global_property_aliases = loader._global_property_aliases()
+
+ assert (
+ loader.prepare_query("MATCH (n:PaymentTransaction) RETURN count(n.identity), count(n.id)")
+ == "MATCH (n:PaymentTransaction) RETURN count(n.transaction_id), "
+ "count(n.transaction_id)"
+ )
+ assert (
+ loader.prepare_query("MATCH (u:USER)-[r:Approves]->(report:REPORT) RETURN r.identity")
+ == "MATCH (u:USER)-[r:Approves]->(report:REPORT) RETURN r.EDGE_ID"
+ )
+ assert (
+ loader.prepare_query(
+ "MATCH (approver:USER)-[r:Approves]->(report:REPORT) RETURN approver.approval_date"
+ )
+ == "MATCH (approver:USER)-[r:Approves]->(report:REPORT) RETURN r.approval_date"
+ )
+
+
+def test_neo4j_compare_preserves_real_id_property_over_pseudo_identity():
+ loader = DatasetNeo4jLoader.__new__(DatasetNeo4jLoader)
+ config = {
+ "schema": [
+ {
+ "label": "Question",
+ "type": "VERTEX",
+ "primary": "vid",
+ "properties": [
+ {"name": "vid", "type": "STRING"},
+ {"name": "id", "type": "INT64"},
+ {"name": "title", "type": "STRING"},
+ ],
+ },
+ {
+ "label": "PaymentTransaction",
+ "type": "VERTEX",
+ "primary": "transaction_id",
+ "properties": [{"name": "transaction_id", "type": "STRING"}],
+ },
+ ]
+ }
+ loader.cypher_schema = CypherSchema(config)
+ loader.vertex_labels = {"Question", "PaymentTransaction"}
+ loader.edge_labels = set()
+ loader.primary_by_label = {
+ "Question": "vid",
+ "PaymentTransaction": "transaction_id",
+ }
+ loader.property_types_by_label = loader.cypher_schema.property_types_by_label
+ loader.node_label_aliases = loader._schema_name_aliases(loader.vertex_labels)
+ loader.edge_type_aliases = {}
+ loader.property_aliases_by_label = {
+ label: loader._schema_name_aliases(properties)
+ for label, properties in loader.property_types_by_label.items()
+ }
+ loader.global_property_aliases = loader._global_property_aliases()
+
+ assert (
+ loader.prepare_query("MATCH (q:Question) RETURN q.id, q.title")
+ == "MATCH (q:Question) RETURN q.id, q.title"
+ )
+ assert (
+ loader.prepare_query("MATCH (n:PaymentTransaction) RETURN n.id, n.identity")
+ == "MATCH (n:PaymentTransaction) RETURN n.transaction_id, n.transaction_id"
+ )
+
+
+def test_loader_uses_vertex_file_when_edge_label_collides():
+ loader = DatasetOracleLoader.__new__(DatasetOracleLoader)
+ loader.config = {
+ "files": [
+ {"label": "zip_data", "path": "zip_data.csv", "columns": ["_id"]},
+ {
+ "label": "zip_data",
+ "path": "statezip_dataCBSA.csv",
+ "SRC_ID": "state",
+ "DST_ID": "CBSA",
+ "columns": ["SRC_ID", "DST_ID"],
+ },
+ ]
+ }
+ loader.manifest = {
+ "vertices": [{"label": "zip_data", "table": "zip_data", "columns": []}],
+ "edges": [],
+ }
+ calls = []
+
+ def fake_load_file(item, file_item, is_edge=False):
+ calls.append((item["label"], file_item["path"], is_edge))
+ return 1
+
+ loader._load_file = fake_load_file
+
+ assert loader._load_csv_files() == {"zip_data": 1}
+ assert calls == [("zip_data", "zip_data.csv", False)]
+
+
+def test_loader_does_not_reuse_edge_file_for_different_constraint():
+ loader = DatasetOracleLoader.__new__(DatasetOracleLoader)
+ loader.config = {
+ "files": [
+ {
+ "label": "GENERATES",
+ "path": "GENERATES_Device.csv",
+ "SRC_ID": "DEVICE",
+ "DST_ID": "ALERT",
+ "columns": ["SRC_ID", "DST_ID"],
+ }
+ ]
+ }
+ loader.manifest = {
+ "vertices": [],
+ "edges": [
+ {
+ "label": "GENERATES",
+ "src": "DEVICE",
+ "dst": "ALERT",
+ "table": "DEVICE_GENERATES_ALERT",
+ "columns": [],
+ },
+ {
+ "label": "GENERATES",
+ "src": "SENSOR",
+ "dst": "ALERT",
+ "table": "SENSOR_GENERATES_ALERT",
+ "columns": [],
+ },
+ ],
+ }
+ calls = []
+
+ def fake_load_file(item, file_item, is_edge=False):
+ calls.append((item["table"], file_item["path"], is_edge))
+ return 1
+
+ loader._load_file = fake_load_file
+
+ assert loader._load_csv_files() == {"DEVICE_GENERATES_ALERT": 1}
+ assert calls == [("DEVICE_GENERATES_ALERT", "GENERATES_Device.csv", True)]
+
+
+def test_loader_exposes_oracle_graph_label_map_for_collisions():
+ loader = DatasetOracleLoader.__new__(DatasetOracleLoader)
+ loader.manifest = {
+ "vertices": [{"label": "book", "graph_label": "book"}],
+ "edges": [
+ {"label": "book", "graph_label": "book_language_book_publisher"},
+ {"label": "book", "graph_label": "book_author_book"},
+ ],
+ }
+
+ assert loader.node_label_map() == {"book": ["book"]}
+ assert loader.edge_label_map() == {"book": ["book_language_book_publisher", "book_author_book"]}
+
+
+def test_loader_exposes_file_stem_label_aliases():
+ loader = DatasetOracleLoader.__new__(DatasetOracleLoader)
+ loader.config = {
+ "files": [
+ {"label": "InfoSource", "path": "Source.csv", "columns": ["infosource_id"]},
+ ]
+ }
+ loader.manifest = {
+ "vertices": [{"label": "InfoSource", "graph_label": "InfoSource"}],
+ "edges": [],
+ }
+
+ assert loader.node_label_map() == {
+ "InfoSource": ["InfoSource"],
+ "Source": ["InfoSource"],
+ }
+
+
+def test_dataset_loader_manifest_is_tolerant_for_benchmark_data(tmp_path: Path):
+ config = {
+ "schema": [
+ {
+ "label": "A",
+ "type": "VERTEX",
+ "primary": "id",
+ "properties": [
+ {"name": "id", "type": "INT64"},
+ {"name": "score", "type": "INT64"},
+ ],
+ },
+ {
+ "label": "B",
+ "type": "VERTEX",
+ "primary": "id",
+ "properties": [
+ {"name": "id", "type": "INT64"},
+ {"name": "score", "type": "STRING"},
+ ],
+ },
+ {
+ "label": "REL",
+ "type": "EDGE",
+ "constraints": [["A", "B"]],
+ "properties": [],
+ },
+ ],
+ "files": [],
+ }
+ config_path = tmp_path / "import_config.json"
+ config_path.write_text(json.dumps(config), encoding="utf-8")
+
+ loader = DatasetOracleLoader.__new__(DatasetOracleLoader)
+ loader.import_config_path = config_path
+ loader.graph_name = "G"
+ manifest = loader._build_manifest()
+
+ assert "FOREIGN KEY" not in manifest["table_ddl"]
+ assert "ENFORCED MODE" not in manifest["property_graph_ddl"]
+ score_types = [
+ column["type"]
+ for item in manifest["vertices"]
+ for column in item["columns"]
+ if column["name"] == "score"
+ ]
+ assert score_types == ["VARCHAR2(4000)", "VARCHAR2(4000)"]
diff --git a/test/test_oracle_sqlpgq.py b/test/test_oracle_sqlpgq.py
new file mode 100644
index 0000000..1076c2d
--- /dev/null
+++ b/test/test_oracle_sqlpgq.py
@@ -0,0 +1,153 @@
+from pathlib import Path
+
+from app.core.clauses.match_clause import EdgePattern, MatchClause, NodePattern, PathPattern
+from app.core.clauses.return_clause import ReturnBody, ReturnClause, ReturnItem, SortItem
+from app.core.clauses.where_clause import CompareExpression, WhereClause
+from app.core.schema.edge import Edge
+from app.core.schema.node import Node
+from app.core.schema.schema_graph import SchemaGraph
+from app.core.validator.db_client import DB_Client, QueryResult, QueryStatus
+from app.core.validator.validator import CorpusValidator
+from app.impl.oracle_sqlpgq.ast_visitor.oracle_sqlpgq_ast_visitor import OracleSqlPgqAstVisitor
+from app.impl.oracle_sqlpgq.schema.schema_parser import OracleSqlPgqSchemaParser
+from app.impl.oracle_sqlpgq.translator.oracle_sqlpgq_query_translator import (
+ OracleSqlPgqQueryTranslator,
+)
+
+
+def _schema_graph() -> SchemaGraph:
+ graph = SchemaGraph("movie_graph")
+ graph.add_node(
+ Node(
+ label="PERSON",
+ primary="PERSON_id",
+ properties=[
+ {"name": "PERSON_id", "type": "INT64"},
+ {"name": "name", "type": "STRING"},
+ ],
+ )
+ )
+ graph.add_node(
+ Node(
+ label="MOVIE",
+ primary="MOVIE_id",
+ properties=[
+ {"name": "MOVIE_id", "type": "INT64"},
+ {"name": "title", "type": "STRING"},
+ ],
+ )
+ )
+ graph.add_edge(
+ Edge(
+ label="ACTED_IN",
+ properties=[{"name": "role", "type": "STRING"}],
+ src_dst_list=[["PERSON", "MOVIE"]],
+ )
+ )
+ return graph
+
+
+def _query_pattern():
+ path = PathPattern(
+ node_pattern_list=[
+ NodePattern("p", "PERSON", []),
+ NodePattern("m", "MOVIE", []),
+ ],
+ edge_pattern_list=[EdgePattern("a", "ACTED_IN", [], "right")],
+ )
+ return [
+ MatchClause(path),
+ WhereClause(CompareExpression("p", "name", "equal", "'Tom Hanks'")),
+ ReturnClause(
+ ReturnBody(
+ return_item_list=[ReturnItem("m", "title", "movie_title")],
+ sort_item_list=[SortItem("m", "title", "ASC")],
+ skip=5,
+ limit=10,
+ )
+ ),
+ ]
+
+
+def test_oracle_schema_parser_generates_table_and_graph_ddl(tmp_path):
+ parser = OracleSqlPgqSchemaParser(db_id="movie_graph")
+ manifest_path = parser.save_schema_to_file(tmp_path, _schema_graph(), "movie", "oracle")
+
+ manifest_file = Path(manifest_path)
+ table_ddl = (tmp_path / "movie_oracle_oracle_tables.sql").read_text()
+ graph_ddl = (tmp_path / "movie_oracle_oracle_property_graph.sql").read_text()
+
+ assert manifest_file.exists()
+ assert 'CREATE TABLE "PERSON"' in table_ddl
+ assert 'CREATE TABLE "PERSON_ACTED_IN_MOVIE"' in table_ddl
+ assert 'CREATE OR REPLACE PROPERTY GRAPH "movie_oracle"' in graph_ddl
+ assert "VERTEX TABLES" in graph_ddl
+ assert "EDGE TABLES" in graph_ddl
+ assert "ENFORCED MODE" in graph_ddl
+
+
+def test_oracle_sqlpgq_translator_renders_graph_table_query():
+ translator = OracleSqlPgqQueryTranslator(graph_name="MOVIE_GRAPH")
+
+ query = translator.translate(_query_pattern())
+
+ assert query.startswith("SELECT *\nFROM GRAPH_TABLE")
+ assert '"MOVIE_GRAPH"' in query
+ assert 'MATCH (p IS "PERSON")-[a IS "ACTED_IN"]->(m IS "MOVIE")' in query
+ assert 'WHERE p."name" = \'Tom Hanks\'' in query
+ assert 'COLUMNS (m."title" AS movie_title)' in query
+ assert "ORDER BY movie_title ASC" in query
+ assert "OFFSET 5 ROWS" in query
+ assert "FETCH FIRST 10 ROWS ONLY" in query
+ assert translator.grammar_check(query)
+
+
+def test_oracle_sqlpgq_translator_projects_element_ids_when_returning_whole_nodes():
+ translator = OracleSqlPgqQueryTranslator(graph_name="MOVIE_GRAPH")
+ path = PathPattern([NodePattern("p", "PERSON", [])], [])
+
+ query = translator.translate(
+ [MatchClause(path), ReturnClause(ReturnBody([ReturnItem("p", "", "")], []))]
+ )
+
+ assert "VERTEX_ID(p) AS p_VALUE" in query
+
+
+def test_oracle_sqlpgq_ast_visitor_parses_translator_subset():
+ query = OracleSqlPgqQueryTranslator(graph_name="MOVIE_GRAPH").translate(_query_pattern())
+
+ success, clauses = OracleSqlPgqAstVisitor().get_query_pattern(query)
+
+ assert success
+ assert isinstance(clauses[0], MatchClause)
+ assert isinstance(clauses[1], WhereClause)
+ assert isinstance(clauses[2], ReturnClause)
+ assert clauses[2].return_body.sort_item_list[0].symbolic_name == "movie_title"
+ assert clauses[2].return_body.sort_item_list[0].order == "ASC"
+ assert clauses[2].return_body.skip == 5
+ assert clauses[2].return_body.limit == 10
+
+
+class FakeDBClient(DB_Client):
+ client = object()
+
+ def create_client(self, db_client_params: dict):
+ return self.client
+
+ def execute_query(self, query: str) -> QueryResult:
+ if "empty" in query:
+ return QueryResult(QueryStatus.NO_RECORD, data=[])
+ return QueryResult(QueryStatus.SUCCESS, data=[{"ok": 1}])
+
+
+def test_corpus_validator_accepts_oracle_backend_with_injected_client():
+ validator = CorpusValidator(backend="oracle_sqlpgq", db_client=FakeDBClient())
+
+ valid = validator.execute_with_results(
+ [
+ {"question": "ok?", "query": "select ok"},
+ {"question": "empty?", "query": "select empty"},
+ ]
+ )
+
+ assert valid == [{"question": "ok?", "query": "select ok", "result": "[{'ok': 1}]"}]
diff --git a/test/test_oracle_sqlpgq_corpus_combiner.py b/test/test_oracle_sqlpgq_corpus_combiner.py
new file mode 100644
index 0000000..f0e883d
--- /dev/null
+++ b/test/test_oracle_sqlpgq_corpus_combiner.py
@@ -0,0 +1,79 @@
+import json
+
+from app.core.validator.db_client import DB_Client, QueryResult, QueryStatus
+from app.impl.oracle_sqlpgq.generator.corpus_combiner import OracleSqlPgqCorpusCombiner
+from app.core.validator.validator import CorpusValidator
+
+
+def test_oracle_corpus_combiner_normalizes_and_deduplicates(tmp_path):
+ first = tmp_path / "template.json"
+ second = tmp_path / "raw.json"
+ first.write_text(
+ json.dumps(
+ [
+ {
+ "question": "Show movies.",
+ "query": "SELECT * FROM GRAPH_TABLE (\"G\" MATCH (m IS \"MOVIE\") COLUMNS (m.\"title\" AS title)) gt",
+ "category": "one_hop",
+ "labels": ["MOVIE"],
+ }
+ ]
+ ),
+ encoding="utf-8",
+ )
+ second.write_text(
+ json.dumps(
+ [
+ {
+ "question": "Show movies.",
+ "query": "SELECT * FROM GRAPH_TABLE (\"G\" MATCH (m IS \"MOVIE\") COLUMNS (m.\"title\" AS title)) gt",
+ },
+ {
+ "question": "Count movies.",
+ "query": "SELECT gt.title, COUNT(*) AS c FROM GRAPH_TABLE (\"G\" MATCH (m IS \"MOVIE\") COLUMNS (m.\"title\" AS title)) gt GROUP BY gt.title",
+ },
+ ]
+ ),
+ encoding="utf-8",
+ )
+
+ records = OracleSqlPgqCorpusCombiner().combine_files([first, second], split=True)
+
+ assert len(records) == 2
+ assert records[0]["id"] == "oracle_sqlpgq_000001"
+ assert records[0]["backend"] == "oracle_sqlpgq"
+ assert records[1]["category"] == "aggregation"
+ assert {record["split"] for record in records} <= {"train", "dev", "test"}
+
+
+class FakeOracleClient(DB_Client):
+ client = object()
+
+ def create_client(self, db_client_params: dict):
+ return self.client
+
+ def execute_query(self, query: str) -> QueryResult:
+ if "BROKEN" in query:
+ return QueryResult(QueryStatus.CLIENT_ERROR, error="broken query")
+ return QueryResult(QueryStatus.SUCCESS, data=[{"ok": 1}])
+
+
+def test_oracle_corpus_combiner_can_live_validate_records(tmp_path):
+ source = tmp_path / "records.json"
+ source.write_text(
+ json.dumps(
+ [
+ {"question": "ok?", "query": "SELECT * FROM GRAPH_TABLE (\"G\" MATCH (n IS \"MOVIE\") COLUMNS (n.\"title\" AS title)) gt"},
+ {"question": "bad?", "query": "BROKEN"},
+ ]
+ ),
+ encoding="utf-8",
+ )
+ validator = CorpusValidator(backend="oracle_sqlpgq", db_client=FakeOracleClient())
+
+ records = OracleSqlPgqCorpusCombiner().combine_files([source], validator=validator)
+
+ assert records[0]["validation"] == "passed"
+ assert records[0]["result"] == "[{'ok': 1}]"
+ assert records[1]["validation"] == "failed"
+ assert records[1]["validation_error"] == "broken query"
diff --git a/test/test_oracle_sqlpgq_live.py b/test/test_oracle_sqlpgq_live.py
new file mode 100644
index 0000000..ae5db4f
--- /dev/null
+++ b/test/test_oracle_sqlpgq_live.py
@@ -0,0 +1,31 @@
+import os
+
+import pytest
+
+from app.core.validator.db_client import QueryStatus
+from app.impl.oracle_sqlpgq.db_client.oracle_db_client import OracleDBClient, oracledb
+
+
+pytestmark = pytest.mark.oracle
+
+
+@pytest.mark.skipif(oracledb is None, reason="oracledb package is not installed")
+@pytest.mark.skipif(
+ not all(os.getenv(name) for name in ["ORACLE_DSN", "ORACLE_USER", "ORACLE_PASSWORD"]),
+ reason="ORACLE_DSN, ORACLE_USER, and ORACLE_PASSWORD are required",
+)
+def test_oracle_db_client_live_smoke():
+ client = OracleDBClient(
+ {
+ "dsn": os.environ["ORACLE_DSN"],
+ "user": os.environ["ORACLE_USER"],
+ "password": os.environ["ORACLE_PASSWORD"],
+ }
+ )
+ try:
+ result = client.execute_query("SELECT 1 AS VALUE FROM dual")
+ assert result.status_code == QueryStatus.SUCCESS
+ assert result.data == [{"VALUE": 1}]
+ finally:
+ client.close()
+
diff --git a/test/test_oracle_sqlpgq_query_generalizer.py b/test/test_oracle_sqlpgq_query_generalizer.py
new file mode 100644
index 0000000..c59711f
--- /dev/null
+++ b/test/test_oracle_sqlpgq_query_generalizer.py
@@ -0,0 +1,79 @@
+from app.core.clauses.match_clause import EdgePattern, MatchClause, NodePattern, PathPattern
+from app.impl.oracle_sqlpgq.generator.query_generalizer import OracleSqlPgqQueryGeneralizer
+from app.impl.oracle_sqlpgq.utils.sqlpgq import validate_graph_table_query
+
+
+def _manifest():
+ return {
+ "graph_name": "MOVIE_GRAPH",
+ "vertices": [
+ {
+ "label": "PERSON",
+ "columns": [
+ {"name": "PERSON_id", "type": "NUMBER(19)"},
+ {"name": "name", "type": "VARCHAR2(4000)"},
+ ],
+ },
+ {
+ "label": "MOVIE",
+ "columns": [
+ {"name": "MOVIE_id", "type": "NUMBER(19)"},
+ {"name": "title", "type": "VARCHAR2(4000)"},
+ ],
+ },
+ {
+ "label": "GENRE",
+ "columns": [
+ {"name": "GENRE_id", "type": "NUMBER(19)"},
+ {"name": "name", "type": "VARCHAR2(4000)"},
+ ],
+ },
+ ],
+ "edges": [
+ {"label": "ACTED_IN", "src": "PERSON", "dst": "MOVIE"},
+ {"label": "BELONGS_TO", "src": "MOVIE", "dst": "GENRE"},
+ {"label": "SIMILAR_TO", "src": "MOVIE", "dst": "MOVIE"},
+ ],
+ }
+
+
+def _one_hop_pattern():
+ return [
+ MatchClause(
+ PathPattern(
+ [NodePattern("a", "", []), NodePattern("b", "", [])],
+ [EdgePattern("e", "", [], "right")],
+ )
+ )
+ ]
+
+
+def _two_hop_pattern():
+ return [
+ MatchClause(
+ PathPattern(
+ [NodePattern("a", "", []), NodePattern("b", "", []), NodePattern("c", "", [])],
+ [
+ EdgePattern("e1", "", [], "right"),
+ EdgePattern("e2", "", [], "right"),
+ ],
+ )
+ )
+ ]
+
+
+def test_oracle_query_generalizer_emits_oracle_queries_for_seed_path_length():
+ generated = OracleSqlPgqQueryGeneralizer(_manifest()).generalize(_one_hop_pattern())
+
+ assert len(generated) == 3
+ assert all(validate_graph_table_query(item.query) for item in generated)
+ assert any('MATCH (n1 IS "PERSON")-[e1 IS "ACTED_IN"]->(n2 IS "MOVIE")' in item.query for item in generated)
+
+
+def test_oracle_query_generalizer_supports_two_hop_shapes():
+ generated = OracleSqlPgqQueryGeneralizer(_manifest()).generalize(_two_hop_pattern())
+
+ assert generated
+ assert all(item.source_pattern_length == 2 for item in generated)
+ assert any(item.labels == ["PERSON", "ACTED_IN", "MOVIE", "BELONGS_TO", "GENRE"] for item in generated)
+
diff --git a/test/test_oracle_sqlpgq_template_instantiator.py b/test/test_oracle_sqlpgq_template_instantiator.py
new file mode 100644
index 0000000..421984a
--- /dev/null
+++ b/test/test_oracle_sqlpgq_template_instantiator.py
@@ -0,0 +1,78 @@
+import json
+
+from app.impl.oracle_sqlpgq.generator.template_instantiator import (
+ OracleSqlPgqTemplateInstantiator,
+)
+from app.impl.oracle_sqlpgq.utils.sqlpgq import validate_graph_table_query
+
+
+def _manifest():
+ return {
+ "backend": "oracle_sqlpgq",
+ "graph_name": "MOVIE_GRAPH",
+ "vertices": [
+ {
+ "label": "PERSON",
+ "columns": [
+ {"name": "PERSON_id", "type": "NUMBER(19)"},
+ {"name": "name", "type": "VARCHAR2(4000)"},
+ ],
+ },
+ {
+ "label": "MOVIE",
+ "columns": [
+ {"name": "MOVIE_id", "type": "NUMBER(19)"},
+ {"name": "title", "type": "VARCHAR2(4000)"},
+ ],
+ },
+ ],
+ "edges": [
+ {
+ "label": "ACTED_IN",
+ "src": "PERSON",
+ "dst": "MOVIE",
+ "columns": [
+ {"name": "EDGE_ID", "type": "NUMBER"},
+ {"name": "SRC_ID", "type": "NUMBER"},
+ {"name": "DST_ID", "type": "NUMBER"},
+ {"name": "role", "type": "VARCHAR2(4000)"},
+ ],
+ },
+ {
+ "label": "SIMILAR_TO",
+ "src": "MOVIE",
+ "dst": "MOVIE",
+ "columns": [
+ {"name": "EDGE_ID", "type": "NUMBER"},
+ {"name": "SRC_ID", "type": "NUMBER"},
+ {"name": "DST_ID", "type": "NUMBER"},
+ {"name": "score", "type": "BINARY_FLOAT"},
+ ],
+ },
+ ],
+ }
+
+
+def test_template_instantiator_generates_valid_graph_table_queries():
+ pairs = OracleSqlPgqTemplateInstantiator(_manifest()).generate()
+
+ assert pairs
+ assert all(validate_graph_table_query(pair.query) for pair in pairs)
+ assert any(pair.category == "one_hop_traversal" for pair in pairs)
+ assert any(pair.category == "aggregation" for pair in pairs)
+ assert any(pair.category == "bounded_path" for pair in pairs)
+ assert any("source_MOVIE_title" in pair.query for pair in pairs)
+ assert any("target_MOVIE_title" in pair.query for pair in pairs)
+
+
+def test_template_instantiator_can_load_manifest_file(tmp_path):
+ manifest_path = tmp_path / "oracle_schema.json"
+ manifest_path.write_text(json.dumps(_manifest()), encoding="utf-8")
+
+ pairs = OracleSqlPgqTemplateInstantiator.from_file(manifest_path).generate_dicts(
+ target_size=2,
+ include_metadata=False,
+ )
+
+ assert len(pairs) == 2
+ assert set(pairs[0]) == {"question", "query"}