Skip to content

Commit

Permalink
Typed continuations: cont.new instructions (#6308)
Browse files Browse the repository at this point in the history
This PR is part of a series that adds basic support for the [typed
continuations/wasmfx proposal](https://github.com/wasmfx/specfx).

This particular PR adds support for the `cont.new` instruction for creating
continuations, documented [here(https://github.com/wasmfx/specfx/blob/main/proposals/continuations/Overview.md#instructions).

In short, these instructions are of the form `(cont.new $ct)` where `$ct` must
be a continuation type. The instruction takes a single (nullable) function
reference as its argument, which means that the folded representation of the
instruction is of the form `(cont.new $ct (foo ...))`. 

Support for the instruction is implemented in both the old and the new wat
parser.

Note that this PR does not implement validation of the new instruction.
  • Loading branch information
frank-emrich authored Feb 22, 2024
1 parent 2ceff4d commit e2420f0
Show file tree
Hide file tree
Showing 30 changed files with 280 additions and 38 deletions.
1 change: 1 addition & 0 deletions scripts/fuzz_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ def is_git_repo():
# the fuzzer does not support typed continuations
'typed_continuations.wast',
'typed_continuations_resume.wast',
'typed_continuations_contnew.wast',
# New EH implementation is in progress
'exception-handling.wast',
'translate-eh-old-to-new.wast',
Expand Down
1 change: 1 addition & 0 deletions scripts/gen-s-parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,7 @@
("call_ref", "makeCallRef(s, /*isReturn=*/false)"),
("return_call_ref", "makeCallRef(s, /*isReturn=*/true)"),
# Typed continuations instructions
("cont.new", "makeContNew(s)"),
("resume", "makeResume(s)"),
# GC
("i31.new", "makeRefI31(s)"), # deprecated
Expand Down
75 changes: 47 additions & 28 deletions src/gen-s-parser.inc
Original file line number Diff line number Diff line change
Expand Up @@ -149,21 +149,29 @@ switch (buf[0]) {
}
}
case 'c': {
switch (buf[4]) {
case '\0':
if (op == "call"sv) { return makeCall(s, /*isReturn=*/false); }
goto parse_error;
case '_': {
switch (buf[5]) {
case 'i':
if (op == "call_indirect"sv) { return makeCallIndirect(s, /*isReturn=*/false); }
goto parse_error;
case 'r':
if (op == "call_ref"sv) { return makeCallRef(s, /*isReturn=*/false); }
switch (buf[1]) {
case 'a': {
switch (buf[4]) {
case '\0':
if (op == "call"sv) { return makeCall(s, /*isReturn=*/false); }
goto parse_error;
case '_': {
switch (buf[5]) {
case 'i':
if (op == "call_indirect"sv) { return makeCallIndirect(s, /*isReturn=*/false); }
goto parse_error;
case 'r':
if (op == "call_ref"sv) { return makeCallRef(s, /*isReturn=*/false); }
goto parse_error;
default: goto parse_error;
}
}
default: goto parse_error;
}
}
case 'o':
if (op == "cont.new"sv) { return makeContNew(s); }
goto parse_error;
default: goto parse_error;
}
}
Expand Down Expand Up @@ -3816,30 +3824,41 @@ switch (buf[0]) {
}
}
case 'c': {
switch (buf[4]) {
case '\0':
if (op == "call"sv) {
CHECK_ERR(makeCall(ctx, pos, /*isReturn=*/false));
return Ok{};
}
goto parse_error;
case '_': {
switch (buf[5]) {
case 'i':
if (op == "call_indirect"sv) {
CHECK_ERR(makeCallIndirect(ctx, pos, /*isReturn=*/false));
switch (buf[1]) {
case 'a': {
switch (buf[4]) {
case '\0':
if (op == "call"sv) {
CHECK_ERR(makeCall(ctx, pos, /*isReturn=*/false));
return Ok{};
}
goto parse_error;
case 'r':
if (op == "call_ref"sv) {
CHECK_ERR(makeCallRef(ctx, pos, /*isReturn=*/false));
return Ok{};
case '_': {
switch (buf[5]) {
case 'i':
if (op == "call_indirect"sv) {
CHECK_ERR(makeCallIndirect(ctx, pos, /*isReturn=*/false));
return Ok{};
}
goto parse_error;
case 'r':
if (op == "call_ref"sv) {
CHECK_ERR(makeCallRef(ctx, pos, /*isReturn=*/false));
return Ok{};
}
goto parse_error;
default: goto parse_error;
}
goto parse_error;
}
default: goto parse_error;
}
}
case 'o':
if (op == "cont.new"sv) {
CHECK_ERR(makeContNew(ctx, pos));
return Ok{};
}
goto parse_error;
default: goto parse_error;
}
}
Expand Down
1 change: 1 addition & 0 deletions src/ir/ReFinalize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ void ReFinalize::visitStringSliceIter(StringSliceIter* curr) {
curr->finalize();
}

void ReFinalize::visitContNew(ContNew* curr) { curr->finalize(); }
void ReFinalize::visitResume(Resume* curr) { curr->finalize(); }

void ReFinalize::visitExport(Export* curr) { WASM_UNREACHABLE("unimp"); }
Expand Down
4 changes: 4 additions & 0 deletions src/ir/cost.h
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,10 @@ struct CostAnalyzer : public OverriddenVisitor<CostAnalyzer, CostType> {
return 8 + visit(curr->ref) + visit(curr->num);
}

CostType visitContNew(ContNew* curr) {
// Some arbitrary "high" value, reflecting that this may allocate a stack
return 14 + visit(curr->func);
}
CostType visitResume(Resume* curr) {
// Inspired by indirect calls, but twice the cost.
return 12 + visit(curr->cont);
Expand Down
4 changes: 4 additions & 0 deletions src/ir/effects.h
Original file line number Diff line number Diff line change
Expand Up @@ -974,6 +974,10 @@ class EffectAnalyzer {
parent.implicitTrap = true;
}

void visitContNew(ContNew* curr) {
// traps when curr->func is null ref.
parent.implicitTrap = true;
}
void visitResume(Resume* curr) {
// This acts as a kitchen sink effect.
parent.calls = true;
Expand Down
2 changes: 2 additions & 0 deletions src/ir/module-utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,8 @@ struct CodeScanner
counts.include(get->type);
} else if (auto* set = curr->dynCast<ArraySet>()) {
counts.note(set->ref->type);
} else if (auto* contNew = curr->dynCast<ContNew>()) {
counts.note(contNew->contType);
} else if (auto* resume = curr->dynCast<Resume>()) {
counts.note(resume->contType);
} else if (Properties::isControlFlowStructure(curr)) {
Expand Down
4 changes: 4 additions & 0 deletions src/ir/possible-contents.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1200,6 +1200,10 @@ struct InfoCollector

void visitReturn(Return* curr) { addResult(curr->value); }

void visitContNew(ContNew* curr) {
// TODO: optimize when possible
addRoot(curr);
}
void visitResume(Resume* curr) {
// TODO: optimize when possible
addRoot(curr);
Expand Down
1 change: 1 addition & 0 deletions src/ir/subtype-exprs.h
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,7 @@ struct SubtypingDiscoverer : public OverriddenVisitor<SubType> {
void visitStringSliceWTF(StringSliceWTF* curr) {}
void visitStringSliceIter(StringSliceIter* curr) {}

void visitContNew(ContNew* curr) { WASM_UNREACHABLE("not implemented"); }
void visitResume(Resume* curr) { WASM_UNREACHABLE("not implemented"); }
};

Expand Down
7 changes: 7 additions & 0 deletions src/parser/contexts.h
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,9 @@ struct NullInstrParserCtx {
Result<> makeStringIterMove(Index, StringIterMoveOp) { return Ok{}; }
Result<> makeStringSliceWTF(Index, StringSliceWTFOp) { return Ok{}; }
Result<> makeStringSliceIter(Index) { return Ok{}; }
template<typename HeapTypeT> Result<> makeContNew(Index, HeapTypeT) {
return Ok{};
}
template<typename HeapTypeT>
Result<> makeResume(Index, HeapTypeT, const TagLabelListT&) {
return Ok{};
Expand Down Expand Up @@ -2010,6 +2013,10 @@ struct ParseDefsCtx : TypeParserCtx<ParseDefsCtx> {
return withLoc(pos, irBuilder.makeStringSliceIter());
}

Result<> makeContNew(Index pos, HeapType type) {
return withLoc(pos, irBuilder.makeContNew(type));
}

Result<>
makeResume(Index pos, HeapType type, const TagLabelListT& tagLabels) {
std::vector<Name> tags;
Expand Down
8 changes: 8 additions & 0 deletions src/parser/parsers.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ Result<> makeStringIterMove(Ctx&, Index, StringIterMoveOp op);
template<typename Ctx>
Result<> makeStringSliceWTF(Ctx&, Index, StringSliceWTFOp op);
template<typename Ctx> Result<> makeStringSliceIter(Ctx&, Index);
template<typename Ctx> Result<> makeContNew(Ctx&, Index);
template<typename Ctx> Result<> makeResume(Ctx&, Index);

// Modules
Expand Down Expand Up @@ -1990,6 +1991,13 @@ template<typename Ctx> Result<> makeStringSliceIter(Ctx& ctx, Index pos) {
return ctx.makeStringSliceIter(pos);
}

template<typename Ctx> Result<> makeContNew(Ctx& ctx, Index pos) {
auto type = typeidx(ctx);
CHECK_ERR(type);

return ctx.makeContNew(pos, *type);
}

// resume ::= 'resume' typeidx ('(' 'tag' tagidx labelidx ')')*
template<typename Ctx> Result<> makeResume(Ctx& ctx, Index pos) {
auto type = typeidx(ctx);
Expand Down
5 changes: 5 additions & 0 deletions src/passes/Print.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2379,6 +2379,11 @@ struct PrintExpressionContents
printMedium(o, "stringview_iter.slice");
}

void visitContNew(ContNew* curr) {
printMedium(o, "cont.new ");
printHeapType(curr->contType);
}

void visitResume(Resume* curr) {
printMedium(o, "resume");

Expand Down
1 change: 1 addition & 0 deletions src/passes/TypeGeneralizing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -875,6 +875,7 @@ struct TransferFn : OverriddenVisitor<TransferFn> {
void visitStringSliceWTF(StringSliceWTF* curr) { WASM_UNREACHABLE("TODO"); }
void visitStringSliceIter(StringSliceIter* curr) { WASM_UNREACHABLE("TODO"); }

void visitContNew(ContNew* curr) { WASM_UNREACHABLE("TODO"); }
void visitResume(Resume* curr) { WASM_UNREACHABLE("TODO"); }
};

Expand Down
2 changes: 2 additions & 0 deletions src/wasm-binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -1298,6 +1298,7 @@ enum ASTNodes {
StringNewUTF8ArrayTry = 0xb8,

// typed continuation opcodes
ContNew = 0xe0,
Resume = 0xe3,

};
Expand Down Expand Up @@ -1926,6 +1927,7 @@ class WasmBinaryReader {
void visitCallRef(CallRef* curr);
void visitRefAsCast(RefCast* curr, uint32_t code);
void visitRefAs(RefAs* curr, uint8_t code);
void visitContNew(ContNew* curr);
void visitResume(Resume* curr);

[[noreturn]] void throwError(std::string text);
Expand Down
7 changes: 7 additions & 0 deletions src/wasm-builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -1198,6 +1198,13 @@ class Builder {
return ret;
}

ContNew* makeContNew(HeapType contType, Expression* func) {
auto* ret = wasm.allocator.alloc<ContNew>();
ret->contType = contType;
ret->func = func;
ret->finalize();
return ret;
}
Resume* makeResume(HeapType contType,
const std::vector<Name>& handlerTags,
const std::vector<Name>& handlerBlocks,
Expand Down
7 changes: 7 additions & 0 deletions src/wasm-delegations-fields.def
Original file line number Diff line number Diff line change
Expand Up @@ -946,6 +946,13 @@ switch (DELEGATE_ID) {
break;
}

case Expression::Id::ContNewId: {
DELEGATE_START(ContNew);
DELEGATE_FIELD_CHILD(ContNew, func);
DELEGATE_FIELD_HEAPTYPE(ContNew, contType);
DELEGATE_END(ContNew);
break;
}
case Expression::Id::ResumeId: {
DELEGATE_START(Resume);
DELEGATE_FIELD_TYPE_VECTOR(Resume, sentTypes);
Expand Down
1 change: 1 addition & 0 deletions src/wasm-delegations.def
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ DELEGATE(StringIterNext);
DELEGATE(StringIterMove);
DELEGATE(StringSliceWTF);
DELEGATE(StringSliceIter);
DELEGATE(ContNew);
DELEGATE(Resume);

#undef DELEGATE
2 changes: 2 additions & 0 deletions src/wasm-interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -2401,6 +2401,7 @@ class ConstantExpressionRunner : public ExpressionRunner<SubType> {
}
return ExpressionRunner<SubType>::visitRefAs(curr);
}
Flow visitContNew(ContNew* curr) { WASM_UNREACHABLE("unimplemented"); }
Flow visitResume(Resume* curr) { WASM_UNREACHABLE("unimplemented"); }

void trap(const char* why) override { throw NonconstantException(); }
Expand Down Expand Up @@ -3968,6 +3969,7 @@ class ModuleRunnerBase : public ExpressionRunner<SubType> {
multiValues.pop_back();
return ret;
}
Flow visitContNew(ContNew* curr) { return Flow(NONCONSTANT_FLOW); }
Flow visitResume(Resume* curr) { return Flow(NONCONSTANT_FLOW); }

void trap(const char* why) override { externalInterface->trap(why); }
Expand Down
1 change: 1 addition & 0 deletions src/wasm-ir-builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ class IRBuilder : public UnifiedExpressionVisitor<IRBuilder, Result<>> {
[[nodiscard]] Result<> makeStringIterMove(StringIterMoveOp op);
[[nodiscard]] Result<> makeStringSliceWTF(StringSliceWTFOp op);
[[nodiscard]] Result<> makeStringSliceIter();
[[nodiscard]] Result<> makeContNew(HeapType ct);
[[nodiscard]] Result<> makeResume(HeapType ct,
const std::vector<Name>& tags,
const std::vector<Index>& labels);
Expand Down
1 change: 1 addition & 0 deletions src/wasm-s-parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@ class SExpressionWasmBuilder {
Expression* makeStringIterMove(Element& s, StringIterMoveOp op);
Expression* makeStringSliceWTF(Element& s, StringSliceWTFOp op);
Expression* makeStringSliceIter(Element& s);
Expression* makeContNew(Element& s);
Expression* makeResume(Element& s);

// Helper functions
Expand Down
12 changes: 12 additions & 0 deletions src/wasm.h
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,7 @@ class Expression {
StringIterMoveId,
StringSliceWTFId,
StringSliceIterId,
ContNewId,
ResumeId,
NumExpressionIds
};
Expand Down Expand Up @@ -1997,6 +1998,17 @@ class StringSliceIter
void finalize();
};

class ContNew : public SpecificExpression<Expression::ContNewId> {
public:
ContNew() = default;
ContNew(MixedArena& allocator) {}

HeapType contType;
Expression* func;

void finalize();
};

class Resume : public SpecificExpression<Expression::ResumeId> {
public:
Resume(MixedArena& allocator)
Expand Down
20 changes: 20 additions & 0 deletions src/wasm/wasm-binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4046,6 +4046,12 @@ BinaryConsts::ASTNodes WasmBinaryReader::readExpression(Expression*& curr) {
visitCallRef(call);
break;
}
case BinaryConsts::ContNew: {
auto contNew = allocator.alloc<ContNew>();
curr = contNew;
visitContNew(contNew);
break;
}
case BinaryConsts::Resume: {
visitResume((curr = allocator.alloc<Resume>())->cast<Resume>());
break;
Expand Down Expand Up @@ -7762,6 +7768,20 @@ void WasmBinaryReader::visitRefAs(RefAs* curr, uint8_t code) {
curr->finalize();
}

void WasmBinaryReader::visitContNew(ContNew* curr) {
BYN_TRACE("zz node: ContNew\n");

auto contTypeIndex = getU32LEB();
curr->contType = getTypeByIndex(contTypeIndex);
if (!curr->contType.isContinuation()) {
throwError("non-continuation type in cont.new instruction " +
curr->contType.toString());
}

curr->func = popNonVoidExpression();
curr->finalize();
}

void WasmBinaryReader::visitResume(Resume* curr) {
BYN_TRACE("zz node: Resume\n");

Expand Down
11 changes: 11 additions & 0 deletions src/wasm/wasm-ir-builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1783,6 +1783,17 @@ Result<> IRBuilder::makeStringSliceIter() {
return Ok{};
}

Result<> IRBuilder::makeContNew(HeapType ct) {
if (!ct.isContinuation()) {
return Err{"expected continuation type"};
}
ContNew curr;
CHECK_ERR(visitContNew(&curr));

push(builder.makeContNew(ct, curr.func));
return Ok{};
}

Result<> IRBuilder::makeResume(HeapType ct,
const std::vector<Name>& tags,
const std::vector<Index>& labels) {
Expand Down
Loading

0 comments on commit e2420f0

Please sign in to comment.