Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion apps/dbagent/src/evals/chat/tool-choice.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ describe.concurrent('tool_choice', () => {
{
id: 'explain_query',
prompt: 'Explain SELECT * FROM dogs',
expectedToolCalls: ['unsafeExplainQuery'],
expectedToolCalls: ['safeExplainQuery'],
allowOtherTools: false
},
{
Expand Down
2 changes: 1 addition & 1 deletion apps/dbagent/src/lib/ai/prompts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ If the user asks for something that is not related to PostgreSQL or database adm
export const chatSystemPrompt = `
Provide clear, concise, and accurate responses to questions.
Use the provided tools to get context from the PostgreSQL database to answer questions.
When asked why a query is slow, call the unsafeExplainQuery tool and also take into account the table sizes.
When asked why a query is slow, call the safeExplainQuery tool and also take into account the table sizes.
During the initial assessment use the getTablesInfo, getPerfromanceAndVacuumSettings, getConnectionsStats, and getPostgresExtensions, and others if you want.
When asked to run a playbook, use the getPlaybook tool to get the playbook contents. Then use the contents of the playbook
as an action plan. Execute the plan step by step.
Expand Down
30 changes: 28 additions & 2 deletions apps/dbagent/src/lib/ai/tools/db.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import { Tool, tool } from 'ai';
import { z } from 'zod';
import { getPerformanceAndVacuumSettings, toolFindTableSchema } from '~/lib/tools/dbinfo';
import { toolDescribeTable, toolGetSlowQueries, toolUnsafeExplainQuery } from '~/lib/tools/slow-queries';
import {
toolDescribeTable,
toolGetSlowQueries,
toolSafeExplainQuery,
toolUnsafeExplainQuery
} from '~/lib/tools/slow-queries';
import {
toolCurrentActiveQueries,
toolGetConnectionsGroups,
Expand Down Expand Up @@ -30,6 +35,7 @@ export class DBSQLTools implements ToolsetGroup {
return {
getSlowQueries: this.getSlowQueries(),
unsafeExplainQuery: this.unsafeExplainQuery(),
safeExplainQuery: this.safeExplainQuery(),
describeTable: this.describeTable(),
findTableSchema: this.findTableSchema(),
getCurrentActiveQueries: this.getCurrentActiveQueries(),
Expand All @@ -46,7 +52,7 @@ export class DBSQLTools implements ToolsetGroup {
return tool({
description: `Get a list of slow queries formatted as a JSON array. Contains how many times the query was called,
the max execution time in seconds, the mean execution time in seconds, the total execution time
(all calls together) in seconds, and the query itself.`,
(all calls together) in seconds, the query itself, and the queryid for use with safeExplainQuery.`,
parameters: z.object({}),
execute: async () => {
try {
Expand Down Expand Up @@ -85,6 +91,26 @@ If you know the schema, pass it in as well.`,
});
}

safeExplainQuery(): Tool {
const pool = this.#pool;
return tool({
description: `Safely run EXPLAIN on a query by fetching it from pg_stat_statements using queryId.
This prevents SQL injection by not accepting raw SQL queries. Returns the explain plan as received from PostgreSQL.
Use the queryid field from the getSlowQueries tool output as the queryId parameter.`,
parameters: z.object({
schema: z.string(),
queryId: z.string().describe('The query ID from pg_stat_statements (use the queryid field from getSlowQueries)')
}),
execute: async ({ schema = 'public', queryId }) => {
try {
return await withPoolConnection(pool, async (client) => await toolSafeExplainQuery(client, schema, queryId));
} catch (error) {
return `Error running safe EXPLAIN on the query: ${error}`;
}
}
});
}

describeTable(): Tool {
const pool = this.#pool;
return tool({
Expand Down
4 changes: 3 additions & 1 deletion apps/dbagent/src/lib/targetdb/db.ts
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@ interface SlowQuery {
mean_exec_secs: number;
total_exec_secs: number;
query: string;
queryid: string;
}

export async function getSlowQueries(client: ClientBase, thresholdMs: number): Promise<SlowQuery[]> {
Expand All @@ -370,7 +371,8 @@ export async function getSlowQueries(client: ClientBase, thresholdMs: number): P
round(max_exec_time/1000) max_exec_secs,
round(mean_exec_time/1000) mean_exec_secs,
round(total_exec_time/1000) total_exec_secs,
query
query,
queryid::text as queryid
FROM pg_stat_statements
WHERE max_exec_time > $1
ORDER BY total_exec_time DESC
Expand Down
42 changes: 42 additions & 0 deletions apps/dbagent/src/lib/targetdb/safe-explain.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import { ClientBase } from './db';
import { isSingleStatement } from './unsafe-explain';

export async function safeExplainQuery(client: ClientBase, schema: string, queryId: string): Promise<string> {
if (!/^[a-zA-Z_][a-zA-Z0-9_]*$/.test(schema)) {
return 'Invalid schema name. Only alphanumeric characters and underscores are allowed.';
}

// First, fetch the query from pg_stat_statements
const queryResult = await client.query('SELECT query FROM pg_stat_statements WHERE queryid = $1', [queryId]);

if (queryResult.rows.length === 0) {
return 'Query not found in pg_stat_statements for the given queryId';
}

const query = queryResult.rows[0].query;

if (!isSingleStatement(query)) {
return 'The query is not a single safe statement. Only SELECT, INSERT, UPDATE, DELETE, and WITH statements are allowed.';
}

const hasPlaceholders = /\$\d+/.test(query);

let toReturn = '';
try {
await client.query('BEGIN');
await client.query("SET LOCAL statement_timeout = '2000ms'");
await client.query("SET LOCAL lock_timeout = '200ms'");
await client.query(`SET search_path TO ${schema}`);
const explainQuery = hasPlaceholders ? `EXPLAIN (GENERIC_PLAN true) ${query}` : `EXPLAIN ${query}`;
console.log(schema);
console.log(explainQuery);
const result = await client.query(explainQuery);
console.log(result.rows);
toReturn = result.rows.map((row: { [key: string]: string }) => row['QUERY PLAN']).join('\n');
} catch (error) {
console.error('Error explaining query', error);
toReturn = 'I could not run EXPLAIN on that query. Try a different method.';
}
await client.query('ROLLBACK');
return toReturn;
}
2 changes: 1 addition & 1 deletion apps/dbagent/src/lib/tools/playbooks.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ Use the tool findTableSchema to find the schema of the table involved in the slo
Use the tool describeTable to describe the table you found.

Step 4:
Use the tool unsafeExplainQuery to explain the slow queries. Make sure to pass the schema you found to the tool.
Use the tool safeExplainQuery to explain the slow queries. Make sure to pass the schema you found to the tool.
Also, it's very important to replace the query parameters ($1, $2, etc) with the actual values. Generate your own values, but
take into account the data types of the columns.

Expand Down
6 changes: 6 additions & 0 deletions apps/dbagent/src/lib/tools/slow-queries.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { ClientBase, describeTable, getSlowQueries } from '../targetdb/db';
import { safeExplainQuery } from '../targetdb/safe-explain';
import { unsafeExplainQuery } from '../targetdb/unsafe-explain';

export async function toolGetSlowQueries(client: ClientBase, thresholdMs: number): Promise<string> {
Expand All @@ -25,6 +26,11 @@ export async function toolUnsafeExplainQuery(client: ClientBase, schema: string,
return JSON.stringify(result);
}

export async function toolSafeExplainQuery(client: ClientBase, schema: string, queryId: string): Promise<string> {
const result = await safeExplainQuery(client, schema, queryId);
return JSON.stringify(result);
}

export async function toolDescribeTable(client: ClientBase, schema: string, table: string): Promise<string> {
const result = await describeTable(client, schema, table);
return JSON.stringify(result);
Expand Down
Loading