diff --git a/packages/safe-ds-lang/src/language/helpers/nodeProperties.ts b/packages/safe-ds-lang/src/language/helpers/nodeProperties.ts index ea0705ff3..8f1498a8c 100644 --- a/packages/safe-ds-lang/src/language/helpers/nodeProperties.ts +++ b/packages/safe-ds-lang/src/language/helpers/nodeProperties.ts @@ -1,3 +1,4 @@ +import { AstNode, getContainerOfType, Stream, stream } from 'langium'; import { isSdsAnnotation, isSdsArgumentList, @@ -57,7 +58,6 @@ import { SdsTypeParameter, SdsTypeParameterList, } from '../generated/ast.js'; -import { AstNode, getContainerOfType, Stream, stream } from 'langium'; // ------------------------------------------------------------------------------------------------- // Checks diff --git a/packages/safe-ds-lang/src/language/typing/model.ts b/packages/safe-ds-lang/src/language/typing/model.ts index 12bcbcf0f..a95687ffd 100644 --- a/packages/safe-ds-lang/src/language/typing/model.ts +++ b/packages/safe-ds-lang/src/language/typing/model.ts @@ -149,11 +149,9 @@ export class NamedTupleType extends Type { /** * The length of this tuple. */ - /* c8 ignore start */ get length(): number { return this.entries.length; } - /* c8 ignore stop */ /** * Returns the type of the entry at the given index. If the index is out of bounds, returns `undefined`. diff --git a/packages/safe-ds-lang/src/language/typing/safe-ds-type-checker.ts b/packages/safe-ds-lang/src/language/typing/safe-ds-type-checker.ts index 660e66cfd..d3dbe3ca8 100644 --- a/packages/safe-ds-lang/src/language/typing/safe-ds-type-checker.ts +++ b/packages/safe-ds-lang/src/language/typing/safe-ds-type-checker.ts @@ -1,4 +1,5 @@ import { getContainerOfType } from 'langium'; +import type { SafeDsClasses } from '../builtins/safe-ds-classes.js'; import { isSdsEnum, SdsDeclaration } from '../generated/ast.js'; import { BooleanConstant, @@ -19,16 +20,18 @@ import { StaticType, Type, UnionType, + UnknownType, } from './model.js'; import { SafeDsClassHierarchy } from './safe-ds-class-hierarchy.js'; import { SafeDsCoreTypes } from './safe-ds-core-types.js'; -/* c8 ignore start */ export class SafeDsTypeChecker { + private readonly builtinClasses: SafeDsClasses; private readonly classHierarchy: SafeDsClassHierarchy; private readonly coreTypes: SafeDsCoreTypes; constructor(services: SafeDsServices) { + this.builtinClasses = services.builtins.Classes; this.classHierarchy = services.types.ClassHierarchy; this.coreTypes = services.types.CoreTypes; } @@ -37,6 +40,12 @@ export class SafeDsTypeChecker { * Checks whether {@link type} is assignable {@link other}. */ isAssignableTo(type: Type, other: Type): boolean { + if (type === UnknownType || other === UnknownType) { + return false; + } else if (other instanceof UnionType) { + return other.possibleTypes.some((it) => this.isAssignableTo(type, it)); + } + if (type instanceof CallableType) { return this.callableTypeIsAssignableTo(type, other); } else if (type instanceof ClassType) { @@ -53,48 +62,66 @@ export class SafeDsTypeChecker { return this.staticTypeIsAssignableTo(type, other); } else if (type instanceof UnionType) { return this.unionTypeIsAssignableTo(type, other); - } else { - return false; - } + } /* c8 ignore start */ else { + throw new Error(`Unexpected type: ${type.constructor.name}`); + } /* c8 ignore stop */ } private callableTypeIsAssignableTo(type: CallableType, other: Type): boolean { - // return when (val unwrappedOther = unwrapVariadicType(other)) { - // is CallableType -> { - // // TODO: We need to compare names of parameters & results and can allow additional optional parameters - // - // // Sizes must match (too strict requirement -> should be loosened later) - // if (this.parameters.size != unwrappedOther.parameters.size || this.results.size != this.results.size) { - // return false - // } - // - // // Actual parameters must be supertypes of expected parameters (contravariance) - // this.parameters.zip(unwrappedOther.parameters).forEach { (thisParameter, otherParameter) -> - // if (!otherParameter.isSubstitutableFor(thisParameter)) { - // return false - // } - // } - // - // // Expected results must be subtypes of expected results (covariance) - // this.results.zip(unwrappedOther.results).forEach { (thisResult, otherResult) -> - // if (!thisResult.isSubstitutableFor(otherResult)) { - // return false - // } - // } - // - // true - // } - // is ClassType -> { - // unwrappedOther.sdsClass.qualifiedNameOrNull() == StdlibClasses.Any - // } - // is UnionType -> { - // unwrappedOther.possibleTypes.any { this.isSubstitutableFor(it) } - // } - // else -> false - // } - // } + if (other instanceof ClassType) { + return other.declaration === this.builtinClasses.Any; + } else if (other instanceof CallableType) { + // Must accept at least as many parameters and produce at least as many results + if (type.inputType.length < other.inputType.length || type.outputType.length < other.outputType.length) { + return false; + } - return type.equals(other); + // Check expected parameters + for (let i = 0; i < other.inputType.length; i++) { + const typeEntry = type.inputType.entries[i]; + const otherEntry = other.inputType.entries[i]; + + // Names must match + if (typeEntry.name !== otherEntry.name) { + return false; + } + + // Types must be contravariant + if (!this.isAssignableTo(otherEntry.type, typeEntry.type)) { + return false; + } + } + + // Additional parameters must be optional + for (let i = other.inputType.length; i < type.inputType.length; i++) { + const typeEntry = type.inputType.entries[i]; + if (!typeEntry.declaration?.defaultValue) { + return false; + } + } + + // Check expected results + for (let i = 0; i < other.outputType.length; i++) { + const typeEntry = type.outputType.entries[i]; + const otherEntry = other.outputType.entries[i]; + + // Names must match + if (typeEntry.name !== otherEntry.name) { + return false; + } + + // Types must be covariant + if (!this.isAssignableTo(typeEntry.type, otherEntry.type)) { + return false; + } + } + + // Additional results are OK + + return true; + } else { + return false; + } } private classTypeIsAssignableTo(type: ClassType, other: Type): boolean { @@ -104,8 +131,6 @@ export class SafeDsTypeChecker { if (other instanceof ClassType) { return this.classHierarchy.isEqualToOrSubclassOf(type.declaration, other.declaration); - } else if (other instanceof UnionType) { - return other.possibleTypes.some((it) => this.isAssignableTo(type, it)); } else { return false; } @@ -116,22 +141,13 @@ export class SafeDsTypeChecker { return false; } - if (other instanceof EnumType) { + if (other instanceof ClassType) { + return other.declaration === this.builtinClasses.Any; + } else if (other instanceof EnumType) { return type.declaration === other.declaration; + } else { + return false; } - - // return when (val unwrappedOther = unwrapVariadicType(other)) { - // is ClassType -> { - // (!this.isNullable || unwrappedOther.isNullable) && - // unwrappedOther.sdsClass.qualifiedNameOrNull() == StdlibClasses.Any - // } - // is UnionType -> { - // unwrappedOther.possibleTypes.any { this.isSubstitutableFor(it) } - // } - // else -> false - // } - - return type.equals(other); } private enumVariantTypeIsAssignableTo(type: EnumVariantType, other: Type): boolean { @@ -139,22 +155,16 @@ export class SafeDsTypeChecker { return false; } - if (other instanceof EnumType) { + if (other instanceof ClassType) { + return other.declaration === this.builtinClasses.Any; + } else if (other instanceof EnumType) { const containingEnum = getContainerOfType(type.declaration, isSdsEnum); return containingEnum === other.declaration; } else if (other instanceof EnumVariantType) { return type.declaration === other.declaration; + } else { + return false; } - // return when (val unwrappedOther = unwrapVariadicType(other)) { - // is ClassType -> { - // (!this.isNullable || unwrappedOther.isNullable) && - // unwrappedOther.sdsClass.qualifiedNameOrNull() == StdlibClasses.Any - // } - // is UnionType -> unwrappedOther.possibleTypes.any { this.isSubstitutableFor(it) } - // else -> false - // } - - return type.equals(other); } private literalTypeIsAssignableTo(type: LiteralType, other: Type): boolean { @@ -173,7 +183,6 @@ export class SafeDsTypeChecker { other.constants.some((otherConstant) => constant.equals(otherConstant)), ); } else { - // TODO: union type return false; } } @@ -197,9 +206,19 @@ export class SafeDsTypeChecker { return this.isAssignableTo(classType, other); } - /* c8 ignore start */ private namedTupleTypeIsAssignableTo(type: NamedTupleType, other: Type): boolean { - return type.equals(other); + if (other instanceof NamedTupleType) { + return ( + type.length === other.length && + type.entries.every((typeEntry, index) => { + const otherEntry = other.entries[index]; + // We deliberately ignore the declarations here + return typeEntry.name === otherEntry.name && this.isAssignableTo(typeEntry.type, otherEntry.type); + }) + ); + } else { + return false; + } } private staticTypeIsAssignableTo(type: Type, other: Type): boolean { @@ -207,8 +226,6 @@ export class SafeDsTypeChecker { } private unionTypeIsAssignableTo(type: UnionType, other: Type): boolean { - // return this.possibleTypes.all { it.isSubstitutableFor(other) } - return type.equals(other); + return type.possibleTypes.every((it) => this.isAssignableTo(it, other)); } } -/* c8 ignore stop */ diff --git a/packages/safe-ds-lang/tests/language/typing/safe-ds-type-checker.test.ts b/packages/safe-ds-lang/tests/language/typing/safe-ds-type-checker.test.ts new file mode 100644 index 000000000..bf7e44c59 --- /dev/null +++ b/packages/safe-ds-lang/tests/language/typing/safe-ds-type-checker.test.ts @@ -0,0 +1,605 @@ +import { streamAllContents } from 'langium'; +import { NodeFileSystem } from 'langium/node'; +import { describe, expect, it } from 'vitest'; +import { + isSdsClass, + isSdsEnum, + isSdsEnumVariant, + isSdsFunction, + isSdsModule, +} from '../../../src/language/generated/ast.js'; +import { getModuleMembers } from '../../../src/language/helpers/nodeProperties.js'; +import { createSafeDsServicesWithBuiltins } from '../../../src/language/index.js'; +import { + BooleanConstant, + FloatConstant, + IntConstant, + NullConstant, + StringConstant, +} from '../../../src/language/partialEvaluation/model.js'; +import { + ClassType, + LiteralType, + NamedTupleEntry, + NamedTupleType, + StaticType, + Type, + UnionType, + UnknownType, +} from '../../../src/language/typing/model.js'; +import { getNodeOfType } from '../../helpers/nodeFinder.js'; + +const services = (await createSafeDsServicesWithBuiltins(NodeFileSystem)).SafeDs; +const coreTypes = services.types.CoreTypes; +const typeChecker = services.types.TypeChecker; +const typeComputer = services.types.TypeComputer; + +const code = ` + fun func1() -> () + fun func2(p: Int = 0) -> () + fun func3(p: Int) -> () + fun func4(r: Int) -> () + fun func5(p: Any) -> () + fun func6(p: String) -> () + fun func7() -> (r: Int) + fun func8() -> (s: Int) + fun func9() -> (r: Any) + fun func10() -> (r: String) + + class Class1 + class Class2 sub Class1 + class Class3 + + enum Enum1 { + Variant1 + Variant2 + } + enum Enum2 +`; +const module = await getNodeOfType(services, code, isSdsModule); +const functions = getModuleMembers(module).filter(isSdsFunction); +const callableType1 = typeComputer.computeType(functions[0]); +const callableType2 = typeComputer.computeType(functions[1]); +const callableType3 = typeComputer.computeType(functions[2]); +const callableType4 = typeComputer.computeType(functions[3]); +const callableType5 = typeComputer.computeType(functions[4]); +const callableType6 = typeComputer.computeType(functions[5]); +const callableType7 = typeComputer.computeType(functions[6]); +const callableType8 = typeComputer.computeType(functions[7]); +const callableType9 = typeComputer.computeType(functions[8]); +const callableType10 = typeComputer.computeType(functions[9]); + +const classes = getModuleMembers(module).filter(isSdsClass); +const class1 = classes[0]; +const class2 = classes[1]; +const class3 = classes[2]; +const classType1 = typeComputer.computeType(class1) as ClassType; +const classType2 = typeComputer.computeType(class2) as ClassType; +const classType3 = typeComputer.computeType(class3) as ClassType; + +const enums = getModuleMembers(module).filter(isSdsEnum); +const enum1 = enums[0]; +const enum2 = enums[1]; +const enumType1 = typeComputer.computeType(enum1); +const enumType2 = typeComputer.computeType(enum2); + +const enumVariants = streamAllContents(module).filter(isSdsEnumVariant).toArray(); +const enumVariant1 = enumVariants[0]; +const enumVariant2 = enumVariants[1]; +const enumVariantType1 = typeComputer.computeType(enumVariant1); +const enumVariantType2 = typeComputer.computeType(enumVariant2); + +describe('SafeDsTypeChecker', async () => { + const testCases: IsAssignableToTest[] = [ + { + type1: callableType1, + type2: callableType1, + expected: true, + }, + { + type1: callableType2, + type2: callableType1, + expected: true, + }, + { + type1: callableType1, + type2: callableType2, + expected: false, + }, + { + type1: callableType3, + type2: callableType1, + expected: false, + }, + { + type1: callableType3, + type2: callableType4, + expected: false, + }, + { + type1: callableType3, + type2: callableType5, + expected: false, + }, + { + type1: callableType5, + type2: callableType3, + expected: true, + }, + { + type1: callableType6, + type2: callableType3, + expected: false, + }, + { + type1: callableType7, + type2: callableType1, + expected: true, + }, + { + type1: callableType1, + type2: callableType7, + expected: false, + }, + { + type1: callableType8, + type2: callableType7, + expected: false, + }, + { + type1: callableType9, + type2: callableType7, + expected: false, + }, + { + type1: callableType7, + type2: callableType9, + expected: true, + }, + { + type1: callableType10, + type2: callableType7, + expected: false, + }, + // Callable type to class type + { + type1: callableType1, + type2: coreTypes.Any, + expected: true, + }, + { + type1: callableType1, + type2: coreTypes.AnyOrNull, + expected: true, + }, + // Callable type to other + { + type1: callableType1, + type2: enumType1, + expected: false, + }, + // Class type to class type + { + type1: classType1, + type2: classType1, + expected: true, + }, + { + type1: classType2, + type2: classType1, + expected: true, + }, + { + type1: classType1, + type2: classType3, + expected: false, + }, + { + type1: classType1, + type2: coreTypes.Any, + expected: true, + }, + { + type1: classType2.updateNullability(true), + type2: classType1, + expected: false, + }, + { + type1: classType2.updateNullability(true), + type2: classType1.updateNullability(true), + expected: true, + }, + // Class type to union type + { + type1: classType1, + type2: new UnionType(classType1), + expected: true, + }, + { + type1: classType1, + type2: new UnionType(classType3), + expected: false, + }, + // Class type to other + { + type1: classType1, + type2: enumType1, + expected: false, + }, + // Enum type to class type + { + type1: enumType1, + type2: classType1, + expected: false, + }, + { + type1: enumType1, + type2: coreTypes.Any, + expected: true, + }, + { + type1: enumType1.updateNullability(true), + type2: coreTypes.Any, + expected: false, + }, + { + type1: enumType1.updateNullability(true), + type2: coreTypes.AnyOrNull, + expected: true, + }, + // Enum type to enum type + { + type1: enumType1, + type2: enumType1, + expected: true, + }, + { + type1: enumType1, + type2: enumType2, + expected: false, + }, + { + type1: enumType1.updateNullability(true), + type2: enumType1, + expected: false, + }, + { + type1: enumType1.updateNullability(true), + type2: enumType1.updateNullability(true), + expected: true, + }, + // Enum type to union type + { + type1: enumType1, + type2: new UnionType(enumType1), + expected: true, + }, + { + type1: enumType1, + type2: new UnionType(enumType2), + expected: false, + }, + // Enum type to other + { + type1: enumType1, + type2: new LiteralType(), + expected: false, + }, + // Enum variant type to class type + { + type1: enumVariantType1, + type2: classType1, + expected: false, + }, + { + type1: enumVariantType1, + type2: coreTypes.Any, + expected: true, + }, + { + type1: enumVariantType1.updateNullability(true), + type2: coreTypes.Any, + expected: false, + }, + { + type1: enumVariantType1.updateNullability(true), + type2: coreTypes.AnyOrNull, + expected: true, + }, + // Enum variant type to enum type + { + type1: enumVariantType1, + type2: enumType1, + expected: true, + }, + { + type1: enumVariantType1, + type2: enumType2, + expected: false, + }, + { + type1: enumVariantType1.updateNullability(true), + type2: enumType1, + expected: false, + }, + { + type1: enumVariantType1.updateNullability(true), + type2: enumType1.updateNullability(true), + expected: true, + }, + // Enum variant type to enum variant type + { + type1: enumVariantType1, + type2: enumVariantType1, + expected: true, + }, + { + type1: enumVariantType1, + type2: enumVariantType2, + expected: false, + }, + { + type1: enumVariantType1.updateNullability(true), + type2: enumVariantType1, + expected: false, + }, + { + type1: enumVariantType1.updateNullability(true), + type2: enumVariantType1.updateNullability(true), + expected: true, + }, + // Enum variant type to union type + { + type1: enumVariantType1, + type2: new UnionType(enumType1), + expected: true, + }, + { + type1: enumVariantType1, + type2: new UnionType(enumType2), + expected: false, + }, + // Enum variant type to other + { + type1: enumVariantType1, + type2: new LiteralType(), + expected: false, + }, + // Literal type to class type + { + type1: new LiteralType(), + type2: classType1, + expected: true, + }, + { + type1: new LiteralType(new BooleanConstant(true)), + type2: coreTypes.Boolean, + expected: true, + }, + { + type1: new LiteralType(new FloatConstant(1.5)), + type2: coreTypes.Float, + expected: true, + }, + { + type1: new LiteralType(new IntConstant(1n)), + type2: coreTypes.Int, + expected: true, + }, + { + type1: new LiteralType(NullConstant), + type2: coreTypes.NothingOrNull, + expected: true, + }, + { + type1: new LiteralType(new StringConstant('')), + type2: coreTypes.String, + expected: true, + }, + { + type1: new LiteralType(new IntConstant(1n)), + type2: coreTypes.Any, + expected: true, + }, + { + type1: new LiteralType(new IntConstant(1n)), + type2: coreTypes.String, + expected: false, + }, + { + type1: new LiteralType(new IntConstant(1n), NullConstant), + type2: coreTypes.Int.updateNullability(true), + expected: true, + }, + { + type1: new LiteralType(new IntConstant(1n), NullConstant), + type2: coreTypes.Int, + expected: false, + }, + { + type1: new LiteralType(new IntConstant(1n), new StringConstant('')), + type2: coreTypes.Int, + expected: false, + }, + { + type1: new LiteralType(new IntConstant(1n), new StringConstant('')), + type2: coreTypes.String, + expected: false, + }, + { + type1: new LiteralType(new IntConstant(1n), new StringConstant('')), + type2: coreTypes.Any, + expected: true, + }, + { + type1: new LiteralType(new IntConstant(1n), new StringConstant(''), NullConstant), + type2: coreTypes.AnyOrNull, + expected: true, + }, + // Literal type to literal type + { + type1: new LiteralType(), + type2: new LiteralType(), + expected: true, + }, + { + type1: new LiteralType(new BooleanConstant(true)), + type2: new LiteralType(new BooleanConstant(true)), + expected: true, + }, + { + type1: new LiteralType(new BooleanConstant(true)), + type2: new LiteralType(new BooleanConstant(false)), + expected: false, + }, + { + type1: new LiteralType(new BooleanConstant(true)), + type2: new LiteralType(new FloatConstant(1.5)), + expected: false, + }, + { + type1: new LiteralType(new BooleanConstant(true), NullConstant), + type2: new LiteralType(new BooleanConstant(true), NullConstant), + expected: true, + }, + { + type1: new LiteralType(new BooleanConstant(true), NullConstant), + type2: new LiteralType(new BooleanConstant(true)), + expected: false, + }, + // Literal type to union type + { + type1: new LiteralType(new IntConstant(1n)), + type2: new UnionType(coreTypes.Any), + expected: true, + }, + { + type1: new LiteralType(new IntConstant(1n)), + type2: new UnionType(coreTypes.String), + expected: false, + }, + // Literal type to other + { + type1: new LiteralType(new IntConstant(1n)), + type2: enumType1, + expected: false, + }, + // Named tuple type to named tuple type + { + type1: new NamedTupleType(), + type2: new NamedTupleType(), + expected: true, + }, + { + type1: new NamedTupleType(new NamedTupleEntry(undefined, 'a', coreTypes.Int)), + type2: new NamedTupleType(new NamedTupleEntry(undefined, 'a', coreTypes.Int)), + expected: true, + }, + { + type1: new NamedTupleType(new NamedTupleEntry(class1, 'a', coreTypes.Int)), + type2: new NamedTupleType(new NamedTupleEntry(class2, 'a', coreTypes.Int)), + expected: true, + }, + { + type1: new NamedTupleType(new NamedTupleEntry(undefined, 'a', coreTypes.Int)), + type2: new NamedTupleType(new NamedTupleEntry(undefined, 'a', coreTypes.Any)), + expected: true, + }, + { + type1: new NamedTupleType(new NamedTupleEntry(undefined, 'a', coreTypes.Any)), + type2: new NamedTupleType(new NamedTupleEntry(undefined, 'a', coreTypes.Int)), + expected: false, + }, + { + type1: new NamedTupleType(new NamedTupleEntry(undefined, 'a', coreTypes.Int)), + type2: new NamedTupleType(new NamedTupleEntry(undefined, 'b', coreTypes.Int)), + expected: false, + }, + // Named tuple type to other + { + type1: new NamedTupleType(), + type2: enumType1, + expected: false, + }, + // Static type to static type + { + type1: new StaticType(classType1), + type2: new StaticType(classType1), + expected: true, + }, + { + type1: new StaticType(classType1), + type2: new StaticType(classType2), + expected: false, + }, + // Static type to other + { + type1: new StaticType(classType1), + type2: enumType1, + expected: false, + }, + // Union type to X + { + type1: new UnionType(), + type2: classType1, + expected: true, + }, + { + type1: new UnionType(classType1), + type2: classType1, + expected: true, + }, + { + type1: new UnionType(classType1, classType2), + type2: classType1, + expected: true, + }, + { + type1: new UnionType(classType1, classType3), + type2: classType1, + expected: false, + }, + { + type1: new UnionType(classType1.updateNullability(true)), + type2: classType1, + expected: false, + }, + { + type1: new UnionType(classType1.updateNullability(true)), + type2: classType1.updateNullability(true), + expected: true, + }, + // Unknown to X + { + type1: UnknownType, + type2: UnknownType, + expected: false, + }, + ]; + + describe.each(testCases)('isAssignableTo', ({ type1, type2, expected }) => { + it(`should check whether ${type1} is assignable to ${type2}`, () => { + expect(typeChecker.isAssignableTo(type1, type2)).toBe(expected); + }); + }); +}); + +/** + * A test case for {@link SafeDsTypeChecker.isAssignableTo}. + */ +interface IsAssignableToTest { + /** + * The first type to check. + */ + type1: Type; + + /** + * The second type to check. + */ + type2: Type; + + /** + * Whether {@link type1} is expected to be assignable to {@link type2}. + */ + expected: boolean; +}