diff --git a/base/src/main/java/org/aya/normalize/Normalizer.java b/base/src/main/java/org/aya/normalize/Normalizer.java index e1c736f6d..aa470a718 100644 --- a/base/src/main/java/org/aya/normalize/Normalizer.java +++ b/base/src/main/java/org/aya/normalize/Normalizer.java @@ -53,10 +53,13 @@ public final class Normalizer implements UnaryOperator { */ @SuppressWarnings("UnnecessaryContinue") @Override public Term apply(Term term) { while (true) { - if (term instanceof StableWHNF || term instanceof FreeTerm) return term; - // ConCall for point constructors are always in WHNF - if (term instanceof ConCall con && !con.ref().hasEq()) return con; + var alreadyWHNF = term instanceof StableWHNF || + term instanceof FreeTerm || + // ConCall for point constructors are always in WHNF + (term instanceof ConCall con && !con.ref().hasEq()); + if (alreadyWHNF && !usePostTerm) return term; var descentedTerm = term.descent(this); + if (alreadyWHNF && usePostTerm) return descentedTerm; // descent may change the java type of term, i.e. beta reduce, // and can also reduce the subterms. We intend to return the reduction // result when it beta reduces, so keep `descentedTerm` both when in NF mode or @@ -75,15 +78,11 @@ public final class Normalizer implements UnaryOperator { } case FnCall(JitFn instance, int ulift, var args) -> { var result = instance.invoke(args); - if (result instanceof FnCall resultCall && - resultCall.ref() == instance && - resultCall.args().sameElements(args, true) - ) { - return defaultValue; - } else { - term = result.elevate(ulift); - continue; - } + if (result instanceof FnCall(var ref, _, var newArgs) && + ref == instance && newArgs.sameElements(args, true) + ) return defaultValue; + term = result.elevate(ulift); + continue; } case FnCall(FnDef.Delegate delegate, int ulift, var args) -> { FnDef core = delegate.core(); @@ -172,7 +171,10 @@ case MatchCall(Matchy clauses, var discr, var captures) -> { } case MatchCall(JitMatchy fn, var discr, var captures) -> { var result = fn.invoke(captures, discr); - if (result == null) return defaultValue; + if (result instanceof MatchCall(var ref, var newDiscr, var newCaptures) && + ref == fn && newDiscr.sameElements(discr, true) && + newCaptures.sameElements(captures, true) + ) return defaultValue; term = result; continue; } @@ -216,7 +218,7 @@ private boolean isOpaque(@NotNull FnDef fn) { private class Full implements UnaryOperator { { usePostTerm = true; } - @Override public Term apply(Term term) { return Normalizer.this.apply(term).descent(this); } + @Override public Term apply(Term term) { return Normalizer.this.apply(term); } } /** diff --git a/jit-compiler/src/main/java/org/aya/compiler/free/Constants.java b/jit-compiler/src/main/java/org/aya/compiler/free/Constants.java index 28f980c29..4e765b1dd 100644 --- a/jit-compiler/src/main/java/org/aya/compiler/free/Constants.java +++ b/jit-compiler/src/main/java/org/aya/compiler/free/Constants.java @@ -9,7 +9,6 @@ import kala.control.Result; import org.aya.compiler.free.data.FieldRef; import org.aya.compiler.free.data.MethodRef; -import org.aya.generic.stmt.Shaped; import org.aya.syntax.compile.JitClass; import org.aya.syntax.compile.JitCon; import org.aya.syntax.compile.JitData; diff --git a/jit-compiler/src/main/java/org/aya/compiler/serializers/MatchySerializer.java b/jit-compiler/src/main/java/org/aya/compiler/serializers/MatchySerializer.java index 776955d0a..574eb6951 100644 --- a/jit-compiler/src/main/java/org/aya/compiler/serializers/MatchySerializer.java +++ b/jit-compiler/src/main/java/org/aya/compiler/serializers/MatchySerializer.java @@ -12,9 +12,13 @@ import org.aya.syntax.compile.CompiledAya; import org.aya.syntax.compile.JitMatchy; import org.aya.syntax.core.def.Matchy; +import org.aya.syntax.core.def.MatchyLike; import org.aya.syntax.core.repr.CodeShape; +import org.aya.syntax.core.term.Term; +import org.aya.syntax.core.term.call.MatchCall; import org.jetbrains.annotations.NotNull; +import java.lang.constant.ClassDesc; import java.util.function.Consumer; public class MatchySerializer extends ClassTargetSerializer { @@ -37,15 +41,32 @@ public MatchySerializer(ModuleSerializer.@NotNull MatchyRecorder recorder) { return NameSerializer.javifyClassName(unit.matchy.qualifiedName().module(), unit.matchy.qualifiedName().name()); } - /** - * @see JitMatchy#invoke(Seq, Seq) - */ - private void buildInvoke(@NotNull FreeCodeBuilder builder, @NotNull MatchyData data, @NotNull LocalVariable captures, @NotNull LocalVariable args) { + public static @NotNull MethodRef resolveInvoke(@NotNull ClassDesc owner, int capturec, int argc) { + return new MethodRef.Default( + owner, "invoke", + Constants.CD_Term, ImmutableSeq.fill(capturec + argc, Constants.CD_Term), + false + ); + } + + private void buildInvoke( + @NotNull FreeCodeBuilder builder, @NotNull MatchyData data, + @NotNull ImmutableSeq captures, @NotNull ImmutableSeq args + ) { var unit = data.matchy; - int argc = data.argsSize; - Consumer onFailed = b -> b.returnWith(b.aconstNull(Constants.CD_Term)); + var captureExprs = captures.map(LocalVariable::ref); + var argExprs = args.map(LocalVariable::ref); - if (argc == 0) { + Consumer onFailed = b -> { + var result = b.mkNew(MatchCall.class, ImmutableSeq.of( + AbstractExprializer.getInstance(b, NameSerializer.getClassDesc(data.matchy)), + AbstractExprializer.makeImmutableSeq(b, Term.class, captureExprs), + AbstractExprializer.makeImmutableSeq(b, Term.class, argExprs) + )); + b.returnWith(result); + }; + + if (args.isEmpty()) { onFailed.accept(builder); return; } @@ -54,17 +75,34 @@ private void buildInvoke(@NotNull FreeCodeBuilder builder, @NotNull MatchyData d new PatternSerializer.Matching(clause.bindCount(), clause.patterns(), (ps, cb, bindCount) -> { var resultSeq = AbstractExprializer.fromSeq(cb, Constants.CD_Term, ps.result.ref(), bindCount); - var captureSeq = AbstractExprializer.fromSeq(cb, Constants.CD_Term, captures.ref(), data.capturesSize); - var fullSeq = resultSeq.appendedAll(captureSeq); + var fullSeq = resultSeq.appendedAll(captureExprs); var returns = serializeTermUnderTele(cb, clause.body(), fullSeq); cb.returnWith(returns); }) ); - new PatternSerializer(AbstractExprializer.fromSeq(builder, Constants.CD_Term, args.ref(), argc), onFailed, false) + new PatternSerializer(argExprs, onFailed, false) .serialize(builder, matching); } + /** + * @see JitMatchy#invoke(Seq, Seq) + */ + private void buildInvoke( + @NotNull FreeCodeBuilder builder, @NotNull MatchyData data, + @NotNull LocalVariable captures, @NotNull LocalVariable args + ) { + var capturec = data.capturesSize; + int argc = data.argsSize; + var invokeRef = resolveInvoke(NameSerializer.getClassDesc(data.matchy), capturec, argc); + var invokeExpr = builder.invoke(invokeRef, builder.thisRef(), + AbstractExprializer.fromSeq(builder, Constants.CD_Term, captures.ref(), capturec) + .appendedAll(AbstractExprializer.fromSeq(builder, Constants.CD_Term, args.ref(), argc)) + ); + + builder.returnWith(invokeExpr); + } + /** @see JitMatchy#type */ private void buildType(@NotNull FreeCodeBuilder builder, @NotNull MatchyData data, @NotNull LocalVariable captures, @NotNull LocalVariable args) { var captureSeq = AbstractExprializer.fromSeq(builder, Constants.CD_Term, captures.ref(), data.capturesSize); @@ -84,6 +122,16 @@ private void buildType(@NotNull FreeCodeBuilder builder, @NotNull MatchyData dat @Override public @NotNull ClassTargetSerializer serialize(@NotNull FreeClassBuilder builder0, MatchyData unit) { buildFramework(builder0, unit, builder -> { + var capturec = unit.capturesSize; + var argc = unit.argsSize; + + builder.buildMethod(Constants.CD_Term, "invoke", ImmutableSeq.fill(capturec + argc, Constants.CD_Term), + (ap, cb) -> { + var captures = ImmutableSeq.fill(capturec, ap::arg); + var args = ImmutableSeq.fill(argc, i -> ap.arg(i + capturec)); + buildInvoke(cb, unit, captures, args); + }); + builder.buildMethod(Constants.CD_Term, "invoke", ImmutableSeq.of( Constants.CD_Seq, Constants.CD_Seq ), (ap, cb) -> { diff --git a/jit-compiler/src/main/java/org/aya/compiler/serializers/TermExprializer.java b/jit-compiler/src/main/java/org/aya/compiler/serializers/TermExprializer.java index 2f1e7fe7a..a053447c6 100644 --- a/jit-compiler/src/main/java/org/aya/compiler/serializers/TermExprializer.java +++ b/jit-compiler/src/main/java/org/aya/compiler/serializers/TermExprializer.java @@ -107,20 +107,27 @@ private TermExprializer( } private @NotNull FreeJavaExpr - buildFnInvoke(@NotNull ClassDesc defClass, int ulift, @NotNull ImmutableSeq args) { - var argsExpr = args.map(this::doSerialize); + buildFnInvoke(@NotNull ClassDesc defClass, int ulift, @NotNull ImmutableSeq args) { var invokeExpr = builder.invoke( - FnSerializer.resolveInvoke(defClass, args.size()), - getInstance(builder, defClass), - argsExpr - ); + FnSerializer.resolveInvoke(defClass, args.size()), getInstance(builder, defClass), args); if (ulift != 0) { - assert ulift > 0; - invokeExpr = builder.invoke(Constants.ELEVATE, invokeExpr, ImmutableSeq.of(builder.iconst(ulift))); - } + return builder.invoke(Constants.ELEVATE, invokeExpr, ImmutableSeq.of(builder.iconst(ulift))); + } else return invokeExpr; + } - return invokeExpr; + // There is a chance I need to add lifting to match, so keep a function for us to + // add the if-else in it + private @NotNull FreeJavaExpr buildMatchyInvoke( + @NotNull ClassDesc matchyClass, + @NotNull ImmutableSeq args, + @NotNull ImmutableSeq captures + ) { + return builder.invoke( + MatchySerializer.resolveInvoke(matchyClass, captures.size(), args.size()), + getInstance(builder, matchyClass), + captures.appendedAll(args) + ); } @Override protected @NotNull FreeJavaExpr doSerialize(@NotNull Term term) { @@ -164,7 +171,9 @@ case ConCall(var head, var args) -> builder.mkNew(ConCall.class, ImmutableSeq.of builder.iconst(head.ulift()), serializeToImmutableSeq(Term.class, args) )); - case FnCall(var ref, var ulift, var args) -> buildFnInvoke(NameSerializer.getClassDesc(ref), ulift, args); + case FnCall(var ref, var ulift, var args) -> buildFnInvoke( + NameSerializer.getClassDesc(ref), ulift, + args.map(this::doSerialize)); case RuleReducer.Con(var rule, int ulift, var ownerArgs, var conArgs) -> { var onStuck = builder.mkNew(RuleReducer.Con.class, ImmutableSeq.of( serializeApplicable(rule), @@ -174,7 +183,7 @@ case ConCall(var head, var args) -> builder.mkNew(ConCall.class, ImmutableSeq.of )); yield builder.invoke(Constants.RULEREDUCER_MAKE, onStuck, ImmutableSeq.empty()); } - case RuleReducer.Fn (var rule, int ulift, var args) -> { + case RuleReducer.Fn(var rule, int ulift, var args) -> { var onStuck = builder.mkNew(RuleReducer.Fn.class, ImmutableSeq.of( serializeApplicable(rule), builder.iconst(ulift), @@ -244,11 +253,8 @@ case ClassCastTerm(var classRef, var subterm, var rember, var forgor) -> )); case MatchCall(var ref, var args, var captures) -> { if (ref instanceof Matchy matchy) recorder.addMatchy(matchy, args.size(), captures.size()); - yield builder.mkNew(MatchCall.class, ImmutableSeq.of( - getInstance(builder, NameSerializer.getClassDesc(ref)), - serializeToImmutableSeq(Term.class, args), - serializeToImmutableSeq(Term.class, captures) - )); + yield buildMatchyInvoke(NameSerializer.getClassDesc(ref), + args.map(this::doSerialize), captures.map(this::doSerialize)); } case NewTerm(var classCall) -> builder.mkNew(NewTerm.class, ImmutableSeq.of(doSerialize(classCall))); }; diff --git a/jit-compiler/src/test/resources/TreeSort.aya b/jit-compiler/src/test/resources/TreeSort.aya index 2f660bef3..d2eb2f5ee 100644 --- a/jit-compiler/src/test/resources/TreeSort.aya +++ b/jit-compiler/src/test/resources/TreeSort.aya @@ -46,13 +46,12 @@ def balanceRight Color (RBTree A) A (RBTree A) : RBTree A rbNode red (rbNode black l v b) y (rbNode black c z d) | c, l, v, b => rbNode c l v b -def insert_lemma (dec_le : Decider A) (a a1 : A) (c : Color) (l1 l2 : RBTree A) (b : Bool) : RBTree A elim b -| True => balanceRight c l1 a1 (insert a l2 dec_le) -| False => balanceLeft c (insert a l1 dec_le) a1 l2 - def insert (a : A) (node : RBTree A) (dec_le : Decider A) : RBTree A elim node | rbLeaf => rbNode red rbLeaf a rbLeaf -| rbNode c l1 a1 l2 => insert_lemma dec_le a a1 c l1 l2 (dec_le a1 a) +| rbNode c l1 a1 l2 => match dec_le a1 a + { True => balanceRight c l1 a1 (insert a l2 dec_le) + | False => balanceLeft c (insert a l1 dec_le) a1 l2 + } private def aux (ls : List A) (r : RBTree A) (dec_le : Decider A) : RBTree A elim ls | [] => r diff --git a/syntax/src/main/java/org/aya/prettier/BasePrettier.java b/syntax/src/main/java/org/aya/prettier/BasePrettier.java index cd76218db..36bc7a18b 100644 --- a/syntax/src/main/java/org/aya/prettier/BasePrettier.java +++ b/syntax/src/main/java/org/aya/prettier/BasePrettier.java @@ -14,6 +14,7 @@ import org.aya.pretty.doc.Link; import org.aya.pretty.doc.Style; import org.aya.pretty.style.AyaStyleKey; +import org.aya.syntax.compile.JitCon; import org.aya.syntax.compile.JitDef; import org.aya.syntax.concrete.stmt.QualifiedID; import org.aya.syntax.concrete.stmt.decl.*; @@ -109,7 +110,11 @@ private static BooleanSeq computeLicitFromDef(@NotNull AnyDef var, int size) { : inner.ref.signature == null ? BooleanSeq.fill(size, true) : inner.ref.signature.params().mapToBooleanTo(MutableBooleanList.create(), Param::explicit); - case JitTele jit -> MutableBooleanList.from(jit.telescopeLicit); + case JitTele jit -> { + var rawLicit = MutableBooleanList.from(jit.telescopeLicit); + if (jit instanceof JitCon con) yield rawLicit.takeLast(con.selfTeleSize()); + yield rawLicit; + } default -> Panic.unreachable(); }; } diff --git a/syntax/src/main/java/org/aya/syntax/compile/JitMatchy.java b/syntax/src/main/java/org/aya/syntax/compile/JitMatchy.java index 224d8009e..964b55ab0 100644 --- a/syntax/src/main/java/org/aya/syntax/compile/JitMatchy.java +++ b/syntax/src/main/java/org/aya/syntax/compile/JitMatchy.java @@ -6,9 +6,7 @@ import org.aya.syntax.core.def.MatchyLike; import org.aya.syntax.core.term.Term; import org.aya.syntax.ref.QName; -import org.aya.util.error.Panic; import org.jetbrains.annotations.NotNull; -import org.jetbrains.annotations.Nullable; public non-sealed abstract class JitMatchy extends JitUnit implements MatchyLike { protected JitMatchy() { @@ -16,7 +14,7 @@ protected JitMatchy() { } /** @return null if stuck */ - public abstract @Nullable Term invoke( + public abstract @NotNull Term invoke( @NotNull Seq<@NotNull Term> captures, @NotNull Seq<@NotNull Term> args ); diff --git a/syntax/src/main/java/org/aya/syntax/compile/JitUnit.java b/syntax/src/main/java/org/aya/syntax/compile/JitUnit.java index 9a98d9586..fe1705bf3 100644 --- a/syntax/src/main/java/org/aya/syntax/compile/JitUnit.java +++ b/syntax/src/main/java/org/aya/syntax/compile/JitUnit.java @@ -18,7 +18,7 @@ public abstract class JitUnit { return metadata; } - public @NotNull String name() { return metadata.name(); } + public @NotNull String name() { return metadata().name(); } public @NotNull ModulePath module() { return new ModulePath(rawModule()); } private @NotNull ImmutableArray rawModule() { diff --git a/syntax/src/main/java/org/aya/syntax/core/pat/Pat.java b/syntax/src/main/java/org/aya/syntax/core/pat/Pat.java index 2bcefb89b..d33682034 100644 --- a/syntax/src/main/java/org/aya/syntax/core/pat/Pat.java +++ b/syntax/src/main/java/org/aya/syntax/core/pat/Pat.java @@ -37,7 +37,7 @@ * * @author kiva, ice1000, HoshinoTented */ -@Debug.Renderer(text = "PatToTerm.visit(this).debuggerOnlyToString()") +@Debug.Renderer(text = "PatToTerm.visit(this).easyToString()") public sealed interface Pat { default @NotNull Pat descentPat(@NotNull UnaryOperator op) { return this; } default @NotNull Pat descentTerm(@NotNull UnaryOperator op) { return this; }