Skip to content

Commit feca199

Browse files
committed
chore: address comments
1 parent a6abd82 commit feca199

File tree

2 files changed

+79
-1
lines changed

2 files changed

+79
-1
lines changed

enterprise/aiproxyd/aiproxyd_test.go

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,16 @@ package aiproxyd_test
33
import (
44
"crypto/rand"
55
"crypto/rsa"
6+
"crypto/tls"
67
"crypto/x509"
78
"crypto/x509/pkix"
89
"encoding/pem"
10+
"io"
911
"math/big"
12+
"net"
13+
"net/http"
14+
"net/http/httptest"
15+
"net/url"
1016
"os"
1117
"path/filepath"
1218
"testing"
@@ -17,6 +23,7 @@ import (
1723
"cdr.dev/slog/sloggers/slogtest"
1824

1925
"github.com/coder/coder/v2/enterprise/aiproxyd"
26+
"github.com/coder/coder/v2/testutil"
2027
)
2128

2229
// generateTestCA creates a temporary CA certificate and key for testing.
@@ -76,6 +83,7 @@ func TestNew(t *testing.T) {
7683

7784
t.Run("MissingCertFile", func(t *testing.T) {
7885
t.Parallel()
86+
7987
logger := slogtest.Make(t, nil)
8088

8189
_, err := aiproxyd.New(t.Context(), logger, aiproxyd.Options{
@@ -88,6 +96,7 @@ func TestNew(t *testing.T) {
8896

8997
t.Run("MissingKeyFile", func(t *testing.T) {
9098
t.Parallel()
99+
91100
logger := slogtest.Make(t, nil)
92101

93102
_, err := aiproxyd.New(t.Context(), logger, aiproxyd.Options{
@@ -100,6 +109,7 @@ func TestNew(t *testing.T) {
100109

101110
t.Run("InvalidCertFile", func(t *testing.T) {
102111
t.Parallel()
112+
103113
logger := slogtest.Make(t, nil)
104114

105115
_, err := aiproxyd.New(t.Context(), logger, aiproxyd.Options{
@@ -150,3 +160,69 @@ func TestClose(t *testing.T) {
150160
err = srv.Close()
151161
require.NoError(t, err)
152162
}
163+
164+
func TestProxy_MITM(t *testing.T) {
165+
t.Parallel()
166+
167+
// Create a mock HTTPS server that will be the target of our proxied request.
168+
targetServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
169+
w.WriteHeader(http.StatusOK)
170+
_, _ = w.Write([]byte("hello from target"))
171+
}))
172+
defer targetServer.Close()
173+
174+
certFile, keyFile := generateTestCA(t)
175+
logger := slogtest.Make(t, nil)
176+
177+
// Start the proxy server.
178+
srv, err := aiproxyd.New(t.Context(), logger, aiproxyd.Options{
179+
ListenAddr: "127.0.0.1:8888",
180+
CertFile: certFile,
181+
KeyFile: keyFile,
182+
})
183+
require.NoError(t, err)
184+
t.Cleanup(func() { _ = srv.Close() })
185+
186+
// Wait for the proxy server to be ready.
187+
require.Eventually(t, func() bool {
188+
conn, err := net.Dial("tcp", "127.0.0.1:8888")
189+
if err != nil {
190+
return false
191+
}
192+
_ = conn.Close()
193+
return true
194+
}, testutil.WaitShort, testutil.IntervalFast)
195+
196+
// Load the CA certificate so the client trusts the proxy's MITM certificate.
197+
certPEM, err := os.ReadFile(certFile)
198+
require.NoError(t, err)
199+
certPool := x509.NewCertPool()
200+
certPool.AppendCertsFromPEM(certPEM)
201+
202+
// Create an HTTP client configured to use the proxy.
203+
proxyURL, err := url.Parse("http://127.0.0.1:8888")
204+
require.NoError(t, err)
205+
206+
client := &http.Client{
207+
Transport: &http.Transport{
208+
Proxy: http.ProxyURL(proxyURL),
209+
TLSClientConfig: &tls.Config{
210+
MinVersion: tls.VersionTLS12,
211+
RootCAs: certPool,
212+
},
213+
},
214+
}
215+
216+
// Make a request through the proxy to the target server.
217+
req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, targetServer.URL, nil)
218+
require.NoError(t, err)
219+
resp, err := client.Do(req)
220+
require.NoError(t, err)
221+
defer resp.Body.Close()
222+
223+
// Verify the response was successfully proxied.
224+
body, err := io.ReadAll(resp.Body)
225+
require.NoError(t, err)
226+
require.Equal(t, http.StatusOK, resp.StatusCode)
227+
require.Equal(t, "hello from target", string(body))
228+
}

enterprise/cli/aiproxyd.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ package cli
55
import (
66
"context"
77

8+
"golang.org/x/xerrors"
9+
810
"github.com/coder/coder/v2/enterprise/aiproxyd"
911
"github.com/coder/coder/v2/enterprise/coderd"
1012
)
@@ -21,7 +23,7 @@ func newAIProxyDaemon(coderAPI *coderd.API) (*aiproxyd.Server, error) {
2123
KeyFile: coderAPI.DeploymentValues.AI.ProxyConfig.KeyFile.String(),
2224
})
2325
if err != nil {
24-
return nil, err
26+
return nil, xerrors.Errorf("failed to start in-memory aiproxy daemon: %w", err)
2527
}
2628

2729
return srv, nil

0 commit comments

Comments
 (0)