From f2df0bfac474456af637870c437629c9f705799d Mon Sep 17 00:00:00 2001 From: edlsh Date: Mon, 13 Apr 2026 09:39:01 -0400 Subject: [PATCH] fix(amp): preserve lowercase glob tool name --- internal/api/modules/amp/response_rewriter.go | 50 ++++++++++++++++++ .../api/modules/amp/response_rewriter_test.go | 51 +++++++++++++++++++ 2 files changed, 101 insertions(+) diff --git a/internal/api/modules/amp/response_rewriter.go b/internal/api/modules/amp/response_rewriter.go index 707fe576b4..895c494e74 100644 --- a/internal/api/modules/amp/response_rewriter.go +++ b/internal/api/modules/amp/response_rewriter.go @@ -123,6 +123,52 @@ func (rw *ResponseRewriter) Flush() { var modelFieldPaths = []string{"message.model", "model", "modelVersion", "response.model", "response.modelVersion"} +// ampCanonicalToolNames maps tool names to the exact casing expected by the +// Amp mode tool whitelist (case-sensitive match). +var ampCanonicalToolNames = map[string]string{ + "bash": "Bash", + "read": "Read", + "grep": "Grep", + "glob": "glob", + "task": "Task", + "check": "Check", +} + +// normalizeAmpToolNames fixes tool_use block names to match Amp's canonical casing. +// Some upstream models return lowercase tool names (e.g. "bash" instead of "Bash") +// which causes Amp's case-sensitive mode whitelist to reject them. +func normalizeAmpToolNames(data []byte) []byte { + // Non-streaming: content[].name in tool_use blocks + for index, block := range gjson.GetBytes(data, "content").Array() { + if block.Get("type").String() != "tool_use" { + continue + } + name := block.Get("name").String() + if canonical, ok := ampCanonicalToolNames[strings.ToLower(name)]; ok && name != canonical { + path := fmt.Sprintf("content.%d.name", index) + var err error + data, err = sjson.SetBytes(data, path, canonical) + if err != nil { + log.Warnf("Amp ResponseRewriter: failed to normalize tool name %q to %q: %v", name, canonical, err) + } + } + } + + // Streaming: content_block.name in content_block_start events + if gjson.GetBytes(data, "content_block.type").String() == "tool_use" { + name := gjson.GetBytes(data, "content_block.name").String() + if canonical, ok := ampCanonicalToolNames[strings.ToLower(name)]; ok && name != canonical { + var err error + data, err = sjson.SetBytes(data, "content_block.name", canonical) + if err != nil { + log.Warnf("Amp ResponseRewriter: failed to normalize streaming tool name %q to %q: %v", name, canonical, err) + } + } + } + + return data +} + // ensureAmpSignature injects empty signature fields into tool_use/thinking blocks // in API responses so that the Amp TUI does not crash on P.signature.length. func ensureAmpSignature(data []byte) []byte { @@ -179,6 +225,7 @@ func (rw *ResponseRewriter) suppressAmpThinking(data []byte) []byte { func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte { data = ensureAmpSignature(data) + data = normalizeAmpToolNames(data) data = rw.suppressAmpThinking(data) if len(data) == 0 { return data @@ -278,6 +325,9 @@ func (rw *ResponseRewriter) rewriteStreamEvent(data []byte) []byte { // Inject empty signature where needed data = ensureAmpSignature(data) + // Normalize tool names to canonical casing + data = normalizeAmpToolNames(data) + // Rewrite model name if rw.originalModel != "" { for _, path := range modelFieldPaths { diff --git a/internal/api/modules/amp/response_rewriter_test.go b/internal/api/modules/amp/response_rewriter_test.go index ac95dfc64f..a3a350cb23 100644 --- a/internal/api/modules/amp/response_rewriter_test.go +++ b/internal/api/modules/amp/response_rewriter_test.go @@ -175,6 +175,57 @@ func TestSanitizeAmpRequestBody_MixedInvalidThinkingAndToolUseSignature(t *testi } } +func TestNormalizeAmpToolNames_NonStreaming(t *testing.T) { + input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"bash","input":{"cmd":"ls"}},{"type":"tool_use","id":"toolu_02","name":"read","input":{"path":"/tmp"}},{"type":"text","text":"hello"}]}`) + result := normalizeAmpToolNames(input) + + if !contains(result, []byte(`"name":"Bash"`)) { + t.Errorf("expected bash->Bash, got %s", string(result)) + } + if !contains(result, []byte(`"name":"Read"`)) { + t.Errorf("expected read->Read, got %s", string(result)) + } + if contains(result, []byte(`"name":"bash"`)) { + t.Errorf("expected lowercase bash to be replaced, got %s", string(result)) + } +} + +func TestNormalizeAmpToolNames_Streaming(t *testing.T) { + input := []byte(`{"type":"content_block_start","index":1,"content_block":{"type":"tool_use","name":"grep","id":"toolu_01","input":{}}}`) + result := normalizeAmpToolNames(input) + + if !contains(result, []byte(`"name":"Grep"`)) { + t.Errorf("expected grep->Grep in streaming, got %s", string(result)) + } +} + +func TestNormalizeAmpToolNames_AlreadyCorrect(t *testing.T) { + input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"}}]}`) + result := normalizeAmpToolNames(input) + + if string(result) != string(input) { + t.Errorf("expected no modification for correctly-cased tool, got %s", string(result)) + } +} + +func TestNormalizeAmpToolNames_GlobPreserved(t *testing.T) { + input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"glob","input":{"pattern":"*.go"}}]}`) + result := normalizeAmpToolNames(input) + + if string(result) != string(input) { + t.Errorf("expected glob to remain lowercase, got %s", string(result)) + } +} + +func TestNormalizeAmpToolNames_UnknownToolUntouched(t *testing.T) { + input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"edit_file","input":{"path":"/tmp/x"}}]}`) + result := normalizeAmpToolNames(input) + + if string(result) != string(input) { + t.Errorf("expected no modification for unknown tool, got %s", string(result)) + } +} + func contains(data, substr []byte) bool { for i := 0; i <= len(data)-len(substr); i++ { if string(data[i:i+len(substr)]) == string(substr) {