Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix JIT of matchy calls #1230

Merged
merged 7 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 16 additions & 14 deletions base/src/main/java/org/aya/normalize/Normalizer.java
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,13 @@
*/
@SuppressWarnings("UnnecessaryContinue") @Override public Term apply(Term term) {
while (true) {
if (term instanceof StableWHNF || term instanceof FreeTerm) return term;
// ConCall for point constructors are always in WHNF
if (term instanceof ConCall con && !con.ref().hasEq()) return con;
var alreadyWHNF = term instanceof StableWHNF ||
term instanceof FreeTerm ||
// ConCall for point constructors are always in WHNF
(term instanceof ConCall con && !con.ref().hasEq());
if (alreadyWHNF && !usePostTerm) return term;
var descentedTerm = term.descent(this);
if (alreadyWHNF && usePostTerm) return descentedTerm;
// descent may change the java type of term, i.e. beta reduce,
// and can also reduce the subterms. We intend to return the reduction
// result when it beta reduces, so keep `descentedTerm` both when in NF mode or
Expand All @@ -75,15 +78,11 @@
}
case FnCall(JitFn instance, int ulift, var args) -> {
var result = instance.invoke(args);
if (result instanceof FnCall resultCall &&
resultCall.ref() == instance &&
resultCall.args().sameElements(args, true)
) {
return defaultValue;
} else {
term = result.elevate(ulift);
continue;
}
if (result instanceof FnCall(var ref, _, var newArgs) &&
ref == instance && newArgs.sameElements(args, true)
) return defaultValue;
term = result.elevate(ulift);
continue;
}
case FnCall(FnDef.Delegate delegate, int ulift, var args) -> {
FnDef core = delegate.core();
Expand Down Expand Up @@ -172,7 +171,10 @@
}
case MatchCall(JitMatchy fn, var discr, var captures) -> {
var result = fn.invoke(captures, discr);
if (result == null) return defaultValue;
if (result instanceof MatchCall(var ref, var newDiscr, var newCaptures) &&
ref == fn && newDiscr.sameElements(discr, true) &&
newCaptures.sameElements(captures, true)
) return defaultValue;

Check warning on line 177 in base/src/main/java/org/aya/normalize/Normalizer.java

View check run for this annotation

Codecov / codecov/patch

base/src/main/java/org/aya/normalize/Normalizer.java#L177

Added line #L177 was not covered by tests
term = result;
continue;
}
Expand Down Expand Up @@ -216,7 +218,7 @@
private class Full implements UnaryOperator<Term> {
{ usePostTerm = true; }

@Override public Term apply(Term term) { return Normalizer.this.apply(term).descent(this); }
@Override public Term apply(Term term) { return Normalizer.this.apply(term); }
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import kala.control.Result;
import org.aya.compiler.free.data.FieldRef;
import org.aya.compiler.free.data.MethodRef;
import org.aya.generic.stmt.Shaped;
import org.aya.syntax.compile.JitClass;
import org.aya.syntax.compile.JitCon;
import org.aya.syntax.compile.JitData;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,13 @@
import org.aya.syntax.compile.CompiledAya;
import org.aya.syntax.compile.JitMatchy;
import org.aya.syntax.core.def.Matchy;
import org.aya.syntax.core.def.MatchyLike;
import org.aya.syntax.core.repr.CodeShape;
import org.aya.syntax.core.term.Term;
import org.aya.syntax.core.term.call.MatchCall;
import org.jetbrains.annotations.NotNull;

import java.lang.constant.ClassDesc;
import java.util.function.Consumer;

public class MatchySerializer extends ClassTargetSerializer<MatchySerializer.MatchyData> {
Expand All @@ -37,15 +41,32 @@ public MatchySerializer(ModuleSerializer.@NotNull MatchyRecorder recorder) {
return NameSerializer.javifyClassName(unit.matchy.qualifiedName().module(), unit.matchy.qualifiedName().name());
}

/**
* @see JitMatchy#invoke(Seq, Seq)
*/
private void buildInvoke(@NotNull FreeCodeBuilder builder, @NotNull MatchyData data, @NotNull LocalVariable captures, @NotNull LocalVariable args) {
public static @NotNull MethodRef resolveInvoke(@NotNull ClassDesc owner, int capturec, int argc) {
return new MethodRef.Default(
owner, "invoke",
Constants.CD_Term, ImmutableSeq.fill(capturec + argc, Constants.CD_Term),
false
);
}

private void buildInvoke(
@NotNull FreeCodeBuilder builder, @NotNull MatchyData data,
@NotNull ImmutableSeq<LocalVariable> captures, @NotNull ImmutableSeq<LocalVariable> args
) {
var unit = data.matchy;
int argc = data.argsSize;
Consumer<FreeCodeBuilder> onFailed = b -> b.returnWith(b.aconstNull(Constants.CD_Term));
var captureExprs = captures.map(LocalVariable::ref);
var argExprs = args.map(LocalVariable::ref);

if (argc == 0) {
Consumer<FreeCodeBuilder> onFailed = b -> {
var result = b.mkNew(MatchCall.class, ImmutableSeq.of(
AbstractExprializer.getInstance(b, NameSerializer.getClassDesc(data.matchy)),
AbstractExprializer.makeImmutableSeq(b, Term.class, captureExprs),
AbstractExprializer.makeImmutableSeq(b, Term.class, argExprs)
));
b.returnWith(result);
};

if (args.isEmpty()) {
onFailed.accept(builder);
return;
}
Expand All @@ -54,17 +75,34 @@ private void buildInvoke(@NotNull FreeCodeBuilder builder, @NotNull MatchyData d
new PatternSerializer.Matching(clause.bindCount(), clause.patterns(),
(ps, cb, bindCount) -> {
var resultSeq = AbstractExprializer.fromSeq(cb, Constants.CD_Term, ps.result.ref(), bindCount);
var captureSeq = AbstractExprializer.fromSeq(cb, Constants.CD_Term, captures.ref(), data.capturesSize);
var fullSeq = resultSeq.appendedAll(captureSeq);
var fullSeq = resultSeq.appendedAll(captureExprs);
var returns = serializeTermUnderTele(cb, clause.body(), fullSeq);
cb.returnWith(returns);
})
);

new PatternSerializer(AbstractExprializer.fromSeq(builder, Constants.CD_Term, args.ref(), argc), onFailed, false)
new PatternSerializer(argExprs, onFailed, false)
.serialize(builder, matching);
}

/**
* @see JitMatchy#invoke(Seq, Seq)
*/
private void buildInvoke(
@NotNull FreeCodeBuilder builder, @NotNull MatchyData data,
@NotNull LocalVariable captures, @NotNull LocalVariable args
) {
var capturec = data.capturesSize;
int argc = data.argsSize;
var invokeRef = resolveInvoke(NameSerializer.getClassDesc(data.matchy), capturec, argc);
var invokeExpr = builder.invoke(invokeRef, builder.thisRef(),
AbstractExprializer.fromSeq(builder, Constants.CD_Term, captures.ref(), capturec)
.appendedAll(AbstractExprializer.fromSeq(builder, Constants.CD_Term, args.ref(), argc))
);

builder.returnWith(invokeExpr);
}

/** @see JitMatchy#type */
private void buildType(@NotNull FreeCodeBuilder builder, @NotNull MatchyData data, @NotNull LocalVariable captures, @NotNull LocalVariable args) {
var captureSeq = AbstractExprializer.fromSeq(builder, Constants.CD_Term, captures.ref(), data.capturesSize);
Expand All @@ -84,6 +122,16 @@ private void buildType(@NotNull FreeCodeBuilder builder, @NotNull MatchyData dat
@Override public @NotNull ClassTargetSerializer<MatchyData>
serialize(@NotNull FreeClassBuilder builder0, MatchyData unit) {
buildFramework(builder0, unit, builder -> {
var capturec = unit.capturesSize;
var argc = unit.argsSize;

builder.buildMethod(Constants.CD_Term, "invoke", ImmutableSeq.fill(capturec + argc, Constants.CD_Term),
(ap, cb) -> {
var captures = ImmutableSeq.fill(capturec, ap::arg);
var args = ImmutableSeq.fill(argc, i -> ap.arg(i + capturec));
buildInvoke(cb, unit, captures, args);
});

builder.buildMethod(Constants.CD_Term, "invoke", ImmutableSeq.of(
Constants.CD_Seq, Constants.CD_Seq
), (ap, cb) -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,20 +107,27 @@
}

private @NotNull FreeJavaExpr
buildFnInvoke(@NotNull ClassDesc defClass, int ulift, @NotNull ImmutableSeq<Term> args) {
var argsExpr = args.map(this::doSerialize);
buildFnInvoke(@NotNull ClassDesc defClass, int ulift, @NotNull ImmutableSeq<FreeJavaExpr> args) {
var invokeExpr = builder.invoke(
FnSerializer.resolveInvoke(defClass, args.size()),
getInstance(builder, defClass),
argsExpr
);
FnSerializer.resolveInvoke(defClass, args.size()), getInstance(builder, defClass), args);

if (ulift != 0) {
assert ulift > 0;
invokeExpr = builder.invoke(Constants.ELEVATE, invokeExpr, ImmutableSeq.of(builder.iconst(ulift)));
}
return builder.invoke(Constants.ELEVATE, invokeExpr, ImmutableSeq.of(builder.iconst(ulift)));

Check warning on line 115 in jit-compiler/src/main/java/org/aya/compiler/serializers/TermExprializer.java

View check run for this annotation

Codecov / codecov/patch

jit-compiler/src/main/java/org/aya/compiler/serializers/TermExprializer.java#L115

Added line #L115 was not covered by tests
} else return invokeExpr;
}

return invokeExpr;
// There is a chance I need to add lifting to match, so keep a function for us to
// add the if-else in it
private @NotNull FreeJavaExpr buildMatchyInvoke(
@NotNull ClassDesc matchyClass,
@NotNull ImmutableSeq<FreeJavaExpr> args,
@NotNull ImmutableSeq<FreeJavaExpr> captures
) {
return builder.invoke(
MatchySerializer.resolveInvoke(matchyClass, captures.size(), args.size()),
getInstance(builder, matchyClass),
captures.appendedAll(args)
);
}

@Override protected @NotNull FreeJavaExpr doSerialize(@NotNull Term term) {
Expand Down Expand Up @@ -164,7 +171,9 @@
builder.iconst(head.ulift()),
serializeToImmutableSeq(Term.class, args)
));
case FnCall(var ref, var ulift, var args) -> buildFnInvoke(NameSerializer.getClassDesc(ref), ulift, args);
case FnCall(var ref, var ulift, var args) -> buildFnInvoke(
NameSerializer.getClassDesc(ref), ulift,
args.map(this::doSerialize));
case RuleReducer.Con(var rule, int ulift, var ownerArgs, var conArgs) -> {
var onStuck = builder.mkNew(RuleReducer.Con.class, ImmutableSeq.of(
serializeApplicable(rule),
Expand All @@ -174,7 +183,7 @@
));
yield builder.invoke(Constants.RULEREDUCER_MAKE, onStuck, ImmutableSeq.empty());
}
case RuleReducer.Fn (var rule, int ulift, var args) -> {
case RuleReducer.Fn(var rule, int ulift, var args) -> {
var onStuck = builder.mkNew(RuleReducer.Fn.class, ImmutableSeq.of(
serializeApplicable(rule),
builder.iconst(ulift),
Expand Down Expand Up @@ -244,11 +253,8 @@
));
case MatchCall(var ref, var args, var captures) -> {
if (ref instanceof Matchy matchy) recorder.addMatchy(matchy, args.size(), captures.size());
yield builder.mkNew(MatchCall.class, ImmutableSeq.of(
getInstance(builder, NameSerializer.getClassDesc(ref)),
serializeToImmutableSeq(Term.class, args),
serializeToImmutableSeq(Term.class, captures)
));
yield buildMatchyInvoke(NameSerializer.getClassDesc(ref),
args.map(this::doSerialize), captures.map(this::doSerialize));
}
case NewTerm(var classCall) -> builder.mkNew(NewTerm.class, ImmutableSeq.of(doSerialize(classCall)));
};
Expand Down
9 changes: 4 additions & 5 deletions jit-compiler/src/test/resources/TreeSort.aya
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,12 @@ def balanceRight Color (RBTree A) A (RBTree A) : RBTree A
rbNode red (rbNode black l v b) y (rbNode black c z d)
| c, l, v, b => rbNode c l v b

def insert_lemma (dec_le : Decider A) (a a1 : A) (c : Color) (l1 l2 : RBTree A) (b : Bool) : RBTree A elim b
| True => balanceRight c l1 a1 (insert a l2 dec_le)
| False => balanceLeft c (insert a l1 dec_le) a1 l2

def insert (a : A) (node : RBTree A) (dec_le : Decider A) : RBTree A elim node
| rbLeaf => rbNode red rbLeaf a rbLeaf
| rbNode c l1 a1 l2 => insert_lemma dec_le a a1 c l1 l2 (dec_le a1 a)
| rbNode c l1 a1 l2 => match dec_le a1 a
{ True => balanceRight c l1 a1 (insert a l2 dec_le)
| False => balanceLeft c (insert a l1 dec_le) a1 l2
}

private def aux (ls : List A) (r : RBTree A) (dec_le : Decider A) : RBTree A elim ls
| [] => r
Expand Down
7 changes: 6 additions & 1 deletion syntax/src/main/java/org/aya/prettier/BasePrettier.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.aya.pretty.doc.Link;
import org.aya.pretty.doc.Style;
import org.aya.pretty.style.AyaStyleKey;
import org.aya.syntax.compile.JitCon;
import org.aya.syntax.compile.JitDef;
import org.aya.syntax.concrete.stmt.QualifiedID;
import org.aya.syntax.concrete.stmt.decl.*;
Expand Down Expand Up @@ -109,7 +110,11 @@ private static BooleanSeq computeLicitFromDef(@NotNull AnyDef var, int size) {
: inner.ref.signature == null
? BooleanSeq.fill(size, true)
: inner.ref.signature.params().mapToBooleanTo(MutableBooleanList.create(), Param::explicit);
case JitTele jit -> MutableBooleanList.from(jit.telescopeLicit);
case JitTele jit -> {
var rawLicit = MutableBooleanList.from(jit.telescopeLicit);
if (jit instanceof JitCon con) yield rawLicit.takeLast(con.selfTeleSize());
yield rawLicit;
}
default -> Panic.unreachable();
};
}
Expand Down
4 changes: 1 addition & 3 deletions syntax/src/main/java/org/aya/syntax/compile/JitMatchy.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,15 @@
import org.aya.syntax.core.def.MatchyLike;
import org.aya.syntax.core.term.Term;
import org.aya.syntax.ref.QName;
import org.aya.util.error.Panic;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

public non-sealed abstract class JitMatchy extends JitUnit implements MatchyLike {
protected JitMatchy() {
super();
}

/** @return null if stuck */
public abstract @Nullable Term invoke(
public abstract @NotNull Term invoke(
@NotNull Seq<@NotNull Term> captures,
@NotNull Seq<@NotNull Term> args
);
Expand Down
2 changes: 1 addition & 1 deletion syntax/src/main/java/org/aya/syntax/compile/JitUnit.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public abstract class JitUnit {
return metadata;
}

public @NotNull String name() { return metadata.name(); }
public @NotNull String name() { return metadata().name(); }
public @NotNull ModulePath module() { return new ModulePath(rawModule()); }

private @NotNull ImmutableArray<String> rawModule() {
Expand Down
2 changes: 1 addition & 1 deletion syntax/src/main/java/org/aya/syntax/core/pat/Pat.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
*
* @author kiva, ice1000, HoshinoTented
*/
@Debug.Renderer(text = "PatToTerm.visit(this).debuggerOnlyToString()")
@Debug.Renderer(text = "PatToTerm.visit(this).easyToString()")
public sealed interface Pat {
default @NotNull Pat descentPat(@NotNull UnaryOperator<Pat> op) { return this; }
default @NotNull Pat descentTerm(@NotNull UnaryOperator<Term> op) { return this; }
Expand Down
Loading