From d57e85dec3008b8140474f6bd8b8355eb2fbb3bc Mon Sep 17 00:00:00 2001 From: Santiago Mola Date: Mon, 28 Nov 2022 18:14:05 +0100 Subject: [PATCH] Fix bug in dependency resolving under high concurrency --- telemetry/build.gradle | 3 + .../dependency/DependencyResolverQueue.java | 6 +- .../LocationsCollectingTransformer.java | 42 ++++--- ...ependencyResolverQueueSpecification.groovy | 7 ++ ...sCollectingTransformerSpecification.groovy | 105 ++++++++++++++++++ 5 files changed, 135 insertions(+), 28 deletions(-) diff --git a/telemetry/build.gradle b/telemetry/build.gradle index de5ce7f0bea7..96e854280820 100644 --- a/telemetry/build.gradle +++ b/telemetry/build.gradle @@ -30,6 +30,9 @@ dependencies { implementation project(':internal-api') + compileOnly project(':dd-java-agent:agent-tooling') + testImplementation project(':dd-java-agent:agent-tooling') + compileOnly project(':communication') testImplementation project(':communication') diff --git a/telemetry/src/main/java/datadog/telemetry/dependency/DependencyResolverQueue.java b/telemetry/src/main/java/datadog/telemetry/dependency/DependencyResolverQueue.java index 455bcd180889..35e38a6b2a14 100644 --- a/telemetry/src/main/java/datadog/telemetry/dependency/DependencyResolverQueue.java +++ b/telemetry/src/main/java/datadog/telemetry/dependency/DependencyResolverQueue.java @@ -35,7 +35,7 @@ public void queueURI(URI uri) { // ignore already processed url synchronized (this) { - if (processedUrlsSet.contains(uri)) { + if (!processedUrlsSet.add(uri)) { return; } } @@ -64,10 +64,6 @@ public List pollDependency() { log.debug("dependency detected {} for {}", dep, uri); } - synchronized (this) { - processedUrlsSet.add(uri); - } - return dep; } } diff --git a/telemetry/src/main/java/datadog/telemetry/dependency/LocationsCollectingTransformer.java b/telemetry/src/main/java/datadog/telemetry/dependency/LocationsCollectingTransformer.java index fb98987e85fc..c84ba3d9831c 100644 --- a/telemetry/src/main/java/datadog/telemetry/dependency/LocationsCollectingTransformer.java +++ b/telemetry/src/main/java/datadog/telemetry/dependency/LocationsCollectingTransformer.java @@ -1,25 +1,25 @@ package datadog.telemetry.dependency; +import datadog.trace.agent.tooling.WeakCaches; +import datadog.trace.bootstrap.WeakCache; import java.lang.instrument.ClassFileTransformer; import java.net.URL; import java.security.CodeSource; import java.security.ProtectionDomain; -import java.util.Collections; -import java.util.Set; -import java.util.WeakHashMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; class LocationsCollectingTransformer implements ClassFileTransformer { private static final Logger log = LoggerFactory.getLogger(LocationsCollectingTransformer.class); + private static final int MAX_CACHED_JARS = 4096; private final DependencyServiceImpl dependencyService; - private final Set seenDomains = - Collections.newSetFromMap(new WeakHashMap()); + private final WeakCache seenDomains = + WeakCaches.newWeakCache(MAX_CACHED_JARS); public LocationsCollectingTransformer(DependencyServiceImpl dependencyService) { this.dependencyService = dependencyService; - seenDomains.add(LocationsCollectingTransformer.class.getProtectionDomain()); + seenDomains.put(LocationsCollectingTransformer.class.getProtectionDomain(), true); } @Override @@ -32,24 +32,20 @@ public byte[] transform( if (protectionDomain == null) { return null; } - if (!seenDomains.add(protectionDomain)) { - return null; - } - log.debug("Saw new protection domain: {}", protectionDomain); - - CodeSource codeSource = protectionDomain.getCodeSource(); - if (codeSource == null) { - return null; - } - - URL location = codeSource.getLocation(); - if (location == null) { - return null; - } - - dependencyService.addURL(location); - + seenDomains.computeIfAbsent(protectionDomain, this::addDependency); // returning 'null' is the best way to indicate that no transformation has been done. return null; } + + private boolean addDependency(final ProtectionDomain domain) { + log.debug("Saw new protection domain: {}", domain); + final CodeSource codeSource = domain.getCodeSource(); + if (null != codeSource) { + final URL location = codeSource.getLocation(); + if (null != location) { + dependencyService.addURL(location); + } + } + return true; + } } diff --git a/telemetry/src/test/groovy/datadog/telemetry/dependency/DependencyResolverQueueSpecification.groovy b/telemetry/src/test/groovy/datadog/telemetry/dependency/DependencyResolverQueueSpecification.groovy index e6605ea4544b..62fa1fed216b 100644 --- a/telemetry/src/test/groovy/datadog/telemetry/dependency/DependencyResolverQueueSpecification.groovy +++ b/telemetry/src/test/groovy/datadog/telemetry/dependency/DependencyResolverQueueSpecification.groovy @@ -45,5 +45,12 @@ class DependencyResolverQueueSpecification extends DepSpecification { then: assert deps.isEmpty() + + when: 'a repeated dependency is added' + resolverQueue.queueURI(getJar('junit-4.12.jar').toURI()) + deps = resolverQueue.pollDependency() + + then: 'it has no effect' + assert deps.isEmpty() } } diff --git a/telemetry/src/test/groovy/datadog/telemetry/dependency/LocationsCollectingTransformerSpecification.groovy b/telemetry/src/test/groovy/datadog/telemetry/dependency/LocationsCollectingTransformerSpecification.groovy index 3e7ade538dd0..c809b74c1bb6 100644 --- a/telemetry/src/test/groovy/datadog/telemetry/dependency/LocationsCollectingTransformerSpecification.groovy +++ b/telemetry/src/test/groovy/datadog/telemetry/dependency/LocationsCollectingTransformerSpecification.groovy @@ -1,7 +1,12 @@ package datadog.telemetry.dependency +import spock.lang.Timeout + import java.security.CodeSource import java.security.ProtectionDomain +import java.util.concurrent.ArrayBlockingQueue +import java.util.concurrent.CountDownLatch +import java.util.concurrent.Executors class LocationsCollectingTransformerSpecification extends DepSpecification { @@ -76,4 +81,104 @@ class LocationsCollectingTransformerSpecification extends DepSpecification { def dependencies = depService.drainDeterminedDependencies() dependencies.size()==1 } + + void 'multiple dependencies'() { + given: + def nDomains = 1000 + def domains = new ArrayList() + (1..nDomains).each { + CodeSource source = new CodeSource(new URL("file:///bson-${it}.jar"), (java.security.cert.Certificate[])null) + ProtectionDomain domain = new ProtectionDomain(source, null) + domains.add(domain) + } + + and: + def depService = Mock(DependencyServiceImpl) + def transformer = new LocationsCollectingTransformer(depService) + + when: + domains.each { + transformer.transform(null, null, null, it, null) + } + + then: + nDomains * depService.addURL(_) + } + + @Timeout(10) + void 'multiple dependencies with concurrency'() { + given: + def threads = 16 + def executor = Executors.newFixedThreadPool(threads) + def latch = new CountDownLatch(threads) + + and: + def nDomains = 3000 + def domains = new ArrayBlockingQueue(nDomains) + (1..nDomains).each { + CodeSource source = new CodeSource(new URL("file:///bson-${it}.jar"), (java.security.cert.Certificate[])null) + ProtectionDomain domain = new ProtectionDomain(source, null) + domains.add(domain) + } + + and: + def depService = Mock(DependencyServiceImpl) + def transformer = new LocationsCollectingTransformer(depService) + + when: + def futures = (1..threads).collect { + executor.submit { + latch.countDown() + latch.await() + ProtectionDomain domain = null + while ((domain = domains.poll()) != null) { + transformer.transform(null, null, null, domain, null) + } + } + } + futures.each { it.get() } + + then: + nDomains * depService.addURL(_) + 0 * _ + + cleanup: + executor?.shutdownNow() + } + + @Timeout(10) + void 'single dependencies with concurrency'() { + given: + def threads = 16 + def executor = Executors.newFixedThreadPool(threads) + def latch = new CountDownLatch(threads) + + and: + def nDomains = 3000 + CodeSource source = new CodeSource(new URL("file:///bson-1.jar"), (java.security.cert.Certificate[])null) + ProtectionDomain domain = new ProtectionDomain(source, null) + + and: + def depService = Mock(DependencyServiceImpl) + def transformer = new LocationsCollectingTransformer(depService) + + when: + def futures = (1..threads).collect { + executor.submit { + latch.countDown() + latch.await() + for (int i = 0; i < nDomains; i++) { + transformer.transform(null, null, null, domain, null) + } + } + } + futures.each { it.get() } + + then: + (1..threads) * depService.addURL(_) + 0 * _ + + cleanup: + executor?.shutdownNow() + } }