@@ -14,21 +14,47 @@ import com.coder.gateway.sdk.v2.models.Workspace
14
14
import com.coder.gateway.sdk.v2.models.WorkspaceBuild
15
15
import com.coder.gateway.sdk.v2.models.WorkspaceTransition
16
16
import com.coder.gateway.sdk.v2.models.toAgentModels
17
+ import com.coder.gateway.services.CoderSettingsState
17
18
import com.google.gson.Gson
18
19
import com.google.gson.GsonBuilder
19
20
import com.intellij.ide.plugins.PluginManagerCore
20
21
import com.intellij.openapi.components.Service
22
+ import com.intellij.openapi.components.service
21
23
import com.intellij.openapi.extensions.PluginId
22
24
import com.intellij.openapi.util.SystemInfo
23
25
import okhttp3.OkHttpClient
26
+ import okhttp3.internal.tls.OkHostnameVerifier
24
27
import okhttp3.logging.HttpLoggingInterceptor
25
28
import org.zeroturnaround.exec.ProcessExecutor
26
29
import retrofit2.Retrofit
27
30
import retrofit2.converter.gson.GsonConverterFactory
31
+ import java.io.File
32
+ import java.io.FileInputStream
28
33
import java.net.HttpURLConnection.HTTP_CREATED
34
+ import java.net.InetAddress
35
+ import java.net.Socket
29
36
import java.net.URL
37
+ import java.security.KeyFactory
38
+ import java.security.KeyStore
39
+ import java.security.PrivateKey
40
+ import java.security.cert.CertificateFactory
41
+ import java.security.cert.X509Certificate
42
+ import java.security.spec.InvalidKeySpecException
43
+ import java.security.spec.PKCS8EncodedKeySpec
30
44
import java.time.Instant
45
+ import java.util.Base64
46
+ import java.util.Locale
31
47
import java.util.UUID
48
+ import javax.net.ssl.HostnameVerifier
49
+ import javax.net.ssl.KeyManagerFactory
50
+ import javax.net.ssl.SNIHostName
51
+ import javax.net.ssl.SSLContext
52
+ import javax.net.ssl.SSLSession
53
+ import javax.net.ssl.SSLSocket
54
+ import javax.net.ssl.SSLSocketFactory
55
+ import javax.net.ssl.TrustManagerFactory
56
+ import javax.net.ssl.TrustManager
57
+ import javax.net.ssl.X509TrustManager
32
58
33
59
@Service(Service .Level .APP )
34
60
class CoderRestClientService {
@@ -66,7 +92,11 @@ class CoderRestClient(var url: URL, var token: String,
66
92
pluginVersion= PluginManagerCore .getPlugin(PluginId .getId(" com.coder.gateway" ))!! .version// this is the id from the plugin.xml
67
93
}
68
94
95
+ val socketFactory= coderSocketFactory()
96
+ val trustManagers= coderTrustManagers()
69
97
httpClient= OkHttpClient .Builder ()
98
+ .sslSocketFactory(socketFactory, trustManagers[0 ]as X509TrustManager )
99
+ .hostnameVerifier(CoderHostnameVerifier ())
70
100
.addInterceptor { it.proceed(it.request().newBuilder().addHeader(" Coder-Session-Token" , token).build()) }
71
101
.addInterceptor { it.proceed(it.request().newBuilder().addHeader(" User-Agent" ," Coder Gateway/${pluginVersion} (${SystemInfo .getOsNameAndVersion()} ;${SystemInfo .OS_ARCH } )" ).build()) }
72
102
.addInterceptor {
@@ -218,3 +248,168 @@ class CoderRestClient(var url: URL, var token: String,
218
248
}
219
249
}
220
250
}
251
+
252
+ fun coderSocketFactory () :SSLSocketFactory {
253
+ val state: CoderSettingsState = service()
254
+
255
+ if (state.tlsCertPath.isBlank()|| state.tlsKeyPath.isBlank()) {
256
+ return SSLSocketFactory .getDefault()as SSLSocketFactory
257
+ }
258
+
259
+ val certificateFactory= CertificateFactory .getInstance(" X.509" )
260
+ val certInputStream= FileInputStream (state.tlsCertPath)
261
+ val certChain= certificateFactory.generateCertificates(certInputStream)
262
+ certInputStream.close()
263
+
264
+ // ideally we would use something like PemReader from BouncyCastle, but
265
+ // BC is used by the IDE. This makes using BC very impractical since
266
+ // type casting will mismatch due to the different class loaders.
267
+ val privateKeyPem= File (state.tlsKeyPath).readText()
268
+ val start: Int = privateKeyPem.indexOf(" -----BEGIN PRIVATE KEY-----" )
269
+ val end: Int = privateKeyPem.indexOf(" -----END PRIVATE KEY-----" , start)
270
+ val pemBytes: ByteArray = Base64 .getDecoder().decode(
271
+ privateKeyPem.substring(start+ " -----BEGIN PRIVATE KEY-----" .length, end)
272
+ .replace(" \\ s+" .toRegex()," " )
273
+ )
274
+
275
+ var privateKey: PrivateKey
276
+ try {
277
+ val kf= KeyFactory .getInstance(" RSA" )
278
+ val keySpec= PKCS8EncodedKeySpec (pemBytes)
279
+ privateKey= kf.generatePrivate(keySpec)
280
+ }catch (e: InvalidKeySpecException ) {
281
+ val kf= KeyFactory .getInstance(" EC" )
282
+ val keySpec= PKCS8EncodedKeySpec (pemBytes)
283
+ privateKey= kf.generatePrivate(keySpec)
284
+ }
285
+
286
+ val keyStore= KeyStore .getInstance(KeyStore .getDefaultType())
287
+ keyStore.load(null )
288
+ certChain.withIndex().forEach {
289
+ keyStore.setCertificateEntry(" cert${it.index} " , it.valueas X509Certificate )
290
+ }
291
+ keyStore.setKeyEntry(" key" , privateKey,null , certChain.toTypedArray())
292
+
293
+ val keyManagerFactory= KeyManagerFactory .getInstance(KeyManagerFactory .getDefaultAlgorithm())
294
+ keyManagerFactory.init (keyStore,null )
295
+
296
+ val sslContext= SSLContext .getInstance(" TLS" )
297
+
298
+ val trustManagers= coderTrustManagers()
299
+ sslContext.init (keyManagerFactory.keyManagers, trustManagers,null )
300
+
301
+ if (state.tlsAlternateHostname.isBlank()) {
302
+ return sslContext.socketFactory
303
+ }
304
+
305
+ return AlternateNameSSLSocketFactory (sslContext.socketFactory, state.tlsAlternateHostname)
306
+ }
307
+
308
+ fun coderTrustManagers () :Array <TrustManager > {
309
+ val state: CoderSettingsState = service()
310
+
311
+ val trustManagerFactory= TrustManagerFactory .getInstance(TrustManagerFactory .getDefaultAlgorithm())
312
+ if (state.tlsCAPath.isBlank()) {
313
+ // return default trust managers
314
+ trustManagerFactory.init (null as KeyStore ? )
315
+ return trustManagerFactory.trustManagers
316
+ }
317
+
318
+
319
+ val certificateFactory= CertificateFactory .getInstance(" X.509" )
320
+ val caInputStream= FileInputStream (state.tlsCAPath)
321
+ val certChain= certificateFactory.generateCertificates(caInputStream)
322
+
323
+ val truststore= KeyStore .getInstance(KeyStore .getDefaultType())
324
+ truststore.load(null )
325
+ certChain.withIndex().forEach {
326
+ truststore.setCertificateEntry(" cert${it.index} " , it.valueas X509Certificate )
327
+ }
328
+ trustManagerFactory.init (truststore)
329
+ return trustManagerFactory.trustManagers
330
+ }
331
+
332
+ class AlternateNameSSLSocketFactory (private val delegate : SSLSocketFactory , privateval alternateName : String ) : SSLSocketFactory() {
333
+ override fun getDefaultCipherSuites ():Array <String > {
334
+ return delegate.defaultCipherSuites
335
+ }
336
+
337
+ override fun getSupportedCipherSuites ():Array <String > {
338
+ return delegate.supportedCipherSuites
339
+ }
340
+
341
+ override fun createSocket ():Socket {
342
+ val socket= delegate.createSocket()as SSLSocket
343
+ customizeSocket(socket)
344
+ return socket
345
+ }
346
+
347
+ override fun createSocket (host : String? ,port : Int ):Socket {
348
+ val socket= delegate.createSocket(host, port)as SSLSocket
349
+ customizeSocket(socket)
350
+ return socket
351
+ }
352
+
353
+ override fun createSocket (host : String? ,port : Int ,localHost : InetAddress ? ,localPort : Int ):Socket {
354
+ val socket= delegate.createSocket(host, port, localHost, localPort)as SSLSocket
355
+ customizeSocket(socket)
356
+ return socket
357
+ }
358
+
359
+ override fun createSocket (host : InetAddress ? ,port : Int ):Socket {
360
+ val socket= delegate.createSocket(host, port)as SSLSocket
361
+ customizeSocket(socket)
362
+ return socket
363
+ }
364
+
365
+ override fun createSocket (address : InetAddress ? ,port : Int ,localAddress : InetAddress ? ,localPort : Int ):Socket {
366
+ val socket= delegate.createSocket(address, port, localAddress, localPort)as SSLSocket
367
+ customizeSocket(socket)
368
+ return socket
369
+ }
370
+
371
+ override fun createSocket (s : Socket ? ,host : String? ,port : Int ,autoClose : Boolean ):Socket {
372
+ val socket= delegate.createSocket(s, host, port, autoClose)as SSLSocket
373
+ customizeSocket(socket)
374
+ return socket
375
+ }
376
+
377
+ private fun customizeSocket (socket : SSLSocket ) {
378
+ val params= socket.sslParameters
379
+ params.serverNames= listOf (SNIHostName (alternateName))
380
+ socket.sslParameters= params
381
+ }
382
+ }
383
+
384
+ class CoderHostnameVerifier () : HostnameVerifier {
385
+ private val alternateName: String
386
+
387
+ init {
388
+ val state: CoderSettingsState = service()
389
+ this .alternateName= state.tlsAlternateHostname.lowercase(Locale .getDefault())
390
+ }
391
+
392
+ override fun verify (host : String ,session : SSLSession ):Boolean {
393
+ if (alternateName.isEmpty()) {
394
+ return OkHostnameVerifier .verify(host, session)
395
+ }
396
+ val certs= session.peerCertificates? : return false
397
+ for (certin certs) {
398
+ if (cert!is X509Certificate ) {
399
+ continue
400
+ }
401
+ val entries= cert.subjectAlternativeNames? : continue
402
+ for (entryin entries) {
403
+ val kind= entry[0 ]as Int
404
+ if (kind!= 2 ) {// DNS Name
405
+ continue
406
+ }
407
+ val hostname= entry[1 ]as String
408
+ if (hostname.lowercase(Locale .getDefault())== alternateName) {
409
+ return true
410
+ }
411
+ }
412
+ }
413
+ return false
414
+ }
415
+ }