Skip to content

Commit

Permalink
Added generator decorator and tests for it
Browse files Browse the repository at this point in the history
- generator decorator supports k8s schema name, field name, field type,
  paths, and priority configuration
  • Loading branch information
MarkintoshZ committed Jan 13, 2024
1 parent 4a9281c commit 1d3c829
Show file tree
Hide file tree
Showing 2 changed files with 316 additions and 0 deletions.
163 changes: 163 additions & 0 deletions acto/input/generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
"""This module provides a decorator for generating test cases for a schema and
a function to get all test cases for a schema."""

from collections import namedtuple
from typing import Callable, Literal, Optional

from acto.input.k8s_schemas import KubernetesObjectSchema, KubernetesSchema
from acto.input.testcase import TestCase
from acto.schema import (
AnyOfSchema,
ArraySchema,
BaseSchema,
BooleanSchema,
IntegerSchema,
NumberSchema,
ObjectSchema,
OneOfSchema,
OpaqueSchema,
StringSchema,
)

TestGenerator = namedtuple(
"TestGeneratorObject",
[
"k8s_schema_name",
"field_name",
"field_type",
"paths",
"priority",
"func",
],
)


# global variable for registered test generators
test_generators: TestGenerator = []


def generator(
k8s_schema_name: Optional[str] = None,
field_name: Optional[str] = None,
field_type: Optional[
Literal[
"AnyOf",
"Array",
"Boolean",
"Integer",
"Number",
"Object",
"OneOf",
"Opaque",
"String",
]
] = None,
paths: Optional[list[str]] = None,
priority: int = 0,
):
"""Annotates a function as a test generator
Args:
k8s_schema_name (str, optional): Kubernetes schema name. Defaults to None.
field_name (str, optional): field/property name. Defaults to None.
field_type (str, optional): field/property type. Defaults to None.
paths (list[str], optional): Path suffixes. Defaults to None.
priority (int, optional): Priority. Defaults to 0."""
assert (
k8s_schema_name is not None
or field_name is not None
or field_type is not None
or paths is not None
), "One of k8s_schema_name, schema_name, schema_type, paths must be specified"

def wrapped_func(func: Callable[[BaseSchema], list[TestCase]]):
gen_obj = TestGenerator(
k8s_schema_name,
field_name,
field_type,
paths,
priority,
func,
)
test_generators.append(gen_obj)
return func

return wrapped_func


def get_testcases(
schema: BaseSchema,
matched_schemas: [tuple[BaseSchema, KubernetesSchema]],
) -> list[tuple[list[str], TestCase]]:
"""Get all test cases for a schema from registered test generators"""
matched_schemas: dict[str, KubernetesObjectSchema] = {
"/".join(s.path): m for s, m in matched_schemas
}

def get_testcases_helper(schema: BaseSchema, field_name: Optional[str]):
# print(schema_name, schema.path, type(schema))
test_cases = []
generator_candidates = []
# check paths
path_str = "/".join(schema.path)
matched_schema = matched_schemas.get(path_str)
for test_gen in test_generators:
# check paths
for path in test_gen.paths or []:
if path_str.endswith(path):
generator_candidates.append(test_gen)
continue

# check field name
if (
test_gen.field_name is not None
and test_gen.field_name == field_name
):
generator_candidates.append(test_gen)
continue

# check k8s schema name
if (
test_gen.k8s_schema_name is not None
and matched_schema is not None
and matched_schema.k8s_schema_name.endswith(
test_gen.k8s_schema_name
)
):
generator_candidates.append(test_gen)
continue

# check type
matching_types = {
"AnyOf": AnyOfSchema,
"Array": ArraySchema,
"Boolean": BooleanSchema,
"Integer": IntegerSchema,
"Number": NumberSchema,
"Object": ObjectSchema,
"OneOf": OneOfSchema,
"Opaque": OpaqueSchema,
"String": StringSchema,
}
if schema_type_obj := matching_types.get(test_gen.field_type):
if isinstance(schema, schema_type_obj):
generator_candidates.append(test_gen)

# sort by priority
generator_candidates.sort(key=lambda x: x.priority, reverse=True)
if len(generator_candidates) > 0:
test_cases.append(
(schema.path, generator_candidates[0].func(schema))
)

# check sub schemas
if isinstance(schema, ArraySchema):
test_cases.extend(
get_testcases_helper(schema.get_item_schema(), "ITEM")
)
elif isinstance(schema, ObjectSchema):
for field, sub_schema in schema.properties.items():
test_cases.extend(get_testcases_helper(sub_schema, field))
return test_cases

return get_testcases_helper(schema, None)
153 changes: 153 additions & 0 deletions test/integration_tests/test_testcase_generator_decorator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# pylint: disable=missing-docstring, line-too-long

