Skip to content

Commit

Permalink
[experimental] Support custom discriminator properties for discrimina…
Browse files Browse the repository at this point in the history
…ted unions (#190)

* Support custom discriminator properties for discriminated union

* Regexify strings when testing code

For more reliable code assertions

* Add more tests

* Fix name transformer input
  • Loading branch information
lorisleiva authored Mar 28, 2024
1 parent 38943c1 commit 2f79032
Show file tree
Hide file tree
Showing 6 changed files with 201 additions and 43 deletions.
5 changes: 5 additions & 0 deletions .changeset/light-meals-know.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@metaplex-foundation/kinobi": patch
---

Support custom discriminator properties for discriminated union types in JS experimental renderer
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
// Data Enum Helpers.
{% for variant in typeNode.variants %}
{% if variant.kind === 'enumStructVariantTypeNode' %}
export function {{ discriminatedUnionFunction }}(kind: '{{ getVariant(variant.name) }}', data: GetDiscriminatedUnionVariantContent<{{ looseName }}, '__kind', '{{ getVariant(variant.name) }}'>): GetDiscriminatedUnionVariant<{{ looseName }}, '__kind', '{{ getVariant(variant.name) }}'>;
export function {{ discriminatedUnionFunction }}(kind: '{{ getVariant(variant.name) }}', data: GetDiscriminatedUnionVariantContent<{{ looseName }}, '{{ discriminatedUnionDiscriminator }}', '{{ getVariant(variant.name) }}'>): GetDiscriminatedUnionVariant<{{ looseName }}, '{{ discriminatedUnionDiscriminator }}', '{{ getVariant(variant.name) }}'>;
{% elif variant.kind === 'enumTupleVariantTypeNode' %}
export function {{ discriminatedUnionFunction }}(kind: '{{ getVariant(variant.name) }}', data: GetDiscriminatedUnionVariantContent<{{ looseName }}, '__kind', '{{ getVariant(variant.name) }}'>['fields']): GetDiscriminatedUnionVariant<{{ looseName }}, '__kind', '{{ getVariant(variant.name) }}'>;
export function {{ discriminatedUnionFunction }}(kind: '{{ getVariant(variant.name) }}', data: GetDiscriminatedUnionVariantContent<{{ looseName }}, '{{ discriminatedUnionDiscriminator }}', '{{ getVariant(variant.name) }}'>['fields']): GetDiscriminatedUnionVariant<{{ looseName }}, '{{ discriminatedUnionDiscriminator }}', '{{ getVariant(variant.name) }}'>;
{% else %}
export function {{ discriminatedUnionFunction }}(kind: '{{ getVariant(variant.name) }}'): GetDiscriminatedUnionVariant<{{ looseName }}, '__kind', '{{ getVariant(variant.name) }}'>;
export function {{ discriminatedUnionFunction }}(kind: '{{ getVariant(variant.name) }}'): GetDiscriminatedUnionVariant<{{ looseName }}, '{{ discriminatedUnionDiscriminator }}', '{{ getVariant(variant.name) }}'>;
{% endif %}
{% endfor %}
export function {{ discriminatedUnionFunction }}<K extends {{ looseName }}['{{ discriminatedUnionDiscriminator }}'], Data>(
Expand Down
28 changes: 22 additions & 6 deletions src/renderers/js-experimental/getTypeManifestVisitor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import {
structTypeNode,
structTypeNodeFromInstructionArgumentNodes,
} from '../../nodes';
import { camelCase, jsDocblock, pipe } from '../../shared';
import { camelCase, jsDocblock, mainCase, pipe } from '../../shared';
import { Visitor, extendVisitor, staticVisitor, visit } from '../../visitors';
import { ImportMap } from './ImportMap';
import { TypeManifest, mergeManifests } from './TypeManifest';
Expand Down Expand Up @@ -138,8 +138,6 @@ export function getTypeManifestVisitor(input: {

visitEnumType(enumType, { self }) {
const currentParentName = parentName;
parentName = null;

const encoderImports = new ImportMap();
const decoderImports = new ImportMap();
const encoderOptions: string[] = [];
Expand All @@ -153,6 +151,14 @@ export function getTypeManifestVisitor(input: {
decoderOptions.push(`size: ${sizeManifest.decoder.render}`);
}

const discriminator = nameApi.discriminatedUnionDiscriminator(
mainCase(currentParentName?.strict ?? '')
);
if (!isScalarEnum(enumType) && discriminator !== '__kind') {
encoderOptions.push(`discriminator: '${discriminator}'`);
decoderOptions.push(`discriminator: '${discriminator}'`);
}

const encoderOptionsAsString =
encoderOptions.length > 0
? `, { ${encoderOptions.join(', ')} }`
Expand Down Expand Up @@ -226,7 +232,7 @@ export function getTypeManifestVisitor(input: {

visitEnumEmptyVariantType(enumEmptyVariantType) {
const discriminator = nameApi.discriminatedUnionDiscriminator(
enumEmptyVariantType.name
mainCase(parentName?.strict ?? '')
);
const name = nameApi.discriminatedUnionVariant(
enumEmptyVariantType.name
Expand All @@ -248,14 +254,19 @@ export function getTypeManifestVisitor(input: {
},

visitEnumStructVariantType(enumStructVariantType, { self }) {
const currentParentName = parentName;
const discriminator = nameApi.discriminatedUnionDiscriminator(
enumStructVariantType.name
mainCase(currentParentName?.strict ?? '')
);
const name = nameApi.discriminatedUnionVariant(
enumStructVariantType.name
);
const kindAttribute = `${discriminator}: "${name}"`;

parentName = null;
const structManifest = visit(enumStructVariantType.struct, self);
parentName = currentParentName;

structManifest.strictType.mapRender(
(r) => `{ ${kindAttribute},${r.slice(1, -1)}}`
);
Expand All @@ -268,8 +279,9 @@ export function getTypeManifestVisitor(input: {
},

visitEnumTupleVariantType(enumTupleVariantType, { self }) {
const currentParentName = parentName;
const discriminator = nameApi.discriminatedUnionDiscriminator(
enumTupleVariantType.name
mainCase(currentParentName?.strict ?? '')
);
const name = nameApi.discriminatedUnionVariant(
enumTupleVariantType.name
Expand All @@ -281,7 +293,11 @@ export function getTypeManifestVisitor(input: {
type: enumTupleVariantType.tuple,
}),
]);

parentName = null;
const structManifest = visit(struct, self);
parentName = currentParentName;

structManifest.strictType.mapRender(
(r) => `{ ${kindAttribute},${r.slice(1, -1)}}`
);
Expand Down
37 changes: 21 additions & 16 deletions test/renderers/js-experimental/_setup.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import { format } from '@prettier/sync';
import type { ExecutionContext } from 'ava';
import chalk from 'chalk';
import { type Options as PrettierOptions } from 'prettier';
import type { RenderMap } from '../../../src';

Expand Down Expand Up @@ -32,18 +31,18 @@ export function codeContains(
) {
const expectedArray = Array.isArray(expected) ? expected : [expected];
const normalizedActual = normalizeCode(actual);
expectedArray.forEach((e) => {
if (typeof e === 'string') {
const normalizeExpected = normalizeCode(e);
t.true(
normalizedActual.includes(normalizeExpected),
`The following expected code is missing from the actual content:\n\n` +
`${chalk.blue(normalizeExpected)}\n\n` +
`Actual content:\n\n` +
`${chalk.blue(normalizedActual)}`
);
expectedArray.forEach((expectedResult) => {
if (typeof expectedResult === 'string') {
const stringAsRegex = escapeRegex(expectedResult)
// Transform spaces between words into required whitespace.
.replace(/(\w)\s+(\w)/g, '$1\\s+$2')
// Do it again for single-character words — e.g. "as[ ]a[ ]token".
.replace(/(\w)\s+(\w)/g, '$1\\s+$2')
// Transform other spaces into optional whitespace.
.replace(/\s+/g, '\\s*');
t.regex(normalizedActual, new RegExp(stringAsRegex));
} else {
t.regex(normalizedActual, e);
t.regex(normalizedActual, expectedResult);
}
});
}
Expand All @@ -63,7 +62,7 @@ export function codeContainsImports(
actual: string,
expectedImports: Record<string, string[]>
) {
const normalizedActual = normalizeCode(actual);
const normalizedActual = inlineCode(actual);
const importPairs = Object.entries(expectedImports).flatMap(
([key, value]) => {
return value.map((v) => [key, v] as const);
Expand All @@ -82,9 +81,15 @@ function normalizeCode(code: string) {
try {
code = format(code, PRETTIER_OPTIONS);
} catch (e) {}
return code.trim();
}

return code
function inlineCode(code: string) {
return normalizeCode(code)
.replace(/\s+/g, ' ')
.replace(/\s*(\W)\s*/g, '$1')
.trim();
.replace(/\s*(\W)\s*/g, '$1');
}

function escapeRegex(stringAsRegex: string) {
return stringAsRegex.replace(/[-\/\\^$*+?.()|[\]{}]/g, '\\$&');
}
38 changes: 20 additions & 18 deletions test/renderers/js-experimental/programsPage.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ test('it renders an enum of all available accounts for a program', (t) => {

// Then we expect the following program account enum.
renderMapContains(t, renderMap, 'programs/splToken.ts', [
'export enum SplTokenAccount { Mint, Token };',
'export enum SplTokenAccount { Mint, Token }',
]);
});

Expand Down Expand Up @@ -91,11 +91,11 @@ test('it renders an function that identifies accounts in a program', (t) => {
// Then we expect the following identifier function to be rendered.
// Notice it does not include the `mint` account because it has no discriminators.
renderMapContains(t, renderMap, 'programs/splToken.ts', [
`export function identifySplTokenAccount(account: { data: Uint8Array } | Uint8Array): SplTokenAccount {\n` +
`const data = account instanceof Uint8Array ? account : account.data;\n` +
`if (memcmp(data, getU8Encoder().encode(5), 0)) { return SplTokenAccount.Metadata; }\n` +
`if (data.length === 72 && memcmp(data, new Uint8Array([1, 2, 3]), 4)) { return SplTokenAccount.Token; }\n` +
`throw new Error('The provided account could not be identified as a splToken account.')\n` +
`export function identifySplTokenAccount( account: { data: Uint8Array } | Uint8Array ): SplTokenAccount { ` +
`const data = account instanceof Uint8Array ? account : account.data; ` +
`if ( memcmp(data, getU8Encoder().encode(5), 0) ) { return SplTokenAccount.Metadata; } ` +
`if ( data.length === 72 && memcmp(data, new Uint8Array([1, 2, 3]), 4) ) { return SplTokenAccount.Token; } ` +
`throw new Error ( 'The provided account could not be identified as a splToken account.' ); ` +
`}`,
]);

Expand All @@ -122,7 +122,7 @@ test('it renders an enum of all available instructions for a program', (t) => {

// Then we expect the following program instruction enum.
renderMapContains(t, renderMap, 'programs/splToken.ts', [
'export enum SplTokenInstruction { MintTokens, TransferTokens, UpdateAuthority };',
'export enum SplTokenInstruction { MintTokens, TransferTokens, UpdateAuthority }',
]);
});

Expand Down Expand Up @@ -163,11 +163,11 @@ test('it renders an function that identifies instructions in a program', (t) =>
// Then we expect the following identifier function to be rendered.
// Notice it does not include the `updateAuthority` instruction because it has no discriminators.
renderMapContains(t, renderMap, 'programs/splToken.ts', [
`export function identifySplTokenInstruction(instruction: { data: Uint8Array } | Uint8Array): SplTokenInstruction {\n` +
`const data = instruction instanceof Uint8Array ? instruction : instruction.data;\n` +
`if (memcmp(data, getU8Encoder().encode(1), 0)) { return SplTokenInstruction.MintTokens; }\n` +
`if (data.length === 72 && memcmp(data, new Uint8Array([1, 2, 3]), 4)) { return SplTokenInstruction.TransferTokens; }\n` +
`throw new Error('The provided instruction could not be identified as a splToken instruction.')\n` +
`export function identifySplTokenInstruction ( instruction: { data: Uint8Array } | Uint8Array ): SplTokenInstruction { ` +
`const data = instruction instanceof Uint8Array ? instruction : instruction.data; ` +
`if ( memcmp(data, getU8Encoder().encode(1), 0) ) { return SplTokenInstruction.MintTokens; } ` +
`if ( data.length === 72 && memcmp(data, new Uint8Array([1, 2, 3]), 4) ) { return SplTokenInstruction.TransferTokens; } ` +
`throw new Error( 'The provided instruction could not be identified as a splToken instruction.' ); ` +
`}`,
]);

Expand Down Expand Up @@ -232,8 +232,10 @@ test('it checks the discriminator of sub-instructions before their parents.', (t

// Then we expect the sub-instruction condition to be rendered before the parent instruction condition.
renderMapContains(t, renderMap, 'programs/splToken.ts', [
`if (memcmp(data, getU8Encoder().encode(1), 0) && memcmp(data, getU32Encoder().encode(1), 1)) { return SplTokenInstruction.MintTokensV1; }\n` +
`if (memcmp(data, getU8Encoder().encode(1), 0)) { return SplTokenInstruction.MintTokens; }`,
`if ( memcmp(data, getU8Encoder().encode(1), 0) && memcmp(data, getU32Encoder().encode(1), 1) ) ` +
`{ return SplTokenInstruction.MintTokensV1; } ` +
`if ( memcmp(data, getU8Encoder().encode(1), 0) ) ` +
`{ return SplTokenInstruction.MintTokens; }`,
]);
});

Expand All @@ -254,9 +256,9 @@ test('it renders a parsed union type of all available instructions for a program

// Then we expect the following program parsed instruction union type.
renderMapContains(t, renderMap, 'programs/splToken.ts', [
"export type ParsedSplTokenInstruction<TProgram extends string='TokenkegQfeZyiNwAJbNbGKPFXCWuBvf9Ss623VQ5DA'>=",
'|({instructionType: SplTokenInstruction.MintTokens;} & ParsedMintTokensInstruction<TProgram>)',
'|({instructionType: SplTokenInstruction.TransferTokens;} & ParsedTransferTokensInstruction<TProgram>)',
'|({instructionType: SplTokenInstruction.UpdateAuthority;} & ParsedUpdateAuthorityInstruction<TProgram>)',
"export type ParsedSplTokenInstruction < TProgram extends string = 'TokenkegQfeZyiNwAJbNbGKPFXCWuBvf9Ss623VQ5DA' >",
'| ({ instructionType: SplTokenInstruction.MintTokens; } & ParsedMintTokensInstruction<TProgram>)',
'| ({ instructionType: SplTokenInstruction.TransferTokens; } & ParsedTransferTokensInstruction<TProgram>)',
'| ({ instructionType: SplTokenInstruction.UpdateAuthority; } & ParsedUpdateAuthorityInstruction<TProgram>)',
]);
});
130 changes: 130 additions & 0 deletions test/renderers/js-experimental/types/discriminatedUnions.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import test from 'ava';
import {
definedTypeNode,
enumEmptyVariantTypeNode,
enumStructVariantTypeNode,
enumTupleVariantTypeNode,
enumTypeNode,
numberTypeNode,
stringTypeNode,
structFieldTypeNode,
structTypeNode,
tupleTypeNode,
visit,
} from '../../../../src';
import { getRenderMapVisitor } from '../../../../src/renderers/js-experimental/getRenderMapVisitor';
import { renderMapContains } from '../_setup';

// Given the following event discriminated union.
const eventTypeNode = definedTypeNode({
name: 'event',
type: enumTypeNode([
enumEmptyVariantTypeNode('quit'),
enumTupleVariantTypeNode('write', tupleTypeNode([stringTypeNode()])),
enumStructVariantTypeNode(
'move',
structTypeNode([
structFieldTypeNode({ name: 'x', type: numberTypeNode('u32') }),
structFieldTypeNode({ name: 'y', type: numberTypeNode('u32') }),
])
),
]),
});

test('it exports discriminated union types', (t) => {
// When we render a discriminated union.
const renderMap = visit(eventTypeNode, getRenderMapVisitor());

// Then we expect the following types to be exported.
renderMapContains(t, renderMap, 'types/event.ts', [
'export type Event =',
"| { __kind: 'Quit' }",
"| { __kind: 'Write'; fields: readonly [string] }",
"| { __kind: 'Move'; x: number; y: number }",
]);
});

test('it exports discriminated union codecs', (t) => {
// When we render a discriminated union.
const renderMap = visit(eventTypeNode, getRenderMapVisitor());

// Then we expect the following codec functions to be exported.
renderMapContains(t, renderMap, 'types/event.ts', [
'export function getEventEncoder(): Encoder< EventArgs >',
'export function getEventDecoder(): Decoder< Event >',
'export function getEventCodec(): Codec< EventArgs, Event >',
]);
});

test('it exports discriminated union helpers', (t) => {
// When we render a discriminated union.
const renderMap = visit(eventTypeNode, getRenderMapVisitor());

// Then we expect the following helpers to be exported.
renderMapContains(t, renderMap, 'types/event.ts', [
"export function event( kind: 'Quit' ): GetDiscriminatedUnionVariant< EventArgs, '__kind', 'Quit' >;",
"export function event( kind: 'Write', data: GetDiscriminatedUnionVariantContent< EventArgs, '__kind', 'Write' >[ 'fields' ] ): GetDiscriminatedUnionVariant< EventArgs, '__kind', 'Write' >;",
"export function event( kind: 'Move', data: GetDiscriminatedUnionVariantContent< EventArgs, '__kind', 'Move' > ): GetDiscriminatedUnionVariant< EventArgs, '__kind', 'Move' >;",
"export function isEvent< K extends Event['__kind'] >( kind: K, value: Event ): value is Event & { __kind: K }",
]);
});

test('it exports discriminated union with custom discriminator properties', (t) => {
// When we render a discriminated union with a custom discriminator property.
const renderMap = visit(
eventTypeNode,
getRenderMapVisitor({
nameTransformers: { discriminatedUnionDiscriminator: () => `type` },
})
);

// Then we expect the discriminator property to be used instead of __kind.
renderMapContains(t, renderMap, 'types/event.ts', [
"{ discriminator: 'type' }",
"| { type: 'Quit' }",
"| { type: 'Write'; fields: readonly [string] }",
"| { type: 'Move'; x: number; y: number }",
"export function event( kind: 'Quit' ): GetDiscriminatedUnionVariant< EventArgs, 'type', 'Quit' >;",
"export function event( kind: 'Write', data: GetDiscriminatedUnionVariantContent< EventArgs, 'type', 'Write' >[ 'fields' ] ): GetDiscriminatedUnionVariant< EventArgs, 'type', 'Write' >;",
"export function event( kind: 'Move', data: GetDiscriminatedUnionVariantContent< EventArgs, 'type', 'Move' > ): GetDiscriminatedUnionVariant< EventArgs, 'type', 'Move' >;",
"export function isEvent< K extends Event['type'] >( kind: K, value: Event ): value is Event & { type: K }",
]);
});

test('it use a custom discriminator property for selected unions', (t) => {
// Given two discriminated unions A and B.
const eventTypeNodeA = definedTypeNode({ ...eventTypeNode, name: 'eventA' });
const eventTypeNodeB = definedTypeNode({ ...eventTypeNode, name: 'eventB' });

// And given we use different discriminator properties for each union.
const nameTransformers = {
discriminatedUnionDiscriminator: (union: string) =>
union === 'eventA' ? 'typeA' : `typeB`,
};

// When we render both discriminated unions.
const renderMapA = visit(
eventTypeNodeA,
getRenderMapVisitor({ nameTransformers })
);
const renderMapB = visit(
eventTypeNodeB,
getRenderMapVisitor({ nameTransformers })
);

// Then we expect discriminated union A to use 'typeA' as its discriminator property.
renderMapContains(t, renderMapA, 'types/eventA.ts', [
"{ discriminator: 'typeA' }",
"| { typeA: 'Quit' }",
"| { typeA: 'Write'; fields: readonly [string] }",
"| { typeA: 'Move'; x: number; y: number }",
]);

// And discriminated union B to use 'typeB' as its discriminator property.
renderMapContains(t, renderMapB, 'types/eventB.ts', [
"{ discriminator: 'typeB' }",
"| { typeB: 'Quit' }",
"| { typeB: 'Write'; fields: readonly [string] }",
"| { typeB: 'Move'; x: number; y: number }",
]);
});

0 comments on commit 2f79032

Please sign in to comment.