Skip to content

Commit

Permalink
Don't throw in awaitSignal if the timeout is canceled (#1394)
Browse files Browse the repository at this point in the history
  • Loading branch information
squarejesse authored Dec 15, 2023
1 parent 9ece752 commit bf29a91
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 19 deletions.
11 changes: 10 additions & 1 deletion okio/src/jvmMain/kotlin/okio/Timeout.kt
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,19 @@ actual open class Timeout {

if (waitNanos <= 0) throw InterruptedIOException("timeout")

val cancelMarkBefore = cancelMark

// Attempt to wait that long. This will return early if the monitor is notified.
val nanosRemaining = condition.awaitNanos(waitNanos)

if (nanosRemaining <= 0) throw InterruptedIOException("timeout")
// If there's time remaining, we probably got the call we were waiting for.
if (nanosRemaining > 0) return

// Return without throwing if this timeout was canceled while we were waiting. Note that this
// return is a 'spurious wakeup' because Condition.signal() was not called.
if (cancelMark !== cancelMarkBefore) return

throw InterruptedIOException("timeout")
} catch (e: InterruptedException) {
Thread.currentThread().interrupt() // Retain interrupted status.
throw InterruptedIOException("interrupted")
Expand Down
93 changes: 75 additions & 18 deletions okio/src/jvmTest/kotlin/okio/AwaitSignalTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ import java.util.concurrent.locks.Condition
import java.util.concurrent.locks.ReentrantLock
import okio.TestUtil.assumeNotWindows
import org.junit.After
import org.junit.Assert
import org.junit.Assert.assertEquals
import org.junit.Assert.assertTrue
import org.junit.Assert.fail
import org.junit.Test

class AwaitSignalTest {
Expand Down Expand Up @@ -57,9 +59,9 @@ class AwaitSignalTest {
val start = now()
try {
timeout.awaitSignal(condition)
Assert.fail()
fail()
} catch (expected: InterruptedIOException) {
Assert.assertEquals("timeout", expected.message)
assertEquals("timeout", expected.message)
}
assertElapsed(1000.0, start)
}
Expand All @@ -72,9 +74,9 @@ class AwaitSignalTest {
val start = now()
try {
timeout.awaitSignal(condition)
Assert.fail()
fail()
} catch (expected: InterruptedIOException) {
Assert.assertEquals("timeout", expected.message)
assertEquals("timeout", expected.message)
}
assertElapsed(1000.0, start)
}
Expand All @@ -88,9 +90,9 @@ class AwaitSignalTest {
val start = now()
try {
timeout.awaitSignal(condition)
Assert.fail()
fail()
} catch (expected: InterruptedIOException) {
Assert.assertEquals("timeout", expected.message)
assertEquals("timeout", expected.message)
}
assertElapsed(1000.0, start)
}
Expand All @@ -104,9 +106,9 @@ class AwaitSignalTest {
val start = now()
try {
timeout.awaitSignal(condition)
Assert.fail()
fail()
} catch (expected: InterruptedIOException) {
Assert.assertEquals("timeout", expected.message)
assertEquals("timeout", expected.message)
}
assertElapsed(1000.0, start)
}
Expand All @@ -119,9 +121,9 @@ class AwaitSignalTest {
val start = now()
try {
timeout.awaitSignal(condition)
Assert.fail()
fail()
} catch (expected: InterruptedIOException) {
Assert.assertEquals("timeout", expected.message)
assertEquals("timeout", expected.message)
}
assertElapsed(0.0, start)
}
Expand All @@ -134,10 +136,10 @@ class AwaitSignalTest {
Thread.currentThread().interrupt()
try {
timeout.awaitSignal(condition)
Assert.fail()
fail()
} catch (expected: InterruptedIOException) {
Assert.assertEquals("interrupted", expected.message)
Assert.assertTrue(Thread.interrupted())
assertEquals("interrupted", expected.message)
assertTrue(Thread.interrupted())
}
assertElapsed(0.0, start)
}
Expand All @@ -149,13 +151,60 @@ class AwaitSignalTest {
Thread.currentThread().interrupt()
try {
timeout.throwIfReached()
Assert.fail()
fail()
} catch (expected: InterruptedIOException) {
Assert.assertEquals("interrupted", expected.message)
Assert.assertTrue(Thread.interrupted())
assertEquals("interrupted", expected.message)
assertTrue(Thread.interrupted())
}
}

@Test
fun cancelBeforeWaitDoesNothing() {
val timeout = Timeout()
timeout.timeout(1000, TimeUnit.MILLISECONDS)
timeout.cancel()
val start = now()
try {
lock.withLock {
timeout.awaitSignal(condition)
}
fail()
} catch (expected: InterruptedIOException) {
assertEquals("timeout", expected.message)
}
assertElapsed(1000.0, start)
}

@Test
fun canceledTimeoutDoesNotThrowWhenNotNotifiedOnTime() {
assumeNotWindows()
val timeout = Timeout()
timeout.timeout(1000, TimeUnit.MILLISECONDS)
timeout.cancelLater(500)

val start = now()
lock.withLock {
timeout.awaitSignal(condition) // Returns early but doesn't throw.
}
assertElapsed(1000.0, start)
}

@Test
@Synchronized
fun multipleCancelsAreIdempotent() {
val timeout = Timeout()
timeout.timeout(1000, TimeUnit.MILLISECONDS)
timeout.cancelLater(250)
timeout.cancelLater(500)
timeout.cancelLater(750)

val start = now()
lock.withLock {
timeout.awaitSignal(condition) // Returns early but doesn't throw.
}
assertElapsed(1000.0, start)
}

/** Returns the nanotime in milliseconds as a double for measuring timeouts. */
private fun now(): Double {
return System.nanoTime() / 1000000.0
Expand All @@ -166,6 +215,14 @@ class AwaitSignalTest {
* -50..+450 milliseconds.
*/
private fun assertElapsed(duration: Double, start: Double) {
Assert.assertEquals(duration, now() - start - 200.0, 250.0)
assertEquals(duration, now() - start - 200.0, 250.0)
}

private fun Timeout.cancelLater(delay: Long) {
executorService.schedule(
{ cancel() },
delay,
TimeUnit.MILLISECONDS,
)
}
}

0 comments on commit bf29a91

Please sign in to comment.