Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions cmd/run/run.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package run

import (
"fmt"
"os"
"path/filepath"

Expand Down Expand Up @@ -41,11 +42,19 @@ func Execute(pb playbook.Playbook, baseDir string, workDir string, outDir string
} else if !filepath.IsAbs(debugDir) {
debugDir = filepath.Join(workDir, debugDir)
}
if _, err := os.Stat(debugDir); os.IsNotExist(err) {
if err := os.MkdirAll(debugDir, 0777); err != nil {
logger.Error("Failed to create debug directory", "error", err)
if info, err := os.Stat(debugDir); err != nil {
if os.IsNotExist(err) {
if mkErr := os.MkdirAll(debugDir, 0777); mkErr != nil {
logger.Error("Failed to create debug directory", "error", mkErr)
return mkErr
}
} else {
logger.Error("Failed to access debug directory", "path", debugDir, "error", err)
return err
}
} else if !info.IsDir() {
logger.Error("Debug path exists but is not a directory", "path", debugDir)
return fmt.Errorf("debug path %s exists but is not a directory", debugDir)
}
logger.Info("Debug capture enabled", "dir", debugDir, "limit", debugLimit)
}
Expand All @@ -57,6 +66,6 @@ func Execute(pb playbook.Playbook, baseDir string, workDir string, outDir string
logger.Debug("Running", "outDir", outDir)

t := task.NewTask(pb.Name, baseDir, workDir, outDir, nInputs)
err = pb.Execute(t, debugDir, debugLimit)
err = pb.ExecuteWithCapture(t, debugDir, debugLimit)
return err
}
77 changes: 58 additions & 19 deletions playbook/execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,22 @@ type stepCaptureState struct {

// captureRecord writes a debug record to the capture file
func (s *stepCaptureState) captureRecord(record map[string]any) {
if s.limit > 0 {
var recordNum uint64

for {
currentCount := atomic.LoadUint64(&s.count)
if currentCount >= uint64(s.limit) {

// Enforce limit strictly under concurrency
if s.limit > 0 && currentCount >= uint64(s.limit) {
return
}
}

recordNum := atomic.AddUint64(&s.count, 1)
next := currentCount + 1
if atomic.CompareAndSwapUint64(&s.count, currentCount, next) {
recordNum = next
break
}
}

envelope := map[string]any{
"pipeline": s.pipelineName,
Expand All @@ -161,15 +169,28 @@ func (s *stepCaptureState) captureRecord(record map[string]any) {
if s.file != nil {
data, err := json.Marshal(envelope)
if err == nil {
s.file.Write(data)
s.file.Write([]byte("\n"))
if _, writeErr := s.file.Write(data); writeErr != nil {
logger.Error("Failed to write debug record data", "error", writeErr)
return
}
if _, writeErr := s.file.Write([]byte("\n")); writeErr != nil {
logger.Error("Failed to write debug record newline", "error", writeErr)
return
}
} else {
logger.Error("Failed to marshal debug record", "error", err)
}
}
}

func (pb *Playbook) Execute(task task.RuntimeTask, captureDir string, captureLimit int) error {
// Execute runs the playbook without debug capture.
// This maintains the original public API signature for backward compatibility.
func (pb *Playbook) Execute(task task.RuntimeTask) error {
return pb.ExecuteWithCapture(task, "", 0)
}

// ExecuteWithCapture runs the playbook with optional debug capture configuration.
func (pb *Playbook) ExecuteWithCapture(task task.RuntimeTask, captureDir string, captureLimit int) error {
logger.Debug("Running playbook")
logger.Debug("Inputs", "config", task.GetConfig())

Expand All @@ -183,6 +204,13 @@ func (pb *Playbook) Execute(task task.RuntimeTask, captureDir string, captureLim
procs := []transform.Processor{}
joins := []joinStruct{}
captureFiles := []*os.File{} // Track all open capture files for cleanup
defer func() {
for _, f := range captureFiles {
if f != nil {
_ = f.Close()
}
}
}()

// Helper function to sanitize filename components
sanitizeFilename := func(s string) string {
Expand All @@ -194,23 +222,41 @@ func (pb *Playbook) Execute(task task.RuntimeTask, captureDir string, captureLim
return s
}

// Helper function to sanitize pipeline names used in filenames
sanitizePipelineName := func(s string) string {
// Use only the last path element to avoid directory traversal
s = filepath.Base(s)

// Treat empty, current-dir, parent-dir, or bare separator as invalid and use a default
if s == "" || s == "." || s == ".." || s == string(os.PathSeparator) {
s = "pipeline"
}

// Replace any remaining path separators with underscores
s = strings.ReplaceAll(s, string(os.PathSeparator), "_")
s = strings.ReplaceAll(s, "/", "_")
s = strings.ReplaceAll(s, "\\", "_")

return s
}

// Helper function to create capture state for a step
createCaptureState := func(pipelineName string, stepIndex int, stepType string) *stepCaptureState {
if captureDir == "" {
return nil
}

filename := fmt.Sprintf("%s.%d.%s.ndjson", pipelineName, stepIndex, sanitizeFilename(stepType))
filepath := filepath.Join(captureDir, filename)
filename := fmt.Sprintf("%s.%d.%s.ndjson", sanitizePipelineName(pipelineName), stepIndex, sanitizeFilename(stepType))
filePath := filepath.Join(captureDir, filename)

file, err := os.Create(filepath)
file, err := os.Create(filePath)
if err != nil {
logger.Error("Failed to create debug capture file", "path", filepath, "error", err)
logger.Error("Failed to create debug capture file", "path", filePath, "error", err)
return nil
}

captureFiles = append(captureFiles, file)
logger.Debug("Created debug capture file", "path", filepath)
logger.Debug("Created debug capture file", "path", filePath)

return &stepCaptureState{
pipelineName: pipelineName,
Expand Down Expand Up @@ -590,13 +636,6 @@ func (pb *Playbook) Execute(task task.RuntimeTask, captureDir string, captureLim
outputs[k].Close()
}

// Close all debug capture files
for _, f := range captureFiles {
if f != nil {
f.Close()
}
}

task.Close()
return nil
}
69 changes: 69 additions & 0 deletions test/command_line_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package test

import (
"bufio"
"bytes"
"compress/gzip"
"fmt"
Expand Down Expand Up @@ -101,3 +102,71 @@ func TestCommandLines(t *testing.T) {

}
}

// TestCaptureMode verifies that --capture-dir and --capture-limit flags create
// NDJSON capture files and respect the record limit.
func TestCaptureMode(t *testing.T) {
captureDir, err := os.MkdirTemp("", "sifter-capture-*")
if err != nil {
t.Fatalf("Failed to create temp capture dir: %s", err)
}
defer os.RemoveAll(captureDir)

playbook := "examples/gene-table/gene-table.yaml"
limit := 3

cmd := exec.Command("../sifter", "run",
"--capture-dir", captureDir,
"--capture-limit", fmt.Sprintf("%d", limit),
playbook,
)
t.Logf("Running: %s with capture-dir=%s capture-limit=%d", playbook, captureDir, limit)
if err := cmd.Run(); err != nil {
t.Fatalf("Failed running %s: %s", playbook, err)
}

// Check that at least one .ndjson file was created
entries, err := os.ReadDir(captureDir)
if err != nil {
t.Fatalf("Failed to read capture dir: %s", err)
}
if len(entries) == 0 {
t.Errorf("Expected capture NDJSON files in %s, but directory is empty", captureDir)
return
}

// Verify each capture file has at most `limit` records
for _, entry := range entries {
if !strings.HasSuffix(entry.Name(), ".ndjson") {
continue
}
filePath := filepath.Join(captureDir, entry.Name())
f, err := os.Open(filePath)
if err != nil {
t.Errorf("Failed to open capture file %s: %s", filePath, err)
continue
}

lineCount := 0
scanner := bufio.NewScanner(f)
for scanner.Scan() {
if scanner.Text() != "" {
lineCount++
}
}
scanErr := scanner.Err()
f.Close()

if scanErr != nil {
t.Errorf("Error reading capture file %s: %s", filePath, scanErr)
}
if lineCount > limit {
t.Errorf("Capture file %s has %d records, expected at most %d", entry.Name(), lineCount, limit)
}
t.Logf("Capture file %s has %d records (limit=%d)", entry.Name(), lineCount, limit)
}

// Clean up the playbook output
outputDir := filepath.Join(filepath.Dir(playbook), "output")
os.RemoveAll(outputDir)
}
Loading