Skip to content

Commit

Permalink
merge: reduction rule (#995)
Browse files Browse the repository at this point in the history
TODO List:
+ [x] ~~FnShape and matcher~~
+ [x] Reduction rule
+ [ ] Fix bugs
+ [ ] Close this PR, and rePR them in several small PRs.
  • Loading branch information
ice1000 authored Dec 13, 2023
2 parents 247390f + 9a364e9 commit cd88eeb
Show file tree
Hide file tree
Showing 45 changed files with 755 additions and 203 deletions.
20 changes: 14 additions & 6 deletions base/src/main/java/org/aya/core/pat/Pat.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import org.aya.core.visitor.Subst;
import org.aya.generic.AyaDocile;
import org.aya.generic.Shaped;
import org.aya.util.error.InternalException;
import org.aya.prettier.AyaPrettierOptions;
import org.aya.prettier.BasePrettier;
import org.aya.prettier.CorePrettier;
Expand All @@ -29,12 +28,14 @@
import org.aya.tyck.pat.ClauseTycker;
import org.aya.tyck.tycker.ConcreteAwareTycker;
import org.aya.util.Arg;
import org.aya.util.error.InternalException;
import org.aya.util.error.SourcePos;
import org.aya.util.prettier.PrettierOptions;
import org.jetbrains.annotations.Debug;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

import java.util.function.IntUnaryOperator;
import java.util.function.UnaryOperator;

/**
Expand All @@ -54,7 +55,7 @@ public sealed interface Pat extends AyaDocile {

@NotNull Pat zonk(@NotNull ConcreteAwareTycker tycker);
/**
* Make sure you are inline all patterns in order
* Make sure you inline all patterns in order
*
* @param ctx when null, the solutions will not be inlined
* @return inlined patterns
Expand Down Expand Up @@ -188,10 +189,11 @@ record Tuple(@NotNull ImmutableSeq<Arg<Pat>> pats) implements Pat {
record Ctor(
@NotNull DefVar<CtorDef, TeleDecl.DataCtor> ref,
@NotNull ImmutableSeq<Arg<Pat>> params,
@Nullable ShapeRecognition typeRecog,
@NotNull DataCall type
) implements Pat {
public @NotNull Ctor update(@NotNull ImmutableSeq<Arg<Pat>> params, @NotNull DataCall type) {
return type == type() && params.sameElements(params(), true) ? this : new Ctor(ref, params, type);
return type == type() && params.sameElements(params(), true) ? this : new Ctor(ref, params, typeRecog, type);
}

@Override public @NotNull Ctor descent(@NotNull UnaryOperator<Pat> f, @NotNull UnaryOperator<Term> g) {
Expand All @@ -205,13 +207,14 @@ record Ctor(
@Override public @NotNull Pat zonk(@NotNull ConcreteAwareTycker tycker) {
return new Ctor(ref,
params.map(pat -> pat.descent(x -> x.zonk(tycker))),
typeRecog,
// The cast must succeed
(DataCall) tycker.zonk(type));
}

@Override public @NotNull Pat inline(@Nullable LocalCtx ctx) {
var params = this.params.map(p -> p.descent(x -> x.inline(ctx)));
return new Ctor(ref, params, (DataCall) ClauseTycker.inlineTerm(type));
return new Ctor(ref, params, typeRecog, (DataCall) ClauseTycker.inlineTerm(type));
}
}

Expand Down Expand Up @@ -243,16 +246,21 @@ public ShapedInt update(DataCall type) {
}

@Override public @NotNull Pat makeZero(@NotNull CtorDef zero) {
return new Pat.Ctor(zero.ref, ImmutableSeq.empty(), type);
return new Pat.Ctor(zero.ref, ImmutableSeq.empty(), recognition, type);
}

@Override public @NotNull Pat makeSuc(@NotNull CtorDef suc, @NotNull Arg<Pat> pat) {
return new Pat.Ctor(suc.ref, ImmutableSeq.of(pat), type);
return new Pat.Ctor(suc.ref, ImmutableSeq.of(pat), recognition, type);
}

@Override public @NotNull Pat destruct(int repr) {
return new Pat.ShapedInt(repr, this.recognition, this.type);
}

@Override
public @NotNull ShapedInt map(@NotNull IntUnaryOperator f) {
return new ShapedInt(f.applyAsInt(repr), recognition, type);
}
}

/**
Expand Down
8 changes: 3 additions & 5 deletions base/src/main/java/org/aya/core/pat/PatMatcher.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import org.aya.core.term.*;
import org.aya.core.visitor.EndoTerm;
import org.aya.core.visitor.Subst;
import org.aya.util.error.InternalException;
import org.aya.util.Arg;
import org.aya.util.error.InternalException;
import org.jetbrains.annotations.NotNull;

import java.util.function.UnaryOperator;
Expand Down Expand Up @@ -59,13 +59,11 @@ private void match(@NotNull Pat pat, @NotNull Term term) throws Mismatch {
case Pat.Ctor ctor -> {
term = pre.apply(term);
switch (term) {
case ConCall conCall -> {
case ConCallLike conCall -> {
if (ctor.ref() != conCall.ref()) throw new Mismatch(false);
visitList(ctor.params(), conCall.conArgs());
}
case MetaPatTerm metaPat -> solve(pat, metaPat);
// TODO[literal]: We may convert constructor call to literals to avoid possible stack overflow?
case IntegerTerm litTerm -> match(ctor, litTerm.constructorForm());
case ListTerm litTerm -> match(ctor, litTerm.constructorForm());
default -> throw new Mismatch(true);
}
Expand Down Expand Up @@ -119,7 +117,7 @@ private void solve(@NotNull Pat pat, @NotNull MetaPatTerm metaPat) throws Mismat
// solve as pat
metaPat.ref().solution().set(metalized);
} else {
// a MetaPat that has solution <==> the solution
// a MetaPat that has a solution <==> the solution
match(pat, todo.toTerm());
}
}
Expand Down
13 changes: 12 additions & 1 deletion base/src/main/java/org/aya/core/pat/PatToTerm.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
// Use of this source code is governed by the MIT license that can be found in the LICENSE.md file.
package org.aya.core.pat;

import kala.collection.immutable.ImmutableSeq;
import org.aya.core.term.*;
import org.aya.ref.LocalVar;
import org.aya.tyck.repr.ShapeFactory;
import org.aya.util.Arg;
import org.jetbrains.annotations.NotNull;

Expand All @@ -30,8 +32,17 @@ public Term visit(@NotNull Pat pat) {
}

protected @NotNull Term visitCtor(Pat.@NotNull Ctor ctor) {
var data = (DataCall) ctor.type();
var data = ctor.type();
var args = Arg.mapSeq(ctor.params(), this::visit);

if (ctor.typeRecog() != null) {
var head = ShapeFactory.ofCtor(ctor.ref(), ctor.typeRecog(), data);
// Not a ShapedCtor, even it's data is a ShapedData
if (head != null) {
return new RuleReducer.Con(head, data.ulift(), ImmutableSeq.empty(), args);
}
}

return new ConCall(data.ref(), ctor.ref(), data.args(), data.ulift(), args);
}
}
3 changes: 1 addition & 2 deletions base/src/main/java/org/aya/core/repr/AyaShape.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import kala.control.Option;
import kala.tuple.Tuple;
import kala.tuple.Tuple2;
import org.aya.core.def.Def;
import org.aya.core.def.GenericDef;
import org.jetbrains.annotations.NotNull;

Expand Down Expand Up @@ -140,7 +139,7 @@ class Factory {
.toImmutableSeq();
}

public @NotNull Option<ShapeRecognition> find(@NotNull Def def) {
public @NotNull Option<ShapeRecognition> find(@NotNull GenericDef def) {
return discovered.getOption(def);
}

Expand Down
10 changes: 7 additions & 3 deletions base/src/main/java/org/aya/core/serde/CompiledAya.java
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,12 @@ private record Serialization(
@NotNull MutableList<SerDef.SerOp> serOps
) {
private void ser(@NotNull ImmutableSeq<GenericDef> defs) {
defs.forEach(this::serDef);
var factory = resolveInfo.shapeFactory();
defs.forEach(x -> serDef(factory, x));
}

private void serDef(@NotNull GenericDef def) {
var serDef = new Serializer(state).serialize(def);
private void serDef(@NotNull AyaShape.Factory factory, @NotNull GenericDef def) {
var serDef = new Serializer(state, factory).serialize(def);
serDefs.append(serDef);
serOp(serDef, def);
switch (serDef) {
Expand Down Expand Up @@ -280,6 +281,9 @@ private void de(@NotNull AyaShape.Factory shapeFactory, @NotNull PhysicalModuleC
var mod = context.modulePath();
var def = serDef.de(state);
assert def.ref().core != null;
if (serDef instanceof SerShapable serShapeDef && serShapeDef.shapeResult() != null) {
shapeFactory.discovered.put(def, serShapeDef.shapeResult().de(state));
}
shapeFactory.bonjour(def);
switch (serDef) {
case SerDef.Fn fn -> {
Expand Down
12 changes: 9 additions & 3 deletions base/src/main/java/org/aya/core/serde/SerDef.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.aya.util.binop.OpDecl;
import org.aya.util.error.InternalException;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

import java.io.Serializable;
import java.util.EnumSet;
Expand All @@ -41,8 +42,9 @@ record Fn(
@NotNull ImmutableSeq<SerTerm.SerParam> telescope,
@NotNull Either<SerTerm, ImmutableSeq<SerPat.Clause>> body,
@NotNull EnumSet<Modifier> modifiers,
@NotNull SerTerm result
) implements SerDef {
@NotNull SerTerm result,
@Override @Nullable SerShapeResult shapeResult
) implements SerDef, SerShapable {
@Override public @NotNull Def de(SerTerm.@NotNull DeState state) {
return new FnDef(
state.def(name), telescope.map(tele -> tele.de(state)),
Expand Down Expand Up @@ -157,18 +159,22 @@ record SerShapeResult(

/** serialized {@link AyaShape} */
enum SerAyaShape implements Serializable {
NAT, LIST;
NAT, LIST, PLUSL, PLUSR;

public @NotNull AyaShape de() {
return switch (this) {
case NAT -> AyaShape.NAT_SHAPE;
case LIST -> AyaShape.LIST_SHAPE;
case PLUSL -> AyaShape.PLUS_LEFT_SHAPE;
case PLUSR -> AyaShape.PLUS_RIGHT_SHAPE;
};
}

public static @NotNull SerAyaShape serialize(@NotNull AyaShape shape) {
if (shape == AyaShape.NAT_SHAPE) return NAT;
if (shape == AyaShape.LIST_SHAPE) return LIST;
if (shape == AyaShape.PLUS_LEFT_SHAPE) return PLUSL;
if (shape == AyaShape.PLUS_RIGHT_SHAPE) return PLUSR;
throw new InternalException("unexpected shape: " + shape.getClass());
}
}
Expand Down
4 changes: 4 additions & 0 deletions base/src/main/java/org/aya/core/serde/SerPat.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import org.aya.util.Arg;
import org.aya.util.error.SourcePos;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

import java.io.Serializable;

Expand Down Expand Up @@ -45,11 +46,14 @@ record Ctor(
boolean explicit,
@NotNull SerDef.QName name,
@NotNull ImmutableSeq<SerPat> params,
@Nullable SerDef.SerShapeResult shapeResult,
@NotNull SerTerm.Data ty
) implements SerPat {
@Override public @NotNull Arg<Pat> de(SerTerm.@NotNull DeState state) {
var shapeRecog = this.shapeResult != null ? this.shapeResult.de(state) : null;
return new Arg<>(new Pat.Ctor(state.resolve(name),
params.map(param -> param.de(state)),
shapeRecog,
ty.de(state)), explicit);
}
}
Expand Down
11 changes: 11 additions & 0 deletions base/src/main/java/org/aya/core/serde/SerShapable.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// Copyright (c) 2020-2023 Tesla (Yinsen) Zhang.
// Use of this source code is governed by the MIT license that can be found in the LICENSE.md file.
package org.aya.core.serde;

import org.jetbrains.annotations.Contract;
import org.jetbrains.annotations.Nullable;

public interface SerShapable {
@Contract(pure = true)
@Nullable SerDef.SerShapeResult shapeResult();
}
55 changes: 55 additions & 0 deletions base/src/main/java/org/aya/core/serde/SerTerm.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,14 @@
import kala.collection.immutable.ImmutableSeq;
import kala.collection.mutable.MutableHashMap;
import kala.collection.mutable.MutableMap;
import kala.control.Either;
import kala.tuple.Tuple;
import org.aya.concrete.stmt.decl.TeleDecl;
import org.aya.core.def.CtorDef;
import org.aya.core.def.FnDef;
import org.aya.core.def.PrimDef;
import org.aya.core.term.*;
import org.aya.generic.Shaped;
import org.aya.generic.SortKind;
import org.aya.guest0x0.cubical.Formula;
import org.aya.guest0x0.cubical.Partial;
Expand Down Expand Up @@ -163,6 +168,30 @@ record Fn(@NotNull SerDef.QName name, @NotNull CallData data) implements SerTerm
}
}

