diff --git a/DESIGN.md b/DESIGN.md index 69b6759..3753d06 100644 --- a/DESIGN.md +++ b/DESIGN.md @@ -79,6 +79,55 @@ The main entry point. `Forward(ctx, req)` orchestrates the full request lifecycl - `WithInterceptor(i)` — adds an interceptor to the chain - `WithHTTPClient(c)` — sets a custom `*http.Client` for upstream calls +#### AutoRouter + +An HTTP handler that provides automatic provider and API type detection from a single endpoint. Implements `http.Handler` for easy integration. + +```text +Forward(ctx, req) -> (resp, meta, err) +ServeHTTP(w, r) +``` + +**Detection Flow:** + +1. **Parse body** - Extract model name and request structure +2. **Detect provider** - From `X-Provider` header, model prefix (`openai/gpt-4`), or model pattern (`gpt-*`) +3. **Strip provider prefix** - If model has known provider prefix, strip before forwarding +4. **Detect API type** - From path (`/v1/messages`) or body+provider (`input` → Responses) +5. **Route to provider** - Forward to detected provider with correct endpoint + +**Configuration options:** + +- `WithAutoRouterRegistry(r)` — Use custom registry +- `WithAutoRouterDetector(d)` — Custom provider detection logic +- `WithAutoRouterModelProviderLookup(lookup)` — Hook for model→provider mapping (e.g., models.dev-backed detection); called when model pattern detection fails +- `WithAutoRouterInterceptor(i)` — Add interceptor to chain +- `WithAutoRouterHTTPClient(c)` — Custom HTTP client +- `WithAutoRouterFallbackProvider(p)` — Provider when detection fails + +**Example:** + +```go +// Basic setup +router := llmproxy.NewAutoRouter( + llmproxy.WithAutoRouterFallbackProvider(openaiProvider), + llmproxy.WithAutoRouterInterceptor(interceptors.NewLogging(logger)), +) +router.RegisterProvider(openaiProvider) +router.RegisterProvider(anthropicProvider) + +http.Handle("/", router) +``` + +```go +// With models.dev-backed provider detection +adapter, _ := modelsdev.LoadFromURL() +router := llmproxy.NewAutoRouter( + llmproxy.WithAutoRouterModelProviderLookup(adapter.FindProviderForModel), + llmproxy.WithAutoRouterFallbackProvider(openaiProvider), +) +``` + --- ## Data Types @@ -246,6 +295,105 @@ Steps in detail: --- +## Auto-Routing + +The `AutoRouter` enables automatic provider and API detection from a single endpoint. POST to `/` with any LLM request and routing happens automatically. + +### API Type Detection + +Detection happens in two phases: + +**Phase 1: Path-based detection** + +| Path | API Type | +|------|----------| +| `/v1/chat/completions` | Chat Completions | +| `/v1/responses` | Responses | +| `/v1/completions` | Legacy Completions | +| `/v1/messages` | Anthropic Messages | +| `:generateContent` | Gemini GenerateContent | +| `/converse` | Bedrock Converse | + +**Phase 2: Body + Provider detection** (when path is `/` or unknown) + +| Body Field | Provider | API Type | +|------------|----------|----------| +| `input` | any | Responses | +| `prompt` | any | Completions | +| `contents` | any | GenerateContent | +| `messages` | anthropic | Messages | +| `messages` | other | Chat Completions | + +### Provider Detection + +Provider is detected in priority order: + +1. **X-Provider header** — Explicit override + ```bash + curl -X POST http://localhost:8080/ \ + -H 'X-Provider: anthropic' \ + -d '{"model":"claude-3-opus",...}' + ``` + +2. **Model prefix** — Provider prefix in model name (stripped before forwarding) + ```bash + # Model "openai/gpt-4" routes to OpenAI, forwards "gpt-4" + curl -X POST http://localhost:8080/ \ + -d '{"model":"anthropic/claude-3-opus",...}' + ``` + +3. **Model pattern** — Match against known patterns + | Pattern | Provider | + |---------|----------| + | `gpt-*`, `o1-*`, `o3-*`, `chatgpt-*` | OpenAI | + | `claude-*` | Anthropic | + | `gemini-*`, `gemma-*` | Google AI | + | `grok-*` | x.AI | + | `accounts/fireworks/*` | Fireworks | + | `sonar*` | Perplexity | + | `anthropic.claude-*`, `amazon.*` | Bedrock | + +### Provider Prefix Stripping + +Only known provider prefixes are stripped: + +```go +// Stripped (known providers) +"openai/gpt-4" -> "gpt-4" +"anthropic/claude-3" -> "claude-3" +"fireworks/models/llama" -> "models/llama" + +// Preserved (unknown or model-native paths) +"accounts/fireworks/models/llama" -> "accounts/fireworks/models/llama" +"some-unknown/model" -> "some-unknown/model" +``` + +### Usage Examples + +```bash +# Auto-detect everything - POST to / +curl -X POST http://localhost:8080/ \ + -H 'Content-Type: application/json' \ + -d '{"model":"gpt-4","messages":[{"role":"user","content":"Hello"}]}' + +# Auto-detect Anthropic from model name +curl -X POST http://localhost:8080/ \ + -H 'Content-Type: application/json' \ + -d '{"model":"claude-3-opus","max_tokens":1024,"messages":[{"role":"user","content":"Hello"}]}' + +# Auto-detect Responses API from input field +curl -X POST http://localhost:8080/ \ + -H 'Content-Type: application/json' \ + -d '{"model":"gpt-4o","input":"Hello"}' + +# Traditional path-based routing still works +curl -X POST http://localhost:8080/v1/chat/completions \ + -H 'Content-Type: application/json' \ + -d '{"model":"gpt-4","messages":[{"role":"user","content":"Hello"}]}' +``` + +--- + ## Providers Nine providers are included. Six share the OpenAI-compatible base; three have fully custom implementations. @@ -254,11 +402,13 @@ Nine providers are included. Six share the OpenAI-compatible base; three have fu `providers/openai_compatible` implements `BodyParser`, `ResponseExtractor`, `URLResolver`, and `RequestEnricher` for the OpenAI chat completions format. Providers that speak this protocol embed the base and override only what differs (name, base URL, auth configuration). +The OpenAI provider also supports the **Responses API** (`/v1/responses`) with automatic detection based on the `input` field in the request body. + ### Provider Table | Provider | Package | Auth | API Format | Notes | |----------|---------|------|------------|-------| -| OpenAI | `providers/openai` | Bearer token | OpenAI chat completions | Wraps `openai_compatible` | +| OpenAI | `providers/openai` | Bearer token | Chat completions, Responses | Supports both APIs with auto-detection | | Anthropic | `providers/anthropic` | `x-api-key` header + `anthropic-version` | Anthropic Messages API | Custom parser/extractor | | Groq | `providers/groq` | Bearer token | OpenAI-compatible | Wraps `openai_compatible` | | Fireworks | `providers/fireworks` | Bearer token | OpenAI-compatible | Wraps `openai_compatible` | @@ -270,7 +420,15 @@ Nine providers are included. Six share the OpenAI-compatible base; three have fu ### Provider Details -**OpenAI** — Thin wrapper over `openai_compatible`. Sets the base URL to `https://api.openai.com` and the provider name to `openai`. +**OpenAI** — Wraps `openai_compatible` with support for multiple APIs: +- **Chat Completions** (`/v1/chat/completions`) — Standard messages-based API +- **Responses** (`/v1/responses`) — Newer API with `input` field, built-in tools support +- **Legacy Completions** (`/v1/completions`) — Older prompt-based API + +The provider auto-detects the API type from the request body: +- `input` field → Responses API +- `prompt` field → Completions API +- `messages` field → Chat Completions API **Anthropic** — Custom body parser translates between the proxy's canonical format and Anthropic's Messages API. Custom extractor maps Anthropic's response shape (content blocks, stop_reason) back to `ResponseMetadata`. Auth uses the `x-api-key` header alongside an `anthropic-version` header. @@ -357,6 +515,12 @@ retry := interceptors.NewRetryWithRateLimitHeaders(3, time.Second) - Detects the provider from the model name - Computes input/output/cache costs based on token usage - Calls the `onResult` callback with a `BillingResult` after each request +- Stores `BillingResult` in `ResponseMetadata.Custom["billing_result"]` for downstream access + +When using `AutoRouter`, billing results are automatically added as response headers: +- `X-Gateway-Cost` — Total cost in USD +- `X-Gateway-Prompt-Tokens` — Input token count +- `X-Gateway-Completion-Tokens` — Output token count ### Tracing @@ -719,7 +883,10 @@ Matches the signature of `github.com/agentuity/go-common/logger` without requiri ``` llmproxy/ +├── apitype.go # API type detection and constants +├── autorouter.go # AutoRouter, provider/API auto-detection ├── billing.go # CostInfo, CostLookup, BillingResult, CalculateCost +├── detection.go # Provider detection from model/header ├── enricher.go # RequestEnricher interface ├── extractor.go # ResponseExtractor interface ├── interceptor.go # Interceptor, InterceptorChain, RoundTripFunc @@ -749,9 +916,12 @@ llmproxy/ │ ├── fireworks/ # Fireworks (OpenAI-compatible) │ ├── googleai/ # Google AI Gemini │ ├── groq/ # Groq (OpenAI-compatible) -│ ├── openai/ # OpenAI +│ ├── openai/ # OpenAI (Chat Completions + Responses) │ ├── openai_compatible/ # Base for OpenAI-compatible providers +│ │ ├── multiapi.go # Multi-API parser/extractor +│ │ ├── responses_parser.go # Responses API parser +│ │ └── responses_extractor.go # Responses API extractor │ └── xai/ # x.AI (OpenAI-compatible) └── examples/ - └── basic/ # Multi-provider proxy server example + └── basic/ # Multi-provider proxy server example (uses AutoRouter) ``` diff --git a/README.md b/README.md index c4d869e..1219276 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,8 @@ go get github.com/agentuity/llmproxy ## Quick Start +### Simple Proxy + ```go package main @@ -51,26 +53,118 @@ func main() { } ``` +### AutoRouter (Recommended) + +Single endpoint that auto-detects provider and API type: + +```go +package main + +import ( + "net/http" + + "github.com/agentuity/llmproxy" + "github.com/agentuity/llmproxy/providers/openai" + "github.com/agentuity/llmproxy/providers/anthropic" +) + +func main() { + openaiProvider, _ := openai.New("sk-openai-key") + anthropicProvider, _ := anthropic.New("sk-ant-key") + + router := llmproxy.NewAutoRouter( + llmproxy.WithAutoRouterFallbackProvider(openaiProvider), + ) + router.RegisterProvider(openaiProvider) + router.RegisterProvider(anthropicProvider) + + // Single endpoint handles all providers and APIs + http.Handle("/", router) + http.ListenAndServe(":8080", nil) +} +``` + +POST to `/` with any model - provider and API are auto-detected: + +```bash +# Auto-detect OpenAI from gpt-4 model name +curl -X POST http://localhost:8080/ \ + -H 'Content-Type: application/json' \ + -d '{"model":"gpt-4","messages":[{"role":"user","content":"Hello"}]}' + +# Auto-detect Anthropic from claude model name +curl -X POST http://localhost:8080/ \ + -H 'Content-Type: application/json' \ + -d '{"model":"claude-3-opus","max_tokens":1024,"messages":[{"role":"user","content":"Hello"}]}' + +# Auto-detect Responses API from input field +curl -X POST http://localhost:8080/ \ + -H 'Content-Type: application/json' \ + -d '{"model":"gpt-4o","input":"Hello"}' +``` + ## Features - **9 Provider Implementations**: OpenAI, Anthropic, Groq, Fireworks, x.AI, Google AI, AWS Bedrock, Azure OpenAI, OpenAI-compatible base +- **AutoRouter**: Single endpoint with automatic provider/API detection +- **Responses API**: Full support for OpenAI's new Responses API - **8 Built-in Interceptors**: Logging, Metrics, Retry, Billing, Tracing (OTel), HeaderBan, AddHeader, PromptCaching - **Pricing Integration**: models.dev adapter with markup support - **Prompt Caching**: prompt caching support for Anthropic, OpenAI, xAI, Fireworks, and Bedrock - **Raw Body Preservation**: Custom JSON fields pass through unchanged +## AutoRouter + +The `AutoRouter` provides automatic routing from a single endpoint: + +### Detection Order + +1. **Path-based** - `/v1/messages` → Messages API, `/v1/responses` → Responses API +2. **Body + Provider** - When path is `/` or unknown: + - `input` field → Responses API + - `prompt` field → Completions API + - `contents` field → GenerateContent API + - `messages` + Anthropic → Messages API + - `messages` + other → Chat Completions + +### Provider Detection + +1. **X-Provider header** - Explicit override +2. **Model prefix** - `openai/gpt-4` → OpenAI (strips prefix before forwarding) +3. **Model pattern** - `gpt-*` → OpenAI, `claude-*` → Anthropic, etc. + +### Examples + +```bash +# Explicit provider via header +curl -X POST http://localhost:8080/ \ + -H 'Content-Type: application/json' \ + -H 'X-Provider: anthropic' \ + -d '{"model":"claude-3-opus","max_tokens":1024,"messages":[{"role":"user","content":"Hello"}]}' + +# Provider prefix in model (gets stripped) +curl -X POST http://localhost:8080/ \ + -H 'Content-Type: application/json' \ + -d '{"model":"anthropic/claude-3-opus","max_tokens":1024,"messages":[{"role":"user","content":"Hello"}]}' + +# Traditional path still works +curl -X POST http://localhost:8080/v1/chat/completions \ + -H 'Content-Type: application/json' \ + -d '{"model":"gpt-4","messages":[{"role":"user","content":"Hello"}]}' +``` + ## Providers -| Provider | Auth | API Format | -| ------------ | --------------------- | ------------------------------ | -| OpenAI | Bearer token | Chat completions | -| Anthropic | `x-api-key` | Messages API | -| Groq | Bearer token | OpenAI-compatible | -| Fireworks | Bearer token | OpenAI-compatible | -| x.AI | Bearer token | OpenAI-compatible | -| Google AI | API key query param | Gemini generateContent | -| AWS Bedrock | AWS Signature V4 | Converse API | -| Azure OpenAI | `api-key` or Azure AD | Chat completions (deployments) | +| Provider | Auth | API Format | Notes | +| ------------ | --------------------- | ------------------------------ | ----- | +| OpenAI | Bearer token | Chat completions, Responses | Supports both `/v1/chat/completions` and `/v1/responses` | +| Anthropic | `x-api-key` | Messages API | | +| Groq | Bearer token | OpenAI-compatible | | +| Fireworks | Bearer token | OpenAI-compatible | | +| x.AI | Bearer token | OpenAI-compatible | | +| Google AI | API key query param | Gemini generateContent | | +| AWS Bedrock | AWS Signature V4 | Converse API | | +| Azure OpenAI | `api-key` or Azure AD | Chat completions (deployments) | | ## Interceptors diff --git a/apitype.go b/apitype.go new file mode 100644 index 0000000..0de2b48 --- /dev/null +++ b/apitype.go @@ -0,0 +1,105 @@ +package llmproxy + +import ( + "encoding/json" +) + +type APIType string + +const ( + APITypeChatCompletions APIType = "chat_completions" + APITypeResponses APIType = "responses" + APITypeCompletions APIType = "completions" + APITypeMessages APIType = "messages" + APITypeGenerateContent APIType = "generate_content" + APITypeConverse APIType = "converse" +) + +func DetectAPIType(body []byte) APIType { + var raw map[string]any + if err := json.Unmarshal(body, &raw); err != nil { + return APITypeChatCompletions + } + + if _, hasInput := raw["input"]; hasInput { + if _, hasMessages := raw["messages"]; !hasMessages { + return APITypeResponses + } + } + + if _, hasPrompt := raw["prompt"]; hasPrompt { + if _, hasMessages := raw["messages"]; !hasMessages { + return APITypeCompletions + } + } + + return APITypeChatCompletions +} + +func DetectAPITypeFromPath(path string) APIType { + switch { + case containsPath(path, "/v1/chat/completions"): + return APITypeChatCompletions + case containsPath(path, "/v1/responses"): + return APITypeResponses + case containsPath(path, "/v1/completions"): + return APITypeCompletions + case containsPath(path, "/v1/messages"): + return APITypeMessages + case containsPath(path, ":generateContent"): + return APITypeGenerateContent + case containsPath(path, "/converse"): + return APITypeConverse + default: + return "" + } +} + +func DetectAPITypeFromBodyAndProvider(body []byte, provider string) APIType { + var raw map[string]any + if err := json.Unmarshal(body, &raw); err != nil { + return APITypeChatCompletions + } + + if _, hasInput := raw["input"]; hasInput { + if _, hasMessages := raw["messages"]; !hasMessages { + return APITypeResponses + } + } + + if _, hasPrompt := raw["prompt"]; hasPrompt { + if _, hasMessages := raw["messages"]; !hasMessages { + return APITypeCompletions + } + } + + if _, hasContents := raw["contents"]; hasContents { + return APITypeGenerateContent + } + + if _, hasMessages := raw["messages"]; hasMessages { + switch provider { + case "anthropic": + return APITypeMessages + case "googleai": + if _, hasContents := raw["contents"]; hasContents { + return APITypeGenerateContent + } + return APITypeMessages + case "bedrock": + return APITypeConverse + } + } + + if _, hasSystem := raw["system"]; hasSystem { + if _, hasMessages := raw["messages"]; hasMessages { + return APITypeMessages + } + } + + return APITypeChatCompletions +} + +func containsPath(path, substr string) bool { + return len(path) >= len(substr) && path[len(path)-len(substr):] == substr +} diff --git a/autorouter.go b/autorouter.go new file mode 100644 index 0000000..5c23378 --- /dev/null +++ b/autorouter.go @@ -0,0 +1,255 @@ +package llmproxy + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" +) + +type AutoRouter struct { + registry Registry + detector ProviderDetector + modelProviderLookup ModelProviderLookup + interceptors InterceptorChain + client *http.Client + fallbackProvider Provider +} + +type AutoRouterOption func(*AutoRouter) + +func WithAutoRouterRegistry(r Registry) AutoRouterOption { + return func(a *AutoRouter) { a.registry = r } +} + +func WithAutoRouterDetector(d ProviderDetector) AutoRouterOption { + return func(a *AutoRouter) { a.detector = d } +} + +func WithAutoRouterInterceptor(i Interceptor) AutoRouterOption { + return func(a *AutoRouter) { a.interceptors = append(a.interceptors, i) } +} + +func WithAutoRouterHTTPClient(c *http.Client) AutoRouterOption { + return func(a *AutoRouter) { a.client = c } +} + +func WithAutoRouterFallbackProvider(p Provider) AutoRouterOption { + return func(a *AutoRouter) { a.fallbackProvider = p } +} + +func WithAutoRouterModelProviderLookup(lookup ModelProviderLookup) AutoRouterOption { + return func(a *AutoRouter) { a.modelProviderLookup = lookup } +} + +func NewAutoRouter(opts ...AutoRouterOption) *AutoRouter { + a := &AutoRouter{ + registry: NewRegistry(), + detector: DefaultProviderDetector, + client: http.DefaultClient, + } + for _, opt := range opts { + opt(a) + } + return a +} + +func (a *AutoRouter) RegisterProvider(p Provider) { + a.registry.Register(p) +} + +func (a *AutoRouter) GetProvider(name string) Provider { + p, _ := a.registry.Get(name) + return p +} + +func (a *AutoRouter) Forward(ctx context.Context, req *http.Request) (*http.Response, ResponseMetadata, error) { + body, err := io.ReadAll(req.Body) + if err != nil { + return nil, ResponseMetadata{}, err + } + req.Body.Close() + + var raw map[string]any + var model string + if err := json.Unmarshal(body, &raw); err == nil { + if m, ok := raw["model"].(string); ok { + model = m + } + } + + hint := ProviderHint{ + Model: model, + Headers: req.Header, + } + providerName := a.detector.Detect(hint) + + // If no provider detected and we have a model provider lookup, try that + if providerName == "" && a.modelProviderLookup != nil && model != "" { + providerName = a.modelProviderLookup(model) + } + + var provider Provider + if providerName != "" { + provider, _ = a.registry.Get(providerName) + if provider == nil { + // Explicit provider name was provided but not found in registry + return nil, ResponseMetadata{}, ErrNoProvider + } + } else { + // No provider detected, use fallback + provider = a.fallbackProvider + if provider == nil { + return nil, ResponseMetadata{}, ErrNoProvider + } + } + + // Strip provider prefix from model name (e.g., "openai/gpt-4" -> "gpt-4") + if raw != nil { + if strippedModel, hasPrefix := stripProviderPrefix(model); hasPrefix { + raw["model"] = strippedModel + model = strippedModel + var err error + body, err = json.Marshal(raw) + if err != nil { + return nil, ResponseMetadata{}, fmt.Errorf("failed to marshal request body: %w", err) + } + } + } + + // Detect API type: path takes precedence, then body+provider detection + apiType := DetectAPITypeFromPath(req.URL.Path) + if apiType == "" { + apiType = DetectAPITypeFromBodyAndProvider(body, providerName) + } + + meta, _, err := provider.BodyParser().Parse(io.NopCloser(bytes.NewReader(body))) + if err != nil { + return nil, ResponseMetadata{}, err + } + + if meta.Custom == nil { + meta.Custom = make(map[string]any) + } + meta.Custom["api_type"] = apiType + meta.Custom["provider"] = providerName + + upstreamURL, err := provider.URLResolver().Resolve(meta) + if err != nil { + return nil, ResponseMetadata{}, err + } + + upstreamReq, err := http.NewRequestWithContext(ctx, req.Method, upstreamURL.String(), bytes.NewReader(body)) + if err != nil { + return nil, ResponseMetadata{}, err + } + + for k, v := range req.Header { + upstreamReq.Header[k] = v + } + + if err := provider.RequestEnricher().Enrich(upstreamReq, meta, body); err != nil { + return nil, ResponseMetadata{}, err + } + + ctxValue := MetaContextValue{Meta: meta, RawBody: body} + upstreamReq = upstreamReq.WithContext(context.WithValue(upstreamReq.Context(), MetaContextKey{}, ctxValue)) + + chain := a.interceptors + roundTrip := func(req *http.Request) (*http.Response, ResponseMetadata, []byte, error) { + return a.roundTrip(provider, req) + } + + if len(chain) > 0 { + roundTrip = chain.Wrap(roundTrip) + } + + resp, respMeta, rawRespBody, err := roundTrip(upstreamReq) + if err != nil { + return nil, respMeta, err + } + + resp.Body = io.NopCloser(bytes.NewReader(rawRespBody)) + return resp, respMeta, nil +} + +func (a *AutoRouter) roundTrip(provider Provider, req *http.Request) (*http.Response, ResponseMetadata, []byte, error) { + resp, err := a.client.Do(req) + if err != nil { + return nil, ResponseMetadata{}, nil, err + } + + respMeta, rawBody, err := provider.ResponseExtractor().Extract(resp) + if err != nil { + return nil, ResponseMetadata{}, nil, err + } + + return resp, respMeta, rawBody, nil +} + +func (a *AutoRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) { + resp, meta, err := a.Forward(r.Context(), r) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + for k, v := range resp.Header { + w.Header()[k] = v + } + + if billing, ok := meta.Custom["billing_result"].(BillingResult); ok { + w.Header().Set("X-Gateway-Cost", fmt.Sprintf("%.6f", billing.TotalCost)) + w.Header().Set("X-Gateway-Prompt-Tokens", fmt.Sprintf("%d", billing.PromptTokens)) + w.Header().Set("X-Gateway-Completion-Tokens", fmt.Sprintf("%d", billing.CompletionTokens)) + } + + w.WriteHeader(resp.StatusCode) + if _, err := w.Write(body); err != nil { + // Headers already sent, can't report error to client + } +} + +var ErrNoProvider = &ProviderError{Message: "no provider available for request"} + +type ProviderError struct { + Message string +} + +func (e *ProviderError) Error() string { + return e.Message +} + +var knownProviderPrefixes = map[string]bool{ + "openai": true, + "anthropic": true, + "googleai": true, + "groq": true, + "fireworks": true, + "xai": true, + "perplexity": true, + "bedrock": true, + "azure": true, +} + +func stripProviderPrefix(model string) (stripped string, hasPrefix bool) { + idx := strings.Index(model, "/") + if idx < 0 { + return model, false + } + prefix := model[:idx] + if knownProviderPrefixes[prefix] { + return model[idx+1:], true + } + return model, false +} diff --git a/autorouter_test.go b/autorouter_test.go new file mode 100644 index 0000000..f4cc319 --- /dev/null +++ b/autorouter_test.go @@ -0,0 +1,392 @@ +package llmproxy + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" +) + +type mockProvider struct { + name string + parseFn func(io.ReadCloser) (BodyMetadata, []byte, error) + enrichFn func(*http.Request, BodyMetadata, []byte) error + resolveFn func(BodyMetadata) (*url.URL, error) + extractFn func(*http.Response) (ResponseMetadata, []byte, error) +} + +func (m *mockProvider) Name() string { return m.name } +func (m *mockProvider) BodyParser() BodyParser { + return &mockBodyParser{parseFn: m.parseFn} +} +func (m *mockProvider) RequestEnricher() RequestEnricher { + return &mockEnricher{enrichFn: m.enrichFn} +} +func (m *mockProvider) ResponseExtractor() ResponseExtractor { + return &mockExtractor{extractFn: m.extractFn} +} +func (m *mockProvider) URLResolver() URLResolver { + return &mockResolver{resolveFn: m.resolveFn} +} + +type mockBodyParser struct { + parseFn func(io.ReadCloser) (BodyMetadata, []byte, error) +} + +func (m *mockBodyParser) Parse(body io.ReadCloser) (BodyMetadata, []byte, error) { + return m.parseFn(body) +} + +type mockEnricher struct { + enrichFn func(*http.Request, BodyMetadata, []byte) error +} + +func (m *mockEnricher) Enrich(req *http.Request, meta BodyMetadata, body []byte) error { + return m.enrichFn(req, meta, body) +} + +type mockResolver struct { + resolveFn func(BodyMetadata) (*url.URL, error) +} + +func (m *mockResolver) Resolve(meta BodyMetadata) (*url.URL, error) { + return m.resolveFn(meta) +} + +type mockExtractor struct { + extractFn func(*http.Response) (ResponseMetadata, []byte, error) +} + +func (m *mockExtractor) Extract(resp *http.Response) (ResponseMetadata, []byte, error) { + return m.extractFn(resp) +} + +type mockDetector struct{ detectFn func(ProviderHint) string } + +func (m *mockDetector) Detect(hint ProviderHint) string { return m.detectFn(hint) } + +func TestAutoRouter_Forward(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"id":"test","model":"gpt-4","choices":[{"message":{"role":"assistant","content":"Hello"}}]}`)) + })) + defer upstream.Close() + + provider := &mockProvider{ + name: "test-provider", + parseFn: func(body io.ReadCloser) (BodyMetadata, []byte, error) { + data, _ := io.ReadAll(body) + return BodyMetadata{Model: "gpt-4"}, data, nil + }, + enrichFn: func(req *http.Request, meta BodyMetadata, body []byte) error { + req.Header.Set("Authorization", "Bearer test-key") + return nil + }, + resolveFn: func(meta BodyMetadata) (*url.URL, error) { + return ParseURL(upstream.URL + "/v1/chat/completions") + }, + extractFn: func(resp *http.Response) (ResponseMetadata, []byte, error) { + body, _ := io.ReadAll(resp.Body) + return ResponseMetadata{ID: "test", Model: "gpt-4"}, body, nil + }, + } + + detector := &mockDetector{ + detectFn: func(hint ProviderHint) string { return "test-provider" }, + } + + router := NewAutoRouter( + WithAutoRouterDetector(detector), + ) + router.RegisterProvider(provider) + + req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"Hello"}]}`))) + req.Header.Set("Content-Type", "application/json") + + resp, meta, err := router.Forward(context.Background(), req) + if err != nil { + t.Fatalf("Forward() error = %v", err) + } + + if resp.StatusCode != http.StatusOK { + t.Errorf("StatusCode = %d, want 200", resp.StatusCode) + } + + if meta.ID != "test" { + t.Errorf("ID = %q, want test", meta.ID) + } +} + +func TestAutoRouter_NoProvider(t *testing.T) { + detector := &mockDetector{ + detectFn: func(hint ProviderHint) string { return "" }, + } + + router := NewAutoRouter( + WithAutoRouterDetector(detector), + ) + + req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader([]byte(`{"model":"unknown-model"}`))) + + _, _, err := router.Forward(context.Background(), req) + if err == nil { + t.Fatal("Forward() expected error, got nil") + } + if err != ErrNoProvider { + t.Errorf("error = %v, want ErrNoProvider", err) + } +} + +func TestAutoRouter_FallbackProvider(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"id":"fallback"}`)) + })) + defer upstream.Close() + + fallback := &mockProvider{ + name: "fallback", + parseFn: func(body io.ReadCloser) (BodyMetadata, []byte, error) { + data, _ := io.ReadAll(body) + return BodyMetadata{}, data, nil + }, + enrichFn: func(req *http.Request, meta BodyMetadata, body []byte) error { return nil }, + resolveFn: func(meta BodyMetadata) (*url.URL, error) { + return ParseURL(upstream.URL) + }, + extractFn: func(resp *http.Response) (ResponseMetadata, []byte, error) { + body, _ := io.ReadAll(resp.Body) + return ResponseMetadata{ID: "fallback"}, body, nil + }, + } + + detector := &mockDetector{ + detectFn: func(hint ProviderHint) string { return "" }, + } + + router := NewAutoRouter( + WithAutoRouterDetector(detector), + WithAutoRouterFallbackProvider(fallback), + ) + + req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader([]byte(`{"model":"unknown"}`))) + + resp, meta, err := router.Forward(context.Background(), req) + if err != nil { + t.Fatalf("Forward() error = %v", err) + } + + if resp.StatusCode != http.StatusOK { + t.Errorf("StatusCode = %d, want 200", resp.StatusCode) + } + + if meta.ID != "fallback" { + t.Errorf("ID = %q, want fallback", meta.ID) + } +} + +func TestAutoRouter_ServeHTTP(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("X-Custom", "value") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"id":"test"}`)) + })) + defer upstream.Close() + + provider := &mockProvider{ + name: "test", + parseFn: func(body io.ReadCloser) (BodyMetadata, []byte, error) { + data, _ := io.ReadAll(body) + return BodyMetadata{}, data, nil + }, + enrichFn: func(req *http.Request, meta BodyMetadata, body []byte) error { return nil }, + resolveFn: func(meta BodyMetadata) (*url.URL, error) { + return ParseURL(upstream.URL) + }, + extractFn: func(resp *http.Response) (ResponseMetadata, []byte, error) { + body, _ := io.ReadAll(resp.Body) + return ResponseMetadata{}, body, nil + }, + } + + detector := &mockDetector{ + detectFn: func(hint ProviderHint) string { return "test" }, + } + + router := NewAutoRouter( + WithAutoRouterDetector(detector), + ) + router.RegisterProvider(provider) + + req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader([]byte(`{}`))) + w := httptest.NewRecorder() + + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("StatusCode = %d, want 200", w.Code) + } + + if w.Header().Get("X-Custom") != "value" { + t.Errorf("X-Custom header = %q, want value", w.Header().Get("X-Custom")) + } + + if w.Body.String() != `{"id":"test"}` { + t.Errorf("Body = %q, want {\"id\":\"test\"}", w.Body.String()) + } +} + +func ParseURL(s string) (*url.URL, error) { + return url.Parse(s) +} + +func TestStripProviderPrefix(t *testing.T) { + tests := []struct { + name string + model string + wantStripped string + wantHasPrefix bool + }{ + {"no prefix", "gpt-4", "gpt-4", false}, + {"openai prefix", "openai/gpt-4", "gpt-4", true}, + {"anthropic prefix", "anthropic/claude-3-opus", "claude-3-opus", true}, + {"googleai prefix", "googleai/gemini-pro", "gemini-pro", true}, + {"groq prefix", "groq/llama-3-70b", "llama-3-70b", true}, + {"fireworks prefix", "fireworks/accounts/fireworks/models/llama", "accounts/fireworks/models/llama", true}, + {"xai prefix", "xai/grok-1", "grok-1", true}, + {"perplexity prefix", "perplexity/sonar-small", "sonar-small", true}, + {"bedrock prefix", "bedrock/anthropic.claude-3", "anthropic.claude-3", true}, + {"azure prefix", "azure/gpt-4-deployment", "gpt-4-deployment", true}, + {"multiple slashes preserved", "openai/gpt-4/turbo", "gpt-4/turbo", true}, + {"empty string", "", "", false}, + {"slash only - not a provider", "/", "/", false}, + {"openai slash at end", "openai/", "", true}, + {"non-provider prefix preserved", "accounts/fireworks/models/llama", "accounts/fireworks/models/llama", false}, + {"unknown prefix", "unknown/model-name", "unknown/model-name", false}, + {"model with slash not stripped", "some/path/to/model", "some/path/to/model", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + stripped, hasPrefix := stripProviderPrefix(tt.model) + if stripped != tt.wantStripped { + t.Errorf("stripProviderPrefix(%q) stripped = %q, want %q", tt.model, stripped, tt.wantStripped) + } + if hasPrefix != tt.wantHasPrefix { + t.Errorf("stripProviderPrefix(%q) hasPrefix = %v, want %v", tt.model, hasPrefix, tt.wantHasPrefix) + } + }) + } +} + +func TestAutoRouter_StripsProviderPrefixFromBody(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + var req map[string]any + json.Unmarshal(body, &req) + model := req["model"].(string) + if strings.Contains(model, "/") { + t.Errorf("model sent to upstream contains slash: %q", model) + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"id":"test","model":"gpt-4","choices":[]}`)) + })) + defer upstream.Close() + + provider := &mockProvider{ + name: "openai", + parseFn: func(body io.ReadCloser) (BodyMetadata, []byte, error) { + data, _ := io.ReadAll(body) + return BodyMetadata{Model: "gpt-4"}, data, nil + }, + enrichFn: func(req *http.Request, meta BodyMetadata, body []byte) error { + req.Header.Set("Authorization", "Bearer test-key") + return nil + }, + resolveFn: func(meta BodyMetadata) (*url.URL, error) { + return url.Parse(upstream.URL + "/v1/chat/completions") + }, + extractFn: func(resp *http.Response) (ResponseMetadata, []byte, error) { + body, _ := io.ReadAll(resp.Body) + return ResponseMetadata{ID: "test"}, body, nil + }, + } + + router := NewAutoRouter( + WithAutoRouterDetector(ProviderDetectorFunc(func(hint ProviderHint) string { + return "openai" + })), + ) + router.RegisterProvider(provider) + + req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader([]byte(`{"model":"openai/gpt-4","messages":[{"role":"user","content":"Hello"}]}`))) + req.Header.Set("Content-Type", "application/json") + + resp, _, err := router.Forward(context.Background(), req) + if err != nil { + t.Fatalf("Forward() error = %v", err) + } + + if resp.StatusCode != http.StatusOK { + t.Errorf("StatusCode = %d, want 200", resp.StatusCode) + } +} + +func TestAutoRouter_PreservesModelWithoutPrefix(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + var req map[string]any + json.Unmarshal(body, &req) + model := req["model"].(string) + if model != "gpt-4" { + t.Errorf("model sent to upstream = %q, want gpt-4", model) + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"id":"test","model":"gpt-4","choices":[]}`)) + })) + defer upstream.Close() + + provider := &mockProvider{ + name: "openai", + parseFn: func(body io.ReadCloser) (BodyMetadata, []byte, error) { + data, _ := io.ReadAll(body) + return BodyMetadata{Model: "gpt-4"}, data, nil + }, + enrichFn: func(req *http.Request, meta BodyMetadata, body []byte) error { return nil }, + resolveFn: func(meta BodyMetadata) (*url.URL, error) { + return url.Parse(upstream.URL + "/v1/chat/completions") + }, + extractFn: func(resp *http.Response) (ResponseMetadata, []byte, error) { + body, _ := io.ReadAll(resp.Body) + return ResponseMetadata{ID: "test"}, body, nil + }, + } + + router := NewAutoRouter( + WithAutoRouterDetector(ProviderDetectorFunc(func(hint ProviderHint) string { + return "openai" + })), + ) + router.RegisterProvider(provider) + + req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"Hello"}]}`))) + req.Header.Set("Content-Type", "application/json") + + resp, _, err := router.Forward(context.Background(), req) + if err != nil { + t.Fatalf("Forward() error = %v", err) + } + + if resp.StatusCode != http.StatusOK { + t.Errorf("StatusCode = %d, want 200", resp.StatusCode) + } +} diff --git a/detection.go b/detection.go new file mode 100644 index 0000000..879b0ed --- /dev/null +++ b/detection.go @@ -0,0 +1,141 @@ +package llmproxy + +import ( + "net/http" + "strings" +) + +// ProviderHint contains information that can be used to detect the provider. +type ProviderHint struct { + Model string + Headers http.Header +} + +// ProviderDetector determines the upstream provider based on request characteristics. +type ProviderDetector interface { + Detect(hint ProviderHint) string +} + +// ProviderDetectorFunc is a function that implements ProviderDetector. +type ProviderDetectorFunc func(hint ProviderHint) string + +func (f ProviderDetectorFunc) Detect(hint ProviderHint) string { + return f(hint) +} + +// ModelProviderLookup is a function that finds the provider for a given model name. +// This can be backed by models.dev or another model registry. +type ModelProviderLookup func(model string) string + +// DefaultProviderDetector detects the provider from model name patterns and headers. +// Precedence: X-Provider header > model pattern > other header heuristics +var DefaultProviderDetector = ProviderDetectorFunc(func(hint ProviderHint) string { + // X-Provider header is always an explicit override + if hint.Headers != nil { + if provider := hint.Headers.Get("X-Provider"); provider != "" { + return provider + } + } + + // Model-based detection takes precedence over header heuristics + if hint.Model != "" { + if provider := DetectProviderFromModel(hint.Model); provider != "" { + return provider + } + } + + // Header heuristics as fallback + if hint.Headers != nil { + return detectProviderFromHeaderHeuristics(hint.Headers) + } + + return "" +}) + +func detectProviderFromHeaderHeuristics(headers http.Header) string { + if headers.Get("anthropic-version") != "" || strings.HasPrefix(headers.Get("X-API-Key"), "sk-ant-") { + return "anthropic" + } + + if strings.HasPrefix(headers.Get("Authorization"), "Bearer sk-") { + if strings.Contains(headers.Get("Authorization"), "sk-proj-") { + return "openai" + } + } + + if headers.Get("api-key") != "" { + return "azure" + } + + if strings.HasPrefix(headers.Get("Authorization"), "Bearer gsk_") { + return "groq" + } + + return "" +} + +// DetectProviderFromModel returns the provider name based on model naming patterns. +func DetectProviderFromModel(model string) string { + if model == "" { + return "" + } + + // Check for explicit provider prefix (e.g., "openai/gpt-4", "anthropic/claude-3-opus") + if idx := strings.Index(model, "/"); idx >= 0 { + prefix := model[:idx] + switch prefix { + case "openai", "anthropic", "googleai", "groq", "fireworks", "xai", "perplexity", "bedrock", "azure": + return prefix + } + } + + switch { + case strings.HasPrefix(model, "gpt-"), + strings.HasPrefix(model, "o1-"), + strings.HasPrefix(model, "o3-"), + strings.HasPrefix(model, "o4-"), + strings.HasPrefix(model, "chatgpt-"), + strings.HasPrefix(model, "text-"), + strings.HasPrefix(model, "davinci-"), + strings.HasPrefix(model, "curie-"), + strings.HasPrefix(model, "babbage-"), + strings.HasPrefix(model, "ada-"): + return "openai" + + case strings.HasPrefix(model, "claude-"), + strings.HasPrefix(model, "claude"): + return "anthropic" + + case strings.HasPrefix(model, "gemini-"), + strings.HasPrefix(model, "gemma-"), + strings.HasPrefix(model, "palm-"): + return "googleai" + + case strings.HasPrefix(model, "grok-"): + return "xai" + + case strings.HasPrefix(model, "llama-"), + strings.HasPrefix(model, "mixtral-"), + strings.HasPrefix(model, "mistral-"): + if strings.Contains(model, "groq") { + return "groq" + } + return "openai_compatible" + + case strings.HasPrefix(model, "accounts/fireworks/"), + strings.HasPrefix(model, "fireworks"): + return "fireworks" + + case strings.Contains(model, "sonar"): + return "perplexity" + + case strings.HasPrefix(model, "amazon."), + strings.HasPrefix(model, "anthropic.claude-"), + strings.HasPrefix(model, "meta."), + strings.HasPrefix(model, "cohere."): + return "bedrock" + + default: + return "" + } +} diff --git a/detection_test.go b/detection_test.go new file mode 100644 index 0000000..b4571b2 --- /dev/null +++ b/detection_test.go @@ -0,0 +1,279 @@ +package llmproxy + +import ( + "net/http" + "testing" +) + +func TestDetectAPIType(t *testing.T) { + tests := []struct { + name string + body string + expected APIType + }{ + { + name: "chat completions with messages", + body: `{"model":"gpt-4","messages":[{"role":"user","content":"hello"}]}`, + expected: APITypeChatCompletions, + }, + { + name: "responses API with input", + body: `{"model":"gpt-4o","input":"hello world"}`, + expected: APITypeResponses, + }, + { + name: "responses API with instructions", + body: `{"model":"gpt-4o","input":"hello","instructions":"be helpful"}`, + expected: APITypeResponses, + }, + { + name: "legacy completions with prompt", + body: `{"model":"gpt-3.5-turbo-instruct","prompt":"hello"}`, + expected: APITypeCompletions, + }, + { + name: "invalid JSON defaults to chat completions", + body: `invalid`, + expected: APITypeChatCompletions, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := DetectAPIType([]byte(tt.body)) + if result != tt.expected { + t.Errorf("DetectAPIType() = %v, want %v", result, tt.expected) + } + }) + } +} + +func TestDetectAPITypeFromPath(t *testing.T) { + tests := []struct { + name string + path string + expected APIType + }{ + { + name: "chat completions path", + path: "/v1/chat/completions", + expected: APITypeChatCompletions, + }, + { + name: "responses path", + path: "/v1/responses", + expected: APITypeResponses, + }, + { + name: "legacy completions path", + path: "/v1/completions", + expected: APITypeCompletions, + }, + { + name: "anthropic messages path", + path: "/v1/messages", + expected: APITypeMessages, + }, + { + name: "google generate content path", + path: "/v1/models/gemini-pro:generateContent", + expected: APITypeGenerateContent, + }, + { + name: "bedrock converse path", + path: "/model/model-id/converse", + expected: APITypeConverse, + }, + { + name: "unknown path returns empty", + path: "/unknown", + expected: "", + }, + { + name: "root path returns empty", + path: "/", + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := DetectAPITypeFromPath(tt.path) + if result != tt.expected { + t.Errorf("DetectAPITypeFromPath() = %v, want %v", result, tt.expected) + } + }) + } +} + +func TestDetectProviderFromModel(t *testing.T) { + tests := []struct { + name string + model string + expected string + }{ + {"gpt-4", "gpt-4", "openai"}, + {"gpt-3.5-turbo", "gpt-3.5-turbo", "openai"}, + {"o1-preview", "o1-preview", "openai"}, + {"o3-mini", "o3-mini", "openai"}, + {"chatgpt-4o-latest", "chatgpt-4o-latest", "openai"}, + {"claude-3-opus", "claude-3-opus", "anthropic"}, + {"claude-3-5-sonnet", "claude-3-5-sonnet", "anthropic"}, + {"gemini-pro", "gemini-pro", "googleai"}, + {"gemma-2b", "gemma-2b", "googleai"}, + {"grok-1", "grok-1", "xai"}, + {"fireworks-llama", "accounts/fireworks/models/llama-v3", "fireworks"}, + {"sonar-small", "sonar-small-online", "perplexity"}, + {"bedrock claude", "anthropic.claude-3-sonnet", "bedrock"}, + {"bedrock amazon", "amazon.titan-text-express", "bedrock"}, + {"unknown model", "unknown-model", ""}, + {"empty model", "", ""}, + // Provider prefix tests + {"openai/gpt-4 prefix", "openai/gpt-4", "openai"}, + {"anthropic/claude-3-opus prefix", "anthropic/claude-3-opus", "anthropic"}, + {"googleai/gemini-pro prefix", "googleai/gemini-pro", "googleai"}, + {"groq/llama-3 prefix", "groq/llama-3-70b", "groq"}, + {"fireworks/llama prefix", "fireworks/llama-v3", "fireworks"}, + {"xai/grok prefix", "xai/grok-1", "xai"}, + {"perplexity/sonar prefix", "perplexity/sonar-small", "perplexity"}, + {"bedrock/claude prefix", "bedrock/anthropic.claude-3", "bedrock"}, + {"azure/gpt-4 prefix", "azure/gpt-4", "azure"}, + {"unknown/ prefix returns unknown", "unknown/model", ""}, + {"single slash only", "/", ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := DetectProviderFromModel(tt.model) + if result != tt.expected { + t.Errorf("DetectProviderFromModel(%q) = %q, want %q", tt.model, result, tt.expected) + } + }) + } +} + +func TestDefaultProviderDetector(t *testing.T) { + tests := []struct { + name string + hint ProviderHint + expected string + }{ + { + name: "detect from model", + hint: ProviderHint{Model: "gpt-4"}, + expected: "openai", + }, + { + name: "detect from anthropic model", + hint: ProviderHint{Model: "claude-3-opus"}, + expected: "anthropic", + }, + { + name: "detect from X-Provider header", + hint: ProviderHint{ + Model: "unknown", + Headers: http.Header{"X-Provider": []string{"custom-provider"}}, + }, + expected: "custom-provider", + }, + { + name: "empty hint returns empty", + hint: ProviderHint{}, + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := DefaultProviderDetector.Detect(tt.hint) + if result != tt.expected { + t.Errorf("Detect() = %q, want %q", result, tt.expected) + } + }) + } +} + +func TestDetectAPITypeFromBodyAndProvider(t *testing.T) { + tests := []struct { + name string + body string + provider string + expected APIType + }{ + { + name: "openai with messages -> chat completions", + body: `{"model":"gpt-4","messages":[{"role":"user","content":"hello"}]}`, + provider: "openai", + expected: APITypeChatCompletions, + }, + { + name: "anthropic with messages -> messages API", + body: `{"model":"claude-3-opus","messages":[{"role":"user","content":"hello"}]}`, + provider: "anthropic", + expected: APITypeMessages, + }, + { + name: "anthropic with system and messages -> messages API", + body: `{"model":"claude-3-opus","system":"You are helpful","messages":[{"role":"user","content":"hello"}]}`, + provider: "anthropic", + expected: APITypeMessages, + }, + { + name: "responses API with input", + body: `{"model":"gpt-4o","input":"hello world"}`, + provider: "openai", + expected: APITypeResponses, + }, + { + name: "legacy completions with prompt", + body: `{"model":"gpt-3.5-turbo-instruct","prompt":"hello"}`, + provider: "openai", + expected: APITypeCompletions, + }, + { + name: "googleai with contents -> generateContent", + body: `{"model":"gemini-pro","contents":[{"parts":[{"text":"hello"}]}]}`, + provider: "googleai", + expected: APITypeGenerateContent, + }, + { + name: "groq with messages -> chat completions", + body: `{"model":"llama-3-70b","messages":[{"role":"user","content":"hello"}]}`, + provider: "groq", + expected: APITypeChatCompletions, + }, + { + name: "bedrock with messages -> converse", + body: `{"model":"anthropic.claude-3","messages":[{"role":"user","content":"hello"}]}`, + provider: "bedrock", + expected: APITypeConverse, + }, + { + name: "unknown provider with messages -> chat completions", + body: `{"model":"unknown-model","messages":[{"role":"user","content":"hello"}]}`, + provider: "", + expected: APITypeChatCompletions, + }, + { + name: "system without messages -> chat completions", + body: `{"model":"model","system":"be helpful"}`, + provider: "openai", + expected: APITypeChatCompletions, + }, + { + name: "invalid JSON -> chat completions", + body: `invalid`, + provider: "openai", + expected: APITypeChatCompletions, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := DetectAPITypeFromBodyAndProvider([]byte(tt.body), tt.provider) + if result != tt.expected { + t.Errorf("DetectAPITypeFromBodyAndProvider() = %v, want %v", result, tt.expected) + } + }) + } +} diff --git a/examples/basic/main.go b/examples/basic/main.go index 9505913..2d3d051 100644 --- a/examples/basic/main.go +++ b/examples/basic/main.go @@ -1,25 +1,8 @@ -// Example_basic demonstrates a basic proxy setup with multiple providers. -// Providers are configured from standard environment variables. -// -// Usage: -// -// export OPENAI_API_KEY=sk-your-key -// export MODELS_DEV_JSON=/path/to/models.json # optional, for billing -// go run main.go -// -// With OpenTelemetry tracing: -// -// otelExporter, _ := otlptracehttp.New(ctx) -// tp := tracesdk.NewTracerProvider(tracesdk.WithBatcher(otelExporter)) -// defer tp.Shutdown(ctx) -// otel.SetTracerProvider(tp) -// # Then run the example - traces will be propagated upstream package main import ( "context" "fmt" - "io" "log" "net/http" "os" @@ -40,7 +23,6 @@ import ( "go.opentelemetry.io/otel/trace" ) -// consoleLogger implements llmproxy.Logger using log.Default() type consoleLogger struct { prefix string } @@ -62,18 +44,16 @@ func (l *consoleLogger) Error(msg string, args ...interface{}) { } func main() { - ctx := context.Background() logr := &consoleLogger{prefix: "[llmproxy]"} - registry := llmproxy.NewRegistry() + var providers []llmproxy.Provider - // Register providers from environment variables if apiKey := os.Getenv("OPENAI_API_KEY"); apiKey != "" { provider, err := openai.New(apiKey) if err != nil { log.Fatalf("failed to create openai provider: %v", err) } - registry.Register(provider) + providers = append(providers, provider) logr.Info("Registered: OpenAI") } @@ -82,7 +62,7 @@ func main() { if err != nil { log.Fatalf("failed to create anthropic provider: %v", err) } - registry.Register(provider) + providers = append(providers, provider) logr.Info("Registered: Anthropic") } @@ -91,7 +71,7 @@ func main() { if err != nil { log.Fatalf("failed to create groq provider: %v", err) } - registry.Register(provider) + providers = append(providers, provider) logr.Info("Registered: Groq") } @@ -100,7 +80,7 @@ func main() { if err != nil { log.Fatalf("failed to create fireworks provider: %v", err) } - registry.Register(provider) + providers = append(providers, provider) logr.Info("Registered: Fireworks") } @@ -109,7 +89,7 @@ func main() { if err != nil { log.Fatalf("failed to create xai provider: %v", err) } - registry.Register(provider) + providers = append(providers, provider) logr.Info("Registered: x.AI") } @@ -118,7 +98,7 @@ func main() { if err != nil { log.Fatalf("failed to create perplexity provider: %v", err) } - registry.Register(provider) + providers = append(providers, provider) logr.Info("Registered: Perplexity") } @@ -127,13 +107,12 @@ func main() { if err != nil { log.Fatalf("failed to create googleai provider: %v", err) } - registry.Register(provider) + providers = append(providers, provider) logr.Info("Registered: Google AI") } - // Azure OpenAI if resourceName := os.Getenv("AZURE_OPENAI_RESOURCE"); resourceName != "" { - deploymentID := os.Getenv("AZURE_OPENAI_DEPLOYMENT") // optional, uses model from request if empty + deploymentID := os.Getenv("AZURE_OPENAI_DEPLOYMENT") apiVersion := os.Getenv("AZURE_OPENAI_API_VERSION") if apiVersion == "" { apiVersion = azure.DefaultAPIVersion() @@ -149,75 +128,62 @@ func main() { if err != nil { log.Fatalf("failed to create azure provider: %v", err) } - registry.Register(provider) + providers = append(providers, provider) logr.Info("Registered: Azure OpenAI") } } - // AWS Bedrock requires multiple environment variables if region := os.Getenv("AWS_REGION"); region != "" { accessKeyID := os.Getenv("AWS_ACCESS_KEY_ID") secretAccessKey := os.Getenv("AWS_SECRET_ACCESS_KEY") - sessionToken := os.Getenv("AWS_SESSION_TOKEN") // Optional + sessionToken := os.Getenv("AWS_SESSION_TOKEN") if accessKeyID != "" && secretAccessKey != "" { provider, err := bedrock.New(region, accessKeyID, secretAccessKey, sessionToken) if err != nil { log.Fatalf("failed to create bedrock provider: %v", err) } - registry.Register(provider) + providers = append(providers, provider) logr.Info("Registered: AWS Bedrock") } } - // Check we have at least one provider - openaiProvider, hasOpenAI := registry.Get("openai") - if !hasOpenAI { - openaiProvider, _ = registry.Get("groq") - } - if openaiProvider == nil { - for _, name := range []string{"anthropic", "fireworks", "xai", "perplexity", "googleai"} { - if p, ok := registry.Get(name); ok { - openaiProvider = p - break - } - } - } - - if openaiProvider == nil { + if len(providers) == 0 { log.Fatal("No providers configured. Set at least one API key environment variable.") } - // Load pricing data from models.dev (optional) var costLookup llmproxy.CostLookup + var modelProviderLookup llmproxy.ModelProviderLookup modelsFile := os.Getenv("MODELS_DEV_JSON") modelsURL := os.Getenv("MODELS_DEV_URL") + var adapter *modelsdev.Adapter + var err error if modelsFile != "" { - // Load from local file - adapter, err := modelsdev.LoadFromFile(modelsFile) + adapter, err = modelsdev.LoadFromFile(modelsFile) if err != nil { log.Printf("Warning: could not load models.dev from file: %v", err) } else { costLookup = adapter.GetCostLookup() + modelProviderLookup = adapter.FindProviderForModel logr.Info("Billing enabled from file: %s", modelsFile) } } else if modelsURL != "" { - // Load from custom URL - adapter := modelsdev.New(modelsdev.WithURL(modelsURL)) + adapter = modelsdev.New(modelsdev.WithURL(modelsURL)) if err := adapter.Load(nil); err != nil { log.Printf("Warning: could not load models.dev from URL: %v", err) } else { costLookup = adapter.GetCostLookup() + modelProviderLookup = adapter.FindProviderForModel logr.Info("Billing enabled from URL: %s", modelsURL) } } else { - // Try to fetch from models.dev directly - adapter, err := modelsdev.LoadFromURL() + adapter, err = modelsdev.LoadFromURL() if err != nil { log.Printf("Warning: could not fetch models.dev: %v (billing disabled)", err) } else { costLookup = adapter.GetCostLookup() + modelProviderLookup = adapter.FindProviderForModel logr.Info("Billing enabled from https://models.dev/api.json") } } @@ -237,173 +203,71 @@ func main() { } }) - // OpenAI-compatible endpoint - http.HandleFunc("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { - provider := openaiProvider - opts := []llmproxy.ProxyOption{ - llmproxy.WithInterceptor(interceptors.NewRetryWithRateLimitHeaders(3, time.Millisecond*250)), - llmproxy.WithInterceptor(tracingInterceptor), - llmproxy.WithInterceptor(loggingInterceptor), - llmproxy.WithInterceptor(interceptors.NewMetrics(metrics)), - llmproxy.WithInterceptor(interceptors.NewResponseHeaderBan("Openai-Organization", "Openai-Project", "Set-Cookie")), - llmproxy.WithInterceptor(interceptors.NewAddRequestHeader(interceptors.NewHeader("User-Agent", "Agentuity AI Gateway/1.0"))), - llmproxy.WithInterceptor(interceptors.NewAddResponseHeader(interceptors.NewHeader("Server", "Agentuity AI Gateway/1.0"))), - } - if costLookup != nil { - opts = append(opts, llmproxy.WithInterceptor(interceptors.NewBilling(costLookup, func(r llmproxy.BillingResult) { - logr.Info("Billing: model=%s tokens=%d/%d cost=$%.6f", r.Model, r.PromptTokens, r.CompletionTokens, r.TotalCost) - w.Header().Set("agentuity-gateway-cost", fmt.Sprintf("%f", r.TotalCost)) - w.Header().Set("agentuity-gateway-prompt-tokens", fmt.Sprintf("%d", r.PromptTokens)) - w.Header().Set("agentuity-gateway-completion-tokens", fmt.Sprintf("%d", r.CompletionTokens)) - }))) - } - proxy := llmproxy.NewProxy(provider, opts...) - resp, _, err := proxy.Forward(ctx, r) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - defer resp.Body.Close() - - for k, v := range resp.Header { - w.Header()[k] = v - } - w.WriteHeader(resp.StatusCode) - io.Copy(w, resp.Body) - }) - - // Anthropic endpoint - anthropicProvider, hasAnthropic := registry.Get("anthropic") - if hasAnthropic { - http.HandleFunc("/v1/messages", func(w http.ResponseWriter, r *http.Request) { - opts := []llmproxy.ProxyOption{ - llmproxy.WithInterceptor(tracingInterceptor), - llmproxy.WithInterceptor(loggingInterceptor), - } - if costLookup != nil { - opts = append(opts, llmproxy.WithInterceptor(interceptors.NewBilling(costLookup, func(r llmproxy.BillingResult) { - logr.Info("Billing: model=%s tokens=%d/%d cost=$%.6f", r.Model, r.PromptTokens, r.CompletionTokens, r.TotalCost) - }))) - } - proxy := llmproxy.NewProxy(anthropicProvider, opts...) - resp, _, err := proxy.Forward(ctx, r) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - defer resp.Body.Close() - - for k, v := range resp.Header { - w.Header()[k] = v - } - w.WriteHeader(resp.StatusCode) - io.Copy(w, resp.Body) - }) + opts := []llmproxy.AutoRouterOption{ + llmproxy.WithAutoRouterInterceptor(interceptors.NewRetryWithRateLimitHeaders(3, time.Millisecond*250)), + llmproxy.WithAutoRouterInterceptor(tracingInterceptor), + llmproxy.WithAutoRouterInterceptor(loggingInterceptor), + llmproxy.WithAutoRouterInterceptor(interceptors.NewMetrics(metrics)), + llmproxy.WithAutoRouterInterceptor(interceptors.NewResponseHeaderBan("Openai-Organization", "Openai-Project", "Set-Cookie")), + llmproxy.WithAutoRouterInterceptor(interceptors.NewAddRequestHeader(interceptors.NewHeader("User-Agent", "Agentuity AI Gateway/1.0"))), + llmproxy.WithAutoRouterInterceptor(interceptors.NewAddResponseHeader(interceptors.NewHeader("Server", "Agentuity AI Gateway/1.0"))), + llmproxy.WithAutoRouterFallbackProvider(providers[0]), } - // Google AI endpoint - googleaiProvider, hasGoogleAI := registry.Get("googleai") - if hasGoogleAI { - http.HandleFunc("/v1beta/models/", func(w http.ResponseWriter, r *http.Request) { - opts := []llmproxy.ProxyOption{ - llmproxy.WithInterceptor(tracingInterceptor), - llmproxy.WithInterceptor(loggingInterceptor), - } - if costLookup != nil { - opts = append(opts, llmproxy.WithInterceptor(interceptors.NewBilling(costLookup, func(r llmproxy.BillingResult) { - logr.Info("Billing: model=%s tokens=%d/%d cost=$%.6f", r.Model, r.PromptTokens, r.CompletionTokens, r.TotalCost) - }))) - } - proxy := llmproxy.NewProxy(googleaiProvider, opts...) - resp, _, err := proxy.Forward(ctx, r) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - defer resp.Body.Close() - - for k, v := range resp.Header { - w.Header()[k] = v - } - w.WriteHeader(resp.StatusCode) - io.Copy(w, resp.Body) - }) + if modelProviderLookup != nil { + opts = append(opts, llmproxy.WithAutoRouterModelProviderLookup(modelProviderLookup)) } - // Azure OpenAI endpoint - azureProvider, hasAzure := registry.Get("azure") - if hasAzure { - http.HandleFunc("/azure/openai/deployments/", func(w http.ResponseWriter, r *http.Request) { - opts := []llmproxy.ProxyOption{ - llmproxy.WithInterceptor(tracingInterceptor), - llmproxy.WithInterceptor(loggingInterceptor), - } - if costLookup != nil { - opts = append(opts, llmproxy.WithInterceptor(interceptors.NewBilling(costLookup, func(r llmproxy.BillingResult) { - logr.Info("Billing: model=%s tokens=%d/%d cost=$%.6f", r.Model, r.PromptTokens, r.CompletionTokens, r.TotalCost) - }))) - } - proxy := llmproxy.NewProxy(azureProvider, opts...) - resp, _, err := proxy.Forward(ctx, r) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - defer resp.Body.Close() - - for k, v := range resp.Header { - w.Header()[k] = v - } - w.WriteHeader(resp.StatusCode) - io.Copy(w, resp.Body) - }) + if costLookup != nil { + opts = append(opts, llmproxy.WithAutoRouterInterceptor(interceptors.NewBilling(costLookup, func(r llmproxy.BillingResult) { + logr.Info("Billing: provider=%s model=%s tokens=%d/%d cost=$%.6f", r.Provider, r.Model, r.PromptTokens, r.CompletionTokens, r.TotalCost) + }))) } - logr.Info("Proxy listening on :8080") - logr.Info("Endpoints:") - logr.Info(" POST /v1/chat/completions -> OpenAI-compatible providers") - if hasAnthropic { - logr.Info(" POST /v1/messages -> Anthropic") - } - if hasGoogleAI { - logr.Info(" POST /v1beta/models/{model}:generateContent -> Google AI") - } - if hasAzure { - logr.Info(" POST /azure/openai/deployments/{deployment}/chat/completions -> Azure OpenAI") + router := llmproxy.NewAutoRouter(opts...) + + for _, p := range providers { + router.RegisterProvider(p) } - // Bedrock endpoint - bedrockProvider, hasBedrock := registry.Get("bedrock") - if hasBedrock { - http.HandleFunc("/model/", func(w http.ResponseWriter, r *http.Request) { - // Extract model ID from path: /model/{modelId}/converse or /model/{modelId}/invoke - opts := []llmproxy.ProxyOption{ - llmproxy.WithInterceptor(tracingInterceptor), - llmproxy.WithInterceptor(loggingInterceptor), - } - if costLookup != nil { - opts = append(opts, llmproxy.WithInterceptor(interceptors.NewBilling(costLookup, func(r llmproxy.BillingResult) { - logr.Info("Billing: model=%s tokens=%d/%d cost=$%.6f", r.Model, r.PromptTokens, r.CompletionTokens, r.TotalCost) - }))) - } - proxy := llmproxy.NewProxy(bedrockProvider, opts...) - resp, _, err := proxy.Forward(ctx, r) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - defer resp.Body.Close() + http.Handle("/", router) - for k, v := range resp.Header { - w.Header()[k] = v - } - w.WriteHeader(resp.StatusCode) - io.Copy(w, resp.Body) - }) - if hasBedrock { - logr.Info(" POST /model/{modelId}/converse -> AWS Bedrock (Converse API)") - } + logr.Info("Proxy listening on :8080") + logr.Info("") + logr.Info("Auto-routing enabled - POST to any endpoint or just '/'") + logr.Info("Provider detected from:") + logr.Info(" 1. X-Provider header (explicit override)") + logr.Info(" 2. Model name pattern (gpt-* -> OpenAI, claude-* -> Anthropic, etc.)") + if modelProviderLookup != nil { + logr.Info(" 3. models.dev registry (fallback for unknown models)") } + logr.Info("") + logr.Info("API type detected from:") + logr.Info(" 1. Request path (/v1/messages, /v1/responses, etc.)") + logr.Info(" 2. Request body shape (input -> Responses, messages -> Chat/Messages)") + logr.Info("") + logr.Info("Supported endpoints (all optional - POST to / works too):") + logr.Info(" POST / (auto-detect from body)") + logr.Info(" POST /v1/chat/completions (OpenAI Chat Completions API)") + logr.Info(" POST /v1/responses (OpenAI Responses API)") + logr.Info(" POST /v1/messages (Anthropic Messages API)") + logr.Info(" POST /v1/completions (Legacy OpenAI Completions API)") + logr.Info("") + logr.Info("Example requests:") + logr.Info(" # Auto-detect everything - POST to /") + logr.Info(" curl -X POST http://localhost:8080/ \\") + logr.Info(" -H 'Content-Type: application/json' \\") + logr.Info(" -d '{\"model\":\"gpt-4\",\"messages\":[{\"role\":\"user\",\"content\":\"Hello\"}]}'") + logr.Info("") + logr.Info(" # Auto-detect Anthropic from model name") + logr.Info(" curl -X POST http://localhost:8080/ \\") + logr.Info(" -H 'Content-Type: application/json' \\") + logr.Info(" -d '{\"model\":\"claude-3-opus\",\"max_tokens\":1024,\"messages\":[{\"role\":\"user\",\"content\":\"Hello\"}]}'") + logr.Info("") + logr.Info(" # Use Responses API with input field") + logr.Info(" curl -X POST http://localhost:8080/ \\") + logr.Info(" -H 'Content-Type: application/json' \\") + logr.Info(" -d '{\"model\":\"gpt-4o\",\"input\":\"Hello\"}'") if err := http.ListenAndServe(":8080", nil); err != nil { log.Fatalf("server error: %v", err) diff --git a/interceptors/billing.go b/interceptors/billing.go index 56d8419..8e092e8 100644 --- a/interceptors/billing.go +++ b/interceptors/billing.go @@ -2,7 +2,6 @@ package interceptors import ( "net/http" - "strings" "github.com/agentuity/llmproxy" ) @@ -25,8 +24,16 @@ func (i *BillingInterceptor) Intercept(req *http.Request, meta llmproxy.BodyMeta return resp, respMeta, rawRespBody, err } - // Try to get provider name from model prefix - provider := detectProvider(meta.Model) + // Prefer router-resolved provider from metadata, fall back to model detection + var provider string + if meta.Custom != nil { + if p, ok := meta.Custom["provider"].(string); ok && p != "" { + provider = p + } + } + if provider == "" { + provider = llmproxy.DetectProviderFromModel(meta.Model) + } // Look up pricing with provider first costInfo, found := i.Lookup(provider, meta.Model) @@ -35,7 +42,7 @@ func (i *BillingInterceptor) Intercept(req *http.Request, meta llmproxy.BodyMeta costInfo, found = i.Lookup("", meta.Model) } - if found && i.OnResult != nil { + if found { // Extract cache usage from response metadata if available var cacheUsage *llmproxy.CacheUsage if cu, ok := respMeta.Custom["cache_usage"]; ok { @@ -44,28 +51,18 @@ func (i *BillingInterceptor) Intercept(req *http.Request, meta llmproxy.BodyMeta } } result := llmproxy.CalculateCost(provider, meta.Model, costInfo, respMeta.Usage.PromptTokens, respMeta.Usage.CompletionTokens, cacheUsage) - i.OnResult(result) + if respMeta.Custom == nil { + respMeta.Custom = make(map[string]any) + } + respMeta.Custom["billing_result"] = result + if i.OnResult != nil { + i.OnResult(result) + } } return resp, respMeta, rawRespBody, nil } -// detectProvider attempts to determine the provider from the model name. -func detectProvider(model string) string { - modelLower := strings.ToLower(model) - switch { - case strings.Contains(modelLower, "gpt-") || strings.Contains(modelLower, "o1-") || strings.Contains(modelLower, "o3-") || strings.Contains(modelLower, "chatgpt"): - return "openai" - case strings.Contains(modelLower, "claude"): - return "anthropic" - case strings.Contains(modelLower, "gemini"): - return "google" - case strings.Contains(modelLower, "llama") || strings.Contains(modelLower, "mixtral"): - return "groq" - } - return "" -} - // NewBilling creates a new billing interceptor with the given lookup function. // // Example: diff --git a/interceptors/billing_test.go b/interceptors/billing_test.go index c8f6470..47bf7e0 100644 --- a/interceptors/billing_test.go +++ b/interceptors/billing_test.go @@ -281,18 +281,18 @@ func TestDetectProvider(t *testing.T) { {"chatgpt-4o", "openai"}, {"claude-3-opus", "anthropic"}, {"claude-3-sonnet", "anthropic"}, - {"gemini-pro", "google"}, - {"gemini-1.5-flash", "google"}, - {"llama-3-70b", "groq"}, - {"mixtral-8x7b", "groq"}, + {"gemini-pro", "googleai"}, + {"gemini-1.5-flash", "googleai"}, + {"llama-3-70b", "openai_compatible"}, + {"mixtral-8x7b", "openai_compatible"}, {"unknown-model", ""}, } for _, tt := range tests { t.Run(tt.model, func(t *testing.T) { - got := detectProvider(tt.model) + got := llmproxy.DetectProviderFromModel(tt.model) if got != tt.expected { - t.Errorf("detectProvider(%q) = %q, want %q", tt.model, got, tt.expected) + t.Errorf("DetectProviderFromModel(%q) = %q, want %q", tt.model, got, tt.expected) } }) } diff --git a/pricing/modelsdev/adapter.go b/pricing/modelsdev/adapter.go index a08e50a..5600e53 100644 --- a/pricing/modelsdev/adapter.go +++ b/pricing/modelsdev/adapter.go @@ -186,6 +186,37 @@ func (a *Adapter) Lookup(provider string, model string) (llmproxy.CostInfo, bool return llmproxy.CostInfo{}, false } +// FindProviderForModel searches all providers to find which one has the given model. +// Returns the provider ID if found, or empty string if not found. +// This is useful for provider detection when the model name is known but provider is not. +// +// If data is not loaded or TTL expired, it loads automatically. +// Note: Auto-load uses context.Background() - call Load() explicitly +// to control the context for initial load. +func (a *Adapter) FindProviderForModel(model string) string { + a.mu.RLock() + needLoad := len(a.data) == 0 || (!a.expires.IsZero() && time.Now().After(a.expires)) + a.mu.RUnlock() + + if needLoad { + _ = a.Load(context.Background()) + } + + a.mu.RLock() + defer a.mu.RUnlock() + + if len(a.data) == 0 { + return "" + } + + for providerID, provider := range a.data { + if _, exists := provider.Models[model]; exists { + return providerID + } + } + return "" +} + // GetCostLookup returns a CostLookup function for use with interceptors. // // Example: diff --git a/providers/openai/provider.go b/providers/openai/provider.go index 3e0aea8..f163c6c 100644 --- a/providers/openai/provider.go +++ b/providers/openai/provider.go @@ -20,5 +20,5 @@ import ( // // provider, _ := openai.New("sk-your-openai-api-key") func New(apiKey string) (*openai_compatible.Provider, error) { - return openai_compatible.New("openai", apiKey, "https://api.openai.com") + return openai_compatible.NewMultiAPI("openai", apiKey, "https://api.openai.com") } diff --git a/providers/openai_compatible/multiapi.go b/providers/openai_compatible/multiapi.go new file mode 100644 index 0000000..8b65cb0 --- /dev/null +++ b/providers/openai_compatible/multiapi.go @@ -0,0 +1,78 @@ +package openai_compatible + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + + "github.com/agentuity/llmproxy" +) + +type MultiAPIParser struct { + chatCompletionsParser *Parser + responsesParser *ResponsesParser +} + +func NewMultiAPIParser() *MultiAPIParser { + return &MultiAPIParser{ + chatCompletionsParser: &Parser{}, + responsesParser: &ResponsesParser{}, + } +} + +func (p *MultiAPIParser) Parse(body io.ReadCloser) (llmproxy.BodyMetadata, []byte, error) { + data, err := io.ReadAll(body) + if err != nil { + return llmproxy.BodyMetadata{}, nil, err + } + body.Close() + + apiType := llmproxy.DetectAPIType(data) + switch apiType { + case llmproxy.APITypeResponses: + return p.responsesParser.Parse(io.NopCloser(bytes.NewReader(data))) + default: + return p.chatCompletionsParser.Parse(io.NopCloser(bytes.NewReader(data))) + } +} + +type MultiAPIExtractor struct { + chatCompletionsExtractor *Extractor + responsesExtractor *ResponsesExtractor +} + +func NewMultiAPIExtractor() *MultiAPIExtractor { + return &MultiAPIExtractor{ + chatCompletionsExtractor: &Extractor{}, + responsesExtractor: &ResponsesExtractor{}, + } +} + +func (e *MultiAPIExtractor) Extract(resp *http.Response) (llmproxy.ResponseMetadata, []byte, error) { + body, err := io.ReadAll(resp.Body) + if err != nil { + return llmproxy.ResponseMetadata{}, nil, err + } + resp.Body.Close() + + // Detect response type by inspecting response-specific fields + // Responses API has "output" and "status", Chat Completions has "choices" + var raw map[string]any + isResponsesAPI := false + if err := json.Unmarshal(body, &raw); err == nil { + if _, hasOutput := raw["output"]; hasOutput { + if _, hasChoices := raw["choices"]; !hasChoices { + isResponsesAPI = true + } + } + } + + // Restore body for downstream extractors + resp.Body = io.NopCloser(bytes.NewReader(body)) + + if isResponsesAPI { + return e.responsesExtractor.Extract(resp) + } + return e.chatCompletionsExtractor.Extract(resp) +} diff --git a/providers/openai_compatible/provider.go b/providers/openai_compatible/provider.go index 30e0927..6830f04 100644 --- a/providers/openai_compatible/provider.go +++ b/providers/openai_compatible/provider.go @@ -36,6 +36,22 @@ func New(name, apiKey, baseURL string) (*Provider, error) { }, nil } +func NewMultiAPI(name, apiKey, baseURL string) (*Provider, error) { + resolver, err := NewResolver(baseURL) + if err != nil { + return nil, err + } + + return &Provider{ + BaseProvider: llmproxy.NewBaseProvider(name, + llmproxy.WithBodyParser(NewMultiAPIParser()), + llmproxy.WithRequestEnricher(NewEnricher(apiKey)), + llmproxy.WithResponseExtractor(NewMultiAPIExtractor()), + llmproxy.WithURLResolver(resolver), + ), + }, nil +} + // NewWithProvider creates a Provider that wraps an existing BaseProvider. // Use this when you need to customize individual components before creating the provider. // diff --git a/providers/openai_compatible/resolver.go b/providers/openai_compatible/resolver.go index 0030c22..2234b2c 100644 --- a/providers/openai_compatible/resolver.go +++ b/providers/openai_compatible/resolver.go @@ -6,22 +6,31 @@ import ( "github.com/agentuity/llmproxy" ) -// Resolver implements llmproxy.URLResolver for OpenAI-compatible APIs. -// It constructs the chat completions endpoint URL from a base URL. type Resolver struct { - // BaseURL is the provider's API base URL (e.g., "https://api.openai.com"). BaseURL *url.URL + APIType llmproxy.APIType } -// Resolve returns the full URL for the chat completions endpoint. -// It appends "/v1/chat/completions" to the base URL. func (r *Resolver) Resolve(meta llmproxy.BodyMetadata) (*url.URL, error) { - endpoint := r.BaseURL.JoinPath("v1", "chat", "completions") - return endpoint, nil + apiType := r.APIType + if apiType == "" { + if v, ok := meta.Custom["api_type"].(llmproxy.APIType); ok { + apiType = v + } else { + apiType = llmproxy.APITypeChatCompletions + } + } + + switch apiType { + case llmproxy.APITypeResponses: + return r.BaseURL.JoinPath("v1", "responses"), nil + case llmproxy.APITypeCompletions: + return r.BaseURL.JoinPath("v1", "completions"), nil + default: + return r.BaseURL.JoinPath("v1", "chat", "completions"), nil + } } -// NewResolver creates a new resolver with the given base URL. -// The baseURL should be the provider's API domain (e.g., "https://api.openai.com"). func NewResolver(baseURL string) (*Resolver, error) { u, err := url.Parse(baseURL) if err != nil { @@ -29,3 +38,11 @@ func NewResolver(baseURL string) (*Resolver, error) { } return &Resolver{BaseURL: u}, nil } + +func NewResolverWithAPIType(baseURL string, apiType llmproxy.APIType) (*Resolver, error) { + u, err := url.Parse(baseURL) + if err != nil { + return nil, err + } + return &Resolver{BaseURL: u, APIType: apiType}, nil +} diff --git a/providers/openai_compatible/responses_extractor.go b/providers/openai_compatible/responses_extractor.go new file mode 100644 index 0000000..a3b17c3 --- /dev/null +++ b/providers/openai_compatible/responses_extractor.go @@ -0,0 +1,135 @@ +package openai_compatible + +import ( + "encoding/json" + "io" + "net/http" + + "github.com/agentuity/llmproxy" +) + +type ResponsesExtractor struct{} + +func (e *ResponsesExtractor) Extract(resp *http.Response) (llmproxy.ResponseMetadata, []byte, error) { + body, err := io.ReadAll(resp.Body) + if err != nil { + return llmproxy.ResponseMetadata{}, nil, err + } + + var responsesResp ResponsesResponse + if err := json.Unmarshal(body, &responsesResp); err != nil { + return llmproxy.ResponseMetadata{}, nil, err + } + + meta := llmproxy.ResponseMetadata{ + ID: responsesResp.ID, + Object: responsesResp.Object, + Model: responsesResp.Model, + Usage: llmproxy.Usage{ + PromptTokens: responsesResp.Usage.InputTokens, + CompletionTokens: responsesResp.Usage.OutputTokens, + TotalTokens: responsesResp.Usage.TotalTokens, + }, + Custom: make(map[string]any), + } + + if responsesResp.Usage.InputTokensDetails != nil && responsesResp.Usage.InputTokensDetails.CachedTokens > 0 { + meta.Custom["cache_usage"] = llmproxy.CacheUsage{ + CachedTokens: responsesResp.Usage.InputTokensDetails.CachedTokens, + } + } + + if len(responsesResp.Output) > 0 { + content := extractResponsesContent(responsesResp.Output) + meta.Choices = []llmproxy.Choice{ + { + Index: 0, + Message: &llmproxy.Message{Role: "assistant", Content: content}, + FinishReason: responsesResp.Status, + }, + } + } + + meta.Custom["status"] = responsesResp.Status + meta.Custom["api_type"] = llmproxy.APITypeResponses + if responsesResp.Error != nil { + meta.Custom["error"] = responsesResp.Error + } + + return meta, body, nil +} + +func extractResponsesContent(output []ResponsesOutputItem) string { + var texts []string + for _, item := range output { + if item.Type == "message" { + for _, c := range item.Content { + if c.Type == "output_text" && c.Text != "" { + texts = append(texts, c.Text) + } + } + } + } + if len(texts) == 0 { + return "" + } + if len(texts) == 1 { + return texts[0] + } + // Join multiple text segments with newline + result := texts[0] + for _, t := range texts[1:] { + result += "\n" + t + } + return result +} + +type ResponsesResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Status string `json:"status"` + Output []ResponsesOutputItem `json:"output"` + Usage ResponsesUsage `json:"usage"` + Error *ResponsesError `json:"error,omitempty"` +} + +type ResponsesOutputItem struct { + ID string `json:"id"` + Type string `json:"type"` + Status string `json:"status"` + Role string `json:"role,omitempty"` + Content []ResponsesOutputContent `json:"content,omitempty"` +} + +type ResponsesOutputContent struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` +} + +type ResponsesUsage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + TotalTokens int `json:"total_tokens"` + InputTokensDetails *ResponsesInputDetails `json:"input_tokens_details,omitempty"` + OutputTokensDetails *ResponsesOutputDetails `json:"output_tokens_details,omitempty"` +} + +type ResponsesInputDetails struct { + CachedTokens int `json:"cached_tokens,omitempty"` +} + +type ResponsesOutputDetails struct { + ReasoningTokens int `json:"reasoning_tokens,omitempty"` +} + +type ResponsesError struct { + Type string `json:"type"` + Code string `json:"code"` + Message string `json:"message"` +} + +func NewResponsesExtractor() *ResponsesExtractor { + return &ResponsesExtractor{} +} diff --git a/providers/openai_compatible/responses_parser.go b/providers/openai_compatible/responses_parser.go new file mode 100644 index 0000000..0b5fded --- /dev/null +++ b/providers/openai_compatible/responses_parser.go @@ -0,0 +1,121 @@ +package openai_compatible + +import ( + "bytes" + "encoding/json" + "io" + + "github.com/agentuity/llmproxy" +) + +type ResponsesParser struct{} + +func (p *ResponsesParser) Parse(body io.ReadCloser) (llmproxy.BodyMetadata, []byte, error) { + data, err := io.ReadAll(body) + if err != nil { + return llmproxy.BodyMetadata{}, nil, err + } + body.Close() + + var req ResponsesRequest + if err := json.Unmarshal(data, &req); err != nil { + return llmproxy.BodyMetadata{}, nil, err + } + + meta := llmproxy.BodyMetadata{ + Model: req.Model, + MaxTokens: req.MaxOutputTokens, + Stream: req.Stream, + Custom: make(map[string]any), + } + + if req.Input != nil { + switch v := req.Input.(type) { + case string: + meta.Messages = []llmproxy.Message{{Role: "user", Content: v}} + case []interface{}: + msgs := make([]llmproxy.Message, 0, len(v)) + for _, item := range v { + if m, ok := item.(map[string]interface{}); ok { + role, hasRole := m["role"].(string) + content, hasContent := m["content"].(string) + // Only append if both role and content are present and non-empty + if hasRole && hasContent && role != "" && content != "" { + msgs = append(msgs, llmproxy.Message{Role: role, Content: content}) + } + } + } + meta.Messages = msgs + } + } + + for k, v := range req.Custom { + meta.Custom[k] = v + } + + meta.Custom["api_type"] = llmproxy.APITypeResponses + if len(req.Instructions) > 0 { + meta.Custom["instructions"] = req.Instructions + } + if len(req.Tools) > 0 { + meta.Custom["tools"] = req.Tools + } + + return meta, data, nil +} + +type ResponsesRequest struct { + Model string `json:"model"` + Input interface{} `json:"input,omitempty"` + Instructions string `json:"instructions,omitempty"` + MaxOutputTokens int `json:"max_output_tokens,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + Stream bool `json:"stream"` + Tools []interface{} `json:"tools,omitempty"` + Truncation string `json:"truncation,omitempty"` + Custom map[string]interface{} `json:"-"` +} + +func (r *ResponsesRequest) UnmarshalJSON(data []byte) error { + type Alias ResponsesRequest + aux := &struct { + *Alias + }{ + Alias: (*Alias)(r), + } + if err := json.Unmarshal(data, &aux); err != nil { + return err + } + + var raw map[string]interface{} + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + + r.Custom = make(map[string]interface{}) + known := map[string]bool{ + "model": true, "input": true, "instructions": true, + "max_output_tokens": true, "temperature": true, "top_p": true, + "stream": true, "tools": true, "truncation": true, + "user": true, "metadata": true, "parallel_tool_calls": true, + "previous_response_id": true, "response_format": true, + "seed": true, "service_tier": true, "store": true, + } + for k, v := range raw { + if !known[k] { + r.Custom[k] = v + } + } + + return nil +} + +func ParseResponsesRequest(body io.ReadCloser) (llmproxy.BodyMetadata, []byte, error) { + return (&ResponsesParser{}).Parse(body) +} + +func ParseResponsesRequestBody(data []byte) (llmproxy.BodyMetadata, error) { + meta, _, err := (&ResponsesParser{}).Parse(io.NopCloser(bytes.NewReader(data))) + return meta, err +} diff --git a/providers/openai_compatible/responses_test.go b/providers/openai_compatible/responses_test.go new file mode 100644 index 0000000..e6bb62d --- /dev/null +++ b/providers/openai_compatible/responses_test.go @@ -0,0 +1,289 @@ +package openai_compatible + +import ( + "bytes" + "io" + "net/http" + "testing" + + "github.com/agentuity/llmproxy" +) + +func TestResponsesParser(t *testing.T) { + body := `{ + "model": "gpt-4o", + "input": "Hello, world!", + "instructions": "Be helpful", + "max_output_tokens": 100, + "temperature": 0.7, + "tools": [{"type": "web_search_preview"}] + }` + + parser := &ResponsesParser{} + meta, data, err := parser.Parse(io.NopCloser(bytes.NewReader([]byte(body)))) + if err != nil { + t.Fatalf("Parse() error = %v", err) + } + + if meta.Model != "gpt-4o" { + t.Errorf("Model = %q, want %q", meta.Model, "gpt-4o") + } + + if meta.MaxTokens != 100 { + t.Errorf("MaxTokens = %d, want 100", meta.MaxTokens) + } + + if len(meta.Messages) != 1 { + t.Errorf("Messages length = %d, want 1", len(meta.Messages)) + } else { + if meta.Messages[0].Role != "user" { + t.Errorf("Message role = %q, want user", meta.Messages[0].Role) + } + if meta.Messages[0].Content != "Hello, world!" { + t.Errorf("Message content = %q, want 'Hello, world!'", meta.Messages[0].Content) + } + } + + if meta.Custom["instructions"] != "Be helpful" { + t.Errorf("instructions = %v, want 'Be helpful'", meta.Custom["instructions"]) + } + + if meta.Custom["api_type"] != llmproxy.APITypeResponses { + t.Errorf("api_type = %v, want responses", meta.Custom["api_type"]) + } + + if len(data) == 0 { + t.Error("data is empty") + } +} + +func TestResponsesParser_InputArray(t *testing.T) { + body := `{ + "model": "gpt-4o", + "input": [ + {"role": "system", "content": "You are helpful"}, + {"role": "user", "content": "Hello"} + ] + }` + + parser := &ResponsesParser{} + meta, _, err := parser.Parse(io.NopCloser(bytes.NewReader([]byte(body)))) + if err != nil { + t.Fatalf("Parse() error = %v", err) + } + + if len(meta.Messages) != 2 { + t.Fatalf("Messages length = %d, want 2", len(meta.Messages)) + } + + if meta.Messages[0].Role != "system" { + t.Errorf("First message role = %q, want system", meta.Messages[0].Role) + } + if meta.Messages[1].Role != "user" { + t.Errorf("Second message role = %q, want user", meta.Messages[1].Role) + } +} + +func TestResponsesExtractor(t *testing.T) { + respBody := `{ + "id": "resp_abc123", + "object": "response", + "created": 1234567890, + "model": "gpt-4o", + "status": "completed", + "output": [ + { + "id": "msg_123", + "type": "message", + "status": "completed", + "role": "assistant", + "content": [ + { + "type": "output_text", + "text": "Hello! How can I help you?" + } + ] + } + ], + "usage": { + "input_tokens": 10, + "output_tokens": 20, + "total_tokens": 30, + "input_tokens_details": { + "cached_tokens": 5 + } + } + }` + + resp := &http.Response{ + StatusCode: 200, + Header: make(http.Header), + Body: io.NopCloser(bytes.NewReader([]byte(respBody))), + } + + extractor := &ResponsesExtractor{} + meta, rawBody, err := extractor.Extract(resp) + if err != nil { + t.Fatalf("Extract() error = %v", err) + } + + if meta.ID != "resp_abc123" { + t.Errorf("ID = %q, want resp_abc123", meta.ID) + } + + if meta.Model != "gpt-4o" { + t.Errorf("Model = %q, want gpt-4o", meta.Model) + } + + if meta.Usage.PromptTokens != 10 { + t.Errorf("PromptTokens = %d, want 10", meta.Usage.PromptTokens) + } + if meta.Usage.CompletionTokens != 20 { + t.Errorf("CompletionTokens = %d, want 20", meta.Usage.CompletionTokens) + } + if meta.Usage.TotalTokens != 30 { + t.Errorf("TotalTokens = %d, want 30", meta.Usage.TotalTokens) + } + + if len(meta.Choices) != 1 { + t.Fatalf("Choices length = %d, want 1", len(meta.Choices)) + } + + if meta.Choices[0].Message == nil { + t.Fatal("Message is nil") + } + if meta.Choices[0].Message.Content != "Hello! How can I help you?" { + t.Errorf("Message content = %q, want 'Hello! How can I help you?'", meta.Choices[0].Message.Content) + } + + if meta.Custom["status"] != "completed" { + t.Errorf("status = %v, want completed", meta.Custom["status"]) + } + + cacheUsage, ok := meta.Custom["cache_usage"].(llmproxy.CacheUsage) + if !ok { + t.Fatal("cache_usage not found or wrong type") + } + if cacheUsage.CachedTokens != 5 { + t.Errorf("CachedTokens = %d, want 5", cacheUsage.CachedTokens) + } + + if len(rawBody) == 0 { + t.Error("rawBody is empty") + } +} + +func TestMultiAPIParser_ChatCompletions(t *testing.T) { + body := `{ + "model": "gpt-4", + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 100 + }` + + parser := NewMultiAPIParser() + meta, _, err := parser.Parse(io.NopCloser(bytes.NewReader([]byte(body)))) + if err != nil { + t.Fatalf("Parse() error = %v", err) + } + + if meta.Model != "gpt-4" { + t.Errorf("Model = %q, want gpt-4", meta.Model) + } + + if len(meta.Messages) != 1 { + t.Errorf("Messages length = %d, want 1", len(meta.Messages)) + } +} + +func TestMultiAPIParser_Responses(t *testing.T) { + body := `{ + "model": "gpt-4o", + "input": "Hello", + "instructions": "Be helpful" + }` + + parser := NewMultiAPIParser() + meta, _, err := parser.Parse(io.NopCloser(bytes.NewReader([]byte(body)))) + if err != nil { + t.Fatalf("Parse() error = %v", err) + } + + if meta.Model != "gpt-4o" { + t.Errorf("Model = %q, want gpt-4o", meta.Model) + } + + if meta.Custom["api_type"] != llmproxy.APITypeResponses { + t.Errorf("api_type = %v, want responses", meta.Custom["api_type"]) + } +} + +func TestResolver_ResponsesAPI(t *testing.T) { + resolver, err := NewResolver("https://api.openai.com") + if err != nil { + t.Fatalf("NewResolver() error = %v", err) + } + + meta := llmproxy.BodyMetadata{ + Model: "gpt-4o", + Custom: map[string]any{"api_type": llmproxy.APITypeResponses}, + } + + url, err := resolver.Resolve(meta) + if err != nil { + t.Fatalf("Resolve() error = %v", err) + } + + expected := "https://api.openai.com/v1/responses" + if url.String() != expected { + t.Errorf("URL = %q, want %q", url.String(), expected) + } +} + +func TestResolver_ChatCompletionsAPI(t *testing.T) { + resolver, err := NewResolver("https://api.openai.com") + if err != nil { + t.Fatalf("NewResolver() error = %v", err) + } + + meta := llmproxy.BodyMetadata{ + Model: "gpt-4", + Custom: map[string]any{"api_type": llmproxy.APITypeChatCompletions}, + } + + url, err := resolver.Resolve(meta) + if err != nil { + t.Fatalf("Resolve() error = %v", err) + } + + expected := "https://api.openai.com/v1/chat/completions" + if url.String() != expected { + t.Errorf("URL = %q, want %q", url.String(), expected) + } +} + +func TestNewMultiAPI(t *testing.T) { + provider, err := NewMultiAPI("test", "api-key", "https://api.example.com") + if err != nil { + t.Fatalf("NewMultiAPI() error = %v", err) + } + + if provider.Name() != "test" { + t.Errorf("Name() = %q, want test", provider.Name()) + } + + if provider.BodyParser() == nil { + t.Error("BodyParser() is nil") + } + + if provider.ResponseExtractor() == nil { + t.Error("ResponseExtractor() is nil") + } + + if provider.RequestEnricher() == nil { + t.Error("RequestEnricher() is nil") + } + + if provider.URLResolver() == nil { + t.Error("URLResolver() is nil") + } +}