Skip to content

Commit

Permalink
[MOREL-53] Optimize core language by inlining expressions
Browse files Browse the repository at this point in the history
The paper "Secrets of the Glasgow Haskell Compiler inliner"
(Peyton Jones and Marlow, 1999, revised 2002) describes the
approach.

Given the query

let
  val emp = scott.emps
in
  from e in emp
    yield e.deptno
end

we need the compiler (in particular the Calcite compiler) to
know that emp is always equivalent to scott.emps and
therefore can be translated to a Calcite TableScan. Without
inlining, to be safe, we would have to generate a Calcite
plan involving a TableFunctionScan (i.e. indirecting at
runtime rather than compile time) and that seriously limits
query optimization opportunities on the Calcite side.

Inlining also has some other nice effects, such as

* "let val f = fn x => x + 1 in f 3 end"
  becomes
  "3 + 1"
* "let val x = 3 in isOdd 3 end"
  becomes
  "isOdd 3"
* "(fn x => x + 1) 5"
  becomes
  "let x = 5 in x + 1 end" (beta reduction)
* "let x = 1 and y = 2 in y + 3"
  becomes
  "let y = 2 in y + 3" (remove dead declarations)

Inlining is implemented in class Inliner, which is a shuttle
that makes multiple passes over the expression tree. It also
converts references to built-in functions (say
"String.length" or "#length String") into function literals
which can be more easily matched by subsequent optimization
rules.

The paper has guidelines for when inlining can be done safely
(without causing code size or runtime to increase). We
implement those guidelines in class Analyzer.

Refactorings

1. Add class EnvVisitor, and improve Visitor, Shuttle,
EnvShuttle.

2. Simplify Core.Fn (whereas Ast.Fn has a list of matches,
each with a potentially complex pattern, Core.Fn now has
just one IdPat and Exp rather than a match-list; if the
function has alternate branches or a complex pattern,
those become a Core.Case);

3. Split out Core.Local from Core.Let.
Let has a ValDecl; Local has a Datatype.

Local has much less effect on inlining than Let, so it is
cleaner to separate it. (We still don't support 'local' in
the Morel parser.)

Why does Core.Local have a DataType, whereas Ast.Let contains
a DatatypeDecl (and therefore several DataType declarations)?
Because ML datatypes are not recursive; therefore unlike val,
we never need simultaneous local.

DatatypeDecl now only occurs at top-level; only top-level
programs need to declare multiple DataTypes simultaneously.

4.  Change the type of Core.Let.pat from Pat to IdPat.
After this change, Let patterns are always simple, which
makes transformations such as inlining easier. Complex
patterns, such as

  let val (x, y) = (1, 2) in f (x + y) end

are now represented using a single-branch 'case':

  let val v = (1, 2) in case v of (x, y) => f (x + y) end

Similarly lists and datatype constructors. Note that we
have introduced an intermediate variable, 'v'.

Core.Case is now the only element of the Core language that
can deconstruct (pattern-match), via its sub-element
Core.Match. Contrast with the Ast, where there are several
places, such as 'fun', 'fn', 'let' (Ast.FunDecl, Ast.Fn,
Ast.Let) in addition to 'case' (Ast.Case).

5. Simplify Resolver.toCore(Ast.Let) and .toCore(Ast.ValDecl).
These are complex because there may be several let and local
that are intermingled, and we don't know whether we want an
Exp or DatatypeDecl at the end of it all. To help, introduce
class ResolvedDecl to as an intermediate data structure.

6. Each variable reference (Core.Id) used to reference a
variable declaration by name (String) but now contains the
variable declaration (Core.IdPat) explicitly. Variable
declarations are uniquely identified by a name plus an
ordinal, so that all variables in the program are unique,
and shadowing of variables doesn't occur when an expression
is inlined in a different scope.

7. Resolver now maintains an environment. This is necessary
to generate those unique variable names.

8. Use Util.skip(List) in a few places.

9. In class Binding, replace the name and type fields with
an IdPat; add exp field (which we use for inlining).

10. When translating 'fn () => E' to core, don't need 'case'.
  • Loading branch information
