Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions internal/api/modules/amp/response_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,52 @@ func (rw *ResponseRewriter) Flush() {

var modelFieldPaths = []string{"message.model", "model", "modelVersion", "response.model", "response.modelVersion"}

// ampCanonicalToolNames maps tool names to the exact casing expected by the
// Amp mode tool whitelist (case-sensitive match).
var ampCanonicalToolNames = map[string]string{
"bash": "Bash",
"read": "Read",
"grep": "Grep",
"glob": "glob",
"task": "Task",
"check": "Check",
}

// normalizeAmpToolNames fixes tool_use block names to match Amp's canonical casing.
// Some upstream models return lowercase tool names (e.g. "bash" instead of "Bash")
// which causes Amp's case-sensitive mode whitelist to reject them.
func normalizeAmpToolNames(data []byte) []byte {
// Non-streaming: content[].name in tool_use blocks
for index, block := range gjson.GetBytes(data, "content").Array() {
if block.Get("type").String() != "tool_use" {
continue
}
name := block.Get("name").String()
if canonical, ok := ampCanonicalToolNames[strings.ToLower(name)]; ok && name != canonical {
path := fmt.Sprintf("content.%d.name", index)
var err error
data, err = sjson.SetBytes(data, path, canonical)
if err != nil {
log.Warnf("Amp ResponseRewriter: failed to normalize tool name %q to %q: %v", name, canonical, err)
}
}
}

// Streaming: content_block.name in content_block_start events
if gjson.GetBytes(data, "content_block.type").String() == "tool_use" {
name := gjson.GetBytes(data, "content_block.name").String()
if canonical, ok := ampCanonicalToolNames[strings.ToLower(name)]; ok && name != canonical {
Comment on lines +158 to +160
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The code performs redundant JSON parsing by calling gjson.GetBytes twice for the same content_block object. It is more efficient to retrieve the content_block once and then use Get on the resulting gjson.Result to access its fields. This is particularly relevant in streaming mode where this logic is executed for every chunk.

	// Streaming: content_block.name in content_block_start events
	contentBlock := gjson.GetBytes(data, "content_block")
	if contentBlock.Get("type").String() == "tool_use" {
		name := contentBlock.Get("name").String()

var err error
data, err = sjson.SetBytes(data, "content_block.name", canonical)
if err != nil {
log.Warnf("Amp ResponseRewriter: failed to normalize streaming tool name %q to %q: %v", name, canonical, err)
}
}
}

return data
}

// ensureAmpSignature injects empty signature fields into tool_use/thinking blocks
// in API responses so that the Amp TUI does not crash on P.signature.length.
func ensureAmpSignature(data []byte) []byte {
Expand Down Expand Up @@ -179,6 +225,7 @@ func (rw *ResponseRewriter) suppressAmpThinking(data []byte) []byte {

func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
data = ensureAmpSignature(data)
data = normalizeAmpToolNames(data)
data = rw.suppressAmpThinking(data)
if len(data) == 0 {
return data
Expand Down Expand Up @@ -278,6 +325,9 @@ func (rw *ResponseRewriter) rewriteStreamEvent(data []byte) []byte {
// Inject empty signature where needed
data = ensureAmpSignature(data)

// Normalize tool names to canonical casing
data = normalizeAmpToolNames(data)

// Rewrite model name
if rw.originalModel != "" {
for _, path := range modelFieldPaths {
Expand Down
51 changes: 51 additions & 0 deletions internal/api/modules/amp/response_rewriter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,57 @@ func TestSanitizeAmpRequestBody_MixedInvalidThinkingAndToolUseSignature(t *testi
}
}

func TestNormalizeAmpToolNames_NonStreaming(t *testing.T) {
input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"bash","input":{"cmd":"ls"}},{"type":"tool_use","id":"toolu_02","name":"read","input":{"path":"/tmp"}},{"type":"text","text":"hello"}]}`)
result := normalizeAmpToolNames(input)

if !contains(result, []byte(`"name":"Bash"`)) {
t.Errorf("expected bash->Bash, got %s", string(result))
}
if !contains(result, []byte(`"name":"Read"`)) {
t.Errorf("expected read->Read, got %s", string(result))
}
if contains(result, []byte(`"name":"bash"`)) {
t.Errorf("expected lowercase bash to be replaced, got %s", string(result))
}
}

func TestNormalizeAmpToolNames_Streaming(t *testing.T) {
input := []byte(`{"type":"content_block_start","index":1,"content_block":{"type":"tool_use","name":"grep","id":"toolu_01","input":{}}}`)
result := normalizeAmpToolNames(input)

if !contains(result, []byte(`"name":"Grep"`)) {
t.Errorf("expected grep->Grep in streaming, got %s", string(result))
}
}

func TestNormalizeAmpToolNames_AlreadyCorrect(t *testing.T) {
input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"}}]}`)
result := normalizeAmpToolNames(input)

if string(result) != string(input) {
t.Errorf("expected no modification for correctly-cased tool, got %s", string(result))
}
}

func TestNormalizeAmpToolNames_GlobPreserved(t *testing.T) {
input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"glob","input":{"pattern":"*.go"}}]}`)
result := normalizeAmpToolNames(input)

if string(result) != string(input) {
t.Errorf("expected glob to remain lowercase, got %s", string(result))
}
}

func TestNormalizeAmpToolNames_UnknownToolUntouched(t *testing.T) {
input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"edit_file","input":{"path":"/tmp/x"}}]}`)
result := normalizeAmpToolNames(input)

if string(result) != string(input) {
t.Errorf("expected no modification for unknown tool, got %s", string(result))
}
}

func contains(data, substr []byte) bool {
for i := 0; i <= len(data)-len(substr); i++ {
if string(data[i:i+len(substr)]) == string(substr) {
Expand Down
Loading