Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions pkg/backfill/backfill.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,11 +205,11 @@ func getRowCount(ctx context.Context, conn db.DB, tableName string) (int64, erro
if err != nil {
return 0, fmt.Errorf("getting current schema: %w", err)
}
defer rows.Close()

if err := db.ScanFirstValue(rows, &currentSchema); err != nil {
rows.Close()
return 0, fmt.Errorf("scanning current schema: %w", err)
}
rows.Close()

var total int64
rows, err = conn.QueryContext(ctx, `
Expand All @@ -220,8 +220,11 @@ func getRowCount(ctx context.Context, conn db.DB, tableName string) (int64, erro
return 0, fmt.Errorf("getting row count estimate for %q: %w", tableName, err)
}
if err := db.ScanFirstValue(rows, &total); err != nil {
rows.Close()
return 0, fmt.Errorf("scanning row count estimate for %q: %w", tableName, err)
}
rows.Close()

if total > 0 {
return total, nil
}
Expand All @@ -232,8 +235,10 @@ func getRowCount(ctx context.Context, conn db.DB, tableName string) (int64, erro
return 0, fmt.Errorf("getting row count for %q: %w", tableName, err)
}
if err := db.ScanFirstValue(rows, &total); err != nil {
rows.Close()
return 0, fmt.Errorf("scanning row count for %q: %w", tableName, err)
}
rows.Close()

return total, nil
}
Expand Down
50 changes: 50 additions & 0 deletions pkg/backfill/backfill_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// SPDX-License-Identifier: Apache-2.0

package backfill_test

import (
"context"
"database/sql"
"fmt"
"testing"
"time"

"github.com/stretchr/testify/require"

"github.com/xataio/pgroll/internal/testutils"
"github.com/xataio/pgroll/pkg/backfill"
"github.com/xataio/pgroll/pkg/db"
)

func TestMain(m *testing.M) {
testutils.SharedTestMain(m)
}

func TestGetRowCountDoesNotLeakConnections(t *testing.T) {
t.Parallel()

testutils.WithConnectionToContainer(t, func(conn *sql.DB, connStr string) {
// Use a short timeout so that if connections are leaked and the pool is
// exhausted, the test fails quickly instead of hanging.
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

tableName := "test_row_count"
_, err := conn.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s (id INT PRIMARY KEY)", tableName))
require.NoError(t, err)

// Call getRowCount many times with a small pool to verify no connection leak.
// With only 5 max connections, leaked connections would quickly cause errors.
conn2, err := sql.Open("postgres", connStr)
require.NoError(t, err)
defer conn2.Close()

conn2.SetMaxOpenConns(5)
rdb := &db.RDB{DB: conn2}

for i := 0; i < 20; i++ {
_, err := backfill.GetRowCount(ctx, rdb, tableName)
require.NoError(t, err, "iteration %d: getRowCount should not fail due to connection leak", i)
}
})
}
5 changes: 5 additions & 0 deletions pkg/backfill/export_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
// SPDX-License-Identifier: Apache-2.0

package backfill

var GetRowCount = getRowCount
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exported so we can use it in the test.

4 changes: 4 additions & 0 deletions pkg/migrations/dbactions.go
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,8 @@ func (a *createUniqueIndexConcurrentlyAction) isIndexInProgress(ctx context.Cont
// In that case, we can safely return false.
return false, nil
}
defer rows.Close()

var isInProgress bool
if err := db.ScanFirstValue(rows, &isInProgress); err != nil {
return false, fmt.Errorf("scanning index in progress with name %q: %w", quotedQualifiedIndexName, err)
Expand All @@ -521,6 +523,8 @@ func (a *createUniqueIndexConcurrentlyAction) isIndexValid(ctx context.Context,
// In that case, we can safely return true.
return true, nil
}
defer rows.Close()

var isValid bool
if err := db.ScanFirstValue(rows, &isValid); err != nil {
return false, fmt.Errorf("scanning index with name %q: %w", quotedQualifiedIndexName, err)
Expand Down
1 change: 1 addition & 0 deletions pkg/state/history.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ func (s *State) SchemaHistory(ctx context.Context, schema string) ([]HistoryEntr
if err != nil {
return nil, err
}
defer rows.Close()

var entries []HistoryEntry
for rows.Next() {
Expand Down
Loading