Skip to content

Commit e2eca93

Browse files
committed
fix: EXPLAIN ANALYZE VERBOSE handling, string literal paren bypass, commit logic for EXPLAIN ANALYZE
- EXPLAIN handler now consumes all known options (ANALYZE, ANALYSE, VERBOSE) before extracting the real command, fixing 'EXPLAIN ANALYZE VERBOSE SELECT' being blocked - Paren walker in _extract_main_query_after_cte now skips string literals, preventing 'WITH cte AS (SELECT '\''('\'' FROM t) DELETE FROM users' from bypassing detection - _is_write_stmt in execute_sql now resolves EXPLAIN ANALYZE to underlying command via _resolve_explain_command, ensuring session.commit() fires for write operations - 10 new tests covering all three fixes
1 parent 37e7a22 commit e2eca93

2 files changed

Lines changed: 163 additions & 17 deletions

File tree

lib/crewai-tools/src/crewai_tools/tools/nl2sql/nl2sql_tool.py

Lines changed: 89 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -94,35 +94,95 @@ def _detect_writable_cte(stmt: str) -> str | None:
9494
return None
9595

9696

97+
def _skip_string_literal(stmt: str, pos: int) -> int:
98+
"""Skip past a string literal starting at pos (single-quoted).
99+
100+
Handles escaped quotes ('') inside the literal.
101+
Returns the index after the closing quote.
102+
"""
103+
quote_char = stmt[pos]
104+
i = pos + 1
105+
while i < len(stmt):
106+
if stmt[i] == quote_char:
107+
# Check for escaped quote ('')
108+
if i + 1 < len(stmt) and stmt[i + 1] == quote_char:
109+
i += 2
110+
continue
111+
return i + 1
112+
i += 1
113+
return i # Unterminated literal — return end
114+
115+
116+
def _find_matching_close_paren(stmt: str, start: int) -> int:
117+
"""Find the matching close paren, skipping string literals."""
118+
depth = 1
119+
i = start
120+
while i < len(stmt) and depth > 0:
121+
ch = stmt[i]
122+
if ch == "'":
123+
i = _skip_string_literal(stmt, i)
124+
continue
125+
if ch == "(":
126+
depth += 1
127+
elif ch == ")":
128+
depth -= 1
129+
i += 1
130+
return i
131+
132+
97133
def _extract_main_query_after_cte(stmt: str) -> str | None:
98134
"""Extract the main (outer) query that follows all CTE definitions.
99135
100136
For ``WITH cte AS (SELECT 1) DELETE FROM users``, returns ``DELETE FROM users``.
101137
Returns None if no main query is found after the last CTE body.
138+
Handles parentheses inside string literals (e.g., ``SELECT '(' FROM t``).
102139
"""
103-
# Walk through balanced parens after each AS( to find the end of CTE bodies.
104140
last_cte_end = 0
105141
for m in _AS_PAREN_RE.finditer(stmt):
106-
# Find the matching closing paren for this CTE body.
107-
depth = 1
108-
i = m.end()
109-
while i < len(stmt) and depth > 0:
110-
if stmt[i] == "(":
111-
depth += 1
112-
elif stmt[i] == ")":
113-
depth -= 1
114-
i += 1
115-
last_cte_end = i
142+
last_cte_end = _find_matching_close_paren(stmt, m.end())
116143

117144
if last_cte_end > 0:
118145
remainder = stmt[last_cte_end:].strip().lstrip(",").strip()
119-
# Skip additional CTE definitions (name AS (...))
120-
# The remainder after the last CTE closing paren is the main query
121146
if remainder:
122147
return remainder
123148
return None
124149

125150

151+
def _resolve_explain_command(stmt: str) -> str | None:
152+
"""Resolve the underlying command from an EXPLAIN [ANALYZE] [VERBOSE] statement.
153+
154+
Returns the real command (e.g., 'DELETE') if ANALYZE is present, else None.
155+
Handles both space-separated and parenthesized syntax.
156+
"""
157+
rest = stmt.strip()[len("EXPLAIN") :].strip()
158+
if not rest:
159+
return None
160+
161+
analyze_found = False
162+
explain_opts = {"ANALYZE", "ANALYSE", "VERBOSE"}
163+
164+
if rest.startswith("("):
165+
close = rest.find(")")
166+
if close != -1:
167+
options_str = rest[1:close].upper()
168+
analyze_found = any(
169+
opt.strip() in ("ANALYZE", "ANALYSE") for opt in options_str.split(",")
170+
)
171+
rest = rest[close + 1 :].strip()
172+
else:
173+
while rest:
174+
first_opt = rest.split()[0].upper().rstrip(";") if rest.split() else ""
175+
if first_opt in ("ANALYZE", "ANALYSE"):
176+
analyze_found = True
177+
if first_opt not in explain_opts:
178+
break
179+
rest = rest[len(first_opt) :].strip()
180+
181+
if analyze_found and rest:
182+
return rest.split()[0].upper().rstrip(";")
183+
return None
184+
185+
126186
class NL2SQLToolInput(BaseModel):
127187
sql_query: str = Field(
128188
title="SQL Query",
@@ -260,10 +320,17 @@ def _validate_statement(self, stmt: str) -> None:
260320
)
261321
rest = rest[close + 1 :].strip()
262322
else:
263-
# Space-separated: EXPLAIN ANALYZE <stmt>
264-
first_opt = rest.split()[0].upper().rstrip(";") if rest.split() else ""
265-
if first_opt in ("ANALYZE", "ANALYSE"):
266-
analyze_found = True
323+
# Space-separated: EXPLAIN [ANALYZE] [VERBOSE] <stmt>
324+
# Consume all known EXPLAIN options before extracting the real command.
325+
_explain_opts = {"ANALYZE", "ANALYSE", "VERBOSE"}
326+
while rest:
327+
first_opt = (
328+
rest.split()[0].upper().rstrip(";") if rest.split() else ""
329+
)
330+
if first_opt in ("ANALYZE", "ANALYSE"):
331+
analyze_found = True
332+
if first_opt not in _explain_opts:
333+
break
267334
rest = rest[len(first_opt) :].strip()
268335

