@@ -3,10 +3,16 @@ package aiproxyd_test
33import (
44 "crypto/rand"
55 "crypto/rsa"
6+ "crypto/tls"
67 "crypto/x509"
78 "crypto/x509/pkix"
89 "encoding/pem"
10+ "io"
911 "math/big"
12+ "net"
13+ "net/http"
14+ "net/http/httptest"
15+ "net/url"
1016 "os"
1117 "path/filepath"
1218 "testing"
@@ -17,6 +23,7 @@ import (
1723 "cdr.dev/slog/sloggers/slogtest"
1824
1925 "github.com/coder/coder/v2/enterprise/aiproxyd"
26+ "github.com/coder/coder/v2/testutil"
2027)
2128
2229// generateTestCA creates a temporary CA certificate and key for testing.
@@ -76,6 +83,7 @@ func TestNew(t *testing.T) {
7683
7784 t .Run ("MissingCertFile" , func (t * testing.T ) {
7885 t .Parallel ()
86+
7987 logger := slogtest .Make (t , nil )
8088
8189 _ , err := aiproxyd .New (t .Context (), logger , aiproxyd.Options {
@@ -88,6 +96,7 @@ func TestNew(t *testing.T) {
8896
8997 t .Run ("MissingKeyFile" , func (t * testing.T ) {
9098 t .Parallel ()
99+
91100 logger := slogtest .Make (t , nil )
92101
93102 _ , err := aiproxyd .New (t .Context (), logger , aiproxyd.Options {
@@ -100,6 +109,7 @@ func TestNew(t *testing.T) {
100109
101110 t .Run ("InvalidCertFile" , func (t * testing.T ) {
102111 t .Parallel ()
112+
103113 logger := slogtest .Make (t , nil )
104114
105115 _ , err := aiproxyd .New (t .Context (), logger , aiproxyd.Options {
@@ -150,3 +160,69 @@ func TestClose(t *testing.T) {
150160 err = srv .Close ()
151161 require .NoError (t , err )
152162}
163+
164+ func TestProxy_MITM (t * testing.T ) {
165+ t .Parallel ()
166+
167+ // Create a mock HTTPS server that will be the target of our proxied request.
168+ targetServer := httptest .NewTLSServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
169+ w .WriteHeader (http .StatusOK )
170+ _ , _ = w .Write ([]byte ("hello from target" ))
171+ }))
172+ defer targetServer .Close ()
173+
174+ certFile , keyFile := generateTestCA (t )
175+ logger := slogtest .Make (t , nil )
176+
177+ // Start the proxy server.
178+ srv , err := aiproxyd .New (t .Context (), logger , aiproxyd.Options {
179+ ListenAddr : "127.0.0.1:8888" ,
180+ CertFile : certFile ,
181+ KeyFile : keyFile ,
182+ })
183+ require .NoError (t , err )
184+ t .Cleanup (func () { _ = srv .Close () })
185+
186+ // Wait for the proxy server to be ready.
187+ require .Eventually (t , func () bool {
188+ conn , err := net .Dial ("tcp" , "127.0.0.1:8888" )
189+ if err != nil {
190+ return false
191+ }
192+ _ = conn .Close ()
193+ return true
194+ }, testutil .WaitShort , testutil .IntervalFast )
195+
196+ // Load the CA certificate so the client trusts the proxy's MITM certificate.
197+ certPEM , err := os .ReadFile (certFile )
198+ require .NoError (t , err )
199+ certPool := x509 .NewCertPool ()
200+ certPool .AppendCertsFromPEM (certPEM )
201+
202+ // Create an HTTP client configured to use the proxy.
203+ proxyURL , err := url .Parse ("http://127.0.0.1:8888" )
204+ require .NoError (t , err )
205+
206+ client := & http.Client {
207+ Transport : & http.Transport {
208+ Proxy : http .ProxyURL (proxyURL ),
209+ TLSClientConfig : & tls.Config {
210+ MinVersion : tls .VersionTLS12 ,
211+ RootCAs : certPool ,
212+ },
213+ },
214+ }
215+
216+ // Make a request through the proxy to the target server.
217+ req , err := http .NewRequestWithContext (t .Context (), http .MethodGet , targetServer .URL , nil )
218+ require .NoError (t , err )
219+ resp , err := client .Do (req )
220+ require .NoError (t , err )
221+ defer resp .Body .Close ()
222+
223+ // Verify the response was successfully proxied.
224+ body , err := io .ReadAll (resp .Body )
225+ require .NoError (t , err )
226+ require .Equal (t , http .StatusOK , resp .StatusCode )
227+ require .Equal (t , "hello from target" , string (body ))
228+ }
0 commit comments