Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions internal/resolver/warmup_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package resolver

import (
"context"
"testing"
"time"

"github.com/stretchr/testify/require"

"picodns/internal/dns"
)

// mockTransport implements types.Transport for testing warmup.
type mockTransport struct {
queryCount int
resp []byte
}

func (m *mockTransport) Query(_ context.Context, server string, req []byte, timeout time.Duration) ([]byte, func(), error) {
m.queryCount++
if m.resp != nil {
return m.resp, nil, nil
}
// Build a minimal response echoing back the request
hdr, _ := dns.ReadHeader(req)
respBuf := make([]byte, len(req))
copy(respBuf, req)
respHdr := dns.Header{
ID: hdr.ID,
Flags: dns.FlagQR | dns.FlagRA,
QDCount: hdr.QDCount,
}
_ = dns.WriteHeader(respBuf, respHdr)
return respBuf, nil, nil
}

func TestWarmupRTTQueriesRootServers(t *testing.T) {
mt := &mockTransport{}
r := NewRecursive(WithTransport(mt), WithRootServers([]string{"1.1.1.1:53", "8.8.8.8:53"}))

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

r.warmupRTT(ctx)

// Should have queried both root servers
require.Equal(t, 2, mt.queryCount)
}

func TestWarmupPopulatesDelegationCache(t *testing.T) {
mt := &mockTransport{}
r := NewRecursive(WithTransport(mt), WithRootServers([]string{"1.1.1.1:53"}))

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

r.Warmup(ctx)

// Should have made queries for root servers + common TLDs
require.Greater(t, mt.queryCount, 0)
}

func TestWarmupRespectsContextCancellation(t *testing.T) {
mt := &mockTransport{}
r := NewRecursive(WithTransport(mt), WithRootServers([]string{"1.1.1.1:53"}))

ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately

r.Warmup(ctx)

// Should have done minimal work since context was already cancelled.
// The RTT warmup goroutines may still fire due to timing,
// but the TLD warmup loop should exit immediately.
}
190 changes: 190 additions & 0 deletions internal/server/tcp_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
package server

import (
"context"
"encoding/binary"
"io"
"log/slog"
"net"
"testing"
"time"

"github.com/stretchr/testify/require"

"picodns/internal/config"
"picodns/internal/dns"
)

// mockResolver implements types.Resolver for testing.
type mockResolver struct {
resp []byte
err error
}

func (m *mockResolver) Resolve(_ context.Context, req []byte) ([]byte, func(), error) {
if m.err != nil {
return nil, nil, m.err
}
return m.resp, nil, nil
}

func newTestServer(resolver *mockResolver) *Server {
cfg := config.Default()
cfg.Workers = 1
logger := slog.Default()
return New(cfg, logger, resolver)
}

func writeTCPQuery(t *testing.T, conn net.Conn, query []byte) {
t.Helper()
var lenBuf [2]byte
binary.BigEndian.PutUint16(lenBuf[:], uint16(len(query)))
_, err := conn.Write(lenBuf[:])
require.NoError(t, err)
_, err = conn.Write(query)
require.NoError(t, err)
}

func readTCPResponse(t *testing.T, conn net.Conn) []byte {
t.Helper()
var lenBuf [2]byte
_, err := io.ReadFull(conn, lenBuf[:])
require.NoError(t, err)
respLen := binary.BigEndian.Uint16(lenBuf[:])
resp := make([]byte, respLen)
_, err = io.ReadFull(conn, resp)
require.NoError(t, err)
return resp
}

func makeTestQuery(name string) []byte {
buf := make([]byte, 512)
_ = dns.WriteHeader(buf, dns.Header{ID: 0xBEEF, Flags: dns.FlagRD, QDCount: 1})
end, _ := dns.WriteQuestion(buf, dns.HeaderLen, dns.Question{Name: name, Type: dns.TypeA, Class: dns.ClassIN})
return buf[:end]
}

func makeTestResponse(req []byte, ttl uint32) []byte {
resp, _ := dns.BuildResponse(req, []dns.Answer{
{Type: dns.TypeA, Class: dns.ClassIN, TTL: ttl, RData: []byte{1, 2, 3, 4}},
}, 0)
return resp
}

func TestTCPHandlerBasicQuery(t *testing.T) {
query := makeTestQuery("example.com")
resp := makeTestResponse(query, 60)
srv := newTestServer(&mockResolver{resp: resp})

client, server := net.Pipe()
defer func() { _ = client.Close() }()

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

go srv.handleTCPConn(ctx, server)

writeTCPQuery(t, client, query)
got := readTCPResponse(t, client)

hdr, err := dns.ReadHeader(got)
require.NoError(t, err)
require.Equal(t, uint16(0xBEEF), hdr.ID)
require.True(t, hdr.Flags&dns.FlagQR != 0) // is a response
}

func TestTCPHandlerResolverError(t *testing.T) {
query := makeTestQuery("fail.com")
srv := newTestServer(&mockResolver{err: io.ErrUnexpectedEOF})

client, server := net.Pipe()
defer func() { _ = client.Close() }()

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

go srv.handleTCPConn(ctx, server)

writeTCPQuery(t, client, query)
got := readTCPResponse(t, client)

// Should get a SERVFAIL response
hdr, err := dns.ReadHeader(got)
require.NoError(t, err)
require.Equal(t, uint16(dns.RcodeServer), hdr.Flags&dns.RcodeMask)
}

func TestTCPHandlerInvalidMessageSize(t *testing.T) {
srv := newTestServer(&mockResolver{})

client, server := net.Pipe()
defer func() { _ = client.Close() }()

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

done := make(chan struct{})
go func() {
srv.handleTCPConn(ctx, server)
close(done)
}()

// Send a message size of 0
var lenBuf [2]byte
binary.BigEndian.PutUint16(lenBuf[:], 0)
_, err := client.Write(lenBuf[:])
require.NoError(t, err)

// Handler should close the connection
select {
case <-done:
// good - handler returned
case <-time.After(2 * time.Second):
t.Fatal("handler did not close connection for invalid size")
}
}

func TestTCPHandlerMultipleQueries(t *testing.T) {
query := makeTestQuery("multi.com")
resp := makeTestResponse(query, 300)
srv := newTestServer(&mockResolver{resp: resp})

client, server := net.Pipe()
defer func() { _ = client.Close() }()

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

go srv.handleTCPConn(ctx, server)

// Send 3 queries on the same connection (TCP pipelining)
for i := 0; i < 3; i++ {
writeTCPQuery(t, client, query)
got := readTCPResponse(t, client)
hdr, err := dns.ReadHeader(got)
require.NoError(t, err)
require.Equal(t, uint16(0xBEEF), hdr.ID)
}
}

func TestTCPHandlerQueryCounting(t *testing.T) {
query := makeTestQuery("count.com")
resp := makeTestResponse(query, 60)
srv := newTestServer(&mockResolver{resp: resp})

client, server := net.Pipe()
defer func() { _ = client.Close() }()

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

go srv.handleTCPConn(ctx, server)

writeTCPQuery(t, client, query)
_ = readTCPResponse(t, client)

writeTCPQuery(t, client, query)
_ = readTCPResponse(t, client)

require.Equal(t, uint64(2), srv.TotalQueries.Load())
}