@@ -21,6 +21,7 @@ import (
21
21
"runtime"
22
22
"strconv"
23
23
"strings"
24
+ "sync/atomic"
24
25
"testing"
25
26
"time"
26
27
@@ -240,20 +241,64 @@ func TestServer(t *testing.T) {
240
241
err := root .ExecuteContext (ctx )
241
242
require .Error (t ,err )
242
243
})
243
- t .Run ("TLSNoCertFile " ,func (t * testing.T ) {
244
+ t .Run ("TLSInvalid " ,func (t * testing.T ) {
244
245
t .Parallel ()
245
- ctx ,cancelFunc := context .WithCancel (context .Background ())
246
- defer cancelFunc ()
247
246
248
- root ,_ := clitest .New (t ,
249
- "server" ,
250
- "--in-memory" ,
251
- "--address" ,":0" ,
252
- "--tls-enable" ,
253
- "--cache-dir" ,t .TempDir (),
254
- )
255
- err := root .ExecuteContext (ctx )
256
- require .Error (t ,err )
247
+ cert1Path ,key1Path := generateTLSCertificate (t )
248
+ cert2Path ,key2Path := generateTLSCertificate (t )
249
+
250
+ cases := []struct {
251
+ name string
252
+ args []string
253
+ errContains string
254
+ }{
255
+ {
256
+ name :"NoCertAndKey" ,
257
+ args : []string {"--tls-enable" },
258
+ errContains :"--tls-cert-file is required when tls is enabled" ,
259
+ },
260
+ {
261
+ name :"NoCert" ,
262
+ args : []string {"--tls-enable" ,"--tls-key-file" ,key1Path },
263
+ errContains :"--tls-cert-file and --tls-key-file must be used the same amount of times" ,
264
+ },
265
+ {
266
+ name :"NoKey" ,
267
+ args : []string {"--tls-enable" ,"--tls-cert-file" ,cert1Path },
268
+ errContains :"--tls-cert-file and --tls-key-file must be used the same amount of times" ,
269
+ },
270
+ {
271
+ name :"MismatchedCount" ,
272
+ args : []string {"--tls-enable" ,"--tls-cert-file" ,cert1Path ,"--tls-key-file" ,key1Path ,"--tls-cert-file" ,cert2Path },
273
+ errContains :"--tls-cert-file and --tls-key-file must be used the same amount of times" ,
274
+ },
275
+ {
276
+ name :"MismatchedCertAndKey" ,
277
+ args : []string {"--tls-enable" ,"--tls-cert-file" ,cert1Path ,"--tls-key-file" ,key2Path },
278
+ errContains :"load TLS key pair" ,
279
+ },
280
+ }
281
+
282
+ for _ ,c := range cases {
283
+ c := c
284
+ t .Run (c .name ,func (t * testing.T ) {
285
+ t .Parallel ()
286
+ ctx ,cancelFunc := context .WithCancel (context .Background ())
287
+ defer cancelFunc ()
288
+
289
+ args := []string {
290
+ "server" ,
291
+ "--in-memory" ,
292
+ "--address" ,":0" ,
293
+ "--cache-dir" ,t .TempDir (),
294
+ }
295
+ args = append (args ,c .args ... )
296
+ root ,_ := clitest .New (t ,args ... )
297
+ err := root .ExecuteContext (ctx )
298
+ require .Error (t ,err )
299
+ require .ErrorContains (t ,err ,c .errContains )
300
+ })
301
+ }
257
302
})
258
303
t .Run ("TLSValid" ,func (t * testing.T ) {
259
304
t .Parallel ()
@@ -293,6 +338,86 @@ func TestServer(t *testing.T) {
293
338
cancelFunc ()
294
339
require .NoError (t ,<- errC )
295
340
})
341
+ t .Run ("TLSValidMultiple" ,func (t * testing.T ) {
342
+ t .Parallel ()
343
+ ctx ,cancelFunc := context .WithCancel (context .Background ())
344
+ defer cancelFunc ()
345
+
346
+ cert1Path ,key1Path := generateTLSCertificate (t ,"alpaca.com" )
347
+ cert2Path ,key2Path := generateTLSCertificate (t ,"*.llama.com" )
348
+ root ,cfg := clitest .New (t ,
349
+ "server" ,
350
+ "--in-memory" ,
351
+ "--address" ,":0" ,
352
+ "--tls-enable" ,
353
+ "--tls-cert-file" ,cert1Path ,
354
+ "--tls-key-file" ,key1Path ,
355
+ "--tls-cert-file" ,cert2Path ,
356
+ "--tls-key-file" ,key2Path ,
357
+ "--cache-dir" ,t .TempDir (),
358
+ )
359
+ errC := make (chan error ,1 )
360
+ go func () {
361
+ errC <- root .ExecuteContext (ctx )
362
+ }()
363
+ accessURL := waitAccessURL (t ,cfg )
364
+ require .Equal (t ,"https" ,accessURL .Scheme )
365
+ originalHost := accessURL .Host
366
+
367
+ var (
368
+ expectAddr string
369
+ dials int64
370
+ )
371
+ client := codersdk .New (accessURL )
372
+ client .HTTPClient = & http.Client {
373
+ Transport :& http.Transport {
374
+ DialTLSContext :func (ctx context.Context ,network ,addr string ) (net.Conn ,error ) {
375
+ atomic .AddInt64 (& dials ,1 )
376
+ assert .Equal (t ,expectAddr ,addr )
377
+
378
+ host ,_ ,err := net .SplitHostPort (addr )
379
+ require .NoError (t ,err )
380
+
381
+ // Always connect to the accessURL ip:port regardless of
382
+ // hostname.
383
+ conn ,err := tls .Dial (network ,originalHost ,& tls.Config {
384
+ MinVersion :tls .VersionTLS12 ,
385
+ //nolint:gosec
386
+ InsecureSkipVerify :true ,
387
+ ServerName :host ,
388
+ })
389
+ if err != nil {
390
+ return nil ,err
391
+ }
392
+
393
+ // We can't call conn.VerifyHostname because it requires
394
+ // that the certificates are valid, so we call
395
+ // VerifyHostname on the first certificate instead.
396
+ require .Len (t ,conn .ConnectionState ().PeerCertificates ,1 )
397
+ err = conn .ConnectionState ().PeerCertificates [0 ].VerifyHostname (host )
398
+ assert .NoError (t ,err ,"invalid cert common name" )
399
+ return conn ,nil
400
+ },
401
+ },
402
+ }
403
+
404
+ // Use the first certificate and hostname.
405
+ client .URL .Host = "alpaca.com:443"
406
+ expectAddr = "alpaca.com:443"
407
+ _ ,err := client .HasFirstUser (ctx )
408
+ require .NoError (t ,err )
409
+ require .EqualValues (t ,1 ,atomic .LoadInt64 (& dials ))
410
+
411
+ // Use the second certificate (wildcard) and hostname.
412
+ client .URL .Host = "hi.llama.com:443"
413
+ expectAddr = "hi.llama.com:443"
414
+ _ ,err = client .HasFirstUser (ctx )
415
+ require .NoError (t ,err )
416
+ require .EqualValues (t ,2 ,atomic .LoadInt64 (& dials ))
417
+
418
+ cancelFunc ()
419
+ require .NoError (t ,<- errC )
420
+ })
296
421
// This cannot be ran in parallel because it uses a signal.
297
422
//nolint:paralleltest
298
423
t .Run ("Shutdown" ,func (t * testing.T ) {
@@ -480,16 +605,22 @@ func TestServer(t *testing.T) {
480
605
})
481
606
}
482
607
483
- func generateTLSCertificate (t testing.TB ) (certPath ,keyPath string ) {
608
+ func generateTLSCertificate (t testing.TB , commonName ... string ) (certPath ,keyPath string ) {
484
609
dir := t .TempDir ()
485
610
611
+ commonNameStr := "localhost"
612
+ if len (commonName )> 0 {
613
+ commonNameStr = commonName [0 ]
614
+ }
486
615
privateKey ,err := ecdsa .GenerateKey (elliptic .P256 (),rand .Reader )
487
616
require .NoError (t ,err )
488
617
template := x509.Certificate {
489
618
SerialNumber :big .NewInt (1 ),
490
619
Subject : pkix.Name {
491
620
Organization : []string {"Acme Co" },
621
+ CommonName :commonNameStr ,
492
622
},
623
+ DNSNames : []string {commonNameStr },
493
624
NotBefore :time .Now (),
494
625
NotAfter :time .Now ().Add (time .Hour * 24 * 180 ),
495
626
@@ -498,6 +629,7 @@ func generateTLSCertificate(t testing.TB) (certPath, keyPath string) {
498
629
BasicConstraintsValid :true ,
499
630
IPAddresses : []net.IP {net .ParseIP ("127.0.0.1" )},
500
631
}
632
+
501
633
derBytes ,err := x509 .CreateCertificate (rand .Reader ,& template ,& template ,& privateKey .PublicKey ,privateKey )
502
634
require .NoError (t ,err )
503
635
certFile ,err := os .CreateTemp (dir ,"" )