Skip to content

Commit

Permalink
Add request sampling back in http sources (#6810)
Browse files Browse the repository at this point in the history
Add request sampling back in http sources
  • Loading branch information
manuel-alvarez-alvarez authored Mar 14, 2024
1 parent f19053f commit 8d4f9c2
Show file tree
Hide file tree
Showing 108 changed files with 1,492 additions and 566 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import datadog.trace.api.gateway.Flow;
import datadog.trace.api.gateway.RequestContext;
import datadog.trace.api.gateway.RequestContextSlot;
import datadog.trace.api.iast.IastContext;
import datadog.trace.api.iast.InstrumentationBridge;
import datadog.trace.api.iast.SourceTypes;
Expand Down Expand Up @@ -31,7 +32,7 @@ public class GrpcRequestMessageHandler implements BiFunction<RequestContext, Obj
public Flow<Void> apply(final RequestContext ctx, final Object o) {
final PropagationModule module = InstrumentationBridge.PROPAGATION;
if (module != null && o != null) {
final IastContext iastCtx = IastContext.Provider.get(ctx);
final IastContext iastCtx = ctx.getData(RequestContextSlot.IAST);
final byte source = SourceTypes.GRPC_BODY;
final int tainted =
module.taintDeeply(iastCtx, o, source, GrpcRequestMessageHandler::isProtobufArtifact);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import datadog.trace.api.gateway.Flow;
import datadog.trace.api.gateway.IGSpanInfo;
import datadog.trace.api.gateway.RequestContext;
import datadog.trace.api.gateway.RequestContextSlot;
import datadog.trace.api.iast.IastContext;
import datadog.trace.api.iast.InstrumentationBridge;
import datadog.trace.api.iast.sink.HttpRequestEndModule;
Expand All @@ -27,7 +28,7 @@ public RequestEndedHandler(@Nonnull final Dependencies dependencies) {
@Override
public Flow<Void> apply(final RequestContext requestContext, final IGSpanInfo igSpanInfo) {
final TraceSegment traceSegment = requestContext.getTraceSegment();
final IastContext iastCtx = IastContext.Provider.get(requestContext);
final IastContext iastCtx = requestContext.getData(RequestContextSlot.IAST);
if (iastCtx != null) {
for (HttpRequestEndModule module : requestEndModules()) {
if (module != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import datadog.trace.api.gateway.Flow;
import datadog.trace.api.gateway.IGSpanInfo;
import datadog.trace.api.gateway.RequestContext;
import datadog.trace.api.gateway.RequestContextSlot;
import datadog.trace.api.iast.IastContext;
import datadog.trace.api.iast.telemetry.IastMetric;
import datadog.trace.api.iast.telemetry.IastMetricCollector;
Expand Down Expand Up @@ -33,7 +34,7 @@ public Flow<Void> apply(final RequestContext context, final IGSpanInfo igSpanInf
}

private static void onRequestEnded(final RequestContext context) {
final IastContext iastCtx = IastContext.Provider.get(context);
final IastContext iastCtx = context.getData(RequestContextSlot.IAST);
if (!(iastCtx instanceof HasMetricCollector)) {
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import datadog.trace.agent.test.base.WithHttpServer
import datadog.trace.agent.tooling.bytebuddy.iast.TaintableVisitor
import datadog.trace.api.gateway.IGSpanInfo
import datadog.trace.api.gateway.RequestContext
import datadog.trace.api.iast.IastContext
import datadog.trace.api.gateway.RequestContextSlot
import groovy.json.JsonBuilder
import groovy.transform.CompileStatic

Expand All @@ -29,11 +29,11 @@ abstract class IastHttpServerTest<SERVER> extends WithHttpServer<SERVER> impleme
protected Closure getRequestEndAction() {
{ RequestContext requestContext, IGSpanInfo igSpanInfo ->
// request end action
IastRequestContext iastRequestContext = IastContext.Provider.get(requestContext)
IastRequestContext iastRequestContext = requestContext.getData(RequestContextSlot.IAST)
if (iastRequestContext) {
TaintedObjects taintedObjects = iastRequestContext.getTaintedObjects()
TAINTED_OBJECTS.offer(new TaintedObjectCollection(taintedObjects))
List<Vulnerability> vulns = iastRequestContext.getVulnerabilityBatch().getVulnerabilities()
List<Vulnerability> vulns = iastRequestContext.getVulnerabilityBatch().getVulnerabilities() ?: []
VULNERABILITIES.offer(vulns)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import com.datadog.iast.taint.TaintedObjects
import datadog.trace.agent.test.utils.OkHttpUtils
import datadog.trace.api.gateway.IGSpanInfo
import datadog.trace.api.gateway.RequestContext
import datadog.trace.api.iast.IastContext
import datadog.trace.api.gateway.RequestContextSlot
import okhttp3.OkHttpClient

import java.util.concurrent.LinkedBlockingQueue
Expand All @@ -21,7 +21,7 @@ class IastRequestTestRunner extends IastAgentTestRunner implements IastRequestCo
protected Closure getRequestEndAction() {
{ RequestContext requestContext, IGSpanInfo igSpanInfo ->
// request end action
IastRequestContext iastRequestContext = IastContext.Provider.get(requestContext)
IastRequestContext iastRequestContext = requestContext.getData(RequestContextSlot.IAST)
if (iastRequestContext) {
TaintedObjects taintedObjects = iastRequestContext.getTaintedObjects()
TAINTED_OBJECTS.offer(new TaintedObjectCollection(taintedObjects))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ public class TaintableEnumeration implements Enumeration<String> {

private static final String CLASS_NAME = TaintableEnumeration.class.getName();

private volatile IastContext context;
private volatile boolean contextFetched;
private final IastContext context;

private final PropagationModule module;

Expand All @@ -25,11 +24,13 @@ public class TaintableEnumeration implements Enumeration<String> {
private final Enumeration<String> delegate;

private TaintableEnumeration(
final IastContext ctx,
@NonNull final Enumeration<String> delegate,
@NonNull final PropagationModule module,
final byte origin,
@Nullable final CharSequence name,
final boolean useValueAsName) {
this.context = ctx;
this.delegate = delegate;
this.module = module;
this.origin = origin;
Expand Down Expand Up @@ -57,21 +58,13 @@ public String nextElement() {
throw e;
}
try {
module.taint(context(), next, origin, name(next));
module.taint(context, next, origin, name(next));
} catch (final Throwable e) {
module.onUnexpectedException("Failed to taint enumeration", e);
}
return next;
}

private IastContext context() {
if (!contextFetched) {
contextFetched = true;
context = IastContext.Provider.get();
}
return context;
}

private CharSequence name(final String value) {
if (name != null) {
return name;
Expand All @@ -84,18 +77,20 @@ private static boolean nonTaintableEnumerationStack(final StackTraceElement elem
}

public static Enumeration<String> wrap(
final IastContext ctx,
@NonNull final Enumeration<String> delegate,
@NonNull final PropagationModule module,
final byte origin,
@Nullable final CharSequence name) {
return new TaintableEnumeration(delegate, module, origin, name, false);
return new TaintableEnumeration(ctx, delegate, module, origin, name, false);
}

public static Enumeration<String> wrap(
final IastContext ctx,
@NonNull final Enumeration<String> delegate,
@NonNull final PropagationModule module,
final byte origin,
boolean useValueAsName) {
return new TaintableEnumeration(delegate, module, origin, null, useValueAsName);
return new TaintableEnumeration(ctx, delegate, module, origin, null, useValueAsName);
}
}
Original file line number Diff line number Diff line change
@@ -1,44 +1,26 @@
package datadog.trace.agent.tooling.iast

import datadog.trace.api.gateway.RequestContext
import datadog.trace.api.gateway.RequestContextSlot

import datadog.trace.api.iast.IastContext
import datadog.trace.api.iast.InstrumentationBridge
import datadog.trace.api.iast.SourceTypes
import datadog.trace.api.iast.propagation.PropagationModule
import datadog.trace.bootstrap.instrumentation.api.AgentSpan
import datadog.trace.bootstrap.instrumentation.api.AgentTracer
import datadog.trace.test.util.DDSpecification
import spock.lang.Shared

class TaintableEnumerationTest extends DDSpecification {

@Shared
protected static final AgentTracer.TracerAPI ORIGINAL_TRACER = AgentTracer.get()

protected AgentTracer.TracerAPI tracer = Mock(AgentTracer.TracerAPI)

protected IastContext iastCtx = Mock(IastContext)

protected RequestContext reqCtx = Mock(RequestContext) {
getData(RequestContextSlot.IAST) >> iastCtx
}

protected AgentSpan span = Mock(AgentSpan) {
getRequestContext() >> reqCtx
}
protected IastContext iastCtx = Stub(IastContext)

protected PropagationModule module


void setup() {
AgentTracer.forceRegister(tracer)
module = Mock(PropagationModule)
InstrumentationBridge.registerIastModule(module)
}

void cleanup() {
AgentTracer.forceRegister(ORIGINAL_TRACER)
InstrumentationBridge.clearIastModules()
}

Expand All @@ -47,35 +29,34 @@ class TaintableEnumerationTest extends DDSpecification {
final values = (1..10).collect { "value$it".toString() }
final origin = SourceTypes.REQUEST_PARAMETER_NAME
final name = 'test'
final enumeration = TaintableEnumeration.wrap(Collections.enumeration(values), module, origin, name)
final enumeration = TaintableEnumeration.wrap(iastCtx, Collections.enumeration(values), module, origin, name)

when:
final result = enumeration.collect()

then:
result == values
values.each { 1 * module.taint(_, it, origin, name) }
1 * tracer.activeSpan() >> span // only one access to the active context
values.each { 1 * module.taint(iastCtx, it, origin, name) }
}

void 'underlying enumerated values are tainted with the value as a name'() {
given:
final values = (1..10).collect { "value$it".toString() }
final origin = SourceTypes.REQUEST_PARAMETER_NAME
final enumeration = TaintableEnumeration.wrap(Collections.enumeration(values), module, origin, true)
final enumeration = TaintableEnumeration.wrap(iastCtx, Collections.enumeration(values), module, origin, true)

when:
final result = enumeration.collect()

then:
result == values
values.each { 1 * module.taint(_, it, origin, it) }
values.each { 1 * module.taint(iastCtx, it, origin, it) }
}

void 'taintable enumeration leaves no trace in case of error'() {
given:
final origin = SourceTypes.REQUEST_PARAMETER_NAME
final enumeration = TaintableEnumeration.wrap(new BadEnumeration(), module, origin, true)
final enumeration = TaintableEnumeration.wrap(iastCtx, new BadEnumeration(), module, origin, true)

when:
enumeration.hasMoreElements()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@
import akka.http.scaladsl.model.headers.Cookie;
import akka.http.scaladsl.model.headers.HttpCookiePair;
import com.google.auto.service.AutoService;
import datadog.trace.advice.ActiveRequestContext;
import datadog.trace.advice.RequiresRequestContext;
import datadog.trace.agent.tooling.Instrumenter;
import datadog.trace.agent.tooling.InstrumenterModule;
import datadog.trace.api.gateway.RequestContext;
import datadog.trace.api.gateway.RequestContextSlot;
import datadog.trace.api.iast.IastContext;
import datadog.trace.api.iast.InstrumentationBridge;
import datadog.trace.api.iast.Source;
Expand Down Expand Up @@ -50,20 +54,23 @@ public void methodAdvice(MethodTransformer transformer) {
CookieHeaderInstrumentation.class.getName() + "$TaintAllCookiesAdvice");
}

@RequiresRequestContext(RequestContextSlot.IAST)
static class TaintAllCookiesAdvice {
@Advice.OnMethodExit(suppress = Throwable.class)
@Source(SourceTypes.REQUEST_COOKIE_VALUE)
static void after(
@Advice.This HttpHeader cookie, @Advice.Return Seq<HttpCookiePair> cookiePairs) {
@Advice.This HttpHeader cookie,
@Advice.Return Seq<HttpCookiePair> cookiePairs,
@ActiveRequestContext RequestContext reqCtx) {
PropagationModule prop = InstrumentationBridge.PROPAGATION;
if (prop == null || cookiePairs == null || cookiePairs.isEmpty()) {
return;
}
if (!prop.isTainted(cookie)) {
final IastContext ctx = reqCtx.getData(RequestContextSlot.IAST);
if (!prop.isTainted(ctx, cookie)) {
return;
}

final IastContext ctx = IastContext.Provider.get();
Iterator<HttpCookiePair> iterator = cookiePairs.iterator();
while (iterator.hasNext()) {
HttpCookiePair pair = iterator.next();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import akka.http.javadsl.model.HttpHeader;
import datadog.trace.agent.tooling.csi.CallSite;
import datadog.trace.api.iast.IastCallSites;
import datadog.trace.api.iast.IastContext;
import datadog.trace.api.iast.InstrumentationBridge;
import datadog.trace.api.iast.Source;
import datadog.trace.api.iast.SourceTypes;
import datadog.trace.api.iast.propagation.PropagationModule;
import datadog.trace.bootstrap.instrumentation.api.AgentTracer;

/**
* Detects when a header name is directly called from user code. This uses call site instrumentation
Expand All @@ -26,7 +28,11 @@ public static String after(@CallSite.This HttpHeader header, @CallSite.Return St
return result;
}
try {
module.taintIfTainted(result, header, SourceTypes.REQUEST_HEADER_NAME, result);
final IastContext ctx = IastContext.Provider.get(AgentTracer.activeSpan());
if (ctx == null) {
return result;
}
module.taintIfTainted(ctx, result, header, SourceTypes.REQUEST_HEADER_NAME, result);
} catch (final Throwable e) {
module.onUnexpectedException("onHeaderNames threw", e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,13 @@
import akka.http.scaladsl.model.HttpHeader;
import akka.http.scaladsl.model.HttpRequest;
import com.google.auto.service.AutoService;
import datadog.trace.advice.ActiveRequestContext;
import datadog.trace.advice.RequiresRequestContext;
import datadog.trace.agent.tooling.Instrumenter;
import datadog.trace.agent.tooling.InstrumenterModule;
import datadog.trace.api.gateway.RequestContext;
import datadog.trace.api.gateway.RequestContextSlot;
import datadog.trace.api.iast.IastContext;
import datadog.trace.api.iast.InstrumentationBridge;
import datadog.trace.api.iast.Propagation;
import datadog.trace.api.iast.propagation.PropagationModule;
Expand Down Expand Up @@ -53,17 +58,21 @@ public void methodAdvice(MethodTransformer transformer) {
HttpHeaderSubclassesInstrumentation.class.getName() + "$HttpHeaderSubclassesAdvice");
}

@RequiresRequestContext(RequestContextSlot.IAST)
static class HttpHeaderSubclassesAdvice {
@Advice.OnMethodExit(suppress = Throwable.class)
@Propagation
static void onExit(@Advice.This HttpHeader h, @Advice.Return String retVal) {
static void onExit(
@Advice.This HttpHeader h,
@Advice.Return String retVal,
@ActiveRequestContext RequestContext reqCtx) {

PropagationModule propagation = InstrumentationBridge.PROPAGATION;
if (propagation == null) {
return;
}

propagation.taintIfTainted(retVal, h);
IastContext ctx = reqCtx.getData(RequestContextSlot.IAST);
propagation.taintIfTainted(ctx, retVal, h);
}
}
}
Loading

0 comments on commit 8d4f9c2

Please sign in to comment.