diff --git a/cmd/picodns/main.go b/cmd/picodns/main.go index 95ec50b..40d3618 100644 --- a/cmd/picodns/main.go +++ b/cmd/picodns/main.go @@ -19,6 +19,7 @@ import ( func main() { cfg := config.Default() config.BindFlags(&cfg) + config.ParseFlags(&cfg) logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{ Level: parseLevel(cfg.LogLevel), diff --git a/internal/config/config.go b/internal/config/config.go index 500c55f..2a3109e 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -38,16 +38,25 @@ func Default() Config { } } +// flagBindings holds the intermediate string values that bridge flag registration +// and post-parse processing. +type flagBindings struct { + listen string + upstreams string + statsInterval string +} + +var bindings flagBindings + +// BindFlags registers command-line flags on the default flag set, +// binding them to the fields of cfg. Call ParseFlags after to parse os.Args. func BindFlags(cfg *Config) { if cfg == nil { return } - var upstreams string - var listen string - var statsInterval string - flag.StringVar(&listen, "listen", strings.Join(cfg.ListenAddrs, ","), "comma-separated listen addresses") - flag.StringVar(&upstreams, "upstreams", strings.Join(cfg.Upstreams, ","), "comma-separated upstreams") + flag.StringVar(&bindings.listen, "listen", strings.Join(cfg.ListenAddrs, ","), "comma-separated listen addresses") + flag.StringVar(&bindings.upstreams, "upstreams", strings.Join(cfg.Upstreams, ","), "comma-separated upstreams") flag.IntVar(&cfg.Workers, "workers", cfg.Workers, "worker pool size") flag.IntVar(&cfg.CacheSize, "cache-size", cfg.CacheSize, "max cache entries") flag.StringVar(&cfg.LogLevel, "log-level", cfg.LogLevel, "log level (debug, info, warn, error)") @@ -56,18 +65,22 @@ func BindFlags(cfg *Config) { flag.BoolVar(&cfg.Prefetch, "prefetch", cfg.Prefetch, "proactively refresh hot cache entries") flag.BoolVar(&cfg.Stats, "stats", cfg.Stats, "emit one-time stats summary on shutdown") flag.StringVar(&cfg.PerfReport, "perf-report", cfg.PerfReport, "write perf JSON report to this path (perf builds only)") - flag.StringVar(&statsInterval, "stats-interval", cfg.StatsInterval.String(), "DEPRECATED: enables -stats when >0") + flag.StringVar(&bindings.statsInterval, "stats-interval", cfg.StatsInterval.String(), "DEPRECATED: enables -stats when >0") +} +// ParseFlags calls flag.Parse and applies post-processing to cfg. +// Must be called after BindFlags. +func ParseFlags(cfg *Config) { flag.Parse() - if listen != "" { - cfg.ListenAddrs = splitComma(listen) + if bindings.listen != "" { + cfg.ListenAddrs = splitComma(bindings.listen) } - if upstreams != "" { - cfg.Upstreams = splitComma(upstreams) + if bindings.upstreams != "" { + cfg.Upstreams = splitComma(bindings.upstreams) } - if strings.TrimSpace(statsInterval) != "" { - if d, err := time.ParseDuration(strings.TrimSpace(statsInterval)); err == nil { + if strings.TrimSpace(bindings.statsInterval) != "" { + if d, err := time.ParseDuration(strings.TrimSpace(bindings.statsInterval)); err == nil { cfg.StatsInterval = d if d > 0 { cfg.Stats = true diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 0000000..f62c6b7 --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,47 @@ +package config + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestDefaultConfig(t *testing.T) { + cfg := Default() + require.Equal(t, []string{":53"}, cfg.ListenAddrs) + require.Equal(t, []string{"1.1.1.1:53"}, cfg.Upstreams) + require.Equal(t, 128, cfg.Workers) + require.Equal(t, 10000, cfg.CacheSize) + require.Equal(t, "info", cfg.LogLevel) + require.True(t, cfg.Prewarm) + require.True(t, cfg.Prefetch) + require.False(t, cfg.Stats) +} + +func TestBindFlagsDoesNotParse(t *testing.T) { + // BindFlags should only register flags, not call flag.Parse() + // We can verify this by checking that calling BindFlags alone + // doesn't panic or modify the config from defaults. + cfg := Default() + BindFlags(&cfg) + // Config should still have defaults since Parse wasn't called + require.Equal(t, []string{":53"}, cfg.ListenAddrs) + require.Equal(t, 128, cfg.Workers) +} + +func TestSplitComma(t *testing.T) { + tests := []struct { + input string + want []string + }{ + {"a,b,c", []string{"a", "b", "c"}}, + {" a , b , c ", []string{"a", "b", "c"}}, + {"", []string{}}, + {"single", []string{"single"}}, + {"a,,b", []string{"a", "b"}}, + } + for _, tt := range tests { + got := splitComma(tt.input) + require.Equal(t, tt.want, got, "splitComma(%q)", tt.input) + } +}