record FnReduceRule(@NotNull SerTerm.SerShapedApplicable head, @NotNull CallData data) implements SerTerm {
@Override
public @NotNull Term de(@NotNull DeState state) {
return new RuleReducer.Fn(
(Shaped.Applicable<Term, FnDef, TeleDecl.FnDecl>) head.deShape(state),
data.ulift, data.de(state)
);
}
}

record ConReduceRule(
@NotNull SerTerm.SerShapedApplicable head,
@NotNull CallData dataArgs,
@NotNull ImmutableSeq<SerArg> conArgs
) implements SerTerm {
@Override
public @NotNull Term de(@NotNull DeState state) {
return new RuleReducer.Con(
(Shaped.Applicable<Term, CtorDef, TeleDecl.DataCtor>) head.deShape(state),
dataArgs().ulift, dataArgs.de(state), conArgs.map(x -> x.de(state))
);
}
}

record Data(@NotNull SerDef.QName name, @NotNull CallData data) implements SerTerm {
@Override public @NotNull DataCall de(@NotNull DeState state) {
return new DataCall(state.resolve(name), data.ulift, data.de(state));
Expand Down Expand Up @@ -320,4 +349,30 @@ record OutS(@NotNull SerTerm phi, @NotNull SerTerm par, @NotNull SerTerm u) impl
return new OutTerm(phi.de(state), par.de(state), u.de(state));
}
}

/// region ShapedApplicable

sealed interface SerShapedApplicable extends Serializable permits SerIntegerOps {
@NotNull Shaped.Applicable<Term, ?, ?> deShape(@NotNull DeState state);
}

record ConInfo(
SerDef.SerShapeResult result,
SerTerm.Data data
) implements Serializable {}

record SerIntegerOps(
@NotNull SerDef.QName ref,
@NotNull Either<ConInfo, IntegerOps.FnRule.Kind> data
) implements SerShapedApplicable {
@Override
public @NotNull Shaped.Applicable<Term, ?, ?> deShape(@NotNull DeState state) {
return data.fold(
left -> new IntegerOps.ConRule(state.resolve(ref), left.result.de(state), left.data.de(state)),
right -> new IntegerOps.FnRule(state.resolve(ref), right)
);
}
}

/// endregion ShapedApplicable
}
Loading

0 comments on commit cd88eeb

Please sign in to comment.