diff --git a/pkg/backfill/backfill.go b/pkg/backfill/backfill.go index 131034461..c8f266062 100644 --- a/pkg/backfill/backfill.go +++ b/pkg/backfill/backfill.go @@ -205,11 +205,11 @@ func getRowCount(ctx context.Context, conn db.DB, tableName string) (int64, erro if err != nil { return 0, fmt.Errorf("getting current schema: %w", err) } - defer rows.Close() - if err := db.ScanFirstValue(rows, ¤tSchema); err != nil { + rows.Close() return 0, fmt.Errorf("scanning current schema: %w", err) } + rows.Close() var total int64 rows, err = conn.QueryContext(ctx, ` @@ -220,8 +220,11 @@ func getRowCount(ctx context.Context, conn db.DB, tableName string) (int64, erro return 0, fmt.Errorf("getting row count estimate for %q: %w", tableName, err) } if err := db.ScanFirstValue(rows, &total); err != nil { + rows.Close() return 0, fmt.Errorf("scanning row count estimate for %q: %w", tableName, err) } + rows.Close() + if total > 0 { return total, nil } @@ -232,8 +235,10 @@ func getRowCount(ctx context.Context, conn db.DB, tableName string) (int64, erro return 0, fmt.Errorf("getting row count for %q: %w", tableName, err) } if err := db.ScanFirstValue(rows, &total); err != nil { + rows.Close() return 0, fmt.Errorf("scanning row count for %q: %w", tableName, err) } + rows.Close() return total, nil } diff --git a/pkg/backfill/backfill_test.go b/pkg/backfill/backfill_test.go new file mode 100644 index 000000000..936f695c7 --- /dev/null +++ b/pkg/backfill/backfill_test.go @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: Apache-2.0 + +package backfill_test + +import ( + "context" + "database/sql" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/xataio/pgroll/internal/testutils" + "github.com/xataio/pgroll/pkg/backfill" + "github.com/xataio/pgroll/pkg/db" +) + +func TestMain(m *testing.M) { + testutils.SharedTestMain(m) +} + +func TestGetRowCountDoesNotLeakConnections(t *testing.T) { + t.Parallel() + + testutils.WithConnectionToContainer(t, func(conn *sql.DB, connStr string) { + // Use a short timeout so that if connections are leaked and the pool is + // exhausted, the test fails quickly instead of hanging. + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + tableName := "test_row_count" + _, err := conn.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s (id INT PRIMARY KEY)", tableName)) + require.NoError(t, err) + + // Call getRowCount many times with a small pool to verify no connection leak. + // With only 5 max connections, leaked connections would quickly cause errors. + conn2, err := sql.Open("postgres", connStr) + require.NoError(t, err) + defer conn2.Close() + + conn2.SetMaxOpenConns(5) + rdb := &db.RDB{DB: conn2} + + for i := 0; i < 20; i++ { + _, err := backfill.GetRowCount(ctx, rdb, tableName) + require.NoError(t, err, "iteration %d: getRowCount should not fail due to connection leak", i) + } + }) +} diff --git a/pkg/backfill/export_test.go b/pkg/backfill/export_test.go new file mode 100644 index 000000000..a930c18a3 --- /dev/null +++ b/pkg/backfill/export_test.go @@ -0,0 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 + +package backfill + +var GetRowCount = getRowCount diff --git a/pkg/migrations/dbactions.go b/pkg/migrations/dbactions.go index 661a2549c..0f3f0177e 100644 --- a/pkg/migrations/dbactions.go +++ b/pkg/migrations/dbactions.go @@ -500,6 +500,8 @@ func (a *createUniqueIndexConcurrentlyAction) isIndexInProgress(ctx context.Cont // In that case, we can safely return false. return false, nil } + defer rows.Close() + var isInProgress bool if err := db.ScanFirstValue(rows, &isInProgress); err != nil { return false, fmt.Errorf("scanning index in progress with name %q: %w", quotedQualifiedIndexName, err) @@ -521,6 +523,8 @@ func (a *createUniqueIndexConcurrentlyAction) isIndexValid(ctx context.Context, // In that case, we can safely return true. return true, nil } + defer rows.Close() + var isValid bool if err := db.ScanFirstValue(rows, &isValid); err != nil { return false, fmt.Errorf("scanning index with name %q: %w", quotedQualifiedIndexName, err) diff --git a/pkg/state/history.go b/pkg/state/history.go index e441f67ce..64d0310b1 100644 --- a/pkg/state/history.go +++ b/pkg/state/history.go @@ -47,6 +47,7 @@ func (s *State) SchemaHistory(ctx context.Context, schema string) ([]HistoryEntr if err != nil { return nil, err } + defer rows.Close() var entries []HistoryEntry for rows.Next() {