Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions src/crate/client/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,27 @@
import typing as t
import warnings
from datetime import datetime, timedelta, timezone
from itertools import count

from .converter import Converter, DataType
from .exceptions import ProgrammingError

_NAMED_PARAM_RE = re.compile(r"%\(([^)]+)\)s")


def _rewrite_pyformat_sql(sql: str) -> str:
"""Replace %(name)s placeholders with $N positional markers (1-indexed)."""
counter = count(1)
return _NAMED_PARAM_RE.sub(lambda _: f"${next(counter)}", sql)


def _convert_named_to_positional(
sql: str, params: t.Dict[str, t.Any]
) -> t.Tuple[str, t.List[t.Any]]:
"""Convert pyformat-style named parameters to positional qmark parameters.
"""Convert pyformat-style named parameters to positional parameters.

Converts ``%(name)s`` placeholders to ``?`` and returns an ordered list
of corresponding values extracted from ``params``.
Converts ``%(name)s`` placeholders to ``$N`` (1-indexed) and returns an
ordered list of corresponding values extracted from ``params``.

The same name may appear multiple times; each occurrence appends the
value to the positional list independently.
Expand All @@ -47,7 +54,7 @@ def _convert_named_to_positional(

sql = "SELECT * FROM t WHERE a = %(a)s AND b = %(b)s"
params = {"a": 1, "b": 2}
# returns: ("SELECT * FROM t WHERE a = ? AND b = ?", [1, 2])
# returns: ("SELECT * FROM t WHERE a = $1 AND b = $2", [1, 2])
"""
positions = {}
idx = 1
Expand Down Expand Up @@ -91,8 +98,8 @@ def _convert_named_bulk_params(
for row in seq_of_dicts:
if not isinstance(row, dict):
raise ProgrammingError(
"executemany() requires all parameter rows to be dicts "
"when the SQL uses pyformat (%(name)s) placeholders"
"All bulk parameter rows must be dicts when SQL uses "
"pyformat (%(name)s) placeholders; got a non-dict row"
)
positional: t.List[t.Any] = [None] * n
for name, pos in positions.items():
Expand Down Expand Up @@ -136,6 +143,13 @@ def execute(self, sql, parameters=None, bulk_parameters=None):

if isinstance(parameters, dict):
sql, parameters = _convert_named_to_positional(sql, parameters)
elif bulk_parameters is not None and _NAMED_PARAM_RE.search(sql):
if bulk_parameters and isinstance(bulk_parameters[0], dict):
sql, bulk_parameters = _convert_named_bulk_params(
sql, bulk_parameters
)
else:
sql = _rewrite_pyformat_sql(sql)

self._result = self.connection.client.sql(
sql, parameters, bulk_parameters
Expand Down
66 changes: 65 additions & 1 deletion tests/client/test_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,9 @@ def test_executemany_with_mixed_param_types(mocked_connection):
parameter sequence mixes dicts and non-dicts while the SQL uses pyformat.
"""
cursor = mocked_connection.cursor()
with pytest.raises(ProgrammingError, match="requires all parameter rows"):
with pytest.raises(
ProgrammingError, match="All bulk parameter rows must be dicts"
):
cursor.executemany(
"INSERT INTO characters (name) VALUES (%(name)s)",
[{"name": "Arthur"}, ["Trillian"]], # second row is a list
Expand Down Expand Up @@ -329,6 +331,68 @@ def test_execute_with_bulk_args(mocked_connection):
mocked_connection.client.sql.assert_called_once_with(statement, None, [[1]])


def test_execute_with_pyformat_sql_and_bulk_parameters(mocked_connection):
"""
cursor.execute() converts %(name)s SQL to $N when bulk_parameters is
provided. Rows are already positional; only the SQL needs conversion.
"""
cursor = mocked_connection.cursor()
sql = "INSERT INTO t (id, val) VALUES (%(id)s, %(val)s)"
bulk = [[1, "hello"], [2, "world"]]
cursor.execute(sql, bulk_parameters=bulk)
mocked_connection.client.sql.assert_called_once_with(
"INSERT INTO t (id, val) VALUES ($1, $2)", None, bulk
)


def test_execute_with_pyformat_sql_and_dict_bulk_parameters(mocked_connection):
"""
cursor.execute() with pyformat SQL and dict-format bulk_parameters converts
both the SQL template (%(x)s → $N) and the rows (dicts → positional lists).
"""
cursor = mocked_connection.cursor()
sql = "INSERT INTO t (id, val) VALUES (%(id)s, %(val)s)"
bulk = [{"id": 1, "val": "hello"}, {"id": 2, "val": "world"}]
cursor.execute(sql, bulk_parameters=bulk)
mocked_connection.client.sql.assert_called_once_with(
"INSERT INTO t (id, val) VALUES ($1, $2)",
None,
[[1, "hello"], [2, "world"]],
)


def test_execute_with_dict_bulk_parameters_mixed_types_raises(
mocked_connection,
):
"""
cursor.execute() raises ProgrammingError when bulk_parameters mixes
dict and non-dict rows with pyformat SQL.
"""
cursor = mocked_connection.cursor()
with pytest.raises(
ProgrammingError, match="All bulk parameter rows must be dicts"
):
cursor.execute(
"INSERT INTO t (id) VALUES (%(id)s)",
bulk_parameters=[{"id": 1}, [2]],
)
mocked_connection.client.sql.assert_not_called()


def test_execute_with_pyformat_sql_and_bulk_parameters_no_placeholders(
mocked_connection,
):
"""
SQL without %(name)s placeholders is passed through unchanged
even when bulk_parameters is provided.
"""
cursor = mocked_connection.cursor()
sql = "INSERT INTO t (id, val) VALUES (?, ?)"
bulk = [[1, "hello"], [2, "world"]]
cursor.execute(sql, bulk_parameters=bulk)
mocked_connection.client.sql.assert_called_once_with(sql, None, bulk)


def test_execute_custom_converter(mocked_connection):
"""
Verify that a custom converter is correctly applied when passed to a cursor.
Expand Down
Loading