Skip to content

Commit

Permalink
prevent arithmetic overflow and reduce allocations in time conversions
Browse files Browse the repository at this point in the history
  • Loading branch information
magicprinc authored and Ladicek committed Nov 21, 2023
1 parent 589d3e6 commit d2fabf2
Show file tree
Hide file tree
Showing 6 changed files with 177 additions and 46 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package io.smallrye.faulttolerance.core.apiimpl;

import static io.smallrye.faulttolerance.core.Invocation.invocation;
import static io.smallrye.faulttolerance.core.util.Durations.timeInMillis;

import java.time.Duration;
import java.time.temporal.ChronoUnit;
import java.util.Collection;
import java.util.Collections;
Expand Down Expand Up @@ -370,10 +370,6 @@ private <V> FaultToleranceStrategy<CompletionStage<V>> buildAsyncStrategy(Builde
return result;
}

private static long getTimeInMs(long time, ChronoUnit unit) {
return Duration.of(time, unit).toMillis();
}

private static ExceptionDecision createExceptionDecision(Collection<Class<? extends Throwable>> consideredExpected,
Collection<Class<? extends Throwable>> consideredFailure, Predicate<Throwable> whenPredicate) {
if (whenPredicate != null) {
Expand Down Expand Up @@ -516,7 +512,7 @@ public CircuitBreakerBuilder<T, R> delay(long value, ChronoUnit unit) {
Preconditions.check(value, value >= 0, "Delay must be >= 0");
Preconditions.checkNotNull(unit, "Delay unit must be set");

this.delayInMillis = getTimeInMs(value, unit);
this.delayInMillis = timeInMillis(value, unit);
return this;
}

Expand Down Expand Up @@ -664,7 +660,7 @@ public RateLimitBuilder<T, R> window(long value, ChronoUnit unit) {
Preconditions.check(value, value >= 1, "Time window length must be >= 1");
Preconditions.checkNotNull(unit, "Time window length unit must be set");

this.timeWindowInMillis = getTimeInMs(value, unit);
this.timeWindowInMillis = timeInMillis(value, unit);
return this;
}

Expand All @@ -673,7 +669,7 @@ public RateLimitBuilder<T, R> minSpacing(long value, ChronoUnit unit) {
Preconditions.check(value, value >= 0, "Min spacing must be >= 0");
Preconditions.checkNotNull(unit, "Min spacing unit must be set");

this.minSpacingInMillis = getTimeInMs(value, unit);
this.minSpacingInMillis = timeInMillis(value, unit);
return this;
}

Expand Down Expand Up @@ -737,7 +733,7 @@ public RetryBuilder<T, R> delay(long value, ChronoUnit unit) {
Preconditions.check(value, value >= 0, "Delay must be >= 0");
Preconditions.checkNotNull(unit, "Delay unit must be set");

this.delayInMillis = getTimeInMs(value, unit);
this.delayInMillis = timeInMillis(value, unit);
return this;
}

Expand All @@ -746,7 +742,7 @@ public RetryBuilder<T, R> maxDuration(long value, ChronoUnit unit) {
Preconditions.check(value, value >= 0, "Max duration must be >= 0");
Preconditions.checkNotNull(unit, "Max duration unit must be set");

this.maxDurationInMillis = getTimeInMs(value, unit);
this.maxDurationInMillis = timeInMillis(value, unit);
return this;
}

Expand All @@ -755,7 +751,7 @@ public RetryBuilder<T, R> jitter(long value, ChronoUnit unit) {
Preconditions.check(value, value >= 0, "Jitter must be >= 0");
Preconditions.checkNotNull(unit, "Jitter unit must be set");

this.jitterInMillis = getTimeInMs(value, unit);
this.jitterInMillis = timeInMillis(value, unit);
return this;
}

Expand Down Expand Up @@ -857,7 +853,7 @@ public ExponentialBackoffBuilder<T, R> maxDelay(long value, ChronoUnit unit) {
Preconditions.check(value, value >= 0, "Max delay must be >= 0");
Preconditions.checkNotNull(unit, "Max delay unit must be set");

this.maxDelayInMillis = getTimeInMs(value, unit);
this.maxDelayInMillis = timeInMillis(value, unit);
return this;
}

Expand All @@ -882,7 +878,7 @@ public FibonacciBackoffBuilder<T, R> maxDelay(long value, ChronoUnit unit) {
Preconditions.check(value, value >= 0, "Max delay must be >= 0");
Preconditions.checkNotNull(unit, "Max delay unit must be set");

this.maxDelayInMillis = getTimeInMs(value, unit);
this.maxDelayInMillis = timeInMillis(value, unit);
return this;
}

Expand Down Expand Up @@ -935,7 +931,7 @@ public TimeoutBuilder<T, R> duration(long value, ChronoUnit unit) {
Preconditions.check(value, value >= 0, "Timeout duration must be >= 0");
Preconditions.checkNotNull(unit, "Timeout duration unit must be set");

this.durationInMillis = getTimeInMs(value, unit);
this.durationInMillis = timeInMillis(value, unit);
return this;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package io.smallrye.faulttolerance.core.util;

import java.time.temporal.ChronoUnit;
import java.util.concurrent.TimeUnit;

public final class Durations {
private static final long SECONDS_IN_HALF_DAY = ChronoUnit.HALF_DAYS.getDuration().getSeconds();
private static final long SECONDS_IN_WEEK = ChronoUnit.WEEKS.getDuration().getSeconds();;
private static final long SECONDS_IN_MONTH = ChronoUnit.MONTHS.getDuration().getSeconds();;
private static final long SECONDS_IN_YEAR = ChronoUnit.YEARS.getDuration().getSeconds();;

private static final long MAX_HALF_DAYS = Long.MAX_VALUE / SECONDS_IN_HALF_DAY;
private static final long MAX_WEEKS = Long.MAX_VALUE / SECONDS_IN_WEEK;
private static final long MAX_MONTHS = Long.MAX_VALUE / SECONDS_IN_MONTH;
private static final long MAX_YEARS = Long.MAX_VALUE / SECONDS_IN_YEAR;

public static long timeInMillis(long value, ChronoUnit unit) {
switch (unit) {
case NANOS:
return TimeUnit.NANOSECONDS.toMillis(value);
case MICROS:
return TimeUnit.MICROSECONDS.toMillis(value);
case MILLIS:
return value;
case SECONDS:
return TimeUnit.SECONDS.toMillis(value);
case MINUTES:
return TimeUnit.MINUTES.toMillis(value);
case HOURS:
return TimeUnit.HOURS.toMillis(value);
case HALF_DAYS:
return convert(value, MAX_HALF_DAYS, SECONDS_IN_HALF_DAY);
case DAYS:
return TimeUnit.DAYS.toMillis(value);
case WEEKS:
return convert(value, MAX_WEEKS, SECONDS_IN_WEEK);
case MONTHS:
return convert(value, MAX_MONTHS, SECONDS_IN_MONTH);
case YEARS:
return convert(value, MAX_YEARS, SECONDS_IN_YEAR);
default:
throw new IllegalArgumentException("Unsupported time unit: " + unit);
}
}

private static long convert(long value, long maxInUnit, long secondsInUnit) {
if (value == Long.MIN_VALUE) {
return Long.MIN_VALUE;
}

boolean negative = value < 0;
long abs = negative ? -value : value;
if (abs > maxInUnit) {
// `value * secondsInUnit` would overflow
return negative ? Long.MIN_VALUE : Long.MAX_VALUE;
}

return TimeUnit.SECONDS.toMillis(value * secondsInUnit);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package io.smallrye.faulttolerance.core.util;

import static org.assertj.core.api.Assertions.assertThat;

import java.time.Duration;
import java.time.temporal.ChronoUnit;
import java.util.List;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;

import org.junit.jupiter.api.Test;

public class DurationsTest {
@Test
public void timeInMillis() {
assertThat(Durations.timeInMillis(5_000_000, ChronoUnit.NANOS)).isEqualTo(5);
assertThat(Durations.timeInMillis(5_000, ChronoUnit.MICROS)).isEqualTo(5);
assertThat(Durations.timeInMillis(5, ChronoUnit.MILLIS)).isEqualTo(5);
assertThat(Durations.timeInMillis(2, ChronoUnit.SECONDS)).isEqualTo(2000);
assertThat(Durations.timeInMillis(2, ChronoUnit.MINUTES)).isEqualTo(120_000);
assertThat(Durations.timeInMillis(3, ChronoUnit.HOURS)).isEqualTo(Duration.ofHours(3).toMillis());
assertThat(Durations.timeInMillis(2, ChronoUnit.HALF_DAYS)).isEqualTo(Duration.ofDays(1).toMillis());
assertThat(Durations.timeInMillis(8, ChronoUnit.HALF_DAYS)).isEqualTo(Duration.ofDays(4).toMillis());
assertThat(Durations.timeInMillis(365, ChronoUnit.DAYS)).isEqualTo(Duration.ofDays(365).toMillis());
assertThat(Durations.timeInMillis(7, ChronoUnit.WEEKS)).isEqualTo(Duration.ofDays(7 * 7).toMillis());
assertThat(Durations.timeInMillis(17, ChronoUnit.WEEKS)).isEqualTo(Duration.ofDays(17 * 7).toMillis());
assertThat(Durations.timeInMillis(12, ChronoUnit.MONTHS)).isEqualTo(ChronoUnit.YEARS.getDuration().toMillis());
assertThat(Durations.timeInMillis(24, ChronoUnit.MONTHS)).isEqualTo(2 * ChronoUnit.YEARS.getDuration().toMillis());

assertThat(Durations.timeInMillis(0, ChronoUnit.HALF_DAYS)).isEqualTo(0);
assertThat(Durations.timeInMillis(4, ChronoUnit.HALF_DAYS)).isEqualTo(Duration.ofDays(2).toMillis());
assertThat(Durations.timeInMillis(-4, ChronoUnit.HALF_DAYS)).isEqualTo(Duration.ofDays(-2).toMillis());
assertThat(Durations.timeInMillis(0, ChronoUnit.DAYS)).isEqualTo(0);
assertThat(Durations.timeInMillis(4, ChronoUnit.DAYS)).isEqualTo(Duration.ofDays(4).toMillis());
assertThat(Durations.timeInMillis(-4, ChronoUnit.DAYS)).isEqualTo(Duration.ofDays(-4).toMillis());
assertThat(Durations.timeInMillis(0, ChronoUnit.WEEKS)).isEqualTo(0);
assertThat(Durations.timeInMillis(4, ChronoUnit.WEEKS)).isEqualTo(Duration.ofDays(4 * 7).toMillis());
assertThat(Durations.timeInMillis(-4, ChronoUnit.WEEKS)).isEqualTo(Duration.ofDays(-4 * 7).toMillis());
assertThat(Durations.timeInMillis(0, ChronoUnit.MONTHS)).isEqualTo(0);
assertThat(Durations.timeInMillis(7, ChronoUnit.MONTHS)).isEqualTo(7 * ChronoUnit.MONTHS.getDuration().toMillis());
assertThat(Durations.timeInMillis(-7, ChronoUnit.MONTHS)).isEqualTo(-7 * ChronoUnit.MONTHS.getDuration().toMillis());

assertThat(Durations.timeInMillis(Long.MAX_VALUE, ChronoUnit.HOURS)).isEqualTo(Long.MAX_VALUE);
assertThat(Durations.timeInMillis(Long.MAX_VALUE, ChronoUnit.HALF_DAYS)).isEqualTo(Long.MAX_VALUE);
assertThat(Durations.timeInMillis(Long.MAX_VALUE, ChronoUnit.DAYS)).isEqualTo(Long.MAX_VALUE);
assertThat(Durations.timeInMillis(Long.MAX_VALUE, ChronoUnit.WEEKS)).isEqualTo(Long.MAX_VALUE);
assertThat(Durations.timeInMillis(Long.MAX_VALUE, ChronoUnit.MONTHS)).isEqualTo(Long.MAX_VALUE);

assertThat(Durations.timeInMillis(Long.MAX_VALUE - 1, ChronoUnit.HOURS)).isEqualTo(Long.MAX_VALUE);
assertThat(Durations.timeInMillis(Long.MAX_VALUE - 1, ChronoUnit.HALF_DAYS)).isEqualTo(Long.MAX_VALUE);
assertThat(Durations.timeInMillis(Long.MAX_VALUE - 1, ChronoUnit.DAYS)).isEqualTo(Long.MAX_VALUE);
assertThat(Durations.timeInMillis(Long.MAX_VALUE - 1, ChronoUnit.WEEKS)).isEqualTo(Long.MAX_VALUE);
assertThat(Durations.timeInMillis(Long.MAX_VALUE - 1, ChronoUnit.MONTHS)).isEqualTo(Long.MAX_VALUE);

assertThat(Durations.timeInMillis(Long.MIN_VALUE + 1, ChronoUnit.HOURS)).isEqualTo(Long.MIN_VALUE);
assertThat(Durations.timeInMillis(Long.MIN_VALUE + 1, ChronoUnit.HALF_DAYS)).isEqualTo(Long.MIN_VALUE);
assertThat(Durations.timeInMillis(Long.MIN_VALUE + 1, ChronoUnit.DAYS)).isEqualTo(Long.MIN_VALUE);
assertThat(Durations.timeInMillis(Long.MIN_VALUE + 1, ChronoUnit.WEEKS)).isEqualTo(Long.MIN_VALUE);
assertThat(Durations.timeInMillis(Long.MIN_VALUE + 1, ChronoUnit.MONTHS)).isEqualTo(Long.MIN_VALUE);

assertThat(Durations.timeInMillis(Long.MIN_VALUE, ChronoUnit.HOURS)).isEqualTo(Long.MIN_VALUE);
assertThat(Durations.timeInMillis(Long.MIN_VALUE, ChronoUnit.HALF_DAYS)).isEqualTo(Long.MIN_VALUE);
assertThat(Durations.timeInMillis(Long.MIN_VALUE, ChronoUnit.DAYS)).isEqualTo(Long.MIN_VALUE);
assertThat(Durations.timeInMillis(Long.MIN_VALUE, ChronoUnit.WEEKS)).isEqualTo(Long.MIN_VALUE);
assertThat(Durations.timeInMillis(Long.MIN_VALUE, ChronoUnit.MONTHS)).isEqualTo(Long.MIN_VALUE);
}

@Test
public void timeInMillis_random() {
for (ChronoUnit unit : List.of(ChronoUnit.NANOS, ChronoUnit.MICROS, ChronoUnit.MILLIS, ChronoUnit.SECONDS,
ChronoUnit.MINUTES, ChronoUnit.HOURS, ChronoUnit.HALF_DAYS, ChronoUnit.DAYS)) {
for (int i = 0; i < 1_000_000; i++) {
long value = ThreadLocalRandom.current().nextLong(Integer.MIN_VALUE * 1000L, Integer.MAX_VALUE * 1000L);
assertThat(Durations.timeInMillis(value, unit))
.isEqualTo(TimeUnit.MILLISECONDS.convert(Duration.of(value, unit)));
}
}
}
}
Loading

0 comments on commit d2fabf2

Please sign in to comment.