Skip to content

Commit

Permalink
Java: Align Example59_OpenAIFunctionCalling with .NET (#5914)
Browse files Browse the repository at this point in the history
### Motivation and Context
Make Example59_OpenAIFunctionCalling mirror .NET as much as possible.
Want parity with .NET and Python for MSFT Build in May, 2024.
<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->

### Description
Changed the example code to match .NET. 
Fixed some issues in OpenAIChatCompletion to make it work with manual
function calling.

<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [ ] The code builds clean without any errors or warnings
- [ ] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [ ] All unit tests pass, and I have added new tests where possible
- [ ] I didn't break anyone 😄
  • Loading branch information
dsgrieve authored Apr 19, 2024
1 parent 1b2ac89 commit 6f0561e
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 91 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import com.azure.ai.openai.models.ChatRequestToolMessage;
import com.azure.ai.openai.models.ChatRequestUserMessage;
import com.azure.ai.openai.models.ChatResponseMessage;
import com.azure.ai.openai.models.FunctionCall;
import com.azure.core.util.BinaryData;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
Expand Down Expand Up @@ -432,12 +433,11 @@ private static void configureToolCallBehaviorOptions(
@Nullable List<OpenAIFunction> functions,
int autoInvokeAttempts) {

if (functions == null || functions.isEmpty()) {
if (toolCallBehavior == null) {
return;
}

if (toolCallBehavior == null || autoInvokeAttempts == 0) {
// if auto-invoked is not enabled, then we don't need to send any tool definitions
if (functions == null || functions.isEmpty()) {
return;
}

Expand Down Expand Up @@ -496,17 +496,54 @@ private static List<ChatRequestMessage> getChatRequestMessages(ChatHistory chatH
return new ArrayList<>();
}
return messages.stream()
.map(message -> {
AuthorRole authorRole = message.getAuthorRole();
String content = message.getContent();
return getChatRequestMessage(authorRole, content);
})
.map(message -> getChatRequestMessage(message))
.collect(Collectors.toList());
}

private static ChatRequestMessage getChatRequestMessage(
ChatMessageContent<?> message) {

AuthorRole authorRole = message.getAuthorRole();
String content = message.getContent();
switch (authorRole) {
case ASSISTANT:
// TODO: break this out into a separate method and handle tools other than function calls
ChatRequestAssistantMessage asstMessage = new ChatRequestAssistantMessage(content);
List<OpenAIFunctionToolCall> toolCalls = ((OpenAIChatMessageContent<?>) message).getToolCall();
if (toolCalls != null) {
asstMessage.setToolCalls(
toolCalls.stream()
.map(toolCall -> {
KernelFunctionArguments arguments = toolCall.getArguments();
String args = arguments != null && !arguments.isEmpty()
? arguments.entrySet().stream()
.map(entry -> String.format("\"%s\": \"%s\"", entry.getKey(), entry.getValue()))
.collect(Collectors.joining("{","}",","))
: "{}";
FunctionCall fnCall = new FunctionCall(toolCall.getFunctionName(), args);
return new ChatCompletionsFunctionToolCall(toolCall.getId(), fnCall);
})
.collect(Collectors.toList())
);
return asstMessage;
}
case SYSTEM:
return new ChatRequestSystemMessage(content);
case USER:
return new ChatRequestUserMessage(content);
case TOOL:
String id = message.getMetadata().getId();
return new ChatRequestToolMessage(content, id);
default:
LOGGER.debug("Unexpected author role: " + authorRole);
throw new SKException("Unexpected author role: " + authorRole);
}

}

static ChatRequestMessage getChatRequestMessage(
AuthorRole authorRole,
@Nullable String content) {
AuthorRole authorRole,
String content) {

switch (authorRole) {
case ASSISTANT:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,28 @@
import com.azure.core.credential.KeyCredential;
import com.microsoft.semantickernel.Kernel;
import com.microsoft.semantickernel.aiservices.openai.chatcompletion.OpenAIChatCompletion;
import com.microsoft.semantickernel.aiservices.openai.chatcompletion.OpenAIChatMessageContent;
import com.microsoft.semantickernel.aiservices.openai.chatcompletion.OpenAIFunctionToolCall;
import com.microsoft.semantickernel.contextvariables.CaseInsensitiveMap;
import com.microsoft.semantickernel.contextvariables.ContextVariable;
import com.microsoft.semantickernel.contextvariables.ContextVariableTypes;
import com.microsoft.semantickernel.orchestration.FunctionResult;
import com.microsoft.semantickernel.orchestration.FunctionResultMetadata;
import com.microsoft.semantickernel.orchestration.InvocationContext;
import com.microsoft.semantickernel.orchestration.PromptExecutionSettings;
import com.microsoft.semantickernel.orchestration.ToolCallBehavior;
import com.microsoft.semantickernel.plugin.KernelPlugin;
import com.microsoft.semantickernel.plugin.KernelPluginFactory;
import com.microsoft.semantickernel.semanticfunctions.KernelFunction;
import com.microsoft.semantickernel.semanticfunctions.KernelFunctionFromPrompt;
import com.microsoft.semantickernel.semanticfunctions.annotations.DefineKernelFunction;
import com.microsoft.semantickernel.semanticfunctions.annotations.KernelFunctionParameter;
import com.microsoft.semantickernel.services.chatcompletion.AuthorRole;
import com.microsoft.semantickernel.services.chatcompletion.ChatCompletionService;
import com.microsoft.semantickernel.services.chatcompletion.ChatHistory;
import java.nio.charset.StandardCharsets;
import java.time.ZonedDateTime;
import java.time.format.DateTimeFormatter;
import java.util.List;
import java.util.stream.Collectors;

public class Example59_OpenAIFunctionCalling {

Expand All @@ -26,7 +38,33 @@ public class Example59_OpenAIFunctionCalling {
// Only required if AZURE_CLIENT_KEY is set
private static final String CLIENT_ENDPOINT = System.getenv("CLIENT_ENDPOINT");
private static final String MODEL_ID = System.getenv()
.getOrDefault("MODEL_ID", "gpt-35-turbo-2");
.getOrDefault("MODEL_ID", "gpt-3.5-turbo-1106");


// Define functions that can be called by the model
public static class HelperFunctions {

@DefineKernelFunction(name = "currentUtcTime", description = "Retrieves the current time in UTC.")
public String currentUtcTime() {
return ZonedDateTime.now().format(DateTimeFormatter.RFC_1123_DATE_TIME);
}

@DefineKernelFunction(name = "getWeatherForCity", description = "Gets the current weather for the specified city")
public String getWeatherForCity(
@KernelFunctionParameter(name = "cityName", description = "Name of the city") String cityName) {
switch (cityName) {
case "Thrapston": return "80 and sunny";
case "Boston": return "61 and rainy";
case "London": return "55 and cloudy";
case "Miami": return "80 and sunny";
case "Paris": return "60 and rainy";
case "Tokyo": return "50 and sunny";
case "Sydney": return "75 and sunny";
case "Tel Aviv": return "80 and sunny";
default: return "31 and snowing";
}
}
}

public static void main(String[] args) throws NoSuchMethodException {
System.out.println("======== Open AI - Function calling ========");
Expand All @@ -50,16 +88,19 @@ public static void main(String[] args) throws NoSuchMethodException {
.withOpenAIAsyncClient(client)
.build();

KernelPlugin plugin = KernelPluginFactory.createFromObject(new Plugin(), "plugin");
var plugin = KernelPluginFactory.createFromObject(new HelperFunctions(), "HelperFunctions");

var kernel = Kernel.builder()
.withAIService(ChatCompletionService.class, chat)
.withPlugin(plugin)
.build();


System.out.println("======== Example 1: Use automated function calling ========");

var function = KernelFunctionFromPrompt.builder()
.withTemplate(
"What is the probable current color of the sky in Thrapston?")
"Given the current time of day and weather, what is the likely color of the sky in Boston?")
.withDefaultExecutionSettings(
PromptExecutionSettings.builder()
.withTemperature(0.4)
Expand All @@ -68,88 +109,70 @@ public static void main(String[] args) throws NoSuchMethodException {
.build())
.build();

// Example 1: All kernel functions are enabled to be called by the model
kernelFunctions(kernel, function);
// Example 2: A set of functions available to be called by the model
enableFunctions(kernel, plugin, function);
// Example 3: A specific function to be called by the model
requireFunction(kernel, plugin, function);
}

public static void kernelFunctions(Kernel kernel, KernelFunction<?> function) {
System.out.println("======== Kernel functions ========");

var toolCallBehavior = ToolCallBehavior.allowAllKernelFunctions(true);

var result = kernel
.invokeAsync(function)
.withToolCallBehavior(toolCallBehavior)
.withResultType(ContextVariableTypes.getGlobalVariableTypeForClass(String.class))
.block();

System.out.println(result.getResult());
}

public static void enableFunctions(Kernel kernel, KernelPlugin plugin,
KernelFunction<?> function) {
System.out.println("======== Enable functions ========");

// Based on coordinates
var toolCallBehavior = ToolCallBehavior.allowOnlyKernelFunctions(true,
plugin.get("getLatitudeOfCity"),
plugin.get("getLongitudeOfCity"),
plugin.get("getsTheWeatherAtAGivenLocation"));

var result = kernel
.invokeAsync(function)
.withToolCallBehavior(toolCallBehavior)
.withToolCallBehavior(ToolCallBehavior.allowAllKernelFunctions(true))
.withResultType(ContextVariableTypes.getGlobalVariableTypeForClass(String.class))
.block();

System.out.println(result.getResult());
}

public static void requireFunction(Kernel kernel, KernelPlugin plugin,
KernelFunction<?> function) {
System.out.println("======== Require a function ========");

var toolCallBehavior = ToolCallBehavior
.requireKernelFunction(plugin.get("getsTheWeatherForCity"));

var result = kernel
.invokeAsync(function)
.withToolCallBehavior(toolCallBehavior)
.withResultType(ContextVariableTypes.getGlobalVariableTypeForClass(String.class))
.block();

System.out.println(result.getResult());
}

public static class Plugin {

@DefineKernelFunction(name = "getLatitudeOfCity", description = "Gets the latitude of a given city")
public String getLatitudeOfCity(
@KernelFunctionParameter(name = "cityName", description = "City name") String cityName) {
return "1.0";
}

@DefineKernelFunction(name = "getLongitudeOfCity", description = "Gets the longitude of a given city")
public String getLongitudeOfCity(
@KernelFunctionParameter(name = "cityName", description = "City name") String cityName) {
return "2.0";
}

@DefineKernelFunction(name = "getsTheWeatherAtAGivenLocation", description = "Gets the current weather at a given longitude and latitude")
public String getWeatherForCityAtTime(
@KernelFunctionParameter(name = "latitude", description = "latitude of the location") String latitude,
@KernelFunctionParameter(name = "longitude", description = "longitude of the location") String longitude) {
return "61 and rainy";
}

@DefineKernelFunction(name = "getsTheWeatherForCity", description = "Gets the current weather at a city name")
public String getsTheWeatherForCity(
@KernelFunctionParameter(name = "cityName", description = "Name of the city") String cityName) {
return "80 and sunny";

System.out.println("======== Example 2: Use manual function calling ========");

var chatHistory = new ChatHistory();
chatHistory.addUserMessage("Given the current time of day and weather, what is the likely color of the sky in Boston?");

while(true) {
var messages = chat.getChatMessageContentsAsync(
chatHistory,
kernel,
InvocationContext.builder().withToolCallBehavior(ToolCallBehavior.allowAllKernelFunctions(false)).build()
)
.block();

messages.stream()
.filter(it -> it.getContent() != null)
.forEach(it -> System.out.println(it.getContent()));

List<OpenAIFunctionToolCall> toolCalls =
messages.stream()
.filter(it -> it instanceof OpenAIChatMessageContent)
.map(it -> (OpenAIChatMessageContent<?>)it)
.map(OpenAIChatMessageContent::getToolCall)
.flatMap(List::stream)
.collect(Collectors.toList());

if (toolCalls.isEmpty()) {
break;
}

messages.stream()
.forEach(it -> chatHistory.addMessage(it));

for(var toolCall : toolCalls) {

String content = null;
try {
// getFunction will throw an exception if the function is not found
var fn = kernel.getFunction(toolCall.getPluginName(), toolCall.getFunctionName());
FunctionResult<?> fnResult = fn.invokeAsync(kernel, toolCall.getArguments(), null, null).block();
content = (String)fnResult.getResult();
} catch (IllegalArgumentException e) {
content = "Unable to find function. Please try again!";
}

chatHistory.addMessage(
AuthorRole.TOOL,
content,
StandardCharsets.UTF_8,
new FunctionResultMetadata(new CaseInsensitiveMap<>() {
{
put(FunctionResultMetadata.ID, ContextVariable.of(toolCall.getId()));
}
})
);
}
}
}

}

0 comments on commit 6f0561e

Please sign in to comment.