- Notifications
You must be signed in to change notification settings - Fork16
feat: add configuration options to support mtls#315
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.
Already on GitHub?Sign in to your account
Uh oh!
There was an error while loading.Please reload this page.
Changes fromall commits
File filter
Filter by extension
Conversations
Uh oh!
There was an error while loading.Please reload this page.
Jump to
Uh oh!
There was an error while loading.Please reload this page.
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -29,4 +29,4 @@ gradleVersion=7.4 | ||
# Opt-out flag for bundling Kotlin standard library. | ||
# See https://plugins.jetbrains.com/docs/intellij/kotlin.html#kotlin-standard-library for details. | ||
# suppress inspection "UnusedProperty" | ||
kotlin.stdlib.default.dependency=true | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others.Learn more. It seems like we need the kotlin stdlib, without this I was getting:
I suspect this is required since the kotlin 9.x upgrade? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others.Learn more. Where do you see this error? I tried building with this set to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others.Learn more. I see when running |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -140,7 +140,7 @@ class CoderGatewayConnectionProvider : GatewayConnectionProvider { | ||||||
if (token == null) { // User aborted. | ||||||
throw IllegalArgumentException("Unable to connect to $deploymentURL, $TOKEN is missing") | ||||||
} | ||||||
val client = CoderRestClient(deploymentURL, token.first,null, settings) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others.Learn more. Suggested change
| ||||||
return try { | ||||||
Pair(client, client.me().username) | ||||||
} catch (ex: AuthenticationResponseException) { | ||||||
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -14,21 +14,48 @@ import com.coder.gateway.sdk.v2.models.Workspace | ||
import com.coder.gateway.sdk.v2.models.WorkspaceBuild | ||
import com.coder.gateway.sdk.v2.models.WorkspaceTransition | ||
import com.coder.gateway.sdk.v2.models.toAgentModels | ||
import com.coder.gateway.services.CoderSettingsState | ||
import com.google.gson.Gson | ||
import com.google.gson.GsonBuilder | ||
import com.intellij.ide.plugins.PluginManagerCore | ||
import com.intellij.openapi.components.Service | ||
import com.intellij.openapi.extensions.PluginId | ||
import com.intellij.openapi.util.SystemInfo | ||
import okhttp3.OkHttpClient | ||
import okhttp3.internal.tls.OkHostnameVerifier | ||
import okhttp3.logging.HttpLoggingInterceptor | ||
import org.zeroturnaround.exec.ProcessExecutor | ||
import retrofit2.Retrofit | ||
import retrofit2.converter.gson.GsonConverterFactory | ||
import java.io.File | ||
import java.io.FileInputStream | ||
import java.net.HttpURLConnection.HTTP_CREATED | ||
import java.net.InetAddress | ||
import java.net.Socket | ||
import java.net.URL | ||
import java.nio.file.Path | ||
import java.security.KeyFactory | ||
import java.security.KeyStore | ||
import java.security.PrivateKey | ||
import java.security.cert.CertificateException | ||
import java.security.cert.CertificateFactory | ||
import java.security.cert.X509Certificate | ||
import java.security.spec.InvalidKeySpecException | ||
import java.security.spec.PKCS8EncodedKeySpec | ||
import java.time.Instant | ||
import java.util.Base64 | ||
import java.util.Locale | ||
import java.util.UUID | ||
import javax.net.ssl.HostnameVerifier | ||
import javax.net.ssl.KeyManagerFactory | ||
import javax.net.ssl.SNIHostName | ||
import javax.net.ssl.SSLContext | ||
import javax.net.ssl.SSLSession | ||
import javax.net.ssl.SSLSocket | ||
import javax.net.ssl.SSLSocketFactory | ||
import javax.net.ssl.TrustManagerFactory | ||
import javax.net.ssl.TrustManager | ||
import javax.net.ssl.X509TrustManager | ||
@Service(Service.Level.APP) | ||
class CoderRestClientService { | ||
@@ -44,18 +71,19 @@ class CoderRestClientService { | ||
* | ||
* @throws [AuthenticationResponseException] if authentication failed. | ||
*/ | ||
fun initClientSession(url: URL, token: String,settings: CoderSettingsState): User { | ||
client = CoderRestClient(url, token,null, settings) | ||
me = client.me() | ||
buildVersion = client.buildInfo().version | ||
isReady = true | ||
return me | ||
} | ||
} | ||
class CoderRestClient( | ||
var url: URL,vartoken: String, | ||
private var pluginVersion: String?, | ||
private var settings: CoderSettingsState, | ||
) { | ||
private var httpClient: OkHttpClient | ||
private var retroRestClient: CoderV2RestFacade | ||
@@ -66,12 +94,16 @@ class CoderRestClient(var url: URL, var token: String, | ||
pluginVersion = PluginManagerCore.getPlugin(PluginId.getId("com.coder.gateway"))!!.version // this is the id from the plugin.xml | ||
} | ||
val socketFactory = coderSocketFactory(settings) | ||
val trustManagers = coderTrustManagers(settings.tlsCAPath) | ||
httpClient = OkHttpClient.Builder() | ||
.sslSocketFactory(socketFactory, trustManagers[0] as X509TrustManager) | ||
.hostnameVerifier(CoderHostnameVerifier(settings.tlsAlternateHostname)) | ||
.addInterceptor { it.proceed(it.request().newBuilder().addHeader("Coder-Session-Token", token).build()) } | ||
.addInterceptor { it.proceed(it.request().newBuilder().addHeader("User-Agent", "Coder Gateway/${pluginVersion} (${SystemInfo.getOsNameAndVersion()}; ${SystemInfo.OS_ARCH})").build()) } | ||
.addInterceptor { | ||
var request = it.request() | ||
val headers = getHeaders(url,settings.headerCommand) | ||
if (headers.size > 0) { | ||
val builder = request.newBuilder() | ||
headers.forEach { h -> builder.addHeader(h.key, h.value) } | ||
@@ -218,3 +250,203 @@ class CoderRestClient(var url: URL, var token: String, | ||
} | ||
} | ||
} | ||
fun coderSocketFactory(settings: CoderSettingsState) : SSLSocketFactory { | ||
if (settings.tlsCertPath.isBlank() || settings.tlsKeyPath.isBlank()) { | ||
return SSLSocketFactory.getDefault() as SSLSocketFactory | ||
} | ||
val certificateFactory = CertificateFactory.getInstance("X.509") | ||
val certInputStream = FileInputStream(expandPath(settings.tlsCertPath)) | ||
val certChain = certificateFactory.generateCertificates(certInputStream) | ||
certInputStream.close() | ||
// ideally we would use something like PemReader from BouncyCastle, but | ||
// BC is used by the IDE. This makes using BC very impractical since | ||
// type casting will mismatch due to the different class loaders. | ||
val privateKeyPem = File(expandPath(settings.tlsKeyPath)).readText() | ||
val start: Int = privateKeyPem.indexOf("-----BEGIN PRIVATE KEY-----") | ||
val end: Int = privateKeyPem.indexOf("-----END PRIVATE KEY-----", start) | ||
val pemBytes: ByteArray = Base64.getDecoder().decode( | ||
privateKeyPem.substring(start + "-----BEGIN PRIVATE KEY-----".length, end) | ||
.replace("\\s+".toRegex(), "") | ||
) | ||
var privateKey : PrivateKey | ||
try { | ||
Member
| ||
val kf = KeyFactory.getInstance("RSA") | ||
val keySpec = PKCS8EncodedKeySpec(pemBytes) | ||
privateKey = kf.generatePrivate(keySpec) | ||
} catch (e: InvalidKeySpecException) { | ||
val kf = KeyFactory.getInstance("EC") | ||
val keySpec = PKCS8EncodedKeySpec(pemBytes) | ||
privateKey = kf.generatePrivate(keySpec) | ||
} | ||
val keyStore = KeyStore.getInstance(KeyStore.getDefaultType()) | ||
keyStore.load(null) | ||
certChain.withIndex().forEach { | ||
keyStore.setCertificateEntry("cert${it.index}", it.value as X509Certificate) | ||
} | ||
keyStore.setKeyEntry("key", privateKey, null, certChain.toTypedArray()) | ||
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) | ||
keyManagerFactory.init(keyStore, null) | ||
val sslContext = SSLContext.getInstance("TLS") | ||
val trustManagers = coderTrustManagers(settings.tlsCAPath) | ||
Member
| ||
sslContext.init(keyManagerFactory.keyManagers, trustManagers, null) | ||
if (settings.tlsAlternateHostname.isBlank()) { | ||
return sslContext.socketFactory | ||
} | ||
return AlternateNameSSLSocketFactory(sslContext.socketFactory, settings.tlsAlternateHostname) | ||
} | ||
fun coderTrustManagers(tlsCAPath: String) : Array<TrustManager> { | ||
val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) | ||
if (tlsCAPath.isBlank()) { | ||
// return default trust managers | ||
trustManagerFactory.init(null as KeyStore?) | ||
return trustManagerFactory.trustManagers | ||
} | ||
val certificateFactory = CertificateFactory.getInstance("X.509") | ||
val caInputStream = FileInputStream(expandPath(tlsCAPath)) | ||
val certChain = certificateFactory.generateCertificates(caInputStream) | ||
val truststore = KeyStore.getInstance(KeyStore.getDefaultType()) | ||
truststore.load(null) | ||
certChain.withIndex().forEach { | ||
truststore.setCertificateEntry("cert${it.index}", it.value as X509Certificate) | ||
} | ||
trustManagerFactory.init(truststore) | ||
return trustManagerFactory.trustManagers.map { MergedSystemTrustManger(it as X509TrustManager) }.toTypedArray() | ||
} | ||
fun expandPath(path: String): String { | ||
if (path.startsWith("~/")) { | ||
return Path.of(System.getProperty("user.home"), path.substring(1)).toString() | ||
} | ||
if (path.startsWith("\$HOME/")) { | ||
return Path.of(System.getProperty("user.home"), path.substring(5)).toString() | ||
} | ||
if (path.startsWith("\${user.home}/")) { | ||
return Path.of(System.getProperty("user.home"), path.substring(12)).toString() | ||
} | ||
return path | ||
} | ||
class AlternateNameSSLSocketFactory(private val delegate: SSLSocketFactory, private val alternateName: String) : SSLSocketFactory() { | ||
override fun getDefaultCipherSuites(): Array<String> { | ||
return delegate.defaultCipherSuites | ||
} | ||
override fun getSupportedCipherSuites(): Array<String> { | ||
return delegate.supportedCipherSuites | ||
} | ||
override fun createSocket(): Socket { | ||
val socket = delegate.createSocket() as SSLSocket | ||
customizeSocket(socket) | ||
return socket | ||
} | ||
override fun createSocket(host: String?, port: Int): Socket { | ||
val socket = delegate.createSocket(host, port) as SSLSocket | ||
customizeSocket(socket) | ||
return socket | ||
} | ||
override fun createSocket(host: String?, port: Int, localHost: InetAddress?, localPort: Int): Socket { | ||
val socket = delegate.createSocket(host, port, localHost, localPort) as SSLSocket | ||
customizeSocket(socket) | ||
return socket | ||
} | ||
override fun createSocket(host: InetAddress?, port: Int): Socket { | ||
val socket = delegate.createSocket(host, port) as SSLSocket | ||
customizeSocket(socket) | ||
return socket | ||
} | ||
override fun createSocket(address: InetAddress?, port: Int, localAddress: InetAddress?, localPort: Int): Socket { | ||
val socket = delegate.createSocket(address, port, localAddress, localPort) as SSLSocket | ||
customizeSocket(socket) | ||
return socket | ||
} | ||
override fun createSocket(s: Socket?, host: String?, port: Int, autoClose: Boolean): Socket { | ||
val socket = delegate.createSocket(s, host, port, autoClose) as SSLSocket | ||
customizeSocket(socket) | ||
return socket | ||
} | ||
private fun customizeSocket(socket: SSLSocket) { | ||
val params = socket.sslParameters | ||
params.serverNames = listOf(SNIHostName(alternateName)) | ||
socket.sslParameters = params | ||
} | ||
} | ||
class CoderHostnameVerifier(private val alternateName: String) : HostnameVerifier { | ||
override fun verify(host: String, session: SSLSession): Boolean { | ||
if (alternateName.isEmpty()) { | ||
println("using default hostname verifier, alternateName is empty") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others.Learn more. We should copy the logger pattern used elsewhere instead of using | ||
return OkHostnameVerifier.verify(host, session) | ||
} | ||
println("Looking for alternate hostname: $alternateName") | ||
val certs = session.peerCertificates ?: return false | ||
for (cert in certs) { | ||
if (cert !is X509Certificate) { | ||
continue | ||
} | ||
val entries = cert.subjectAlternativeNames ?: continue | ||
for (entry in entries) { | ||
val kind = entry[0] as Int | ||
if (kind != 2) { // DNS Name | ||
continue | ||
} | ||
val hostname = entry[1] as String | ||
println("Found cert hostname: $hostname") | ||
if (hostname.lowercase(Locale.getDefault()) == alternateName) { | ||
return true | ||
} | ||
} | ||
} | ||
println("No matching hostname found") | ||
return false | ||
} | ||
} | ||
class MergedSystemTrustManger(private val otherTrustManager: X509TrustManager) : X509TrustManager { | ||
private val systemTrustManager : X509TrustManager | ||
init { | ||
val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) | ||
trustManagerFactory.init(null as KeyStore?) | ||
systemTrustManager = trustManagerFactory.trustManagers.first { it is X509TrustManager } as X509TrustManager | ||
} | ||
override fun checkClientTrusted(chain: Array<out X509Certificate>, authType: String?) { | ||
try { | ||
otherTrustManager.checkClientTrusted(chain, authType) | ||
} catch (e: CertificateException) { | ||
systemTrustManager.checkClientTrusted(chain, authType) | ||
} | ||
} | ||
override fun checkServerTrusted(chain: Array<out X509Certificate>, authType: String?) { | ||
try { | ||
otherTrustManager.checkServerTrusted(chain, authType) | ||
} catch (e: CertificateException) { | ||
systemTrustManager.checkServerTrusted(chain, authType) | ||
} | ||
} | ||
override fun getAcceptedIssuers(): Array<X509Certificate> { | ||
return otherTrustManager.acceptedIssuers + systemTrustManager.acceptedIssuers | ||
} | ||
} |
Uh oh!
There was an error while loading.Please reload this page.