Skip to content

Commit

Permalink
Prompt content and completion as span events
Browse files Browse the repository at this point in the history
Signed-off-by: Thomas Vitale <ThomasVitale@users.noreply.github.com>
  • Loading branch information
ThomasVitale authored and markpollack committed Aug 20, 2024
1 parent 80fe5e4 commit 3fa102e
Show file tree
Hide file tree
Showing 12 changed files with 612 additions and 53 deletions.
6 changes: 6 additions & 0 deletions spring-ai-core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@
<artifactId>micrometer-core</artifactId>
</dependency>

<dependency>
<groupId>io.micrometer</groupId>
<artifactId>micrometer-tracing-bridge-otel</artifactId>
<optional>true</optional>
</dependency>

<dependency>
<groupId>com.knuddels</groupId>
<artifactId>jtokkit</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,6 @@

import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationFilter;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;

import java.util.StringJoiner;

/**
* An {@link ObservationFilter} to include the chat completion content in the observation.
Expand All @@ -36,25 +32,11 @@ public Observation.Context map(Observation.Context context) {
return context;
}

if (chatModelObservationContext.getResponse() == null
|| chatModelObservationContext.getResponse().getResults() == null
|| CollectionUtils.isEmpty(chatModelObservationContext.getResponse().getResults())) {
return chatModelObservationContext;
}

StringJoiner completionChoicesJoiner = new StringJoiner(", ", "[", "]");
chatModelObservationContext.getResponse()
.getResults()
.stream()
.filter(generation -> generation.getOutput() != null
&& StringUtils.hasText(generation.getOutput().getContent()))
.forEach(generation -> completionChoicesJoiner.add("\"" + generation.getOutput().getContent() + "\""));
var completions = ChatModelObservationContentProcessor.completion(chatModelObservationContext);

if (StringUtils.hasText(chatModelObservationContext.getResponse().getResult().getOutput().getContent())) {
chatModelObservationContext
.addHighCardinalityKeyValue(ChatModelObservationDocumentation.HighCardinalityKeyNames.COMPLETION
.withValue(completionChoicesJoiner.toString()));
}
chatModelObservationContext
.addHighCardinalityKeyValue(ChatModelObservationDocumentation.HighCardinalityKeyNames.COMPLETION
.withValue(ChatModelObservationContentProcessor.concatenateStrings(completions)));

return chatModelObservationContext;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Copyright 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.chat.observation;

import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationHandler;
import io.micrometer.tracing.handler.TracingObservationHandler;
import io.opentelemetry.api.common.AttributeKey;
import io.opentelemetry.api.common.Attributes;
import io.opentelemetry.api.trace.Span;
import org.springframework.ai.observation.conventions.AiObservationAttributes;
import org.springframework.ai.observation.conventions.AiObservationEventNames;

/**
* Handler for including the chat completion content in the observation as a span event.
*
* @author Thomas Vitale
* @since 1.0.0
*/
public class ChatModelCompletionObservationHandler implements ObservationHandler<ChatModelObservationContext> {

@Override
public void onStop(ChatModelObservationContext context) {
TracingObservationHandler.TracingContext tracingContext = context
.get(TracingObservationHandler.TracingContext.class);
Span otelSpan = ChatModelObservationContentProcessor.extractOtelSpan(tracingContext);

if (otelSpan != null) {
otelSpan.addEvent(AiObservationEventNames.CONTENT_COMPLETION.value(),
Attributes.of(AttributeKey.stringArrayKey(AiObservationAttributes.COMPLETION.value()),
ChatModelObservationContentProcessor.completion(context)));
}
}

@Override
public boolean supportsContext(Observation.Context context) {
return context instanceof ChatModelObservationContext;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
* Copyright 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.chat.observation;

import io.micrometer.tracing.handler.TracingObservationHandler;
import io.opentelemetry.api.trace.Span;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.model.Content;
import org.springframework.lang.Nullable;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.List;
import java.util.StringJoiner;

/**
* Utilities to process the prompt and completion content in observations for chat models.
*
* @author Thomas Vitale
*/
public final class ChatModelObservationContentProcessor {

private static final Logger logger = LoggerFactory.getLogger(ChatModelObservationContentProcessor.class);

public static List<String> prompt(ChatModelObservationContext context) {
if (CollectionUtils.isEmpty(context.getRequest().getInstructions())) {
return List.of();
}

return context.getRequest().getInstructions().stream().map(Content::getContent).toList();
}

public static List<String> completion(ChatModelObservationContext context) {
if (context == null || context.getResponse() == null || context.getResponse().getResults() == null
|| CollectionUtils.isEmpty(context.getResponse().getResults())) {
return List.of();
}

if (!StringUtils.hasText(context.getResponse().getResult().getOutput().getContent())) {
return List.of();
}

return context.getResponse()
.getResults()
.stream()
.filter(generation -> generation.getOutput() != null
&& StringUtils.hasText(generation.getOutput().getContent()))
.map(generation -> generation.getOutput().getContent())
.toList();
}

public static String concatenateStrings(List<String> strings) {
var promptMessagesJoiner = new StringJoiner(", ", "[", "]");
strings.forEach(string -> promptMessagesJoiner.add("\"" + string + "\""));
return promptMessagesJoiner.toString();
}

@Nullable
public static Span extractOtelSpan(@Nullable TracingObservationHandler.TracingContext tracingContext) {
if (tracingContext == null) {
return null;
}

io.micrometer.tracing.Span micrometerSpan = tracingContext.getSpan();
try {
Method toOtelMethod = tracingContext.getSpan()
.getClass()
.getDeclaredMethod("toOtel", io.micrometer.tracing.Span.class);
toOtelMethod.setAccessible(true);
Object otelSpanObject = toOtelMethod.invoke(null, micrometerSpan);
if (otelSpanObject instanceof Span otelSpan) {
return otelSpan;
}
}
catch (NoSuchMethodException | InvocationTargetException | IllegalAccessException ex) {
logger.warn("It wasn't possible to extract the OpenTelemetry Span object from Micrometer", ex);
return null;
}

return null;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@

import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationFilter;
import org.springframework.util.CollectionUtils;

import java.util.StringJoiner;

/**
* An {@link ObservationFilter} to include the chat prompt content in the observation.
Expand All @@ -35,18 +32,11 @@ public Observation.Context map(Observation.Context context) {
return context;
}

if (CollectionUtils.isEmpty(chatModelObservationContext.getRequest().getInstructions())) {
return chatModelObservationContext;
}

StringJoiner promptMessagesJoiner = new StringJoiner(", ", "[", "]");
chatModelObservationContext.getRequest()
.getInstructions()
.forEach(message -> promptMessagesJoiner.add("\"" + message.getContent() + "\""));
var prompts = ChatModelObservationContentProcessor.prompt(chatModelObservationContext);

chatModelObservationContext
.addHighCardinalityKeyValue(ChatModelObservationDocumentation.HighCardinalityKeyNames.PROMPT
.withValue(promptMessagesJoiner.toString()));
.withValue(ChatModelObservationContentProcessor.concatenateStrings(prompts)));

return chatModelObservationContext;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Copyright 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.chat.observation;

import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationHandler;
import io.micrometer.tracing.handler.TracingObservationHandler;
import io.opentelemetry.api.common.AttributeKey;
import io.opentelemetry.api.common.Attributes;
import io.opentelemetry.api.trace.Span;
import org.springframework.ai.observation.conventions.AiObservationAttributes;
import org.springframework.ai.observation.conventions.AiObservationEventNames;

/**
* Handler for including the chat prompt content in the observation as a span event.
*
* @author Thomas Vitale
* @since 1.0.0
*/
public class ChatModelPromptContentObservationHandler implements ObservationHandler<ChatModelObservationContext> {

@Override
public void onStop(ChatModelObservationContext context) {
TracingObservationHandler.TracingContext tracingContext = context
.get(TracingObservationHandler.TracingContext.class);
Span otelSpan = ChatModelObservationContentProcessor.extractOtelSpan(tracingContext);

if (otelSpan != null) {
otelSpan.addEvent(AiObservationEventNames.CONTENT_PROMPT.value(),
Attributes.of(AttributeKey.stringArrayKey(AiObservationAttributes.PROMPT.value()),
ChatModelObservationContentProcessor.prompt(context)));
}

}

@Override
public boolean supportsContext(Observation.Context context) {
return context instanceof ChatModelObservationContext;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Copyright 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.chat.observation;

import io.micrometer.tracing.handler.TracingObservationHandler;
import io.micrometer.tracing.otel.bridge.OtelCurrentTraceContext;
import io.micrometer.tracing.otel.bridge.OtelTracer;
import io.opentelemetry.api.common.AttributeKey;
import io.opentelemetry.sdk.trace.ReadableSpan;
import io.opentelemetry.sdk.trace.SdkTracerProvider;
import org.junit.jupiter.api.Test;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.ChatOptionsBuilder;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.observation.conventions.AiObservationAttributes;
import org.springframework.ai.observation.conventions.AiObservationEventNames;

import java.util.List;

import static org.assertj.core.api.Assertions.assertThat;

/**
* Unit tests for {@link ChatModelCompletionObservationHandler}.
*
* @author Thomas Vitale
*/
class ChatModelCompletionObservationHandlerTests {

@Test
void whenCompletionWithTextThenSpanEvent() {
var observationContext = ChatModelObservationContext.builder()
.prompt(new Prompt("supercalifragilisticexpialidocious"))
.provider("mary-poppins")
.requestOptions(ChatOptionsBuilder.builder().withModel("spoonful-of-sugar").build())
.build();
observationContext.setResponse(new ChatResponse(List.of(new Generation(new AssistantMessage("say please")),
new Generation(new AssistantMessage("seriously, say please")))));
var sdkTracer = SdkTracerProvider.builder().build().get("test");
var otelTracer = new OtelTracer(sdkTracer, new OtelCurrentTraceContext(), null);
var span = otelTracer.nextSpan();
var tracingContext = new TracingObservationHandler.TracingContext();
tracingContext.setSpan(span);
observationContext.put(TracingObservationHandler.TracingContext.class, tracingContext);

new ChatModelCompletionObservationHandler().onStop(observationContext);

var otelSpan = ChatModelObservationContentProcessor.extractOtelSpan(tracingContext);
assertThat(otelSpan).isNotNull();
var spanData = ((ReadableSpan) otelSpan).toSpanData();
assertThat(spanData.getEvents().size()).isEqualTo(1);
assertThat(spanData.getEvents().get(0).getName()).isEqualTo(AiObservationEventNames.CONTENT_COMPLETION.value());
assertThat(spanData.getEvents()
.get(0)
.getAttributes()
.get(AttributeKey.stringArrayKey(AiObservationAttributes.COMPLETION.value())))
.containsOnly("say please", "seriously, say please");
}

}
Loading

0 comments on commit 3fa102e

Please sign in to comment.