import os
import pathlib
import unittest

import yaml

from acto.input.generator import generator, get_testcases, test_generators
from acto.input.k8s_schemas import K8sSchemaMatcher
from acto.input.testcase import TestCase
from acto.schema import extract_schema

test_dir = pathlib.Path(__file__).parent.resolve()
test_data_dir = os.path.join(test_dir, "test_data")


def gen(_):
return [TestCase("test", lambda x: True, lambda x: None, lambda x: None)]


class TestSchema(unittest.TestCase):
"""This class tests the schema matching code for various CRDs."""

@classmethod
def setUpClass(cls):
with open(
os.path.join(test_data_dir, "rabbitmq_crd.yaml"),
"r",
encoding="utf-8",
) as operator_yaml:
rabbitmq_crd = yaml.load(operator_yaml, Loader=yaml.FullLoader)
schema_matcher = K8sSchemaMatcher.from_version("1.29")
cls.spec_schema = extract_schema(
[], rabbitmq_crd["spec"]["versions"][0]["schema"]["openAPIV3Schema"]
)
cls.matches = schema_matcher.find_matched_schemas(cls.spec_schema)

def test_path_suffix(self):
test_generators.clear()
generator(paths=["serviceAccountToken/expirationSeconds"])(gen)
testcases = get_testcases(self.spec_schema, self.matches)
self.assertEqual(len(testcases), 1)

test_generators.clear()
generator(
paths=[
"serviceAccountToken/expirationSeconds",
"volumes/ITEM/quobyte/user",
]
)(gen)
testcases = get_testcases(self.spec_schema, self.matches)
self.assertEqual(len(testcases), 2)

def test_k8s_schema_name(self):
test_generators.clear()
generator(k8s_schema_name="v1.NodeAffinity")(gen)
testcases = get_testcases(self.spec_schema, self.matches)
self.assertEqual(len(testcases), 2)

test_generators.clear()
generator(k8s_schema_name="HTTPHeader")(gen)
testcases = get_testcases(self.spec_schema, self.matches)
self.assertEqual(len(testcases), 15)

def test_field_name(self):
test_generators.clear()
generator(field_name="ports")(gen)
testcases = get_testcases(self.spec_schema, self.matches)
self.assertEqual(len(testcases), 4)

test_generators.clear()
generator(field_name="image")(gen)
testcases = get_testcases(self.spec_schema, self.matches)
self.assertEqual(len(testcases), 5)

def test_field_type(self):
test_generators.clear()
generator(field_type="AnyOf")(gen)
testcases = get_testcases(self.spec_schema, self.matches)
self.assertEqual(len(testcases), 38)

test_generators.clear()
generator(field_type="Array")(gen)
testcases = get_testcases(self.spec_schema, self.matches)
self.assertEqual(len(testcases), 173)

test_generators.clear()
generator(field_type="Boolean")(gen)
testcases = get_testcases(self.spec_schema, self.matches)
self.assertEqual(len(testcases), 73)

test_generators.clear()
generator(field_type="Integer")(gen)
testcases = get_testcases(self.spec_schema, self.matches)
self.assertEqual(len(testcases), 106)

test_generators.clear()
generator(field_type="Number")(gen)
testcases = get_testcases(self.spec_schema, self.matches)
self.assertEqual(len(testcases), 106)

test_generators.clear()
generator(field_type="Object")(gen)
testcases = get_testcases(self.spec_schema, self.matches)
self.assertEqual(len(testcases), 368)

# test_generators.clear()
# generator(field_type="OneOf")(gen)
# testcases = get_testcases(self.spec_schema, self.matches)
# self.assertEqual(len(testcases), 0)

# test_generators.clear()
# generator(field_type="Opaque")(gen)
# testcases = get_testcases(self.spec_schema, self.matches)
# self.assertEqual(len(testcases), 0)

test_generators.clear()
generator(field_type="String")(gen)
testcases = get_testcases(self.spec_schema, self.matches)
self.assertEqual(len(testcases), 550)

def test_priority(self):
test_generators.clear()

@generator(field_type="Integer", priority=0)
def gen0(_):
return [
TestCase(
"integer-test",
lambda x: True,
lambda x: None,
lambda x: None,
)
]

@generator(field_name="replicas", priority=1)
def gen1(_):
return [
TestCase(
"replicas-test",
lambda x: True,
lambda x: None,
lambda x: None,
)
]

testcases = get_testcases(self.spec_schema, self.matches)
for path, tests in testcases:
if path[-1] == "replicas":
self.assertEqual(tests[0].name, "replicas-test")
else:
self.assertEqual(tests[0].name, "integer-test")

0 comments on commit 1d3c829

Please sign in to comment.