diff --git a/apps/dbagent/src/evals/chat/tool-choice.test.ts b/apps/dbagent/src/evals/chat/tool-choice.test.ts index e1a326ca..995388c4 100644 --- a/apps/dbagent/src/evals/chat/tool-choice.test.ts +++ b/apps/dbagent/src/evals/chat/tool-choice.test.ts @@ -153,7 +153,7 @@ describe.concurrent('tool_choice', () => { { id: 'explain_query', prompt: 'Explain SELECT * FROM dogs', - expectedToolCalls: ['unsafeExplainQuery'], + expectedToolCalls: ['safeExplainQuery'], allowOtherTools: false }, { diff --git a/apps/dbagent/src/lib/ai/prompts.ts b/apps/dbagent/src/lib/ai/prompts.ts index 9c06bb35..41f92240 100644 --- a/apps/dbagent/src/lib/ai/prompts.ts +++ b/apps/dbagent/src/lib/ai/prompts.ts @@ -11,7 +11,7 @@ If the user asks for something that is not related to PostgreSQL or database adm export const chatSystemPrompt = ` Provide clear, concise, and accurate responses to questions. Use the provided tools to get context from the PostgreSQL database to answer questions. -When asked why a query is slow, call the unsafeExplainQuery tool and also take into account the table sizes. +When asked why a query is slow, call the safeExplainQuery tool and also take into account the table sizes. During the initial assessment use the getTablesInfo, getPerfromanceAndVacuumSettings, getConnectionsStats, and getPostgresExtensions, and others if you want. When asked to run a playbook, use the getPlaybook tool to get the playbook contents. Then use the contents of the playbook as an action plan. Execute the plan step by step. diff --git a/apps/dbagent/src/lib/ai/tools/db.ts b/apps/dbagent/src/lib/ai/tools/db.ts index 1cbf8277..35fd3e9c 100644 --- a/apps/dbagent/src/lib/ai/tools/db.ts +++ b/apps/dbagent/src/lib/ai/tools/db.ts @@ -1,7 +1,12 @@ import { Tool, tool } from 'ai'; import { z } from 'zod'; import { getPerformanceAndVacuumSettings, toolFindTableSchema } from '~/lib/tools/dbinfo'; -import { toolDescribeTable, toolGetSlowQueries, toolUnsafeExplainQuery } from '~/lib/tools/slow-queries'; +import { + toolDescribeTable, + toolGetSlowQueries, + toolSafeExplainQuery, + toolUnsafeExplainQuery +} from '~/lib/tools/slow-queries'; import { toolCurrentActiveQueries, toolGetConnectionsGroups, @@ -30,6 +35,7 @@ export class DBSQLTools implements ToolsetGroup { return { getSlowQueries: this.getSlowQueries(), unsafeExplainQuery: this.unsafeExplainQuery(), + safeExplainQuery: this.safeExplainQuery(), describeTable: this.describeTable(), findTableSchema: this.findTableSchema(), getCurrentActiveQueries: this.getCurrentActiveQueries(), @@ -46,7 +52,7 @@ export class DBSQLTools implements ToolsetGroup { return tool({ description: `Get a list of slow queries formatted as a JSON array. Contains how many times the query was called, the max execution time in seconds, the mean execution time in seconds, the total execution time -(all calls together) in seconds, and the query itself.`, +(all calls together) in seconds, the query itself, and the queryid for use with safeExplainQuery.`, parameters: z.object({}), execute: async () => { try { @@ -85,6 +91,26 @@ If you know the schema, pass it in as well.`, }); } + safeExplainQuery(): Tool { + const pool = this.#pool; + return tool({ + description: `Safely run EXPLAIN on a query by fetching it from pg_stat_statements using queryId. +This prevents SQL injection by not accepting raw SQL queries. Returns the explain plan as received from PostgreSQL. +Use the queryid field from the getSlowQueries tool output as the queryId parameter.`, + parameters: z.object({ + schema: z.string(), + queryId: z.string().describe('The query ID from pg_stat_statements (use the queryid field from getSlowQueries)') + }), + execute: async ({ schema = 'public', queryId }) => { + try { + return await withPoolConnection(pool, async (client) => await toolSafeExplainQuery(client, schema, queryId)); + } catch (error) { + return `Error running safe EXPLAIN on the query: ${error}`; + } + } + }); + } + describeTable(): Tool { const pool = this.#pool; return tool({ diff --git a/apps/dbagent/src/lib/targetdb/db.ts b/apps/dbagent/src/lib/targetdb/db.ts index 00cdf628..9b82a16b 100644 --- a/apps/dbagent/src/lib/targetdb/db.ts +++ b/apps/dbagent/src/lib/targetdb/db.ts @@ -361,6 +361,7 @@ interface SlowQuery { mean_exec_secs: number; total_exec_secs: number; query: string; + queryid: string; } export async function getSlowQueries(client: ClientBase, thresholdMs: number): Promise { @@ -370,7 +371,8 @@ export async function getSlowQueries(client: ClientBase, thresholdMs: number): P round(max_exec_time/1000) max_exec_secs, round(mean_exec_time/1000) mean_exec_secs, round(total_exec_time/1000) total_exec_secs, - query + query, + queryid::text as queryid FROM pg_stat_statements WHERE max_exec_time > $1 ORDER BY total_exec_time DESC diff --git a/apps/dbagent/src/lib/targetdb/safe-explain.ts b/apps/dbagent/src/lib/targetdb/safe-explain.ts new file mode 100644 index 00000000..1f43a608 --- /dev/null +++ b/apps/dbagent/src/lib/targetdb/safe-explain.ts @@ -0,0 +1,42 @@ +import { ClientBase } from './db'; +import { isSingleStatement } from './unsafe-explain'; + +export async function safeExplainQuery(client: ClientBase, schema: string, queryId: string): Promise { + if (!/^[a-zA-Z_][a-zA-Z0-9_]*$/.test(schema)) { + return 'Invalid schema name. Only alphanumeric characters and underscores are allowed.'; + } + + // First, fetch the query from pg_stat_statements + const queryResult = await client.query('SELECT query FROM pg_stat_statements WHERE queryid = $1', [queryId]); + + if (queryResult.rows.length === 0) { + return 'Query not found in pg_stat_statements for the given queryId'; + } + + const query = queryResult.rows[0].query; + + if (!isSingleStatement(query)) { + return 'The query is not a single safe statement. Only SELECT, INSERT, UPDATE, DELETE, and WITH statements are allowed.'; + } + + const hasPlaceholders = /\$\d+/.test(query); + + let toReturn = ''; + try { + await client.query('BEGIN'); + await client.query("SET LOCAL statement_timeout = '2000ms'"); + await client.query("SET LOCAL lock_timeout = '200ms'"); + await client.query(`SET search_path TO ${schema}`); + const explainQuery = hasPlaceholders ? `EXPLAIN (GENERIC_PLAN true) ${query}` : `EXPLAIN ${query}`; + console.log(schema); + console.log(explainQuery); + const result = await client.query(explainQuery); + console.log(result.rows); + toReturn = result.rows.map((row: { [key: string]: string }) => row['QUERY PLAN']).join('\n'); + } catch (error) { + console.error('Error explaining query', error); + toReturn = 'I could not run EXPLAIN on that query. Try a different method.'; + } + await client.query('ROLLBACK'); + return toReturn; +} diff --git a/apps/dbagent/src/lib/tools/playbooks.ts b/apps/dbagent/src/lib/tools/playbooks.ts index 8ed6b157..cf4fc470 100644 --- a/apps/dbagent/src/lib/tools/playbooks.ts +++ b/apps/dbagent/src/lib/tools/playbooks.ts @@ -24,7 +24,7 @@ Use the tool findTableSchema to find the schema of the table involved in the slo Use the tool describeTable to describe the table you found. Step 4: -Use the tool unsafeExplainQuery to explain the slow queries. Make sure to pass the schema you found to the tool. +Use the tool safeExplainQuery to explain the slow queries. Make sure to pass the schema you found to the tool. Also, it's very important to replace the query parameters ($1, $2, etc) with the actual values. Generate your own values, but take into account the data types of the columns. diff --git a/apps/dbagent/src/lib/tools/slow-queries.ts b/apps/dbagent/src/lib/tools/slow-queries.ts index 7c264b05..1729dc7d 100644 --- a/apps/dbagent/src/lib/tools/slow-queries.ts +++ b/apps/dbagent/src/lib/tools/slow-queries.ts @@ -1,4 +1,5 @@ import { ClientBase, describeTable, getSlowQueries } from '../targetdb/db'; +import { safeExplainQuery } from '../targetdb/safe-explain'; import { unsafeExplainQuery } from '../targetdb/unsafe-explain'; export async function toolGetSlowQueries(client: ClientBase, thresholdMs: number): Promise { @@ -25,6 +26,11 @@ export async function toolUnsafeExplainQuery(client: ClientBase, schema: string, return JSON.stringify(result); } +export async function toolSafeExplainQuery(client: ClientBase, schema: string, queryId: string): Promise { + const result = await safeExplainQuery(client, schema, queryId); + return JSON.stringify(result); +} + export async function toolDescribeTable(client: ClientBase, schema: string, table: string): Promise { const result = await describeTable(client, schema, table); return JSON.stringify(result);