Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sumtype: reduce template overhead of match #9087

Merged
merged 5 commits into from
Nov 21, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 106 additions & 76 deletions std/sumtype.d
Original file line number Diff line number Diff line change
Expand Up @@ -1860,88 +1860,65 @@ private template Iota(size_t n)
assert(Iota!3 == AliasSeq!(0, 1, 2));
}

/* The number that the dim-th argument's tag is multiplied by when
* converting TagTuples to and from case indices ("caseIds").
*
* Named by analogy to the stride that the dim-th index into a
* multidimensional static array is multiplied by to calculate the
* offset of a specific element.
*/
private size_t stride(size_t dim, lengths...)()
{
import core.checkedint : mulu;

size_t result = 1;
bool overflow = false;

static foreach (i; 0 .. dim)
{
result = mulu(result, lengths[i], overflow);
}

/* The largest number matchImpl uses, numCases, is calculated with
* stride!(SumTypes.length), so as long as this overflow check
* passes, we don't need to check for overflow anywhere else.
*/
assert(!overflow, "Integer overflow");
return result;
}

private template matchImpl(Flag!"exhaustive" exhaustive, handlers...)
{
auto ref matchImpl(SumTypes...)(auto ref SumTypes args)
if (allSatisfy!(isSumType, SumTypes) && args.length > 0)
{
alias stride(size_t i) = .stride!(i, Map!(typeCount, SumTypes));
alias TagTuple = .TagTuple!(SumTypes);

/*
* A list of arguments to be passed to a handler needed for the case
* labeled with `caseId`.
*/
template handlerArgs(size_t caseId)
// Single dispatch (fast path)
static if (args.length == 1)
{
enum tags = TagTuple.fromCaseId(caseId);
enum argsFrom(size_t i : tags.length) = "";
enum argsFrom(size_t i) = "args[" ~ toCtString!i ~ "].get!(SumTypes[" ~ toCtString!i ~ "]" ~
".Types[" ~ toCtString!(tags[i]) ~ "])(), " ~ argsFrom!(i + 1);
enum handlerArgs = argsFrom!0;
}
/* When there's only one argument, the caseId is just that
* argument's tag, so there's no need for TagTuple.
*/
enum handlerArgs(size_t caseId) =
"args[0].get!(SumTypes[0].Types[" ~ toCtString!caseId ~ "])()";

/* An AliasSeq of the types of the member values in the argument list
* returned by `handlerArgs!caseId`.
*
* Note that these are the actual (that is, qualified) types of the
* member values, which may not be the same as the types listed in
* the arguments' `.Types` properties.
*/
template valueTypes(size_t caseId)
alias valueTypes(size_t caseId) =
typeof(args[0].get!(SumTypes[0].Types[caseId])());

enum numCases = SumTypes[0].Types.length;
}
// Multiple dispatch (slow path)
else
{
enum tags = TagTuple.fromCaseId(caseId);
alias typeCounts = Map!(typeCount, SumTypes);
alias stride(size_t i) = .stride!(i, typeCounts);
alias TagTuple = .TagTuple!typeCounts;

alias handlerArgs(size_t caseId) = .handlerArgs!(caseId, typeCounts);

template getType(size_t i)
/* An AliasSeq of the types of the member values in the argument list
* returned by `handlerArgs!caseId`.
*
* Note that these are the actual (that is, qualified) types of the
* member values, which may not be the same as the types listed in
* the arguments' `.Types` properties.
*/
template valueTypes(size_t caseId)
{
enum tid = tags[i];
alias T = SumTypes[i].Types[tid];
alias getType = typeof(args[i].get!T());
enum tags = TagTuple.fromCaseId(caseId);

template getType(size_t i)
{
enum tid = tags[i];
alias T = SumTypes[i].Types[tid];
alias getType = typeof(args[i].get!T());
}

alias valueTypes = Map!(getType, Iota!(tags.length));
}

alias valueTypes = Map!(getType, Iota!(tags.length));
/* The total number of cases is
*
* Π SumTypes[i].Types.length for 0 ≤ i < SumTypes.length
*
* Conveniently, this is equal to stride!(SumTypes.length), so we can
* use that function to compute it.
*/
enum numCases = stride!(SumTypes.length);
}

/* The total number of cases is
*
* Π SumTypes[i].Types.length for 0 ≤ i < SumTypes.length
*
* Or, equivalently,
*
* ubyte[SumTypes[0].Types.length]...[SumTypes[$-1].Types.length].sizeof
*
* Conveniently, this is equal to stride!(SumTypes.length), so we can
* use that function to compute it.
*/
enum numCases = stride!(SumTypes.length);

/* Guaranteed to never be a valid handler index, since
* handlers.length <= size_t.max.
*/
Expand Down Expand Up @@ -1998,7 +1975,12 @@ private template matchImpl(Flag!"exhaustive" exhaustive, handlers...)
mixin("alias ", handlerName!hid, " = handler;");
}

