Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions okhttp-coroutines/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ plugins {
id("org.jetbrains.dokka")
id("com.vanniktech.maven.publish.base")
id("binary-compatibility-validator")
id("app.cash.burst")
}

project.applyOsgi(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
package okhttp3.coroutines

import app.cash.burst.Burst
import assertk.assertThat
import assertk.assertions.isLessThan
import java.io.IOException
import java.net.InetAddress
import java.net.InetSocketAddress
import java.net.Proxy
import java.net.ServerSocket
import java.time.Duration
import java.time.Instant
import java.util.concurrent.Executors
import kotlinx.coroutines.runBlocking
import okhttp3.Call
import okhttp3.Dns
import okhttp3.EventListener
import okhttp3.OkHttpClient
import okhttp3.Request
import org.junit.jupiter.api.AfterEach
import org.junit.jupiter.api.Test

@Burst
class ReproduceOkHttpIssueTest {
val executorService = Executors.newFixedThreadPool(2)
val stallingServer = StallingServer()

// 1. Define Hostname
val proxyHost = "unresponsive-proxy-host"

// 2. Set Timeouts
val callTimeout = Duration.ofSeconds(3)
val connectTimeout = Duration.ofSeconds(1)
val readTimeout = Duration.ofSeconds(2)

// 3. Build the Client with the Custom Dns
var client = OkHttpClient.Builder()
.callTimeout(callTimeout)
.connectTimeout(connectTimeout)
.readTimeout(readTimeout)
.dns(Dns { hostname: String? ->
if (hostname == proxyHost)
listOf(InetAddress.getByName("127.0.0.1"), InetAddress.getByName("127.0.0.1"))
else
Dns.Companion.SYSTEM.lookup(hostname!!)
})
.eventListener(object : EventListener() {
override fun connectStart(call: Call, inetSocketAddress: InetSocketAddress, proxy: Proxy) {
println(
"connect start - %s - inetSocketAddress: %s, proxy: %s".format(
call.request().url,
inetSocketAddress,
proxy
)
)
}
})
.proxy(Proxy(Proxy.Type.HTTP, InetSocketAddress.createUnresolved(proxyHost, 8080)))
.build()

// 4. Test the Call
val request = Request.Builder().url("https://github.com/").build() // Any valid target URL which is https

@AfterEach
fun close() {
executorService.shutdownNow()
stallingServer.stop()
}

@Test
@Throws(InterruptedException::class)
fun test(execute: Boolean = true, increasedReadTimeout: Boolean = false) {
if (increasedReadTimeout) {
client = client.newBuilder()
.readTimeout(readTimeout.multipliedBy(2))
.build()
}

executorService.submit(Runnable { stallingServer.start(8080) })
Thread.sleep(2000)

val startTime = Instant.now()

try {
if (execute) {
client.newCall(request).execute().use { response ->
println("Call Succeeded unexpectedly.")
}
} else {
runBlocking { client.newCall(request).executeAsync() }.use { response ->
println("Call Succeeded unexpectedly.")
}
}
} catch (e: Exception) {
val totalTime = Duration.between(startTime, Instant.now())
println("--- TEST RESULT ---")
println("Exception: " + e.javaClass.getName() + ": " + e.message)
println("Total Time: $totalTime")
println("Expected Time (Call Timeout): $callTimeout")
println("Observed Time (2 x Read Timeout): " + readTimeout.multipliedBy(2))
assertThat(totalTime).isLessThan(readTimeout.multipliedBy(2))
}
}

class StallingServer {
private lateinit var serverSocket: ServerSocket

fun start(port: Int) {
try {
serverSocket = ServerSocket(port)
serverSocket.use { serverSocket ->
println("Java server listening on port $port")
while (true) {
val clientSocket = serverSocket.accept()
println("Connection accepted from " + clientSocket.getInetAddress())

Thread(Runnable {
try {
clientSocket.use { sock ->
// 1. Read the client's initial request (e.g., "CONNECT google.com:443 HTTP/1.1")
val `in` = sock.getInputStream()
val buffer = ByteArray(1024)
val bytesRead = `in`.read(buffer)
println("Received $bytesRead bytes. Now stalling...")

// 2. Respond with "200 OK" to open the tunnel (Critical step for TLS simulation)
val successResponse = "HTTP/1.1 200 Connection established\r\n\r\n"
sock.getOutputStream().write(successResponse.toByteArray())
sock.getOutputStream().flush()
println("Tunnel established on port $port. Now stalling (TLS Handshake)...")

// 3. Block this thread indefinitely.
// The OkHttp client will now send the TLS ClientHello and hit the Read Timeout (10s) waiting for the ServerHello.
Thread.sleep(Long.Companion.MAX_VALUE)
}
} catch (e: Exception) {
System.err.println("Connection handling error: " + e.message)
}
}).start()
}
}
} catch (e: IOException) {
throw RuntimeException(e)
}
}

fun stop() {
serverSocket.close()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import java.io.IOException
import java.io.InterruptedIOException
import java.lang.ref.WeakReference
import java.net.Socket
import java.time.LocalTime
import java.util.concurrent.CopyOnWriteArrayList
import java.util.concurrent.ExecutorService
import java.util.concurrent.RejectedExecutionException
Expand All @@ -28,6 +29,8 @@ import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.atomic.AtomicReference
import java.util.concurrent.atomic.AtomicReferenceFieldUpdater
import kotlin.reflect.KClass
import kotlin.time.Duration.Companion.milliseconds
import kotlin.time.Duration.Companion.nanoseconds
import okhttp3.Call
import okhttp3.Callback
import okhttp3.EventListener
Expand Down Expand Up @@ -76,11 +79,19 @@ class RealCall(

private val timeout =
object : AsyncTimeout() {

override fun timedOut() {
Exception("Timeout").printStackTrace()
println("RealCall: Timed out ${LocalTime.now()}")
this@RealCall.cancel()
}

override fun toString(): String {
return "RealCall.timeout ${this.timeoutNanos().nanoseconds.inWholeSeconds}"
}
}.apply {
timeout(client.callTimeoutMillis.toLong(), MILLISECONDS)
println("RealCall: new timeout $this ${LocalTime.now()} " + client.callTimeoutMillis.milliseconds.inWholeSeconds)
}

private val executed = AtomicBoolean()
Expand Down Expand Up @@ -180,6 +191,7 @@ class RealCall(
override fun execute(): Response {
check(executed.compareAndSet(false, true)) { "Already Executed" }

println("RealCall: execute ${LocalTime.now()}")
timeout.enter()
callStart()
try {
Expand Down Expand Up @@ -465,6 +477,7 @@ class RealCall(

private fun timeoutExit(cause: IOException?): IOException? {
if (timeoutEarlyExit) return cause
println("RealCall: timeoutExit ${LocalTime.now()}")
if (!timeout.exit()) return cause

val e = InterruptedIOException("timeout")
Expand All @@ -479,6 +492,7 @@ class RealCall(
fun timeoutEarlyExit() {
check(!timeoutEarlyExit)
timeoutEarlyExit = true
println("RealCall: exit ${LocalTime.now()}")
timeout.exit()
}

Expand Down Expand Up @@ -578,6 +592,7 @@ class RealCall(
override fun run() {
threadName("OkHttp ${redactedUrl()}") {
var signalledCallback = false
println("RealCall: timeout enter ${LocalTime.now()}")
timeout.enter()
try {
val response = getResponseWithInterceptorChain()
Expand Down
Loading