diff --git a/.changeset/light-meals-know.md b/.changeset/light-meals-know.md new file mode 100644 index 00000000..e869acda --- /dev/null +++ b/.changeset/light-meals-know.md @@ -0,0 +1,5 @@ +--- +"@metaplex-foundation/kinobi": patch +--- + +Support custom discriminator properties for discriminated union types in JS experimental renderer diff --git a/src/renderers/js-experimental/fragments/typeDiscriminatedUnionHelpers.njk b/src/renderers/js-experimental/fragments/typeDiscriminatedUnionHelpers.njk index cc1a038d..f2d4e5b9 100644 --- a/src/renderers/js-experimental/fragments/typeDiscriminatedUnionHelpers.njk +++ b/src/renderers/js-experimental/fragments/typeDiscriminatedUnionHelpers.njk @@ -1,11 +1,11 @@ // Data Enum Helpers. {% for variant in typeNode.variants %} {% if variant.kind === 'enumStructVariantTypeNode' %} - export function {{ discriminatedUnionFunction }}(kind: '{{ getVariant(variant.name) }}', data: GetDiscriminatedUnionVariantContent<{{ looseName }}, '__kind', '{{ getVariant(variant.name) }}'>): GetDiscriminatedUnionVariant<{{ looseName }}, '__kind', '{{ getVariant(variant.name) }}'>; + export function {{ discriminatedUnionFunction }}(kind: '{{ getVariant(variant.name) }}', data: GetDiscriminatedUnionVariantContent<{{ looseName }}, '{{ discriminatedUnionDiscriminator }}', '{{ getVariant(variant.name) }}'>): GetDiscriminatedUnionVariant<{{ looseName }}, '{{ discriminatedUnionDiscriminator }}', '{{ getVariant(variant.name) }}'>; {% elif variant.kind === 'enumTupleVariantTypeNode' %} - export function {{ discriminatedUnionFunction }}(kind: '{{ getVariant(variant.name) }}', data: GetDiscriminatedUnionVariantContent<{{ looseName }}, '__kind', '{{ getVariant(variant.name) }}'>['fields']): GetDiscriminatedUnionVariant<{{ looseName }}, '__kind', '{{ getVariant(variant.name) }}'>; + export function {{ discriminatedUnionFunction }}(kind: '{{ getVariant(variant.name) }}', data: GetDiscriminatedUnionVariantContent<{{ looseName }}, '{{ discriminatedUnionDiscriminator }}', '{{ getVariant(variant.name) }}'>['fields']): GetDiscriminatedUnionVariant<{{ looseName }}, '{{ discriminatedUnionDiscriminator }}', '{{ getVariant(variant.name) }}'>; {% else %} - export function {{ discriminatedUnionFunction }}(kind: '{{ getVariant(variant.name) }}'): GetDiscriminatedUnionVariant<{{ looseName }}, '__kind', '{{ getVariant(variant.name) }}'>; + export function {{ discriminatedUnionFunction }}(kind: '{{ getVariant(variant.name) }}'): GetDiscriminatedUnionVariant<{{ looseName }}, '{{ discriminatedUnionDiscriminator }}', '{{ getVariant(variant.name) }}'>; {% endif %} {% endfor %} export function {{ discriminatedUnionFunction }}( diff --git a/src/renderers/js-experimental/getTypeManifestVisitor.ts b/src/renderers/js-experimental/getTypeManifestVisitor.ts index 99225bf2..f91d4fee 100644 --- a/src/renderers/js-experimental/getTypeManifestVisitor.ts +++ b/src/renderers/js-experimental/getTypeManifestVisitor.ts @@ -7,7 +7,7 @@ import { structTypeNode, structTypeNodeFromInstructionArgumentNodes, } from '../../nodes'; -import { camelCase, jsDocblock, pipe } from '../../shared'; +import { camelCase, jsDocblock, mainCase, pipe } from '../../shared'; import { Visitor, extendVisitor, staticVisitor, visit } from '../../visitors'; import { ImportMap } from './ImportMap'; import { TypeManifest, mergeManifests } from './TypeManifest'; @@ -138,8 +138,6 @@ export function getTypeManifestVisitor(input: { visitEnumType(enumType, { self }) { const currentParentName = parentName; - parentName = null; - const encoderImports = new ImportMap(); const decoderImports = new ImportMap(); const encoderOptions: string[] = []; @@ -153,6 +151,14 @@ export function getTypeManifestVisitor(input: { decoderOptions.push(`size: ${sizeManifest.decoder.render}`); } + const discriminator = nameApi.discriminatedUnionDiscriminator( + mainCase(currentParentName?.strict ?? '') + ); + if (!isScalarEnum(enumType) && discriminator !== '__kind') { + encoderOptions.push(`discriminator: '${discriminator}'`); + decoderOptions.push(`discriminator: '${discriminator}'`); + } + const encoderOptionsAsString = encoderOptions.length > 0 ? `, { ${encoderOptions.join(', ')} }` @@ -226,7 +232,7 @@ export function getTypeManifestVisitor(input: { visitEnumEmptyVariantType(enumEmptyVariantType) { const discriminator = nameApi.discriminatedUnionDiscriminator( - enumEmptyVariantType.name + mainCase(parentName?.strict ?? '') ); const name = nameApi.discriminatedUnionVariant( enumEmptyVariantType.name @@ -248,14 +254,19 @@ export function getTypeManifestVisitor(input: { }, visitEnumStructVariantType(enumStructVariantType, { self }) { + const currentParentName = parentName; const discriminator = nameApi.discriminatedUnionDiscriminator( - enumStructVariantType.name + mainCase(currentParentName?.strict ?? '') ); const name = nameApi.discriminatedUnionVariant( enumStructVariantType.name ); const kindAttribute = `${discriminator}: "${name}"`; + + parentName = null; const structManifest = visit(enumStructVariantType.struct, self); + parentName = currentParentName; + structManifest.strictType.mapRender( (r) => `{ ${kindAttribute},${r.slice(1, -1)}}` ); @@ -268,8 +279,9 @@ export function getTypeManifestVisitor(input: { }, visitEnumTupleVariantType(enumTupleVariantType, { self }) { + const currentParentName = parentName; const discriminator = nameApi.discriminatedUnionDiscriminator( - enumTupleVariantType.name + mainCase(currentParentName?.strict ?? '') ); const name = nameApi.discriminatedUnionVariant( enumTupleVariantType.name @@ -281,7 +293,11 @@ export function getTypeManifestVisitor(input: { type: enumTupleVariantType.tuple, }), ]); + + parentName = null; const structManifest = visit(struct, self); + parentName = currentParentName; + structManifest.strictType.mapRender( (r) => `{ ${kindAttribute},${r.slice(1, -1)}}` ); diff --git a/test/renderers/js-experimental/_setup.ts b/test/renderers/js-experimental/_setup.ts index e39dc5db..c8b6e371 100644 --- a/test/renderers/js-experimental/_setup.ts +++ b/test/renderers/js-experimental/_setup.ts @@ -1,6 +1,5 @@ import { format } from '@prettier/sync'; import type { ExecutionContext } from 'ava'; -import chalk from 'chalk'; import { type Options as PrettierOptions } from 'prettier'; import type { RenderMap } from '../../../src'; @@ -32,18 +31,18 @@ export function codeContains( ) { const expectedArray = Array.isArray(expected) ? expected : [expected]; const normalizedActual = normalizeCode(actual); - expectedArray.forEach((e) => { - if (typeof e === 'string') { - const normalizeExpected = normalizeCode(e); - t.true( - normalizedActual.includes(normalizeExpected), - `The following expected code is missing from the actual content:\n\n` + - `${chalk.blue(normalizeExpected)}\n\n` + - `Actual content:\n\n` + - `${chalk.blue(normalizedActual)}` - ); + expectedArray.forEach((expectedResult) => { + if (typeof expectedResult === 'string') { + const stringAsRegex = escapeRegex(expectedResult) + // Transform spaces between words into required whitespace. + .replace(/(\w)\s+(\w)/g, '$1\\s+$2') + // Do it again for single-character words — e.g. "as[ ]a[ ]token". + .replace(/(\w)\s+(\w)/g, '$1\\s+$2') + // Transform other spaces into optional whitespace. + .replace(/\s+/g, '\\s*'); + t.regex(normalizedActual, new RegExp(stringAsRegex)); } else { - t.regex(normalizedActual, e); + t.regex(normalizedActual, expectedResult); } }); } @@ -63,7 +62,7 @@ export function codeContainsImports( actual: string, expectedImports: Record ) { - const normalizedActual = normalizeCode(actual); + const normalizedActual = inlineCode(actual); const importPairs = Object.entries(expectedImports).flatMap( ([key, value]) => { return value.map((v) => [key, v] as const); @@ -82,9 +81,15 @@ function normalizeCode(code: string) { try { code = format(code, PRETTIER_OPTIONS); } catch (e) {} + return code.trim(); +} - return code +function inlineCode(code: string) { + return normalizeCode(code) .replace(/\s+/g, ' ') - .replace(/\s*(\W)\s*/g, '$1') - .trim(); + .replace(/\s*(\W)\s*/g, '$1'); +} + +function escapeRegex(stringAsRegex: string) { + return stringAsRegex.replace(/[-\/\\^$*+?.()|[\]{}]/g, '\\$&'); } diff --git a/test/renderers/js-experimental/programsPage.test.ts b/test/renderers/js-experimental/programsPage.test.ts index 8cf1fab3..5d707945 100644 --- a/test/renderers/js-experimental/programsPage.test.ts +++ b/test/renderers/js-experimental/programsPage.test.ts @@ -50,7 +50,7 @@ test('it renders an enum of all available accounts for a program', (t) => { // Then we expect the following program account enum. renderMapContains(t, renderMap, 'programs/splToken.ts', [ - 'export enum SplTokenAccount { Mint, Token };', + 'export enum SplTokenAccount { Mint, Token }', ]); }); @@ -91,11 +91,11 @@ test('it renders an function that identifies accounts in a program', (t) => { // Then we expect the following identifier function to be rendered. // Notice it does not include the `mint` account because it has no discriminators. renderMapContains(t, renderMap, 'programs/splToken.ts', [ - `export function identifySplTokenAccount(account: { data: Uint8Array } | Uint8Array): SplTokenAccount {\n` + - `const data = account instanceof Uint8Array ? account : account.data;\n` + - `if (memcmp(data, getU8Encoder().encode(5), 0)) { return SplTokenAccount.Metadata; }\n` + - `if (data.length === 72 && memcmp(data, new Uint8Array([1, 2, 3]), 4)) { return SplTokenAccount.Token; }\n` + - `throw new Error('The provided account could not be identified as a splToken account.')\n` + + `export function identifySplTokenAccount( account: { data: Uint8Array } | Uint8Array ): SplTokenAccount { ` + + `const data = account instanceof Uint8Array ? account : account.data; ` + + `if ( memcmp(data, getU8Encoder().encode(5), 0) ) { return SplTokenAccount.Metadata; } ` + + `if ( data.length === 72 && memcmp(data, new Uint8Array([1, 2, 3]), 4) ) { return SplTokenAccount.Token; } ` + + `throw new Error ( 'The provided account could not be identified as a splToken account.' ); ` + `}`, ]); @@ -122,7 +122,7 @@ test('it renders an enum of all available instructions for a program', (t) => { // Then we expect the following program instruction enum. renderMapContains(t, renderMap, 'programs/splToken.ts', [ - 'export enum SplTokenInstruction { MintTokens, TransferTokens, UpdateAuthority };', + 'export enum SplTokenInstruction { MintTokens, TransferTokens, UpdateAuthority }', ]); }); @@ -163,11 +163,11 @@ test('it renders an function that identifies instructions in a program', (t) => // Then we expect the following identifier function to be rendered. // Notice it does not include the `updateAuthority` instruction because it has no discriminators. renderMapContains(t, renderMap, 'programs/splToken.ts', [ - `export function identifySplTokenInstruction(instruction: { data: Uint8Array } | Uint8Array): SplTokenInstruction {\n` + - `const data = instruction instanceof Uint8Array ? instruction : instruction.data;\n` + - `if (memcmp(data, getU8Encoder().encode(1), 0)) { return SplTokenInstruction.MintTokens; }\n` + - `if (data.length === 72 && memcmp(data, new Uint8Array([1, 2, 3]), 4)) { return SplTokenInstruction.TransferTokens; }\n` + - `throw new Error('The provided instruction could not be identified as a splToken instruction.')\n` + + `export function identifySplTokenInstruction ( instruction: { data: Uint8Array } | Uint8Array ): SplTokenInstruction { ` + + `const data = instruction instanceof Uint8Array ? instruction : instruction.data; ` + + `if ( memcmp(data, getU8Encoder().encode(1), 0) ) { return SplTokenInstruction.MintTokens; } ` + + `if ( data.length === 72 && memcmp(data, new Uint8Array([1, 2, 3]), 4) ) { return SplTokenInstruction.TransferTokens; } ` + + `throw new Error( 'The provided instruction could not be identified as a splToken instruction.' ); ` + `}`, ]); @@ -232,8 +232,10 @@ test('it checks the discriminator of sub-instructions before their parents.', (t // Then we expect the sub-instruction condition to be rendered before the parent instruction condition. renderMapContains(t, renderMap, 'programs/splToken.ts', [ - `if (memcmp(data, getU8Encoder().encode(1), 0) && memcmp(data, getU32Encoder().encode(1), 1)) { return SplTokenInstruction.MintTokensV1; }\n` + - `if (memcmp(data, getU8Encoder().encode(1), 0)) { return SplTokenInstruction.MintTokens; }`, + `if ( memcmp(data, getU8Encoder().encode(1), 0) && memcmp(data, getU32Encoder().encode(1), 1) ) ` + + `{ return SplTokenInstruction.MintTokensV1; } ` + + `if ( memcmp(data, getU8Encoder().encode(1), 0) ) ` + + `{ return SplTokenInstruction.MintTokens; }`, ]); }); @@ -254,9 +256,9 @@ test('it renders a parsed union type of all available instructions for a program // Then we expect the following program parsed instruction union type. renderMapContains(t, renderMap, 'programs/splToken.ts', [ - "export type ParsedSplTokenInstruction=", - '|({instructionType: SplTokenInstruction.MintTokens;} & ParsedMintTokensInstruction)', - '|({instructionType: SplTokenInstruction.TransferTokens;} & ParsedTransferTokensInstruction)', - '|({instructionType: SplTokenInstruction.UpdateAuthority;} & ParsedUpdateAuthorityInstruction)', + "export type ParsedSplTokenInstruction < TProgram extends string = 'TokenkegQfeZyiNwAJbNbGKPFXCWuBvf9Ss623VQ5DA' >", + '| ({ instructionType: SplTokenInstruction.MintTokens; } & ParsedMintTokensInstruction)', + '| ({ instructionType: SplTokenInstruction.TransferTokens; } & ParsedTransferTokensInstruction)', + '| ({ instructionType: SplTokenInstruction.UpdateAuthority; } & ParsedUpdateAuthorityInstruction)', ]); }); diff --git a/test/renderers/js-experimental/types/discriminatedUnions.test.ts b/test/renderers/js-experimental/types/discriminatedUnions.test.ts new file mode 100644 index 00000000..5c534260 --- /dev/null +++ b/test/renderers/js-experimental/types/discriminatedUnions.test.ts @@ -0,0 +1,130 @@ +import test from 'ava'; +import { + definedTypeNode, + enumEmptyVariantTypeNode, + enumStructVariantTypeNode, + enumTupleVariantTypeNode, + enumTypeNode, + numberTypeNode, + stringTypeNode, + structFieldTypeNode, + structTypeNode, + tupleTypeNode, + visit, +} from '../../../../src'; +import { getRenderMapVisitor } from '../../../../src/renderers/js-experimental/getRenderMapVisitor'; +import { renderMapContains } from '../_setup'; + +// Given the following event discriminated union. +const eventTypeNode = definedTypeNode({ + name: 'event', + type: enumTypeNode([ + enumEmptyVariantTypeNode('quit'), + enumTupleVariantTypeNode('write', tupleTypeNode([stringTypeNode()])), + enumStructVariantTypeNode( + 'move', + structTypeNode([ + structFieldTypeNode({ name: 'x', type: numberTypeNode('u32') }), + structFieldTypeNode({ name: 'y', type: numberTypeNode('u32') }), + ]) + ), + ]), +}); + +test('it exports discriminated union types', (t) => { + // When we render a discriminated union. + const renderMap = visit(eventTypeNode, getRenderMapVisitor()); + + // Then we expect the following types to be exported. + renderMapContains(t, renderMap, 'types/event.ts', [ + 'export type Event =', + "| { __kind: 'Quit' }", + "| { __kind: 'Write'; fields: readonly [string] }", + "| { __kind: 'Move'; x: number; y: number }", + ]); +}); + +test('it exports discriminated union codecs', (t) => { + // When we render a discriminated union. + const renderMap = visit(eventTypeNode, getRenderMapVisitor()); + + // Then we expect the following codec functions to be exported. + renderMapContains(t, renderMap, 'types/event.ts', [ + 'export function getEventEncoder(): Encoder< EventArgs >', + 'export function getEventDecoder(): Decoder< Event >', + 'export function getEventCodec(): Codec< EventArgs, Event >', + ]); +}); + +test('it exports discriminated union helpers', (t) => { + // When we render a discriminated union. + const renderMap = visit(eventTypeNode, getRenderMapVisitor()); + + // Then we expect the following helpers to be exported. + renderMapContains(t, renderMap, 'types/event.ts', [ + "export function event( kind: 'Quit' ): GetDiscriminatedUnionVariant< EventArgs, '__kind', 'Quit' >;", + "export function event( kind: 'Write', data: GetDiscriminatedUnionVariantContent< EventArgs, '__kind', 'Write' >[ 'fields' ] ): GetDiscriminatedUnionVariant< EventArgs, '__kind', 'Write' >;", + "export function event( kind: 'Move', data: GetDiscriminatedUnionVariantContent< EventArgs, '__kind', 'Move' > ): GetDiscriminatedUnionVariant< EventArgs, '__kind', 'Move' >;", + "export function isEvent< K extends Event['__kind'] >( kind: K, value: Event ): value is Event & { __kind: K }", + ]); +}); + +test('it exports discriminated union with custom discriminator properties', (t) => { + // When we render a discriminated union with a custom discriminator property. + const renderMap = visit( + eventTypeNode, + getRenderMapVisitor({ + nameTransformers: { discriminatedUnionDiscriminator: () => `type` }, + }) + ); + + // Then we expect the discriminator property to be used instead of __kind. + renderMapContains(t, renderMap, 'types/event.ts', [ + "{ discriminator: 'type' }", + "| { type: 'Quit' }", + "| { type: 'Write'; fields: readonly [string] }", + "| { type: 'Move'; x: number; y: number }", + "export function event( kind: 'Quit' ): GetDiscriminatedUnionVariant< EventArgs, 'type', 'Quit' >;", + "export function event( kind: 'Write', data: GetDiscriminatedUnionVariantContent< EventArgs, 'type', 'Write' >[ 'fields' ] ): GetDiscriminatedUnionVariant< EventArgs, 'type', 'Write' >;", + "export function event( kind: 'Move', data: GetDiscriminatedUnionVariantContent< EventArgs, 'type', 'Move' > ): GetDiscriminatedUnionVariant< EventArgs, 'type', 'Move' >;", + "export function isEvent< K extends Event['type'] >( kind: K, value: Event ): value is Event & { type: K }", + ]); +}); + +test('it use a custom discriminator property for selected unions', (t) => { + // Given two discriminated unions A and B. + const eventTypeNodeA = definedTypeNode({ ...eventTypeNode, name: 'eventA' }); + const eventTypeNodeB = definedTypeNode({ ...eventTypeNode, name: 'eventB' }); + + // And given we use different discriminator properties for each union. + const nameTransformers = { + discriminatedUnionDiscriminator: (union: string) => + union === 'eventA' ? 'typeA' : `typeB`, + }; + + // When we render both discriminated unions. + const renderMapA = visit( + eventTypeNodeA, + getRenderMapVisitor({ nameTransformers }) + ); + const renderMapB = visit( + eventTypeNodeB, + getRenderMapVisitor({ nameTransformers }) + ); + + // Then we expect discriminated union A to use 'typeA' as its discriminator property. + renderMapContains(t, renderMapA, 'types/eventA.ts', [ + "{ discriminator: 'typeA' }", + "| { typeA: 'Quit' }", + "| { typeA: 'Write'; fields: readonly [string] }", + "| { typeA: 'Move'; x: number; y: number }", + ]); + + // And discriminated union B to use 'typeB' as its discriminator property. + renderMapContains(t, renderMapB, 'types/eventB.ts', [ + "{ discriminator: 'typeB' }", + "| { typeB: 'Quit' }", + "| { typeB: 'Write'; fields: readonly [string] }", + "| { typeB: 'Move'; x: number; y: number }", + ]); +});