immutable argsId = TagTuple(args).toCaseId;
// Single dispatch (fast path)
static if (args.length == 1)
immutable argsId = args[0].tag;
// Multiple dispatch (slow path)
else
immutable argsId = TagTuple(args).toCaseId;

final switch (argsId)
{
Expand Down Expand Up @@ -2029,10 +2011,11 @@ private template matchImpl(Flag!"exhaustive" exhaustive, handlers...)
}
}

// Predicate for staticMap
private enum typeCount(SumType) = SumType.Types.length;

/* A TagTuple represents a single possible set of tags that `args`
* could have at runtime.
/* A TagTuple represents a single possible set of tags that the arguments to
* `matchImpl` could have at runtime.
*
* Because D does not allow a struct to be the controlling expression
* of a switch statement, we cannot dispatch on the TagTuple directly.
Expand All @@ -2054,22 +2037,23 @@ private enum typeCount(SumType) = SumType.Types.length;
* When there is only one argument, the caseId is equal to that
* argument's tag.
*/
private struct TagTuple(SumTypes...)
private struct TagTuple(typeCounts...)
{
size_t[SumTypes.length] tags;
size_t[typeCounts.length] tags;
alias tags this;

alias stride(size_t i) = .stride!(i, Map!(typeCount, SumTypes));
alias stride(size_t i) = .stride!(i, typeCounts);

invariant
{
static foreach (i; 0 .. tags.length)
{
assert(tags[i] < SumTypes[i].Types.length, "Invalid tag");
assert(tags[i] < typeCounts[i], "Invalid tag");
}
}

this(ref const(SumTypes) args)
this(SumTypes...)(ref const SumTypes args)
if (allSatisfy!(isSumType, SumTypes) && args.length == typeCounts.length)
{
static foreach (i; 0 .. tags.length)
{
Expand Down Expand Up @@ -2104,6 +2088,52 @@ private struct TagTuple(SumTypes...)
}
}

/* The number that the dim-th argument's tag is multiplied by when
* converting TagTuples to and from case indices ("caseIds").
*
* Named by analogy to the stride that the dim-th index into a
* multidimensional static array is multiplied by to calculate the
* offset of a specific element.
*/
private size_t stride(size_t dim, lengths...)()
{
import core.checkedint : mulu;

size_t result = 1;
bool overflow = false;

static foreach (i; 0 .. dim)
{
result = mulu(result, lengths[i], overflow);
}

/* The largest number matchImpl uses, numCases, is calculated with
* stride!(SumTypes.length), so as long as this overflow check
* passes, we don't need to check for overflow anywhere else.
*/
assert(!overflow, "Integer overflow");
return result;
}

/* A list of arguments to be passed to a handler needed for the case
* labeled with `caseId`.
*/
private template handlerArgs(size_t caseId, typeCounts...)
{
enum tags = TagTuple!typeCounts.fromCaseId(caseId);

alias handlerArgs = AliasSeq!();

static foreach (i; 0 .. tags.length)
{
handlerArgs = AliasSeq!(
handlerArgs,
"args[" ~ toCtString!i ~ "].get!(SumTypes[" ~ toCtString!i ~ "]" ~
".Types[" ~ toCtString!(tags[i]) ~ "])(), "
);
}
}

// Matching
@safe unittest
{
Expand Down
Loading