Skip to content

Commit

Permalink
Extract APIGatewayProxyRequestEvent headers for context propagation (
Browse files Browse the repository at this point in the history
  • Loading branch information
cleverchuk authored Oct 17, 2024
1 parent 18a277f commit c5d6b4a
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@
import static net.bytebuddy.matcher.ElementMatchers.takesArgument;

import com.amazonaws.services.lambda.runtime.Context;
import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyRequestEvent;
import com.amazonaws.services.lambda.runtime.events.SQSEvent;
import io.opentelemetry.context.Scope;
import io.opentelemetry.instrumentation.awslambdacore.v1_0.AwsLambdaRequest;
import io.opentelemetry.instrumentation.awslambdacore.v1_0.internal.MapUtils;
import io.opentelemetry.javaagent.bootstrap.OpenTelemetrySdkAccess;
import io.opentelemetry.javaagent.extension.instrumentation.TypeInstrumentation;
import io.opentelemetry.javaagent.extension.instrumentation.TypeTransformer;
import java.util.Collections;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import net.bytebuddy.asm.Advice;
import net.bytebuddy.description.type.TypeDescription;
Expand Down Expand Up @@ -60,7 +63,11 @@ public static void onEnter(
@Advice.Local("otelFunctionScope") Scope functionScope,
@Advice.Local("otelMessageContext") io.opentelemetry.context.Context messageContext,
@Advice.Local("otelMessageScope") Scope messageScope) {
input = AwsLambdaRequest.create(context, arg, Collections.emptyMap());
Map<String, String> headers = Collections.emptyMap();
if (arg instanceof APIGatewayProxyRequestEvent) {
headers = MapUtils.lowercaseMap(((APIGatewayProxyRequestEvent) arg).getHeaders());
}
input = AwsLambdaRequest.create(context, arg, headers);
io.opentelemetry.context.Context parentContext =
AwsLambdaInstrumentationHelper.functionInstrumenter().extract(input);

Expand All @@ -87,6 +94,7 @@ public static void onEnter(
@Advice.OnMethodExit(onThrowable = Throwable.class, suppress = Throwable.class)
public static void stopSpan(
@Advice.Argument(value = 0, typing = Typing.DYNAMIC) Object arg,
@Advice.Return Object result,
@Advice.Thrown Throwable throwable,
@Advice.Local("otelInput") AwsLambdaRequest input,
@Advice.Local("otelFunctionContext") io.opentelemetry.context.Context functionContext,
Expand All @@ -103,7 +111,7 @@ public static void stopSpan(
if (functionScope != null) {
functionScope.close();
AwsLambdaInstrumentationHelper.functionInstrumenter()
.end(functionContext, input, null, throwable);
.end(functionContext, input, result, throwable);
}

OpenTelemetrySdkAccess.forceFlush(1, TimeUnit.SECONDS);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/*
* Copyright The OpenTelemetry Authors
* SPDX-License-Identifier: Apache-2.0
*/

package io.opentelemetry.javaagent.instrumentation.awslambdaevents.v2_2;

import static io.opentelemetry.sdk.testing.assertj.OpenTelemetryAssertions.equalTo;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.when;

import com.amazonaws.services.lambda.runtime.Context;
import com.amazonaws.services.lambda.runtime.RequestHandler;
import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyRequestEvent;
import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyResponseEvent;
import io.opentelemetry.api.trace.SpanKind;
import io.opentelemetry.instrumentation.testing.junit.AgentInstrumentationExtension;
import io.opentelemetry.instrumentation.testing.junit.InstrumentationExtension;
import io.opentelemetry.semconv.HttpAttributes;
import io.opentelemetry.semconv.UrlAttributes;
import io.opentelemetry.semconv.UserAgentAttributes;
import io.opentelemetry.semconv.incubating.FaasIncubatingAttributes;
import java.util.HashMap;
import java.util.Map;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.mockito.junit.jupiter.MockitoSettings;
import org.mockito.quality.Strictness;

@ExtendWith(MockitoExtension.class)
@MockitoSettings(strictness = Strictness.LENIENT)
public class AwsLambdaApiGatewayHandlerTest {

@RegisterExtension
public static final InstrumentationExtension testing = AgentInstrumentationExtension.create();

@Mock private Context context;

@BeforeEach
void setUp() {
when(context.getFunctionName()).thenReturn("test_function");
when(context.getAwsRequestId()).thenReturn("1-22-2024");
}

@AfterEach
void tearDown() {
assertThat(testing.forceFlushCalled()).isTrue();
}

@Test
void tracedWithHttpPropagation() {
Map<String, String> headers = new HashMap<>();
headers.put("traceparent", "00-ee13e7026227ebf4c74278ae29691d7a-0000000000000456-01");
headers.put("User-Agent", "Clever Client");
headers.put("host", "localhost:2024");
headers.put("X-FORWARDED-PROTO", "http");

APIGatewayProxyRequestEvent input =
new APIGatewayProxyRequestEvent()
.withHttpMethod("PUT")
.withResource("/hello/{param}")
.withPath("/hello/world")
.withBody("hello")
.withHeaders(headers);

APIGatewayProxyResponseEvent result =
new TestRequestHandlerApiGateway().handleRequest(input, context);
assertThat(result.getBody()).isEqualTo("hello world");
assertThat(result.getStatusCode()).isEqualTo(201);

testing.waitAndAssertTraces(
trace ->
trace.hasSpansSatisfyingExactly(
span ->
span.hasName("PUT /hello/{param}")
.hasKind(SpanKind.SERVER)
.hasTraceId("ee13e7026227ebf4c74278ae29691d7a")
.hasParentSpanId("0000000000000456")
.hasAttributesSatisfyingExactly(
equalTo(FaasIncubatingAttributes.FAAS_INVOCATION_ID, "1-22-2024"),
equalTo(FaasIncubatingAttributes.FAAS_TRIGGER, "http"),
equalTo(HttpAttributes.HTTP_REQUEST_METHOD, "PUT"),
equalTo(UserAgentAttributes.USER_AGENT_ORIGINAL, "Clever Client"),
equalTo(UrlAttributes.URL_FULL, "http://localhost:2024/hello/world"),
equalTo(HttpAttributes.HTTP_RESPONSE_STATUS_CODE, 201L))));
}

public static class TestRequestHandlerApiGateway
implements RequestHandler<APIGatewayProxyRequestEvent, APIGatewayProxyResponseEvent> {

@Override
public APIGatewayProxyResponseEvent handleRequest(
APIGatewayProxyRequestEvent input, Context context) {
return new APIGatewayProxyResponseEvent().withStatusCode(201).withBody("hello world");
}
}
}

0 comments on commit c5d6b4a

Please sign in to comment.