Skip to content

Commit

Permalink
Merge pull request #46 from CJCrafter/moderations
Browse files Browse the repository at this point in the history
Moderations
  • Loading branch information
CJCrafter authored Mar 21, 2024
2 parents 826e082 + acd6faf commit 11ed06e
Show file tree
Hide file tree
Showing 11 changed files with 237 additions and 21 deletions.
40 changes: 40 additions & 0 deletions examples/src/main/java/moderations/ModerationsExample.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package moderations;

import com.cjcrafter.openai.OpenAI;
import com.cjcrafter.openai.moderations.CreateModerationRequest;
import com.cjcrafter.openai.moderations.Moderation;
import io.github.cdimascio.dotenv.Dotenv;

import java.util.Comparator;
import java.util.Scanner;

public class ModerationsExample {

// To use dotenv, you need to add the "io.github.cdimascio:dotenv-kotlin:version"
// dependency. Then you can add a .env file in your project directory.
public static final OpenAI openai = OpenAI.builder()
.apiKey(Dotenv.load().get("OPENAI_TOKEN"))
.build();

public static final Scanner scan = new Scanner(System.in);

public static void main(String[] args) {
while (true) {
System.out.print("Input: ");
String input = scan.nextLine();
CreateModerationRequest request = CreateModerationRequest.builder()
.input(input)
.build();

Moderation moderation = openai.moderations().create(request);
Moderation.Result result = moderation.getResults().get(0);

// Finds the category with the highest score
String highest = result.getCategoryScores().keySet().stream()
.max(Comparator.comparing(a -> result.getCategoryScores().get(a)))
.orElseThrow(() -> new RuntimeException("No categories found!"));

System.out.println("Highest category: " + highest + ", with a score of " + result.getCategoryScores().get(highest));
}
}
}
25 changes: 25 additions & 0 deletions examples/src/main/kotlin/moderations/ModerationsExample.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package moderations

import com.cjcrafter.openai.moderations.create
import com.cjcrafter.openai.openAI
import io.github.cdimascio.dotenv.dotenv


