Skip to content

Commit

Permalink
Deduce type of polymorphic field in tuple or record
Browse files Browse the repository at this point in the history
The problem was that after applying an action during
unification (when we have deduced the type of a record
selector function) we did not check whether the variable had
already been expanded by unification.

Create a lighter weight Substitution class, for use during
the unification algorithm, whose map is not immutable and not
sorted.
  • Loading branch information
julianhyde committed May 3, 2020
1 parent 75da291 commit 93b9a0e
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 36 deletions.
9 changes: 5 additions & 4 deletions src/main/java/net/hydromatic/morel/compile/TypeResolver.java
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ private static String str(int i) {
private Ast.RecordSelector deduceRecordSelectorType(TypeEnv env,
Unifier.Variable vResult, Unifier.Variable vArg,
Ast.RecordSelector recordSelector) {
actionMap.put(vArg, (v, t, termPairs) -> {
actionMap.put(vArg, (v, t, substitution, termPairs) -> {
// We now know that the type arg, say "{a: int, b: real}".
// So, now we can declare that the type of vResult, say "#b", is
// "real".
Expand All @@ -499,9 +499,10 @@ private Ast.RecordSelector deduceRecordSelectorType(TypeEnv env,
if (fieldList != null) {
int i = fieldList.indexOf(recordSelector.name);
if (i >= 0) {
termPairs.add(
new Unifier.TermTerm(vResult,
sequence.terms.get(i)));
final Unifier.Term result2 = substitution.resolve(vResult);
final Unifier.Term term = sequence.terms.get(i);
final Unifier.Term term2 = substitution.resolve(term);
termPairs.add(new Unifier.TermTerm(result2, term2));
recordSelector.slot = i;
}
}
Expand Down
43 changes: 26 additions & 17 deletions src/main/java/net/hydromatic/morel/util/MartelliUnifier.java
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public class MartelliUnifier extends Unifier {
final Map<Variable, Term> result = new LinkedHashMap<>();
for (;;) {
if (termPairs.isEmpty()) {
return new Substitution(result);
return SubstitutionResult.create(result);
}
int i = findDelete(termPairs);
if (i >= 0) {
Expand Down Expand Up @@ -103,42 +103,51 @@ public class MartelliUnifier extends Unifier {
return failure("cycle: variable " + variable + " in " + term);
}
tracer.onVariable(variable, term);
final Map<Variable, Term> map = ImmutableMap.of(variable, term);
result.put(variable, term);
act(variable, term, termPairs, termActions, 0);
for (int j = 0; j < termPairs.size(); j++) {
final TermTerm pair2 = termPairs.get(j);
final Term left2 = pair2.left.apply(map);
final Term right2 = pair2.right.apply(map);
if (left2 != pair2.left
|| right2 != pair2.right) {
tracer.onSubstitute(pair2.left, pair2.right, left2, right2);
termPairs.set(j, new TermTerm(left2, right2));
}
}
act(variable, term, termPairs, new Substitution(result),
termActions, 0);
substituteList(termPairs, tracer, ImmutableMap.of(variable, term));
}
}
}

/** Applies a mapping to all term pairs in a list, modifying them in place. */
private void substituteList(List<TermTerm> termPairs, Tracer tracer,
Map<Variable, Term> map) {
for (int j = 0; j < termPairs.size(); j++) {
final TermTerm pair2 = termPairs.get(j);
final Term left2 = pair2.left.apply(map);
final Term right2 = pair2.right.apply(map);
if (left2 != pair2.left
|| right2 != pair2.right) {
tracer.onSubstitute(pair2.left, pair2.right, left2, right2);
termPairs.set(j, new TermTerm(left2, right2));
}
}
}

private void act(Variable variable, Term term, List<TermTerm> termPairs,
Map<Variable, Action> termActions, int depth) {
Substitution substitution, Map<Variable, Action> termActions,
int depth) {
final Action action = termActions.get(variable);
if (action != null) {
action.accept(variable, term, termPairs);
action.accept(variable, term, substitution, termPairs);
}
if (term instanceof Variable) {
// Copy list to prevent concurrent modification, in case the action
// appends to the list. Limit on depth, to prevent infinite recursion.
final List<TermTerm> termPairsCopy = new ArrayList<>(termPairs);
termPairsCopy.forEach(termPair -> {
if (termPair.left.equals(term) && depth < 2) {
act(variable, termPair.right, termPairs, termActions, depth + 1);
act(variable, termPair.right, termPairs, substitution, termActions,
depth + 1);
}
});
// If the term is a variable, recurse to see whether there is an
// action for that variable. Limit on depth to prevent swapping back.
if (depth < 1) {
act((Variable) term, variable, termPairs, termActions, depth + 1);
act((Variable) term, variable, termPairs, substitution, termActions,
depth + 1);
}
}
}
Expand Down
18 changes: 10 additions & 8 deletions src/main/java/net/hydromatic/morel/util/RobinsonUnifier.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
*/
package net.hydromatic.morel.util;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSortedMap;

import java.util.HashMap;
import java.util.List;
Expand Down Expand Up @@ -46,7 +46,7 @@ static Map<Variable, Term> compose(Map<Variable, Term> s1,
return failure("sequences have different operator: " + lhs + ", " + rhs);
}
if (lhs.terms.isEmpty()) {
return EMPTY;
return SubstitutionResult.EMPTY;
}
Term firstLhs = lhs.terms.get(0);
Term firstRhs = rhs.terms.get(0);
Expand All @@ -64,10 +64,12 @@ static Map<Variable, Term> compose(Map<Variable, Term> s1,
return r2;
}
final Substitution subs2 = (Substitution) r2;
Map<Variable, Term> joined = new HashMap<>();
joined.putAll(subs1.resultMap);
joined.putAll(subs2.resultMap);
return new Substitution(joined);
final Map<Variable, Term> joined =
ImmutableSortedMap.<Variable, Term>naturalOrder()
.putAll(subs1.resultMap)
.putAll(subs2.resultMap)
.build();
return SubstitutionResult.create(joined);
}

static <E> List<E> skip(List<E> list) {
Expand All @@ -86,10 +88,10 @@ static <E> List<E> skip(List<E> list) {

public @Nonnull Result unify(Term lhs, Term rhs) {
if (lhs instanceof Variable) {
return new Substitution(ImmutableMap.of((Variable) lhs, rhs));
return SubstitutionResult.create((Variable) lhs, rhs);
}
if (rhs instanceof Variable) {
return new Substitution(ImmutableMap.of((Variable) rhs, lhs));
return SubstitutionResult.create((Variable) rhs, lhs);
}
if (lhs instanceof Sequence && rhs instanceof Sequence) {
return sequenceUnify((Sequence) lhs, (Sequence) rhs);
Expand Down
38 changes: 31 additions & 7 deletions src/main/java/net/hydromatic/morel/util/Unifier.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@
/** Given pairs of terms, finds a substitution to minimize those pairs of
* terms. */
public abstract class Unifier {
static final Substitution EMPTY = new Substitution(ImmutableMap.of());

private int varId;
private final Map<String, Variable> variableMap = new HashMap<>();
private final Map<String, Sequence> atomMap = new HashMap<>();
Expand Down Expand Up @@ -127,7 +125,8 @@ protected Failure failure(String reason) {
/** Called by the unifier when a Term's type becomes known. */
@FunctionalInterface
public interface Action {
void accept(Variable variable, Term term, List<TermTerm> termPairs);
void accept(Variable variable, Term term, Substitution substitution,
List<TermTerm> termPairs);
}

/** Result of attempting unification. A success is {@link Substitution},
Expand All @@ -142,15 +141,40 @@ public static class Failure implements Result {
/** The results of a successful unification. Gives access to the raw variable
* mapping that resulted from the algorithm, but can also resolve a variable
* to the fullest extent possible with the {@link #resolve} method. */
public static final class Substitution implements Result {
public static final class SubstitutionResult extends Substitution
implements Result {
private SubstitutionResult(Map<Variable, Term> resultMap) {
super(ImmutableSortedMap.copyOf(resultMap, Ordering.natural()));
}

/** Empty substitution result. */
public static final SubstitutionResult EMPTY =
create(ImmutableSortedMap.of());

/** Creates a substitution result from a map. */
public static SubstitutionResult create(Map<Variable, Term> resultMap) {
return new SubstitutionResult(
ImmutableSortedMap.copyOf(resultMap, Ordering.natural()));
}

/** Creates a substitution result with one (variable, term) entry. */
public static SubstitutionResult create(Variable v, Term t) {
return new SubstitutionResult(ImmutableSortedMap.of(v, t));
}
}

/** Map from variables to terms.
*
* <p>Quicker to create than its sub-class {@link SubstitutionResult}
* because the map is mutable and not sorted. */
public static class Substitution {
/** The result of the unification algorithm proper. This does not have
* everything completely resolved: some variable substitutions are required
* before getting the most atom-y representation. */
public final Map<Variable, Term> resultMap;

Substitution(Map<Variable, Term> resultMap) {
this.resultMap =
ImmutableSortedMap.copyOf(resultMap, Ordering.natural());
this.resultMap = resultMap;
}

@Override public int hashCode() {
Expand All @@ -175,7 +199,7 @@ public StringBuilder accept(StringBuilder b) {
return b.append("]");
}

Term resolve(Term term) {
public Term resolve(Term term) {
Term previous;
Term current = term;
do {
Expand Down
20 changes: 20 additions & 0 deletions src/test/resources/script/datatype.sml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,26 @@ max_int (2, 3);
fun max_real (x, y) = if x < y then y + 0.0 else x;
max_real (2.0, 3.0);
(*) Tuple type with a polymorphic member
let
val r = (fn x => x, 2)
in
(#1 r) 1
end;
(*) Record type with a polymorphic member
let
val r = {a = fn x => x, b = 2}
in
r.a 1
end;
let
val r = {a = fn x => x, b = 2}
in
(r.a "x", r.b)
end;
(*) A recursive type, without generics
datatype inttree = Empty | Node of inttree * int * inttree;
fun max (x, y) = if x < y then y + 0 else x;
Expand Down
26 changes: 26 additions & 0 deletions src/test/resources/script/datatype.sml.out
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,32 @@ max_real (2.0, 3.0);
val it = 3.0 : real


(*) Tuple type with a polymorphic member
let
val r = (fn x => x, 2)
in
(#1 r) 1
end;
val it = 1 : int


(*) Record type with a polymorphic member
let
val r = {a = fn x => x, b = 2}
in
r.a 1
end;
val it = 1 : int


let
val r = {a = fn x => x, b = 2}
in
(r.a "x", r.b)
end;
val it = ("x",2) : string * int


(*) A recursive type, without generics
datatype inttree = Empty | Node of inttree * int * inttree;
datatype inttree = Empty | Node of inttree * int * inttree
Expand Down

0 comments on commit 93b9a0e

Please sign in to comment.