269336
if analyze_found and rest:
@@ -406,6 +473,11 @@ def _is_write_stmt(s: str) -> bool:
406473
cmd = self._extract_command(s)
407474
if cmd in _WRITE_COMMANDS:
408475
return True
476+
if cmd == "EXPLAIN":
477+
# Resolve the underlying command for EXPLAIN ANALYZE
478+
resolved = _resolve_explain_command(s)
479+
if resolved and resolved in _WRITE_COMMANDS:
480+
return True
409481
if cmd == "WITH":
410482
if _detect_writable_cte(s):
411483
return True

lib/crewai-tools/tests/tools/test_nl2sql_security.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -560,3 +560,77 @@ def test_extended_write_commands_blocked_by_default(self, stmt: str):
560560
tool = _make_tool(allow_dml=False)
561561
with pytest.raises(ValueError, match="read-only mode"):
562562
tool._validate_query(stmt)
563+
564+
565+
# ---------------------------------------------------------------------------
566+
# EXPLAIN ANALYZE VERBOSE handling
567+
# ---------------------------------------------------------------------------
568+
569+
570+
class TestExplainAnalyzeVerbose:
571+
def test_explain_analyze_verbose_select_allowed(self):
572+
"""EXPLAIN ANALYZE VERBOSE SELECT should be allowed (read-only)."""
573+
tool = _make_tool(allow_dml=False)
574+
tool._validate_query("EXPLAIN ANALYZE VERBOSE SELECT * FROM users")
575+
576+
def test_explain_analyze_verbose_delete_blocked(self):
577+
"""EXPLAIN ANALYZE VERBOSE DELETE should be blocked."""
578+
tool = _make_tool(allow_dml=False)
579+
with pytest.raises(ValueError, match="read-only mode"):
580+
tool._validate_query("EXPLAIN ANALYZE VERBOSE DELETE FROM users")
581+
582+
def test_explain_verbose_select_allowed(self):
583+
"""EXPLAIN VERBOSE SELECT (no ANALYZE) should be allowed."""
584+
tool = _make_tool(allow_dml=False)
585+
tool._validate_query("EXPLAIN VERBOSE SELECT * FROM users")
586+
587+
588+
# ---------------------------------------------------------------------------
589+
# CTE with string literal parens
590+
# ---------------------------------------------------------------------------
591+
592+
593+
class TestCTEStringLiteralParens:
594+
def test_cte_string_paren_does_not_bypass(self):
595+
"""Parens inside string literals should not confuse the paren walker."""
596+
tool = _make_tool(allow_dml=False)
597+
with pytest.raises(ValueError, match="read-only mode"):
598+
tool._validate_query(
599+
"WITH cte AS (SELECT '(' FROM t) DELETE FROM users"
600+
)
601+
602+
def test_cte_string_paren_read_only_allowed(self):
603+
"""Read-only CTE with string literal parens should be allowed."""
604+
tool = _make_tool(allow_dml=False)
605+
tool._validate_query(
606+
"WITH cte AS (SELECT '(' FROM t) SELECT * FROM cte"
607+
)
608+
609+
610+
# ---------------------------------------------------------------------------
611+
# EXPLAIN ANALYZE commit logic
612+
# ---------------------------------------------------------------------------
613+
614+
615+
class TestExplainAnalyzeCommit:
616+
def test_explain_analyze_delete_triggers_commit(self):
617+
"""EXPLAIN ANALYZE DELETE should trigger commit when allow_dml=True."""
618+
tool = _make_tool(allow_dml=True)
619+
620+
mock_session = MagicMock()
621+
mock_result = MagicMock()
622+
mock_result.returns_rows = True
623+
mock_result.keys.return_value = ["QUERY PLAN"]
624+
mock_result.fetchall.return_value = [("Delete on users",)]
625+
mock_session.execute.return_value = mock_result
626+
mock_session_cls = MagicMock(return_value=mock_session)
627+
628+
with (
629+
patch("crewai_tools.tools.nl2sql.nl2sql_tool.create_engine"),
630+
patch(
631+
"crewai_tools.tools.nl2sql.nl2sql_tool.sessionmaker",
632+
return_value=mock_session_cls,
633+
),
634+
):
635+
tool.execute_sql("EXPLAIN ANALYZE DELETE FROM users")
636+
mock_session.commit.assert_called_once()

0 commit comments

Comments
 (0)