fun main() {

// To use dotenv, you need to add the "io.github.cdimascio:dotenv-kotlin:version"
// dependency. Then you can add a .env file in your project directory.
val key = dotenv()["OPENAI_TOKEN"]
val openai = openAI { apiKey(key) }

while (true) {
print("Input: ")
val input = readln()
val moderation = openai.moderations.create {
input(input)
}

val max = moderation.results[0].categoryScores.entries.maxBy { it.value }
println("Highest category: ${max.key} with a score of ${max.value}")
}
}
14 changes: 14 additions & 0 deletions src/main/kotlin/com/cjcrafter/openai/OpenAI.kt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import com.cjcrafter.openai.completions.CompletionResponseChunk
import com.cjcrafter.openai.embeddings.EmbeddingsRequest
import com.cjcrafter.openai.embeddings.EmbeddingsResponse
import com.cjcrafter.openai.files.*
import com.cjcrafter.openai.moderations.ModerationHandler
import com.cjcrafter.openai.threads.ThreadHandler
import com.cjcrafter.openai.threads.message.TextAnnotation
import com.cjcrafter.openai.util.OpenAIDslMarker
Expand Down Expand Up @@ -135,6 +136,19 @@ interface OpenAI {
@Contract(pure = true)
fun files(): FileHandler = files

/**
* Returns the handler for the moderations endpoint. This handler can be used
* to create moderations.
*/
val moderations: ModerationHandler

/**
* Returns the handler for the moderations endpoint. This method is purely
* syntactic sugar for Java users.
*/
@Contract(pure = true)
fun moderations(): ModerationHandler = moderations

/**
* Returns the handler for the assistants endpoint. This handler can be used
* to create, retrieve, and delete assistants.
Expand Down
25 changes: 16 additions & 9 deletions src/main/kotlin/com/cjcrafter/openai/OpenAIImpl.kt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import com.cjcrafter.openai.completions.CompletionResponseChunk
import com.cjcrafter.openai.embeddings.EmbeddingsRequest
import com.cjcrafter.openai.embeddings.EmbeddingsResponse
import com.cjcrafter.openai.files.*
import com.cjcrafter.openai.moderations.ModerationHandler
import com.cjcrafter.openai.moderations.ModerationHandlerImpl
import com.cjcrafter.openai.threads.ThreadHandler
import com.cjcrafter.openai.threads.ThreadHandlerImpl
import com.fasterxml.jackson.databind.JavaType
Expand Down Expand Up @@ -127,23 +129,28 @@ open class OpenAIImpl @ApiStatus.Internal constructor(
return requestHelper.executeRequest(httpRequest, EmbeddingsResponse::class.java)
}

private var files0: FileHandlerImpl? = null
override val files: FileHandler
get() = files0 ?: FileHandlerImpl(requestHelper, FILES_ENDPOINT).also { files0 = it }
override val files: FileHandler by lazy {
FileHandlerImpl(requestHelper, FILES_ENDPOINT)
}

private var assistants0: AssistantHandlerImpl? = null
override val assistants: AssistantHandler
get() = assistants0 ?: AssistantHandlerImpl(requestHelper, ASSISTANTS_ENDPOINT).also { assistants0 = it }
override val moderations: ModerationHandler by lazy {
ModerationHandlerImpl(requestHelper, MODERATIONS_ENDPOINT)
}

private var threads0: ThreadHandlerImpl? = null
override val threads: ThreadHandler
get() = threads0 ?: ThreadHandlerImpl(requestHelper, THREADS_ENDPOINT).also { threads0 = it }
override val assistants: AssistantHandler by lazy {
AssistantHandlerImpl(requestHelper, ASSISTANTS_ENDPOINT)
}

override val threads: ThreadHandler by lazy {
ThreadHandlerImpl(requestHelper, THREADS_ENDPOINT)
}

companion object {
const val COMPLETIONS_ENDPOINT = "v1/completions"
const val CHAT_ENDPOINT = "v1/chat/completions"
const val EMBEDDINGS_ENDPOINT = "v1/embeddings"
const val FILES_ENDPOINT = "v1/files"
const val MODERATIONS_ENDPOINT = "v1/moderations"
const val ASSISTANTS_ENDPOINT = "v1/assistants"
const val THREADS_ENDPOINT = "v1/threads"
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package com.cjcrafter.openai.moderations

import com.cjcrafter.openai.util.OpenAIDslMarker

/**
* Represents a request to create a new moderation request.
*
* @property input The input to moderate
* @property model The model to use for moderation
*/
data class CreateModerationRequest internal constructor(
var input: Any,
var model: String? = null
) {

@OpenAIDslMarker
class Builder internal constructor() {
private var input: Any? = null
private var model: String? = null

/**
* Sets the input to moderate.
*
* @param input The input to moderate
*/
fun input(input: String) = apply { this.input = input }

/**
* Sets the input to moderate.
*
* @param input The input to moderate
*/
fun input(input: List<String>) = apply { this.input = input }

/**
* Sets the model to use for moderation.
*
* @param model The model to use for moderation
*/
fun model(model: String) = apply { this.model = model }

/**
* Builds the [CreateModerationRequest] instance.
*/
fun build(): CreateModerationRequest {
return CreateModerationRequest(
input = input ?: throw IllegalStateException("input must be defined to use CreateModerationRequest"),
model = model
)
}
}

companion object {
/**
* Returns a builder to construct a [CreateModerationRequest] instance.
*/
@JvmStatic
fun builder() = Builder()
}
}
29 changes: 29 additions & 0 deletions src/main/kotlin/com/cjcrafter/openai/moderations/Moderation.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package com.cjcrafter.openai.moderations

import com.fasterxml.jackson.annotation.JsonProperty

/**
* A moderation object returned by the moderations api.
*
* @property id The id of the moderation request. Always starts with "modr-".
* @property model The model which was used to moderate the content.
* @property results The results of the moderation request.
*/
data class Moderation(
@JsonProperty(required = true) val id: String,
@JsonProperty(required = true) val model: String,
@JsonProperty(required = true) val results: List<Result>,
) {
/**
* The results of the moderation request.
*
* @property flagged If any categories were flagged.
* @property categories The categories that were flagged.
* @property categoryScores The scores of each category.
*/
data class Result(
@JsonProperty(required = true) val flagged: Boolean,
@JsonProperty(required = true) val categories: Map<String, Boolean>,
@JsonProperty("category_scores", required = true) val categoryScores: Map<String, Double>,
)
}
10 changes: 10 additions & 0 deletions src/main/kotlin/com/cjcrafter/openai/moderations/ModerationDsl.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package com.cjcrafter.openai.moderations

fun createModerationRequest(block: CreateModerationRequest.Builder.() -> Unit): CreateModerationRequest {
return CreateModerationRequest.builder().apply(block).build()
}

fun ModerationHandler.create(block: CreateModerationRequest.Builder.() -> Unit): Moderation {
val request = createModerationRequest(block)
return create(request)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package com.cjcrafter.openai.moderations

/**
* Handler used to interact with [Moderation] objects.
*/
interface ModerationHandler {

/**
* Creates a new moderation request with the given options.
*
* @param request The values of the moderation to create
* @return The created moderation
*/
fun create(request: CreateModerationRequest): Moderation
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package com.cjcrafter.openai.moderations

import com.cjcrafter.openai.RequestHelper

class ModerationHandlerImpl(
private val requestHelper: RequestHelper,
private val endpoint: String,
): ModerationHandler {
override fun create(request: CreateModerationRequest): Moderation {
val httpRequest = requestHelper.buildRequest(request, endpoint).build()
return requestHelper.executeRequest(httpRequest, Moderation::class.java)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ package com.cjcrafter.openai.chat

import com.cjcrafter.openai.MockedTest
import com.cjcrafter.openai.chat.ChatMessage.Companion.toSystemMessage
import com.cjcrafter.openai.chat.tool.FunctionToolCall
import com.cjcrafter.openai.chat.tool.Tool
import com.cjcrafter.openai.chat.tool.ToolCall
import okhttp3.mockwebserver.MockResponse
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Test
Expand Down Expand Up @@ -46,9 +49,9 @@ class MockedChatStreamTest : MockedTest() {

// Assertions
assertEquals(ChatUser.ASSISTANT, toolMessage.role, "Tool call should be from the assistant")
assertEquals(ToolType.FUNCTION, toolMessage.toolCalls?.get(0)?.type, "Tool call should be a function")
assertEquals("solve_math_problem", toolMessage.toolCalls?.get(0)?.function?.name)
assertEquals("3/2", toolMessage.toolCalls?.get(0)?.function?.tryParseArguments()?.get("equation")?.asText())
assertEquals(Tool.Type.FUNCTION, toolMessage.toolCalls?.get(0)?.type, "Tool call should be a function")
assertEquals("solve_math_problem", (toolMessage.toolCalls?.get(0) as? FunctionToolCall)?.function?.name)
assertEquals("3/2", (toolMessage.toolCalls?.get(0) as? FunctionToolCall)?.function?.tryParseArguments()?.get("equation")?.asText())

assertEquals(ChatUser.ASSISTANT, message.role, "Message should be from the assistant")
assertEquals("The result of 3 divided by 2 is 1.5.", message.content)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class FunctionCallTest {
name("enum_checker")
description("This function is used to test the enum parameter")
addEnumParameter("enum", mutableListOf("a", "b", "c"))
}.toTool()
}
)
@Language("json")
val json = "{\"name\": \"enum_checker\", \"arguments\": \"{\\\"enum\\\": \\\"d\\\"}\"}" // d is not a valid enum
Expand All @@ -37,7 +37,7 @@ class FunctionCallTest {
name("enum_checker")
description("This function is used to test the enum parameter")
addEnumParameter("enum", mutableListOf("a", "b", "c"))
}.toTool()
}
)
@Language("json")
val json = "{\"name\": \"enum_checker\", \"arguments\": \"{\\\"enum\\\": \\\"a\\\"}\"}" // a is a valid enum
Expand All @@ -55,7 +55,7 @@ class FunctionCallTest {
name("integer_checker")
description("This function is used to test the integer parameter")
addIntegerParameter("integer", "test parameter")
}.toTool()
}
)
@Language("json")
val json = "{\"name\": \"integer_checker\", \"arguments\": \"{\\\"integer\\\": \\\"not an integer\\\"}\"}" // not an integer
Expand All @@ -73,7 +73,7 @@ class FunctionCallTest {
name("integer_checker")
description("This function is used to test the integer parameter")
addIntegerParameter("integer", "test parameter")
}.toTool()
}
)
@Language("json")
val json = "{\"name\": \"integer_checker\", \"arguments\": \"{\\\"integer\\\": 1}\"}" // 1 is an integer
Expand All @@ -91,7 +91,7 @@ class FunctionCallTest {
name("boolean_checker")
description("This function is used to test the boolean parameter")
addBooleanParameter("is_true", "test parameter")
}.toTool()
}
)
@Language("json")
val json = "{\"name\": \"boolean_checker\", \"arguments\": \"{\\\"boolean\\\": \\\"not a boolean\\\"}\"}" // not a boolean
Expand All @@ -109,7 +109,7 @@ class FunctionCallTest {
name("boolean_checker")
description("This function is used to test the boolean parameter")
addBooleanParameter("is_true", "test parameter")
}.toTool()
}
)
@Language("json")
val json = "{\"name\": \"boolean_checker\", \"arguments\": \"{\\\"is_true\\\": true}\"}" // true is a boolean
Expand All @@ -128,7 +128,7 @@ class FunctionCallTest {
description("This function is used to test the required parameter")
addIntegerParameter("required", "test parameter", required = true)
addBooleanParameter("not_required", "test parameter")
}.toTool()
}
)
@Language("json")
val json = "{\"name\": \"required_parameter_function\", \"arguments\": \"{\\\"not_required\\\": true}\"}" // missing required parameter
Expand All @@ -147,7 +147,7 @@ class FunctionCallTest {
description("This function is used to test the required parameter")
addIntegerParameter("required", "test parameter", required = true)
addBooleanParameter("not_required", "test parameter")
}.toTool()
}
)
@Language("json")
val json = "{\"name\": \"required_parameter_function\", \"arguments\": \"{\\\"required\\\": 1, \\\"not_required\\\": true}\"}" // has required parameter
Expand All @@ -165,7 +165,7 @@ class FunctionCallTest {
name("function_name_checker")
description("This function is used to test the function name")
noParameters()
}.toTool()
}
)
@Language("json")
val json = "{\"name\": \"invalid_function_name\", \"arguments\": \"{}\"}" // invalid function name
Expand Down

0 comments on commit 11ed06e

Please sign in to comment.