Skip to content

Commit

Permalink
Fix bug in dependency resolving under high concurrency
Browse files Browse the repository at this point in the history
  • Loading branch information
smola committed Nov 28, 2022
1 parent 945fc40 commit d57e85d
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 28 deletions.
3 changes: 3 additions & 0 deletions telemetry/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ dependencies {

implementation project(':internal-api')

compileOnly project(':dd-java-agent:agent-tooling')
testImplementation project(':dd-java-agent:agent-tooling')

compileOnly project(':communication')
testImplementation project(':communication')

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public void queueURI(URI uri) {

// ignore already processed url
synchronized (this) {
if (processedUrlsSet.contains(uri)) {
if (!processedUrlsSet.add(uri)) {
return;
}
}
Expand Down Expand Up @@ -64,10 +64,6 @@ public List<Dependency> pollDependency() {
log.debug("dependency detected {} for {}", dep, uri);
}

synchronized (this) {
processedUrlsSet.add(uri);
}

return dep;
}
}
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
package datadog.telemetry.dependency;

import datadog.trace.agent.tooling.WeakCaches;
import datadog.trace.bootstrap.WeakCache;
import java.lang.instrument.ClassFileTransformer;
import java.net.URL;
import java.security.CodeSource;
import java.security.ProtectionDomain;
import java.util.Collections;
import java.util.Set;
import java.util.WeakHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

class LocationsCollectingTransformer implements ClassFileTransformer {
private static final Logger log = LoggerFactory.getLogger(LocationsCollectingTransformer.class);

private static final int MAX_CACHED_JARS = 4096;
private final DependencyServiceImpl dependencyService;
private final Set<ProtectionDomain> seenDomains =
Collections.newSetFromMap(new WeakHashMap<ProtectionDomain, Boolean>());
private final WeakCache<ProtectionDomain, Boolean> seenDomains =
WeakCaches.newWeakCache(MAX_CACHED_JARS);

public LocationsCollectingTransformer(DependencyServiceImpl dependencyService) {
this.dependencyService = dependencyService;
seenDomains.add(LocationsCollectingTransformer.class.getProtectionDomain());
seenDomains.put(LocationsCollectingTransformer.class.getProtectionDomain(), true);
}

@Override
Expand All @@ -32,24 +32,20 @@ public byte[] transform(
if (protectionDomain == null) {
return null;
}
if (!seenDomains.add(protectionDomain)) {
return null;
}
log.debug("Saw new protection domain: {}", protectionDomain);

CodeSource codeSource = protectionDomain.getCodeSource();
if (codeSource == null) {
return null;
}

URL location = codeSource.getLocation();
if (location == null) {
return null;
}

dependencyService.addURL(location);

seenDomains.computeIfAbsent(protectionDomain, this::addDependency);
// returning 'null' is the best way to indicate that no transformation has been done.
return null;
}

private boolean addDependency(final ProtectionDomain domain) {
log.debug("Saw new protection domain: {}", domain);
final CodeSource codeSource = domain.getCodeSource();
if (null != codeSource) {
final URL location = codeSource.getLocation();
if (null != location) {
dependencyService.addURL(location);
}
}
return true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,12 @@ class DependencyResolverQueueSpecification extends DepSpecification {

then:
assert deps.isEmpty()

when: 'a repeated dependency is added'
resolverQueue.queueURI(getJar('junit-4.12.jar').toURI())
deps = resolverQueue.pollDependency()

then: 'it has no effect'
assert deps.isEmpty()
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
package datadog.telemetry.dependency

import spock.lang.Timeout

import java.security.CodeSource
import java.security.ProtectionDomain
import java.util.concurrent.ArrayBlockingQueue
import java.util.concurrent.CountDownLatch
import java.util.concurrent.Executors

class LocationsCollectingTransformerSpecification extends DepSpecification {

Expand Down Expand Up @@ -76,4 +81,104 @@ class LocationsCollectingTransformerSpecification extends DepSpecification {
def dependencies = depService.drainDeterminedDependencies()
dependencies.size()==1
}

void 'multiple dependencies'() {
given:
def nDomains = 1000
def domains = new ArrayList<ProtectionDomain>()
(1..nDomains).each {
CodeSource source = new CodeSource(new URL("file:///bson-${it}.jar"), (java.security.cert.Certificate[])null)
ProtectionDomain domain = new ProtectionDomain(source, null)
domains.add(domain)
}

and:
def depService = Mock(DependencyServiceImpl)
def transformer = new LocationsCollectingTransformer(depService)

when:
domains.each {
transformer.transform(null, null, null, it, null)
}

then:
nDomains * depService.addURL(_)
}

@Timeout(10)
void 'multiple dependencies with concurrency'() {
given:
def threads = 16
def executor = Executors.newFixedThreadPool(threads)
def latch = new CountDownLatch(threads)

and:
def nDomains = 3000
def domains = new ArrayBlockingQueue<ProtectionDomain>(nDomains)
(1..nDomains).each {
CodeSource source = new CodeSource(new URL("file:///bson-${it}.jar"), (java.security.cert.Certificate[])null)
ProtectionDomain domain = new ProtectionDomain(source, null)
domains.add(domain)
}

and:
def depService = Mock(DependencyServiceImpl)
def transformer = new LocationsCollectingTransformer(depService)

when:
def futures = (1..threads).collect {
executor.submit {
latch.countDown()
latch.await()
ProtectionDomain domain = null
while ((domain = domains.poll()) != null) {
transformer.transform(null, null, null, domain, null)
}
}
}
futures.each { it.get() }

then:
nDomains * depService.addURL(_)
0 * _

cleanup:
executor?.shutdownNow()
}

@Timeout(10)
void 'single dependencies with concurrency'() {
given:
def threads = 16
def executor = Executors.newFixedThreadPool(threads)
def latch = new CountDownLatch(threads)

and:
def nDomains = 3000
CodeSource source = new CodeSource(new URL("file:///bson-1.jar"), (java.security.cert.Certificate[])null)
ProtectionDomain domain = new ProtectionDomain(source, null)

and:
def depService = Mock(DependencyServiceImpl)
def transformer = new LocationsCollectingTransformer(depService)

when:
def futures = (1..threads).collect {
executor.submit {
latch.countDown()
latch.await()
for (int i = 0; i < nDomains; i++) {
transformer.transform(null, null, null, domain, null)
}
}
}
futures.each { it.get() }

then:
(1..threads) * depService.addURL(_)
0 * _

cleanup:
executor?.shutdownNow()
}
}

0 comments on commit d57e85d

Please sign in to comment.