From 1b90220de5c24570cab91fcb34b7a7eb7ffa634d Mon Sep 17 00:00:00 2001
From: Owais Lone <owais@lone.pw>
Date: Tue, 13 Apr 2021 11:06:59 +0530
Subject: [PATCH] Added trace response headers to Flask

---
 .github/workflows/test.yml                    |  2 +-
 .../instrumentation/flask/__init__.py         | 10 +++++++
 .../tests/test_programmatic.py                | 30 +++++++++++++++++++
 .../instrumentation/wsgi/__init__.py          |  8 +++++
 4 files changed, 49 insertions(+), 1 deletion(-)

diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index 9b5e1e5438..5cfaa1eba5 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -6,7 +6,7 @@ on:
     - 'release/*'
   pull_request:
 env:
-  CORE_REPO_SHA: cad261e5dae1fe986c87e6965664b45cc9ab73c3
+  CORE_REPO_SHA: 7b11971c504387341df0c38f5a34d7d1293c7e4f
 
 jobs:
   build:
diff --git a/instrumentation/opentelemetry-instrumentation-flask/src/opentelemetry/instrumentation/flask/__init__.py b/instrumentation/opentelemetry-instrumentation-flask/src/opentelemetry/instrumentation/flask/__init__.py
index f657a1f3bd..2939186029 100644
--- a/instrumentation/opentelemetry-instrumentation-flask/src/opentelemetry/instrumentation/flask/__init__.py
+++ b/instrumentation/opentelemetry-instrumentation-flask/src/opentelemetry/instrumentation/flask/__init__.py
@@ -55,6 +55,9 @@ def hello():
 from opentelemetry import context, trace
 from opentelemetry.instrumentation.flask.version import __version__
 from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
+from opentelemetry.instrumentation.propagators import (
+    get_global_back_propagator,
+)
 from opentelemetry.propagate import extract
 from opentelemetry.util._time import _time_ns
 from opentelemetry.util.http import get_excluded_urls
@@ -91,6 +94,13 @@ def _start_response(status, response_headers, *args, **kwargs):
             if not _excluded_urls.url_disabled(flask.request.url):
                 span = flask.request.environ.get(_ENVIRON_SPAN_KEY)
 
+                propagator = get_global_back_propagator()
+                if propagator:
+                    propagator.inject(
+                        response_headers,
+                        setter=otel_wsgi.default_back_propagation_setter,
+                    )
+
                 if span:
                     otel_wsgi.add_response_attributes(
                         span, status, response_headers
diff --git a/instrumentation/opentelemetry-instrumentation-flask/tests/test_programmatic.py b/instrumentation/opentelemetry-instrumentation-flask/tests/test_programmatic.py
index 3c62dd751d..cdf9d5f1ce 100644
--- a/instrumentation/opentelemetry-instrumentation-flask/tests/test_programmatic.py
+++ b/instrumentation/opentelemetry-instrumentation-flask/tests/test_programmatic.py
@@ -18,6 +18,11 @@
 
 from opentelemetry import trace
 from opentelemetry.instrumentation.flask import FlaskInstrumentor
+from opentelemetry.instrumentation.propagators import (
+    TraceResponsePropagator,
+    get_global_back_propagator,
+    set_global_back_propagator,
+)
 from opentelemetry.test.test_base import TestBase
 from opentelemetry.test.wsgitestutil import WsgiTestBase
 from opentelemetry.util.http import get_excluded_urls
@@ -119,6 +124,31 @@ def test_simple(self):
         self.assertEqual(span_list[0].kind, trace.SpanKind.SERVER)
         self.assertEqual(span_list[0].attributes, expected_attrs)
 
+    def test_trace_response(self):
+        orig = get_global_back_propagator()
+
+        set_global_back_propagator(TraceResponsePropagator())
+        response = self.client.get("/hello/123")
+        headers = response.headers
+
+        span_list = self.memory_exporter.get_finished_spans()
+        self.assertEqual(len(span_list), 1)
+        span = span_list[0]
+
+        self.assertIn("traceresponse", headers)
+        self.assertEqual(
+            headers["access-control-expose-headers"], "traceresponse",
+        )
+        self.assertEqual(
+            headers["traceresponse"],
+            "00-{0}-{1}-01".format(
+                trace.format_trace_id(span.get_span_context().trace_id),
+                trace.format_span_id(span.get_span_context().span_id),
+            ),
+        )
+
+        set_global_back_propagator(orig)
+
     def test_not_recording(self):
         mock_tracer = Mock()
         mock_span = Mock()
diff --git a/instrumentation/opentelemetry-instrumentation-wsgi/src/opentelemetry/instrumentation/wsgi/__init__.py b/instrumentation/opentelemetry-instrumentation-wsgi/src/opentelemetry/instrumentation/wsgi/__init__.py
index c57aca11b5..0d12598d38 100644
--- a/instrumentation/opentelemetry-instrumentation-wsgi/src/opentelemetry/instrumentation/wsgi/__init__.py
+++ b/instrumentation/opentelemetry-instrumentation-wsgi/src/opentelemetry/instrumentation/wsgi/__init__.py
@@ -253,3 +253,11 @@ def _end_span_after_iterating(iterable, span, tracer, token):
             close()
         span.end()
         context.detach(token)
+
+
+class BackPropagationSetter:
+    def set(self, carrier, key, value):  # pylint: disable=no-self-use
+        carrier.append((key, value))
+
+
+default_back_propagation_setter = BackPropagationSetter()