Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions bridge.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net/http"
"net/url"
"strings"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -90,7 +91,17 @@ func NewRequestBridge(ctx context.Context, providers []provider.Provider, rec re
// Add the known provider-specific routes which are bridged (i.e. intercepted and augmented).
for _, path := range prov.BridgedRoutes() {
handler := newInterceptionProcessor(prov, cbs, rec, mcpProxy, logger, m, tracer)
mux.Handle(path, handler)
route, err := url.JoinPath(prov.RoutePrefix(), path)
if err != nil {
logger.Error(ctx, "failed to join path",
slog.Error(err),
slog.F("provider", providerName),
slog.F("prefix", prov.RoutePrefix()),
slog.F("path", path),
)
return nil, fmt.Errorf("failed to configure provider '%v': failed to join bridged path: %w", providerName, err)
}
mux.Handle(route, handler)
}

// Any requests which passthrough to this will be reverse-proxied to the upstream.
Expand All @@ -99,9 +110,17 @@ func NewRequestBridge(ctx context.Context, providers []provider.Provider, rec re
// configured, so we should just reverse-proxy known-safe routes.
ftr := newPassthroughRouter(prov, logger.Named(fmt.Sprintf("passthrough.%s", prov.Name())), m, tracer)
for _, path := range prov.PassthroughRoutes() {
prefix := fmt.Sprintf("/%s", prov.Name())
route := fmt.Sprintf("%s%s", prefix, path)
mux.HandleFunc(route, http.StripPrefix(prefix, ftr).ServeHTTP)
route, err := url.JoinPath(prov.RoutePrefix(), path)
if err != nil {
logger.Error(ctx, "failed to join path",
slog.Error(err),
slog.F("provider", providerName),
slog.F("prefix", prov.RoutePrefix()),
slog.F("path", path),
)
return nil, fmt.Errorf("failed to configure provider '%v': failed to join passed through path: %w", providerName, err)
}
mux.Handle(route, http.StripPrefix(prov.RoutePrefix(), ftr))
}
}

Expand Down
56 changes: 33 additions & 23 deletions bridge_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -772,17 +772,21 @@ func TestFallthrough(t *testing.T) {
t.Parallel()

testCases := []struct {
name string
providerName string
fixture []byte
basePath string
configureFunc func(string, aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge)
name string
providerName string
fixture []byte
basePath string
requestPath string
expectedUpstreamPath string
configureFunc func(string, aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge)
}{
{
name: "ant_empty_base_url_path",
providerName: config.ProviderAnthropic,
fixture: fixtures.AntFallthrough,
basePath: "",
name: "ant_empty_base_url_path",
providerName: config.ProviderAnthropic,
fixture: fixtures.AntFallthrough,
basePath: "",
requestPath: "/anthropic/v1/models",
expectedUpstreamPath: "/v1/models",
configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
provider := provider.NewAnthropic(anthropicCfg(addr, apiKey), nil)
Expand All @@ -792,10 +796,12 @@ func TestFallthrough(t *testing.T) {
},
},
{
name: "oai_empty_base_url_path",
providerName: config.ProviderOpenAI,
fixture: fixtures.OaiChatFallthrough,
basePath: "",
name: "oai_empty_base_url_path",
providerName: config.ProviderOpenAI,
fixture: fixtures.OaiChatFallthrough,
basePath: "",
requestPath: "/openai/v1/models",
expectedUpstreamPath: "/models",
configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
provider := provider.NewOpenAI(openaiCfg(addr, apiKey))
Expand All @@ -805,10 +811,12 @@ func TestFallthrough(t *testing.T) {
},
},
{
name: "ant_some_base_url_path",
providerName: config.ProviderAnthropic,
fixture: fixtures.AntFallthrough,
basePath: "/api",
name: "ant_some_base_url_path",
providerName: config.ProviderAnthropic,
fixture: fixtures.AntFallthrough,
basePath: "/api",
requestPath: "/anthropic/v1/models",
expectedUpstreamPath: "/api/v1/models",
configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
provider := provider.NewAnthropic(anthropicCfg(addr, apiKey), nil)
Expand All @@ -818,10 +826,12 @@ func TestFallthrough(t *testing.T) {
},
},
{
name: "oai_some_base_url_path",
providerName: config.ProviderOpenAI,
fixture: fixtures.OaiChatFallthrough,
basePath: "/api",
name: "oai_some_base_url_path",
providerName: config.ProviderOpenAI,
fixture: fixtures.OaiChatFallthrough,
basePath: "/api",
requestPath: "/openai/v1/models",
expectedUpstreamPath: "/api/models",
configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
provider := provider.NewOpenAI(openaiCfg(addr, apiKey))
Expand All @@ -841,7 +851,7 @@ func TestFallthrough(t *testing.T) {

files := filesMap(arc)
require.Contains(t, files, fixtureResponse)
expectedPath := tc.basePath + "/v1/models"
expectedPath := tc.expectedUpstreamPath

var receivedHeaders *http.Header
respBody := files[fixtureResponse]
Expand Down Expand Up @@ -871,7 +881,7 @@ func TestFallthrough(t *testing.T) {
bridgeSrv.Start()
t.Cleanup(bridgeSrv.Close)

req, err := http.NewRequestWithContext(t.Context(), "GET", fmt.Sprintf("%s/%s/v1/models", bridgeSrv.URL, tc.providerName), nil)
req, err := http.NewRequestWithContext(t.Context(), "GET", fmt.Sprintf("%s%s", bridgeSrv.URL, tc.requestPath), nil)
require.NoError(t, err)

resp, err := http.DefaultClient.Do(req)
Expand Down
106 changes: 106 additions & 0 deletions bridge_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
package aibridge

import (
"net/http"
"net/http/httptest"
"testing"

"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/aibridge/config"
"github.com/coder/aibridge/internal/testutil"
"github.com/coder/aibridge/provider"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestPassthroughRoutesForProviders(t *testing.T) {
t.Parallel()

upstreamRespBody := "upstream response"
tests := []struct {
name string
baseURLPath string
requestPath string
provider func(string) provider.Provider
expectPath string
}{
{
name: "openAI_no_base_path",
requestPath: "/openai/v1/conversations",
provider: func(baseURL string) provider.Provider {
return NewOpenAIProvider(config.OpenAI{BaseURL: baseURL})
},
expectPath: "/conversations",
},
{
name: "openAI_with_base_path",
baseURLPath: "/v1",
requestPath: "/openai/v1/conversations",
provider: func(baseURL string) provider.Provider {
return NewOpenAIProvider(config.OpenAI{BaseURL: baseURL})
},
expectPath: "/v1/conversations",
},
{
name: "anthropic_no_base_path",
requestPath: "/anthropic/v1/models",
provider: func(baseURL string) provider.Provider {
return NewAnthropicProvider(config.Anthropic{BaseURL: baseURL}, nil)
},
expectPath: "/v1/models",
},
{
name: "anthropic_with_base_path",
baseURLPath: "/v1",
requestPath: "/anthropic/v1/models",
provider: func(baseURL string) provider.Provider {
return NewAnthropicProvider(config.Anthropic{BaseURL: baseURL}, nil)
},
expectPath: "/v1/v1/models",
},
{
name: "copilot_no_base_path",
requestPath: "/copilot/models",
provider: func(baseURL string) provider.Provider {
return NewCopilotProvider(config.Copilot{BaseURL: baseURL})
},
expectPath: "/models",
},
{
name: "copilot_with_base_path",
baseURLPath: "/v1",
requestPath: "/copilot/models",
provider: func(baseURL string) provider.Provider {
return NewCopilotProvider(config.Copilot{BaseURL: baseURL})
},
expectPath: "/v1/models",
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

logger := slogtest.Make(t, nil)

upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, tc.expectPath, r.URL.Path)
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(upstreamRespBody))
}))
t.Cleanup(upstream.Close)

recorder := testutil.MockRecorder{}
prov := tc.provider(upstream.URL + tc.baseURLPath)
bridge, err := NewRequestBridge(t.Context(), []provider.Provider{prov}, &recorder, nil, logger, nil, testTracer)
require.NoError(t, err)

req := httptest.NewRequest("", tc.requestPath, nil)
resp := httptest.NewRecorder()
bridge.mux.ServeHTTP(resp, req)

assert.Equal(t, http.StatusOK, resp.Code)
assert.Contains(t, resp.Body.String(), upstreamRespBody)
})
}
}
Loading