Skip to content

Commit

Permalink
merge: Fix JIT of matchy calls (#1230)
Browse files Browse the repository at this point in the history
  • Loading branch information
ice1000 authored Dec 17, 2024
2 parents d82d632 + af782ba commit 6d061c9
Show file tree
Hide file tree
Showing 9 changed files with 110 additions and 53 deletions.
30 changes: 16 additions & 14 deletions base/src/main/java/org/aya/normalize/Normalizer.java
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,13 @@ public final class Normalizer implements UnaryOperator<Term> {
*/
@SuppressWarnings("UnnecessaryContinue") @Override public Term apply(Term term) {
while (true) {
if (term instanceof StableWHNF || term instanceof FreeTerm) return term;
// ConCall for point constructors are always in WHNF
if (term instanceof ConCall con && !con.ref().hasEq()) return con;
var alreadyWHNF = term instanceof StableWHNF ||
term instanceof FreeTerm ||
// ConCall for point constructors are always in WHNF
(term instanceof ConCall con && !con.ref().hasEq());
if (alreadyWHNF && !usePostTerm) return term;
var descentedTerm = term.descent(this);
if (alreadyWHNF && usePostTerm) return descentedTerm;
// descent may change the java type of term, i.e. beta reduce,
// and can also reduce the subterms. We intend to return the reduction
// result when it beta reduces, so keep `descentedTerm` both when in NF mode or
Expand All @@ -75,15 +78,11 @@ public final class Normalizer implements UnaryOperator<Term> {
}
case FnCall(JitFn instance, int ulift, var args) -> {
var result = instance.invoke(args);
if (result instanceof FnCall resultCall &&
resultCall.ref() == instance &&
resultCall.args().sameElements(args, true)
) {
return defaultValue;
} else {
term = result.elevate(ulift);
continue;
}
if (result instanceof FnCall(var ref, _, var newArgs) &&
ref == instance && newArgs.sameElements(args, true)
) return defaultValue;
term = result.elevate(ulift);
continue;
}
case FnCall(FnDef.Delegate delegate, int ulift, var args) -> {
FnDef core = delegate.core();
Expand Down Expand Up @@ -172,7 +171,10 @@ case MatchCall(Matchy clauses, var discr, var captures) -> {
}
case MatchCall(JitMatchy fn, var discr, var captures) -> {
var result = fn.invoke(captures, discr);
if (result == null) return defaultValue;
if (result instanceof MatchCall(var ref, var newDiscr, var newCaptures) &&
ref == fn && newDiscr.sameElements(discr, true) &&
newCaptures.sameElements(captures, true)
) return defaultValue;
term = result;
continue;
}
Expand Down Expand Up @@ -216,7 +218,7 @@ private boolean isOpaque(@NotNull FnDef fn) {
private class Full implements UnaryOperator<Term> {
{ usePostTerm = true; }

@Override public Term apply(Term term) { return Normalizer.this.apply(term).descent(this); }
@Override public Term apply(Term term) { return Normalizer.this.apply(term); }
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import kala.control.Result;
import org.aya.compiler.free.data.FieldRef;
import org.aya.compiler.free.data.MethodRef;
import org.aya.generic.stmt.Shaped;
import org.aya.syntax.compile.JitClass;
import org.aya.syntax.compile.JitCon;
import org.aya.syntax.compile.JitData;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,13 @@
import org.aya.syntax.compile.CompiledAya;
import org.aya.syntax.compile.JitMatchy;
import org.aya.syntax.core.def.Matchy;
import org.aya.syntax.core.def.MatchyLike;
import org.aya.syntax.core.repr.CodeShape;
import org.aya.syntax.core.term.Term;
import org.aya.syntax.core.term.call.MatchCall;
import org.jetbrains.annotations.NotNull;

import java.lang.constant.ClassDesc;
import java.util.function.Consumer;

public class MatchySerializer extends ClassTargetSerializer<MatchySerializer.MatchyData> {
Expand All @@ -37,15 +41,32 @@ public MatchySerializer(ModuleSerializer.@NotNull MatchyRecorder recorder) {
return NameSerializer.javifyClassName(unit.matchy.qualifiedName().module(), unit.matchy.qualifiedName().name());
}

/**
* @see JitMatchy#invoke(Seq, Seq)
*/
private void buildInvoke(@NotNull FreeCodeBuilder builder, @NotNull MatchyData data, @NotNull LocalVariable captures, @NotNull LocalVariable args) {
public static @NotNull MethodRef resolveInvoke(@NotNull ClassDesc owner, int capturec, int argc) {
return new MethodRef.Default(
owner, "invoke",
Constants.CD_Term, ImmutableSeq.fill(capturec + argc, Constants.CD_Term),
false
);
}

private void buildInvoke(
@NotNull FreeCodeBuilder builder, @NotNull MatchyData data,
@NotNull ImmutableSeq<LocalVariable> captures, @NotNull ImmutableSeq<LocalVariable> args
) {
var unit = data.matchy;
int argc = data.argsSize;
Consumer<FreeCodeBuilder> onFailed = b -> b.returnWith(b.aconstNull(Constants.CD_Term));
var captureExprs = captures.map(LocalVariable::ref);
var argExprs = args.map(LocalVariable::ref);

if (argc == 0) {
Consumer<FreeCodeBuilder> onFailed = b -> {
var result = b.mkNew(MatchCall.class, ImmutableSeq.of(
AbstractExprializer.getInstance(b, NameSerializer.getClassDesc(data.matchy)),
AbstractExprializer.makeImmutableSeq(b, Term.class, captureExprs),
AbstractExprializer.makeImmutableSeq(b, Term.class, argExprs)
));
b.returnWith(result);
};

if (args.isEmpty()) {
onFailed.accept(builder);
return;
}
Expand All @@ -54,17 +75,34 @@ private void buildInvoke(@NotNull FreeCodeBuilder builder, @NotNull MatchyData d
new PatternSerializer.Matching(clause.bindCount(), clause.patterns(),
(ps, cb, bindCount) -> {
var resultSeq = AbstractExprializer.fromSeq(cb, Constants.CD_Term, ps.result.ref(), bindCount);
var captureSeq = AbstractExprializer.fromSeq(cb, Constants.CD_Term, captures.ref(), data.capturesSize);
var fullSeq = resultSeq.appendedAll(captureSeq);
var fullSeq = resultSeq.appendedAll(captureExprs);
var returns = serializeTermUnderTele(cb, clause.body(), fullSeq);
cb.returnWith(returns);
})
);

new PatternSerializer(AbstractExprializer.fromSeq(builder, Constants.CD_Term, args.ref(), argc), onFailed, false)
new PatternSerializer(argExprs, onFailed, false)
.serialize(builder, matching);
}

/**
* @see JitMatchy#invoke(Seq, Seq)
*/
private void buildInvoke(
@NotNull FreeCodeBuilder builder, @NotNull MatchyData data,
@NotNull LocalVariable captures, @NotNull LocalVariable args
) {
var capturec = data.capturesSize;
int argc = data.argsSize;
var invokeRef = resolveInvoke(NameSerializer.getClassDesc(data.matchy), capturec, argc);
var invokeExpr = builder.invoke(invokeRef, builder.thisRef(),
AbstractExprializer.fromSeq(builder, Constants.CD_Term, captures.ref(), capturec)
.appendedAll(AbstractExprializer.fromSeq(builder, Constants.CD_Term, args.ref(), argc))
);

builder.returnWith(invokeExpr);
}

/** @see JitMatchy#type */
private void buildType(@NotNull FreeCodeBuilder builder, @NotNull MatchyData data, @NotNull LocalVariable captures, @NotNull LocalVariable args) {
var captureSeq = AbstractExprializer.fromSeq(builder, Constants.CD_Term, captures.ref(), data.capturesSize);
Expand All @@ -84,6 +122,16 @@ private void buildType(@NotNull FreeCodeBuilder builder, @NotNull MatchyData dat
@Override public @NotNull ClassTargetSerializer<MatchyData>
serialize(@NotNull FreeClassBuilder builder0, MatchyData unit) {
buildFramework(builder0, unit, builder -> {
var capturec = unit.capturesSize;
var argc = unit.argsSize;

builder.buildMethod(Constants.CD_Term, "invoke", ImmutableSeq.fill(capturec + argc, Constants.CD_Term),
(ap, cb) -> {
var captures = ImmutableSeq.fill(capturec, ap::arg);
var args = ImmutableSeq.fill(argc, i -> ap.arg(i + capturec));
buildInvoke(cb, unit, captures, args);
});

builder.buildMethod(Constants.CD_Term, "invoke", ImmutableSeq.of(
Constants.CD_Seq, Constants.CD_Seq
), (ap, cb) -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,20 +107,27 @@ private TermExprializer(
}

private @NotNull FreeJavaExpr
buildFnInvoke(@NotNull ClassDesc defClass, int ulift, @NotNull ImmutableSeq<Term> args) {
var argsExpr = args.map(this::doSerialize);
buildFnInvoke(@NotNull ClassDesc defClass, int ulift, @NotNull ImmutableSeq<FreeJavaExpr> args) {
var invokeExpr = builder.invoke(
FnSerializer.resolveInvoke(defClass, args.size()),
getInstance(builder, defClass),
argsExpr
);
FnSerializer.resolveInvoke(defClass, args.size()), getInstance(builder, defClass), args);

if (ulift != 0) {
assert ulift > 0;
invokeExpr = builder.invoke(Constants.ELEVATE, invokeExpr, ImmutableSeq.of(builder.iconst(ulift)));
}
return builder.invoke(Constants.ELEVATE, invokeExpr, ImmutableSeq.of(builder.iconst(ulift)));
} else return invokeExpr;
}

return invokeExpr;
// There is a chance I need to add lifting to match, so keep a function for us to
// add the if-else in it
private @NotNull FreeJavaExpr buildMatchyInvoke(
@NotNull ClassDesc matchyClass,
@NotNull ImmutableSeq<FreeJavaExpr> args,
@NotNull ImmutableSeq<FreeJavaExpr> captures
) {
return builder.invoke(
MatchySerializer.resolveInvoke(matchyClass, captures.size(), args.size()),
getInstance(builder, matchyClass),
captures.appendedAll(args)
);
}

@Override protected @NotNull FreeJavaExpr doSerialize(@NotNull Term term) {
Expand Down Expand Up @@ -164,7 +171,9 @@ case ConCall(var head, var args) -> builder.mkNew(ConCall.class, ImmutableSeq.of
builder.iconst(head.ulift()),
serializeToImmutableSeq(Term.class, args)
));
case FnCall(var ref, var ulift, var args) -> buildFnInvoke(NameSerializer.getClassDesc(ref), ulift, args);
case FnCall(var ref, var ulift, var args) -> buildFnInvoke(
NameSerializer.getClassDesc(ref), ulift,
args.map(this::doSerialize));
case RuleReducer.Con(var rule, int ulift, var ownerArgs, var conArgs) -> {
var onStuck = builder.mkNew(RuleReducer.Con.class, ImmutableSeq.of(
serializeApplicable(rule),
Expand All @@ -174,7 +183,7 @@ case ConCall(var head, var args) -> builder.mkNew(ConCall.class, ImmutableSeq.of
));
yield builder.invoke(Constants.RULEREDUCER_MAKE, onStuck, ImmutableSeq.empty());
}
case RuleReducer.Fn (var rule, int ulift, var args) -> {
case RuleReducer.Fn(var rule, int ulift, var args) -> {
var onStuck = builder.mkNew(RuleReducer.Fn.class, ImmutableSeq.of(
serializeApplicable(rule),
builder.iconst(ulift),
Expand Down Expand Up @@ -244,11 +253,8 @@ case ClassCastTerm(var classRef, var subterm, var rember, var forgor) ->
));
case MatchCall(var ref, var args, var captures) -> {
if (ref instanceof Matchy matchy) recorder.addMatchy(matchy, args.size(), captures.size());
yield builder.mkNew(MatchCall.class, ImmutableSeq.of(
getInstance(builder, NameSerializer.getClassDesc(ref)),
serializeToImmutableSeq(Term.class, args),
serializeToImmutableSeq(Term.class, captures)
));
yield buildMatchyInvoke(NameSerializer.getClassDesc(ref),
args.map(this::doSerialize), captures.map(this::doSerialize));
}
case NewTerm(var classCall) -> builder.mkNew(NewTerm.class, ImmutableSeq.of(doSerialize(classCall)));
};
Expand Down
9 changes: 4 additions & 5 deletions jit-compiler/src/test/resources/TreeSort.aya
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,12 @@ def balanceRight Color (RBTree A) A (RBTree A) : RBTree A
rbNode red (rbNode black l v b) y (rbNode black c z d)
| c, l, v, b => rbNode c l v b

def insert_lemma (dec_le : Decider A) (a a1 : A) (c : Color) (l1 l2 : RBTree A) (b : Bool) : RBTree A elim b
| True => balanceRight c l1 a1 (insert a l2 dec_le)
| False => balanceLeft c (insert a l1 dec_le) a1 l2

def insert (a : A) (node : RBTree A) (dec_le : Decider A) : RBTree A elim node
| rbLeaf => rbNode red rbLeaf a rbLeaf
| rbNode c l1 a1 l2 => insert_lemma dec_le a a1 c l1 l2 (dec_le a1 a)
| rbNode c l1 a1 l2 => match dec_le a1 a
{ True => balanceRight c l1 a1 (insert a l2 dec_le)
| False => balanceLeft c (insert a l1 dec_le) a1 l2
}

private def aux (ls : List A) (r : RBTree A) (dec_le : Decider A) : RBTree A elim ls
| [] => r
Expand Down
7 changes: 6 additions & 1 deletion syntax/src/main/java/org/aya/prettier/BasePrettier.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.aya.pretty.doc.Link;
import org.aya.pretty.doc.Style;
import org.aya.pretty.style.AyaStyleKey;
import org.aya.syntax.compile.JitCon;
import org.aya.syntax.compile.JitDef;
import org.aya.syntax.concrete.stmt.QualifiedID;
import org.aya.syntax.concrete.stmt.decl.*;
Expand Down Expand Up @@ -109,7 +110,11 @@ private static BooleanSeq computeLicitFromDef(@NotNull AnyDef var, int size) {
: inner.ref.signature == null
? BooleanSeq.fill(size, true)
: inner.ref.signature.params().mapToBooleanTo(MutableBooleanList.create(), Param::explicit);
case JitTele jit -> MutableBooleanList.from(jit.telescopeLicit);
case JitTele jit -> {
var rawLicit = MutableBooleanList.from(jit.telescopeLicit);
if (jit instanceof JitCon con) yield rawLicit.takeLast(con.selfTeleSize());
yield rawLicit;
}
default -> Panic.unreachable();
};
}
Expand Down
4 changes: 1 addition & 3 deletions syntax/src/main/java/org/aya/syntax/compile/JitMatchy.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,15 @@
import org.aya.syntax.core.def.MatchyLike;
import org.aya.syntax.core.term.Term;
import org.aya.syntax.ref.QName;
import org.aya.util.error.Panic;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

public non-sealed abstract class JitMatchy extends JitUnit implements MatchyLike {
protected JitMatchy() {
super();
}

/** @return null if stuck */
public abstract @Nullable Term invoke(
public abstract @NotNull Term invoke(
@NotNull Seq<@NotNull Term> captures,
@NotNull Seq<@NotNull Term> args
);
Expand Down
2 changes: 1 addition & 1 deletion syntax/src/main/java/org/aya/syntax/compile/JitUnit.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public abstract class JitUnit {
return metadata;
}

public @NotNull String name() { return metadata.name(); }
public @NotNull String name() { return metadata().name(); }
public @NotNull ModulePath module() { return new ModulePath(rawModule()); }

private @NotNull ImmutableArray<String> rawModule() {
Expand Down
2 changes: 1 addition & 1 deletion syntax/src/main/java/org/aya/syntax/core/pat/Pat.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
*
* @author kiva, ice1000, HoshinoTented
*/
@Debug.Renderer(text = "PatToTerm.visit(this).debuggerOnlyToString()")
@Debug.Renderer(text = "PatToTerm.visit(this).easyToString()")
public sealed interface Pat {
default @NotNull Pat descentPat(@NotNull UnaryOperator<Pat> op) { return this; }
default @NotNull Pat descentTerm(@NotNull UnaryOperator<Term> op) { return this; }
Expand Down

0 comments on commit 6d061c9

Please sign in to comment.