From 75b5d46fc0303e2197c18f1b9b7a34428dc31a72 Mon Sep 17 00:00:00 2001 From: Uri Shaked Date: Fri, 9 Mar 2018 01:47:29 +0200 Subject: [PATCH] feat: find more types using Type Inference --- src/apply-types.ts | 59 ++++++++++++++++++++++++++++++----- src/instrument.ts | 19 +++++++++-- src/integration.spec.ts | 50 ++++++++++++++++++++++++++++- src/replacement.ts | 13 +++++--- src/test-utils/transpile.ts | 20 +++++++----- src/type-collector-snippet.ts | 18 ++++++++--- 6 files changed, 153 insertions(+), 26 deletions(-) diff --git a/src/apply-types.ts b/src/apply-types.ts index 2239a01..7aee26b 100644 --- a/src/apply-types.ts +++ b/src/apply-types.ts @@ -1,10 +1,14 @@ import * as fs from 'fs'; import * as path from 'path'; +import * as ts from 'typescript'; import { IExtraOptions } from './instrument'; import { applyReplacements, Replacement } from './replacement'; +import { ISourceLocation } from './type-collector-snippet'; -export type ICollectedTypeInfo = Array<[string, number, string[], IExtraOptions]>; +export type ICollectedTypeInfo = Array< + [string, number, Array<[string | undefined, ISourceLocation | undefined]>, IExtraOptions] +>; export interface IApplyTypesOptions { /** @@ -18,20 +22,57 @@ export interface IApplyTypesOptions { * If given, all the file paths in the collected type info will be resolved relative to this directory. */ rootDir?: string; + + /** + * Options for the TypeScript compiler. + */ + tsConfig?: ts.CompilerOptions; + + tsCompilerHost?: ts.CompilerHost; +} + +function findType(program?: ts.Program, name?: string, sourcePos?: ISourceLocation) { + if (program && sourcePos) { + const [sourceName, sourceOffset] = sourcePos; + const typeChecker = program.getTypeChecker(); + let foundType: string | null = null; + function visit(node: ts.Node) { + if (node.getStart() === sourceOffset) { + const type = typeChecker.getTypeAtLocation(node); + foundType = typeChecker.typeToString(type); + } + ts.forEachChild(node, visit); + } + const sourceFile = program.getSourceFile(sourceName); + visit(sourceFile); + if (foundType) { + return foundType; + } + } + return name; } -export function applyTypesToFile(source: string, typeInfo: ICollectedTypeInfo, options: IApplyTypesOptions) { +export function applyTypesToFile( + source: string, + typeInfo: ICollectedTypeInfo, + options: IApplyTypesOptions, + program?: ts.Program, +) { const replacements = []; const prefix = options.prefix || ''; for (const [, pos, types, opts] of typeInfo) { const isOptional = source[pos - 1] === '?'; - let sortedTypes = types.sort(); + let sortedTypes = types + .map(([name, sourcePos]) => findType(program, name, sourcePos)) + .filter((t) => t) + .sort(); if (isOptional) { sortedTypes = sortedTypes.filter((t) => t !== 'undefined'); - if (sortedTypes.length === 0) { - continue; - } } + if (sortedTypes.length === 0) { + continue; + } + let suffix = ''; if (opts && opts.parens) { replacements.push(Replacement.insert(opts.parens[0], '(')); @@ -44,6 +85,10 @@ export function applyTypesToFile(source: string, typeInfo: ICollectedTypeInfo, o export function applyTypes(typeInfo: ICollectedTypeInfo, options: IApplyTypesOptions = {}) { const files: { [key: string]: typeof typeInfo } = {}; + let program: ts.Program | undefined; + if (options.tsConfig) { + program = ts.createProgram(['c:\\test.ts'], options.tsConfig, options.tsCompilerHost); + } for (const entry of typeInfo) { const file = entry[0]; if (!files[file]) { @@ -54,6 +99,6 @@ export function applyTypes(typeInfo: ICollectedTypeInfo, options: IApplyTypesOpt for (const file of Object.keys(files)) { const filePath = options.rootDir ? path.join(options.rootDir, file) : file; const source = fs.readFileSync(filePath, 'utf-8'); - fs.writeFileSync(filePath, applyTypesToFile(source, files[file], options)); + fs.writeFileSync(filePath, applyTypesToFile(source, files[file], options, program)); } } diff --git a/src/instrument.ts b/src/instrument.ts index a94171a..7ae0562 100644 --- a/src/instrument.ts +++ b/src/instrument.ts @@ -45,7 +45,7 @@ function visit(node: ts.Node, replacements: Replacement[], fileName: string) { const instrumentExpr = `$_$twiz(${params.join(',')})`; if (isShortArrow) { replacements.push(Replacement.insert(node.body.getStart(), `(${instrumentExpr},`)); - replacements.push(Replacement.insert(node.body.getEnd(), `)`)); + replacements.push(Replacement.insert(node.body.getEnd(), `)`, 10)); } else { replacements.push(Replacement.insert(node.body.getStart() + 1, `${instrumentExpr};`)); } @@ -53,6 +53,15 @@ function visit(node: ts.Node, replacements: Replacement[], fileName: string) { } } + if (ts.isCallExpression(node) && node.expression.getText() !== 'require.context') { + for (const arg of node.arguments) { + if (!ts.isStringLiteral(arg) && !ts.isNumericLiteral(arg)) { + replacements.push(Replacement.insert(arg.getStart(), '$_$twiz.track(')); + replacements.push(Replacement.insert(arg.getEnd(), `,${JSON.stringify(fileName)},${arg.getStart()})`)); + } + } + } + if ( ts.isPropertyDeclaration(node) && ts.isIdentifier(node.name) && @@ -91,8 +100,12 @@ function visit(node: ts.Node, replacements: Replacement[], fileName: string) { node.forEachChild((child) => visit(child, replacements, fileName)); } -const declaration = - 'declare function $_$twiz(name: string, value: any, pos: number, filename: string, opts: any): void;\n'; +const declaration = ` + declare function $_$twiz(name: string, value: any, pos: number, filename: string, opts: any): void; + declare namespace $_$twiz { + function track(v: T, p: number, f: string): T; + } +`; export function instrument(source: string, fileName: string) { const sourceFile = ts.createSourceFile(fileName, source, ts.ScriptTarget.Latest, true); diff --git a/src/integration.spec.ts b/src/integration.spec.ts index 5d40f8a..5e81642 100644 --- a/src/integration.spec.ts +++ b/src/integration.spec.ts @@ -2,7 +2,7 @@ import * as fs from 'fs'; import * as ts from 'typescript'; import * as vm from 'vm'; -import { transpileSource } from './test-utils/transpile'; +import { transpileSource, virtualCompilerHost } from './test-utils/transpile'; const mockFs = { readFileSync: jest.fn(fs.readFileSync), @@ -26,6 +26,10 @@ function typeWiz(input: string, typeCheck = false, options?: IApplyTypesOptions) mockFs.readFileSync.mockReturnValue(input); mockFs.writeFileSync.mockImplementationOnce(() => 0); + if (options && options.tsConfig) { + options.tsCompilerHost = virtualCompilerHost(input, 'c:/test.ts'); + } + applyTypes(collectedTypes, options); if (mockFs.writeFileSync.mock.calls.length) { @@ -173,6 +177,50 @@ describe('function parameters', () => { optional() + optional(10); `); }); + + it('should use TypeScript inference to find argument types', () => { + const input = ` + function f(a) { + } + + const arr: string[] = []; + f(arr); + `; + + expect(typeWiz(input, false, { tsConfig: {} })).toBe(` + function f(a: string[]) { + } + + const arr: string[] = []; + f(arr); + `); + }); + + it('should discover generic types using Type Inference', () => { + const input = ` + function f(a) { + return a; + } + + const promise = Promise.resolve(15); + f(promise); + `; + + expect( + typeWiz(input, false, { + tsConfig: { + target: ts.ScriptTarget.ES2015, + }, + }), + ).toBe(` + function f(a: Promise) { + return a; + } + + const promise = Promise.resolve(15); + f(promise); + `); + }); }); describe('class fields', () => { diff --git a/src/replacement.ts b/src/replacement.ts index be858ea..9a3dc7f 100644 --- a/src/replacement.ts +++ b/src/replacement.ts @@ -1,17 +1,22 @@ export class Replacement { - public static insert(pos: number, text: string) { - return new Replacement(pos, pos, text); + public static insert(pos: number, text: string, priority = 0) { + return new Replacement(pos, pos, text, priority); } public static delete(start: number, end: number) { return new Replacement(start, end, ''); } - constructor(readonly start: number, readonly end: number, readonly text = '') {} + constructor(readonly start: number, readonly end: number, readonly text = '', readonly priority = 0) {} } export function applyReplacements(source: string, replacements: Replacement[]) { - replacements = replacements.sort((r1, r2) => (r2.end !== r1.end ? r2.end - r1.end : r2.start - r1.start)); + replacements = replacements.sort( + (r1, r2) => + r2.end !== r1.end + ? r2.end - r1.end + : r1.start !== r2.start ? r2.start - r1.start : r1.priority - r2.priority, + ); for (const replacement of replacements) { source = source.slice(0, replacement.start) + replacement.text + source.slice(replacement.end); } diff --git a/src/test-utils/transpile.ts b/src/test-utils/transpile.ts index 5d19a78..d3037fc 100644 --- a/src/test-utils/transpile.ts +++ b/src/test-utils/transpile.ts @@ -1,20 +1,26 @@ import * as ts from 'typescript'; -// similar to ts.transpile(), but also does type checking and throws in case of error -export function transpileSource(input: string, filename: string) { - const compilerOptions = { - target: ts.ScriptTarget.ES2015, - }; - +export function virtualCompilerHost(input: string, filename: string, compilerOptions: ts.CompilerOptions = {}) { const host = ts.createCompilerHost(compilerOptions); const old = host.getSourceFile; - host.getSourceFile = (name: string, target: ts.ScriptTarget, ...args) => { + host.getSourceFile = (name: string, target: ts.ScriptTarget, ...args: any[]) => { if (name === filename) { return ts.createSourceFile(filename, input, target, true); } return old.call(host, name, target, ...args); }; + return host; +} + +// similar to ts.transpile(), but also does type checking and throws in case of error +export function transpileSource(input: string, filename: string) { + const compilerOptions = { + target: ts.ScriptTarget.ES2015, + }; + + const host = virtualCompilerHost(input, filename, compilerOptions); + let outputText; host.writeFile = (name: string, value: string) => { if (name.endsWith('.js')) { diff --git a/src/type-collector-snippet.ts b/src/type-collector-snippet.ts index 39262d0..c10b5af 100644 --- a/src/type-collector-snippet.ts +++ b/src/type-collector-snippet.ts @@ -1,5 +1,7 @@ class NestError extends Error {} +export type ISourceLocation = [string, number]; /* filename, offset */ + interface IKey { filename: string; pos: number; @@ -73,17 +75,18 @@ export function getTypeName(value: any, nest = 0): string | null { } const logs: { [key: string]: Set } = {}; +const trackedObjects = new WeakMap(); export function $_$twiz(name: string, value: any, pos: number, filename: string, opts: any) { + const objectDeclaration = trackedObjects.get(value); const index = JSON.stringify({ filename, pos, opts } as IKey); try { const typeName = getTypeName(value); if (!logs[index]) { logs[index] = new Set(); } - if (typeName) { - logs[index].add(typeName); - } + const typeSpec = JSON.stringify([typeName, objectDeclaration]); + logs[index].add(typeSpec); } catch (e) { if (e instanceof NestError) { // simply ignore the type @@ -98,7 +101,14 @@ export namespace $_$twiz { export const get = () => { return Object.keys(logs).map((key) => { const { filename, pos, opts } = JSON.parse(key) as IKey; - return [filename, pos, Array.from(logs[key]), opts] as [string, number, string[], any]; + const typeOptions = Array.from(logs[key]).map((v) => JSON.parse(v)); + return [filename, pos, typeOptions, opts] as [string, number, string[], any]; }); }; + export const track = (value: any, filename: string, offset: number) => { + if (value && (typeof value === 'object' || typeof value === 'function')) { + trackedObjects.set(value, [filename, offset]); + } + return value; + }; }