Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor header schemas #1008

Merged
merged 2 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ public Map<String, SchemaObject> getSchemas() {

@Override
public ComponentSchema resolvePayloadSchema(Type type, String contentType) {
SwaggerSchemaService.Payload payload = schemaService.resolvePayloadSchema(type, contentType);
SwaggerSchemaService.ExtractedSchemas payload = schemaService.resolveSchema(type, contentType);
payload.referencedSchemas().forEach(this.schemas::putIfAbsent);
return payload.payloadSchema();
return payload.rootSchema();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@ public SchemaObject extractHeader(Method method, PayloadSchemaObject payload) {
Header headerAnnotation = argument.getAnnotation(Header.class);
String headerName = getHeaderAnnotationName(headerAnnotation);

SchemaObject schema =
schemaService.extractSchema(argument.getType()).getRootSchema();
SchemaObject schema = schemaService
.extractSchema(argument.getType())
.rootSchema()
.getSchema();

headers.getProperties().put(headerName, schema);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,8 @@
import io.swagger.v3.core.util.Json;
import io.swagger.v3.core.util.PrimitiveType;
import io.swagger.v3.core.util.RefUtils;
import io.swagger.v3.oas.models.media.BooleanSchema;
import io.swagger.v3.oas.models.media.NumberSchema;
import io.swagger.v3.oas.models.media.ObjectSchema;
import io.swagger.v3.oas.models.media.Schema;
import io.swagger.v3.oas.models.media.StringSchema;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;

Expand All @@ -30,7 +27,6 @@
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;

Expand All @@ -55,13 +51,7 @@ public SwaggerSchemaService(
this.properties = properties;
}

public record ExtractedSchemas(String rootSchemaName, Map<String, SchemaObject> schemas) {
public SchemaObject getRootSchema() {
return schemas.get(rootSchemaName);
}
}

public record Payload(ComponentSchema payloadSchema, Map<String, SchemaObject> referencedSchemas) {}
public record ExtractedSchemas(ComponentSchema rootSchema, Map<String, SchemaObject> referencedSchemas) {}

public SchemaObject extractSchema(SchemaObject headers) {
String schemaName = headers.getTitle();
Expand All @@ -76,35 +66,19 @@ public SchemaObject extractSchema(SchemaObject headers) {
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
headerSchema.setProperties(properties);

postProcessSchema(headerSchema, new HashMap<>(Map.of(schemaName, headerSchema)), DEFAULT_CONTENT_TYPE);
postProcessSchema(
new HashMap<>(Map.of(schemaName, headerSchema)),
new HashMap<>(Map.of(schemaName, headerSchema)),
DEFAULT_CONTENT_TYPE);

return swaggerSchemaUtil.mapSchema(headerSchema);
}

public ExtractedSchemas extractSchema(Class<?> type) {
return this.extractSchema(type, "");
return this.resolveSchema(type, "");
}

public ExtractedSchemas extractSchema(Class<?> type, String contentType) {
String actualContentType =
StringUtils.isBlank(contentType) ? properties.getDocket().getDefaultContentType() : contentType;

Map<String, Schema> swaggerSchemas =
new LinkedHashMap<>(runWithFqnSetting((unused) -> converter.readAll(type)));

String schemaName = getSchemaName(type, swaggerSchemas);
preProcessSchemas(swaggerSchemas, schemaName, type);

Map<String, Schema> postProcessedSchemas = new HashMap<>(swaggerSchemas);
for (Schema schema : swaggerSchemas.values()) {
postProcessSchema(schema, postProcessedSchemas, actualContentType);
}

Map<String, SchemaObject> schemas = swaggerSchemaUtil.mapSchemasMap(postProcessedSchemas);
return new ExtractedSchemas(schemaName, schemas);
}

public Payload resolvePayloadSchema(Type type, String contentType) {
public ExtractedSchemas resolveSchema(Type type, String contentType) {
String actualContentType =
StringUtils.isBlank(contentType) ? properties.getDocket().getDefaultContentType() : contentType;
ResolvedSchema resolvedSchema = runWithFqnSetting(
Expand All @@ -114,89 +88,26 @@ public Payload resolvePayloadSchema(Type type, String contentType) {
// defaulting to stringSchema when resolvedSchema is null
SchemaObject payloadSchema = swaggerSchemaUtil.mapSchema(
PrimitiveType.fromType(String.class).createProperty());
return new Payload(ComponentSchema.of(payloadSchema), Map.of());
return new ExtractedSchemas(ComponentSchema.of(payloadSchema), Map.of());
} else {
Map<String, Schema> preProcessSchemas = new LinkedHashMap<>(resolvedSchema.referencedSchemas);
Schema payloadSchema = resolvedSchema.schema;
preProcessSchemas.putIfAbsent(getNameFromType(type), payloadSchema);
preProcessSchemas(payloadSchema, preProcessSchemas, type);
HashMap<String, Schema> postProcessSchemas = new HashMap<>(preProcessSchemas);
postProcessSchema(preProcessSchemas, postProcessSchemas, actualContentType);
return new Payload(
return new ExtractedSchemas(
swaggerSchemaUtil.mapSchemaOrRef(payloadSchema),
swaggerSchemaUtil.mapSchemasMap(postProcessSchemas));
}
}

private String getSchemaName(Class<?> type, Map<String, Schema> schemas) {
if (schemas.isEmpty()) {
// swagger-parser does not create schemas for primitives
if (type.equals(String.class) || type.equals(Character.class) || type.equals(Byte.class)) {
return registerPrimitive(String.class, new StringSchema(), schemas);
}
if (Boolean.class.isAssignableFrom(type)) {
return registerPrimitive(Boolean.class, new BooleanSchema(), schemas);
}
if (Number.class.isAssignableFrom(type)) {
return registerPrimitive(Number.class, new NumberSchema(), schemas);
}
if (Object.class.isAssignableFrom(type)) {
return registerPrimitive(Object.class, new ObjectSchema(), schemas);
}
}

if (schemas.size() == 1) {
return schemas.keySet().stream().findFirst().get();
}

Set<String> resolvedPayloadModelName =
runWithFqnSetting((unused) -> converter.read(type).keySet());
if (!resolvedPayloadModelName.isEmpty()) {
return resolvedPayloadModelName.stream().findFirst().get();
}

return getNameFromClass(type);
}

private String registerPrimitive(Class<?> type, Schema schema, Map<String, Schema> schemas) {
String schemaName = getNameFromClass(type);
schema.setName(schemaName);

schemas.put(schemaName, schema);
postProcessSchema(schema, schemas, DEFAULT_CONTENT_TYPE);

return schemaName;
}

private void preProcessSchemas(Map<String, Schema> schemas, String schemaName, Class<?> type) {
processCommonModelConverters(schemas);
processAsyncApiPayloadAnnotation(schemas, schemaName, type);
processSchemaAnnotation(schemas, schemaName, type);
}

private void preProcessSchemas(Schema payloadSchema, Map<String, Schema> schemas, Type type) {
processCommonModelConverters(payloadSchema, schemas);
processAsyncApiPayloadAnnotation(schemas, type);
processSchemaAnnotation(payloadSchema, type);
}

private void processCommonModelConverters(Map<String, Schema> schemas) {
schemas.values().stream()
.filter(schema -> schema.getType() == null)
.filter(schema -> schema.get$ref() != null)
.forEach(schema -> {
String targetSchemaName = schema.getName();
String sourceSchemaName = StringUtils.substringAfterLast(schema.get$ref(), "/");

Schema<?> actualSchema = schemas.get(sourceSchemaName);

if (actualSchema != null) {
schemas.put(targetSchemaName, actualSchema);
schemas.remove(sourceSchemaName);
}
});
}

private void processCommonModelConverters(Schema payloadSchema, Map<String, Schema> schemas) {
schemas.values().stream()
.filter(schema -> schema.getType() == null)
Expand Down Expand Up @@ -225,19 +136,6 @@ private void adaptPayloadSchema(Schema schema, String targetSchemaName, String s
}
}

private void processSchemaAnnotation(Map<String, Schema> schemas, String schemaName, Class<?> type) {
Schema schemaForType = schemas.get(schemaName);
if (schemaForType != null) {
var schemaAnnotation = type.getAnnotation(io.swagger.v3.oas.annotations.media.Schema.class);
if (schemaAnnotation != null) {
String description = schemaAnnotation.description();
if (StringUtils.isNotBlank(description)) {
schemaForType.setDescription(description);
}
}
}
}

private void processSchemaAnnotation(Schema payloadSchema, Type type) {
JavaType javaType = Json.mapper().constructType(type);
Class<?> clazz = javaType.getRawClass();
Expand All @@ -249,29 +147,6 @@ private void processSchemaAnnotation(Schema payloadSchema, Type type) {
}
}

private void processAsyncApiPayloadAnnotation(Map<String, Schema> schemas, String schemaName, Class<?> type) {
List<Field> withPayloadAnnotatedFields = Arrays.stream(type.getDeclaredFields())
.filter(field -> field.isAnnotationPresent(AsyncApiPayload.class))
.toList();

if (withPayloadAnnotatedFields.size() == 1) {
Schema envelopSchema = schemas.get(schemaName);
if (envelopSchema != null && envelopSchema.getProperties() != null) {
String fieldName = withPayloadAnnotatedFields.get(0).getName();
Schema actualSchema = (Schema) envelopSchema.getProperties().get(fieldName);
if (actualSchema != null) {
schemas.put(schemaName, actualSchema);
}
}

} else if (withPayloadAnnotatedFields.size() > 1) {
log.warn(
("Found more than one field with @AsyncApiPayload annotation in class {}. "
+ "Falling back and ignoring annotation."),
type.getName());
}
}

private void processAsyncApiPayloadAnnotation(Map<String, Schema> schemas, Type type) {
JavaType javaType = Json.mapper().constructType(type);
Class<?> clazz = javaType.getRawClass();
Expand Down Expand Up @@ -306,13 +181,6 @@ private <R> R runWithFqnSetting(Function<Void, R> callable) {
return result;
}

private String getNameFromClass(Class<?> type) {
if (properties.isUseFqn()) {
return type.getName();
}
return type.getSimpleName();
}

public String getNameFromType(Type type) {
PrimitiveType primitiveType = PrimitiveType.fromType(type);
if (primitiveType != null && properties.isUseFqn()) {
Expand All @@ -330,18 +198,6 @@ public String getSimpleNameFromType(Type type) {
return name;
}

private void postProcessSchema(Schema schema, Map<String, Schema> schemas, String contentType) {
boolean schemasHadEntries = !schemas.isEmpty();
for (SchemasPostProcessor processor : schemaPostProcessors) {
processor.process(schema, schemas, contentType);

if (schemasHadEntries && !schemas.containsValue(schema)) {
// If the post-processor removed the schema, we can stop processing
break;
}
}
}

private void postProcessSchema(
Map<String, Schema> preProcess, Map<String, Schema> postProcess, String contentType) {
boolean schemasHadEntries = !postProcess.isEmpty();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@ class HeaderClassExtractorTest {
"payloadSchemaName", String.class.getSimpleName(), ComponentSchema.of(new SchemaObject()));
private final SchemaObject stringSchema =
SchemaObject.builder().type(SchemaType.STRING).build();
private final SchemaObject stringSwaggerSchema =
SchemaObject.builder().type(SchemaType.STRING).build();
private final ComponentSchema stringSwaggerSchema =
ComponentSchema.of(SchemaObject.builder().type(SchemaType.STRING).build());

@Test
void getNoDocumentedHeaders() throws NoSuchMethodException {
// given
when(schemaService.extractSchema(String.class))
.thenReturn(new SwaggerSchemaService.ExtractedSchemas("String", Map.of("String", stringSwaggerSchema)));
.thenReturn(new SwaggerSchemaService.ExtractedSchemas(stringSwaggerSchema, Map.of()));

// when
Method m = TestClass.class.getDeclaredMethod("consumeWithoutHeadersAnnotation", String.class);
Expand All @@ -48,7 +48,7 @@ void getNoDocumentedHeaders() throws NoSuchMethodException {
void getHeaderWithSingleHeaderAnnotation() throws NoSuchMethodException {
// given
when(schemaService.extractSchema(String.class))
.thenReturn(new SwaggerSchemaService.ExtractedSchemas("String", Map.of("String", stringSwaggerSchema)));
.thenReturn(new SwaggerSchemaService.ExtractedSchemas(stringSwaggerSchema, Map.of()));

// when
Method m = TestClass.class.getDeclaredMethod("consumeWithSingleHeaderAnnotation", String.class);
Expand All @@ -69,7 +69,7 @@ void getHeaderWithSingleHeaderAnnotation() throws NoSuchMethodException {
void getHeaderWithMultipleHeaderAnnotation() throws NoSuchMethodException {
// given
when(schemaService.extractSchema(String.class))
.thenReturn(new SwaggerSchemaService.ExtractedSchemas("String", Map.of("String", stringSwaggerSchema)));
.thenReturn(new SwaggerSchemaService.ExtractedSchemas(stringSwaggerSchema, Map.of()));

// when
Method m = TestClass.class.getDeclaredMethod("consumeWithMultipleHeaderAnnotation", String.class, String.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import com.fasterxml.jackson.core.util.DefaultPrettyPrinter;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
import io.github.springwolf.asyncapi.v3.model.components.ComponentSchema;
import io.github.springwolf.asyncapi.v3.model.schema.SchemaObject;
import io.github.springwolf.core.asyncapi.components.postprocessors.SchemasPostProcessor;
import io.github.springwolf.core.configuration.properties.SpringwolfConfigProperties;
Expand Down Expand Up @@ -57,11 +58,11 @@ void setUp() {

@Test
void classWithSchemaAnnotation() {
String modelName = schemaService
.extractSchema(ClassWithSchemaAnnotation.class, "content-type-not-relevant")
.rootSchemaName();
ComponentSchema schema = schemaService
.resolveSchema(ClassWithSchemaAnnotation.class, "content-type-not-relevant")
.rootSchema();

assertThat(modelName).isEqualTo("DifferentName");
assertThat(schema.getReference().getRef()).isEqualTo("#/components/schemas/DifferentName");
}

@Test
Expand All @@ -77,8 +78,8 @@ void getDefinitionWithoutFqnClassName() throws IOException {
Class<?> clazz =
OneFieldFooWithoutFqn.class; // swagger seems to cache results. Therefore, a new class must be used.
Map<String, SchemaObject> schemas = schemaServiceWithFqn
.extractSchema(clazz, "content-type-not-relevant")
.schemas();
.resolveSchema(clazz, "content-type-not-relevant")
.referencedSchemas();
String actualDefinitions = objectMapper.writer(printer).writeValueAsString(schemas);

// then
Expand All @@ -89,7 +90,7 @@ void getDefinitionWithoutFqnClassName() throws IOException {

@Test
void postProcessorsAreCalled() {
schemaService.extractSchema(ClassWithSchemaAnnotation.class, "some-content-type");
schemaService.resolveSchema(ClassWithSchemaAnnotation.class, "some-content-type");

verify(schemasPostProcessor).process(any(), any(), eq("some-content-type"));
verify(schemasPostProcessor2).process(any(), any(), eq("some-content-type"));
Expand All @@ -105,7 +106,7 @@ void postProcessorIsSkippedWhenSchemaWasRemoved() {
.when(schemasPostProcessor)
.process(any(), any(), any());

schemaService.extractSchema(ClassWithSchemaAnnotation.class, "content-type-not-relevant");
schemaService.resolveSchema(ClassWithSchemaAnnotation.class, "content-type-not-relevant");

verifyNoInteractions(schemasPostProcessor2);
}
Expand Down
Loading
Loading