julianhyde committed Jun 25, 2021
1 parent c1a890b commit ab22295
Show file tree
Hide file tree
Showing 35 changed files with 1,709 additions and 383 deletions.
14 changes: 13 additions & 1 deletion src/main/java/net/hydromatic/morel/ast/AstWriter.java
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,18 @@ public AstWriter id(String s) {
return this;
}

/** Appends a ordinal-qualified-identifier to the output.
*
* <p>Prints "v" for {@code id("v", 0)}, "v#1" for {@code id("v", 1)},
* and so forth. */
public AstWriter id(String s, int i) {
b.append(s);
if (i > 0) {
b.append('#').append(i);
}
return this;
}

/** Appends a call to an infix operator. */
public AstWriter infix(int left, AstNode a0, Op op, AstNode a1, int right) {
if (op == Op.APPLY && a0.op == Op.ID) {
Expand All @@ -65,7 +77,7 @@ public AstWriter infix(int left, AstNode a0, Op op, AstNode a1, int right) {
// be a function literal, and we would use a reverse mapping to
// figure out which built-in operator it implements, and whether it
// is infix (e.g. "+") or in a namespace (e.g. "#translate String")
final Op op2 = Op.BY_OP_NAME.get(((Core.Id) a0).name);
final Op op2 = Op.BY_OP_NAME.get(((Core.Id) a0).idPat);
if (op2 != null && op2.left > 0) {
final List<Core.Exp> args = ((Core.Tuple) a1).args;
return infix(left, args.get(0), op2, args.get(1), right);
Expand Down
156 changes: 111 additions & 45 deletions src/main/java/net/hydromatic/morel/ast/Core.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@
import java.util.Objects;
import java.util.Set;
import java.util.SortedMap;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.ObjIntConsumer;
import javax.annotation.Nullable;

Expand Down Expand Up @@ -103,27 +103,58 @@ public Type type() {
}

/** Named pattern.
*
* <p>Implements {@link Comparable} so that names are sorted correctly
* for record fields (see {@link RecordType#ORDERING}).
*
* @see Ast.Id */
public static class IdPat extends Pat {
public static class IdPat extends Pat implements Comparable<IdPat> {
public final String name;
public final int i;

IdPat(Type type, String name) {
IdPat(Type type, String name, int i) {
super(Op.ID_PAT, type);
this.name = name;
this.i = i;
}

@Override public int hashCode() {
return name.hashCode() + i;
}

@Override public boolean equals(Object obj) {
return obj == this
|| obj instanceof IdPat
&& ((IdPat) obj).name.equals(name)
&& ((IdPat) obj).i == i;
}

/** {@inheritDoc}
*
* <p>Collate first on name, then on ordinal. */
@Override public int compareTo(IdPat o) {
final int c = RecordType.compareNames(name, o.name);
if (c != 0) {
return c;
}
return Integer.compare(i, o.i);
}

@Override AstWriter unparse(AstWriter w, int left, int right) {
return w.id(name);
return w.id(name, i);
}

@Override public Pat accept(Shuttle shuttle) {
@Override public IdPat accept(Shuttle shuttle) {
return shuttle.visit(this);
}

@Override public void accept(Visitor visitor) {
visitor.visit(this);
}

public IdPat withType(Type type) {
return type == this.type ? this : new IdPat(type, name, i);
}
}

/** Literal pattern, the pattern analog of the {@link Literal} expression.
Expand Down Expand Up @@ -406,22 +437,22 @@ public Type type() {
* Core; for example, compare {@link Ast.Con0Pat#tyCon}
* with {@link Con0Pat#tyCon}. */
public static class Id extends Exp {
public final String name;
public final IdPat idPat;

/** Creates an Id. */
Id(String name, Type type) {
super(Op.ID, type);
this.name = requireNonNull(name);
Id(IdPat idPat) {
super(Op.ID, idPat.type);
this.idPat = requireNonNull(idPat);
}

@Override public int hashCode() {
return name.hashCode();
return idPat.hashCode();
}

@Override public boolean equals(Object o) {
return o == this
|| o instanceof Id
&& this.name.equals(((Id) o).name);
&& this.idPat.equals(((Id) o).idPat);
}

@Override public Exp accept(Shuttle shuttle) {
Expand All @@ -433,7 +464,7 @@ public static class Id extends Exp {
}

@Override AstWriter unparse(AstWriter w, int left, int right) {
return w.id(name);
return w.id(idPat.name, idPat.i);
}
}

Expand Down Expand Up @@ -565,7 +596,7 @@ public static class DatatypeDecl extends Decl {
return w;
}

@Override public Decl accept(Shuttle shuttle) {
@Override public DatatypeDecl accept(Shuttle shuttle) {
return shuttle.visit(this);
}

Expand All @@ -577,10 +608,10 @@ public static class DatatypeDecl extends Decl {
/** Value declaration. */
public static class ValDecl extends Decl {
public final boolean rec;
public final Pat pat;
public final IdPat pat;
public final Exp exp;

ValDecl(boolean rec, Pat pat, Exp exp) {
ValDecl(boolean rec, IdPat pat, Exp exp) {
super(Op.VAL_DECL);
this.rec = rec;
this.pat = pat;
Expand Down Expand Up @@ -612,7 +643,7 @@ public static class ValDecl extends Decl {
visitor.visit(this);
}

public ValDecl copy(boolean rec, Pat pat, Exp exp) {
public ValDecl copy(boolean rec, IdPat pat, Exp exp) {
return rec == this.rec && pat == this.pat && exp == this.exp ? this
: core.valDecl(rec, pat, exp);
}
Expand Down Expand Up @@ -658,10 +689,10 @@ public Tuple copy(TypeSystem typeSystem, List<Exp> args) {

/** "Let" expression. */
public static class Let extends Exp {
public final Decl decl;
public final ValDecl decl;
public final Exp exp;

Let(Decl decl, Exp exp) {
Let(ValDecl decl, Exp exp) {
super(Op.LET, exp.type);
this.decl = requireNonNull(decl);
this.exp = requireNonNull(exp);
Expand All @@ -681,13 +712,50 @@ public static class Let extends Exp {
visitor.visit(this);
}

public Exp copy(Decl decl, Exp exp) {
public Exp copy(ValDecl decl, Exp exp) {
return decl == this.decl && exp == this.exp ? this
: core.let(decl, exp);
}
}

/** Match. */
/** "Local" expression. */
public static class Local extends Exp {
public final DataType dataType;
public final Exp exp;

Local(DataType dataType, Exp exp) {
super(Op.LOCAL, exp.type);
this.dataType = requireNonNull(dataType);
this.exp = requireNonNull(exp);
}

@Override AstWriter unparse(AstWriter w, int left, int right) {
return w.append("local datatype ").append(dataType.toString())
.append(" in ").append(exp, 0, 0)
.append(" end");
}

@Override public Exp accept(Shuttle shuttle) {
return shuttle.visit(this);
}

@Override public void accept(Visitor visitor) {
visitor.visit(this);
}

public Exp copy(DataType dataType, Exp exp) {
return dataType == this.dataType && exp == this.exp ? this
: core.local(dataType, exp);
}
}

/** Match.
*
* <p>In AST, there are several places that can deconstruct values via
* patterns: {@link Ast.FunDecl fun}, {@link Ast.Fn fn}, {@link Ast.Let let},
* {@link Ast.Case case}. But in Core, there is only {@code Match}, and
* {@code Match} only occurs within {@link Ast.Case case}. This makes the Core
* language a little more verbose than AST but a lot more uniform. */
public static class Match extends BaseNode {
public final Pat pat;
public final Exp exp;
Expand Down Expand Up @@ -718,20 +786,22 @@ public Match copy(Pat pat, Exp exp) {

/** Lambda expression. */
public static class Fn extends Exp {
public final List<Match> matchList;
public final IdPat idPat;
public final Exp exp;

Fn(FnType type, ImmutableList<Match> matchList) {
Fn(FnType type, IdPat idPat, Exp exp) {
super(Op.FN, type);
this.matchList = requireNonNull(matchList);
checkArgument(!matchList.isEmpty());
this.idPat = requireNonNull(idPat);
this.exp = requireNonNull(exp);
}

@Override public FnType type() {
return (FnType) type;
}

@Override AstWriter unparse(AstWriter w, int left, int right) {
return w.append("fn ").appendAll(matchList, 0, Op.BAR, right);
return w.append("fn ")
.append(idPat, 0, 0).append(" => ").append(exp, 0, right);
}

@Override public Exp accept(Shuttle shuttle) {
Expand All @@ -742,9 +812,9 @@ public static class Fn extends Exp {
visitor.visit(this);
}

public Exp copy(List<Match> matchList) {
return matchList.equals(this.matchList) ? this
: core.fn(type(), matchList);
public Fn copy(IdPat idPat, Exp exp) {
return idPat == this.idPat && exp == this.exp ? this
: core.fn(type(), idPat, exp);
}
}

Expand Down Expand Up @@ -850,8 +920,7 @@ public abstract static class FromStep extends BaseNode {
* <p>By default, a step outputs the same fields as it inputs.
*/
public void deriveOutBindings(Iterable<Binding> inBindings,
BiFunction<String, Type, Binding> binder,
Consumer<Binding> outBindings) {
Function<IdPat, Binding> binder, Consumer<Binding> outBindings) {
inBindings.forEach(outBindings);
}

Expand Down Expand Up @@ -944,11 +1013,11 @@ public OrderItem copy(Exp exp, Ast.Direction direction) {

/** A {@code group} clause in a {@code from} expression. */
public static class Group extends FromStep {
public final SortedMap<String, Exp> groupExps;
public final SortedMap<String, Aggregate> aggregates;
public final SortedMap<Core.IdPat, Exp> groupExps;
public final SortedMap<Core.IdPat, Aggregate> aggregates;

Group(ImmutableSortedMap<String, Exp> groupExps,
ImmutableSortedMap<String, Aggregate> aggregates) {
Group(ImmutableSortedMap<Core.IdPat, Exp> groupExps,
ImmutableSortedMap<Core.IdPat, Aggregate> aggregates) {
super(Op.GROUP);
this.groupExps = groupExps;
this.aggregates = aggregates;
Expand All @@ -965,24 +1034,21 @@ public static class Group extends FromStep {
@Override AstWriter unparse(AstWriter w, int left, int right) {
Pair.forEachIndexed(groupExps, (i, id, exp) ->
w.append(i == 0 ? "group " : ", ")
.id(id).append(" = ").append(exp, 0, 0));
.append(id, 0, 0).append(" = ").append(exp, 0, 0));
Pair.forEachIndexed(aggregates, (i, name, aggregate) ->
w.append(i == 0 ? " compute " : ", ")
.id(name).append(" = ").append(aggregate, 0, 0));
.append(name, 0, 0).append(" = ").append(aggregate, 0, 0));
return w;
}

@Override public void deriveOutBindings(Iterable<Binding> inBindings,
BiFunction<String, Type, Binding> binder,
Consumer<Binding> outBindings) {
groupExps.forEach((id, exp) ->
outBindings.accept(binder.apply(id, exp.type)));
aggregates.forEach((id, aggregate) ->
outBindings.accept(binder.apply(id, aggregate.type)));
Function<Core.IdPat, Binding> binder, Consumer<Binding> outBindings) {
groupExps.keySet().forEach(id -> outBindings.accept(binder.apply(id)));
aggregates.keySet().forEach(id -> outBindings.accept(binder.apply(id)));
}

public Group copy(Map<String, Exp> groupExps,
Map<String, Aggregate> aggregates) {
public Group copy(SortedMap<Core.IdPat, Exp> groupExps,
SortedMap<Core.IdPat, Aggregate> aggregates) {
return groupExps.equals(this.groupExps)
&& aggregates.equals(this.aggregates)
? this
Expand Down Expand Up @@ -1101,7 +1167,7 @@ private static boolean isValidValue(Exp exp, Object o) {
return false;
}
if (o instanceof Id) {
final String name = ((Id) exp).name;
final String name = ((Id) exp).idPat.name;
return !("true".equals(name) || "false".equals(name));
}
return true;
Expand Down
Loading

0 comments on commit ab22295

Please sign in to comment.