From 6d38e16bc550e1d493dd72151597d2b6c19a756e Mon Sep 17 00:00:00 2001 From: esdras-santos Date: Thu, 6 Apr 2023 12:52:29 -0300 Subject: [PATCH 1/6] warplib/math functions replaced by his respective operatores --- .../mathsOperationToFunction.ts | 28 ------------------- 1 file changed, 28 deletions(-) diff --git a/src/passes/builtinHandler/mathsOperationToFunction.ts b/src/passes/builtinHandler/mathsOperationToFunction.ts index 457c243e5..49c17f39d 100644 --- a/src/passes/builtinHandler/mathsOperationToFunction.ts +++ b/src/passes/builtinHandler/mathsOperationToFunction.ts @@ -45,34 +45,6 @@ export class MathsOperationToFunction extends ASTMapper { visitBinaryOperation(node: BinaryOperation, ast: AST): void { this.commonVisit(node, ast); - const operatorMap: Map void> = new Map([ - ['+', () => functionaliseAdd(node, this.inUncheckedBlock, ast)], - ['-', () => functionaliseSub(node, this.inUncheckedBlock, ast)], - ['*', () => functionaliseMul(node, this.inUncheckedBlock, ast)], - ['/', () => functionaliseDiv(node, this.inUncheckedBlock, ast)], - ['%', () => functionaliseMod(node, ast)], - ['**', () => functionaliseExp(node, this.inUncheckedBlock, ast)], - ['==', () => functionaliseEq(node, ast)], - ['!=', () => functionaliseNeq(node, ast)], - ['>=', () => functionaliseGe(node, ast)], - ['>', () => functionaliseGt(node, ast)], - ['<=', () => functionaliseLe(node, ast)], - ['<', () => functionaliseLt(node, ast)], - ['&', () => functionaliseBitwiseAnd(node, ast)], - ['|', () => functionaliseBitwiseOr(node, ast)], - ['^', () => functionaliseXor(node, ast)], - ['<<', () => functionaliseShl(node, ast)], - ['>>', () => functionaliseShr(node, ast)], - ['&&', () => functionaliseAnd(node, ast)], - ['||', () => functionaliseOr(node, ast)], - ]); - - const thunk = operatorMap.get(node.operator); - if (thunk === undefined) { - throw new NotSupportedYetError(`${node.operator} not supported yet`); - } - - thunk(); } visitUnaryOperation(node: UnaryOperation, ast: AST): void { From 6f04a1dad5305fa8272566af80683be1bbca0fc0 Mon Sep 17 00:00:00 2001 From: esdras-santos Date: Thu, 6 Apr 2023 17:28:20 -0300 Subject: [PATCH 2/6] replace warplib/math functions for his respective operatores --- .../mathsOperationToFunction.ts | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/src/passes/builtinHandler/mathsOperationToFunction.ts b/src/passes/builtinHandler/mathsOperationToFunction.ts index 49c17f39d..b921ea312 100644 --- a/src/passes/builtinHandler/mathsOperationToFunction.ts +++ b/src/passes/builtinHandler/mathsOperationToFunction.ts @@ -11,27 +11,8 @@ import { NotSupportedYetError } from '../../utils/errors'; import { createCallToFunction } from '../../utils/functionGeneration'; import { WARPLIB_MATHS } from '../../utils/importPaths'; import { createNumberLiteral, createUint256TypeName } from '../../utils/nodeTemplates'; -import { functionaliseAdd } from '../../warplib/implementations/maths/add'; -import { functionaliseAnd } from '../../warplib/implementations/maths/and'; -import { functionaliseBitwiseAnd } from '../../warplib/implementations/maths/bitwiseAnd'; import { functionaliseBitwiseNot } from '../../warplib/implementations/maths/bitwiseNot'; -import { functionaliseBitwiseOr } from '../../warplib/implementations/maths/bitwiseOr'; -import { functionaliseDiv } from '../../warplib/implementations/maths/div'; -import { functionaliseEq } from '../../warplib/implementations/maths/eq'; -import { functionaliseExp } from '../../warplib/implementations/maths/exp'; -import { functionaliseGe } from '../../warplib/implementations/maths/ge'; -import { functionaliseGt } from '../../warplib/implementations/maths/gt'; -import { functionaliseLe } from '../../warplib/implementations/maths/le'; -import { functionaliseLt } from '../../warplib/implementations/maths/lt'; -import { functionaliseMod } from '../../warplib/implementations/maths/mod'; -import { functionaliseMul } from '../../warplib/implementations/maths/mul'; import { functionaliseNegate } from '../../warplib/implementations/maths/negate'; -import { functionaliseNeq } from '../../warplib/implementations/maths/neq'; -import { functionaliseOr } from '../../warplib/implementations/maths/or'; -import { functionaliseShl } from '../../warplib/implementations/maths/shl'; -import { functionaliseShr } from '../../warplib/implementations/maths/shr'; -import { functionaliseSub } from '../../warplib/implementations/maths/sub'; -import { functionaliseXor } from '../../warplib/implementations/maths/xor'; /* Note we also include mulmod and add mod here */ export class MathsOperationToFunction extends ASTMapper { From e07fe31ce0a75f6d1a049552584be04c7f0deb49 Mon Sep 17 00:00:00 2001 From: esdras-santos Date: Thu, 6 Apr 2023 17:29:51 -0300 Subject: [PATCH 3/6] removed input checkers for integer types --- src/passes/argBoundChecker.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/passes/argBoundChecker.ts b/src/passes/argBoundChecker.ts index a9836fe76..4dd3331b9 100644 --- a/src/passes/argBoundChecker.ts +++ b/src/passes/argBoundChecker.ts @@ -31,7 +31,7 @@ export class ArgBoundChecker extends ASTMapper { if (isExternallyVisible(node) && node.vBody !== undefined) { node.vParameters.vParameters.forEach((decl) => { const type = safeGetNodeType(decl, ast.inference); - if (checkableType(type)) { + if (checkableType(type) && !(type instanceof IntType)) { const functionCall = ast.getUtilFuncGen(node).boundChecks.inputCheck.gen(decl, type); this.insertFunctionCall(node, functionCall, ast); } From 48478ffde97f285dddc8d8ab00caaaecd84b3726 Mon Sep 17 00:00:00 2001 From: esdras-santos Date: Fri, 7 Apr 2023 13:13:45 -0300 Subject: [PATCH 4/6] inputArgChecker and redundant warplib calls removed --- src/cairoUtilFuncGen/export.ts | 1 - src/cairoUtilFuncGen/index.ts | 9 - src/cairoUtilFuncGen/inputArgCheck/export.ts | 1 - .../inputArgCheck/inputCheck.ts | 262 ----------------- src/passes/argBoundChecker.ts | 44 +-- .../references/externalReturnReceiver.ts | 10 - src/warplib/generateWarplib.ts | 39 --- src/warplib/implementations/maths/add.ts | 137 --------- src/warplib/implementations/maths/and.ts | 7 - .../implementations/maths/bitwiseAnd.ts | 7 - .../implementations/maths/bitwiseOr.ts | 7 - src/warplib/implementations/maths/div.ts | 125 -------- src/warplib/implementations/maths/eq.ts | 7 - src/warplib/implementations/maths/exp.ts | 220 -------------- src/warplib/implementations/maths/ge.ts | 40 --- src/warplib/implementations/maths/gt.ts | 40 --- src/warplib/implementations/maths/le.ts | 61 ---- src/warplib/implementations/maths/lt.ts | 43 --- src/warplib/implementations/maths/mod.ts | 55 ---- src/warplib/implementations/maths/mul.ts | 247 ---------------- src/warplib/implementations/maths/neq.ts | 7 - src/warplib/implementations/maths/or.ts | 7 - src/warplib/implementations/maths/shl.ts | 118 -------- src/warplib/implementations/maths/shr.ts | 269 ------------------ src/warplib/implementations/maths/sub.ts | 157 ---------- src/warplib/implementations/maths/xor.ts | 7 - 26 files changed, 1 insertion(+), 1926 deletions(-) delete mode 100644 src/cairoUtilFuncGen/inputArgCheck/export.ts delete mode 100644 src/cairoUtilFuncGen/inputArgCheck/inputCheck.ts delete mode 100644 src/warplib/implementations/maths/add.ts delete mode 100644 src/warplib/implementations/maths/and.ts delete mode 100644 src/warplib/implementations/maths/bitwiseAnd.ts delete mode 100644 src/warplib/implementations/maths/bitwiseOr.ts delete mode 100644 src/warplib/implementations/maths/div.ts delete mode 100644 src/warplib/implementations/maths/eq.ts delete mode 100644 src/warplib/implementations/maths/exp.ts delete mode 100644 src/warplib/implementations/maths/ge.ts delete mode 100644 src/warplib/implementations/maths/gt.ts delete mode 100644 src/warplib/implementations/maths/le.ts delete mode 100644 src/warplib/implementations/maths/lt.ts delete mode 100644 src/warplib/implementations/maths/mod.ts delete mode 100644 src/warplib/implementations/maths/mul.ts delete mode 100644 src/warplib/implementations/maths/neq.ts delete mode 100644 src/warplib/implementations/maths/or.ts delete mode 100644 src/warplib/implementations/maths/shl.ts delete mode 100644 src/warplib/implementations/maths/shr.ts delete mode 100644 src/warplib/implementations/maths/sub.ts delete mode 100644 src/warplib/implementations/maths/xor.ts diff --git a/src/cairoUtilFuncGen/export.ts b/src/cairoUtilFuncGen/export.ts index 2f0205905..773c84489 100644 --- a/src/cairoUtilFuncGen/export.ts +++ b/src/cairoUtilFuncGen/export.ts @@ -1,7 +1,6 @@ export * from './serialisation'; export * from './base'; export * from './memory/export'; -export * from './inputArgCheck/export'; export * from './calldata/export'; export * from './utils/export'; export * from './storage/export'; diff --git a/src/cairoUtilFuncGen/index.ts b/src/cairoUtilFuncGen/index.ts index f6848e964..74bc2eea3 100644 --- a/src/cairoUtilFuncGen/index.ts +++ b/src/cairoUtilFuncGen/index.ts @@ -1,5 +1,4 @@ import { AST } from '../ast/ast'; -import { InputCheckGen } from './inputArgCheck/inputCheck'; import { MemoryArrayLiteralGen } from './memory/arrayLiteral'; import { MemoryDynArrayLengthGen } from './memory/memoryDynArrayLength'; import { MemoryMemberAccessGen } from './memory/memoryMemberAccess'; @@ -85,10 +84,6 @@ export class CairoUtilFuncGen { toStorage: StorageToStorageGen; write: StorageWriteGen; }; - boundChecks: { - inputCheck: InputCheckGen; - enums: EnumInputCheck; - }; events: { index: IndexEncode; event: EventFunction; @@ -177,10 +172,6 @@ export class CairoUtilFuncGen { toStorage: storageToStorage, write: storageWrite, }; - this.boundChecks = { - inputCheck: new InputCheckGen(ast, sourceUnit), - enums: new EnumInputCheck(ast, sourceUnit), - }; this.calldata = { dynArrayStructConstructor: externalDynArrayStructConstructor, toMemory: new CallDataToMemoryGen(ast, sourceUnit), diff --git a/src/cairoUtilFuncGen/inputArgCheck/export.ts b/src/cairoUtilFuncGen/inputArgCheck/export.ts deleted file mode 100644 index 5bebac706..000000000 --- a/src/cairoUtilFuncGen/inputArgCheck/export.ts +++ /dev/null @@ -1 +0,0 @@ -export * from './inputCheck'; diff --git a/src/cairoUtilFuncGen/inputArgCheck/inputCheck.ts b/src/cairoUtilFuncGen/inputArgCheck/inputCheck.ts deleted file mode 100644 index 05e6747f5..000000000 --- a/src/cairoUtilFuncGen/inputArgCheck/inputCheck.ts +++ /dev/null @@ -1,262 +0,0 @@ -import assert from 'assert'; -import { - ArrayType, - BoolType, - BytesType, - DataLocation, - EnumDefinition, - Expression, - FixedBytesType, - FunctionCall, - FunctionStateMutability, - generalizeType, - IntType, - StringType, - StructDefinition, - TypeNode, - UserDefinedType, - VariableDeclaration, -} from 'solc-typed-ast'; -import { CairoFunctionDefinition, FunctionStubKind } from '../../ast/cairoNodes'; -import { printTypeNode } from '../../utils/astPrinter'; -import { CairoDynArray, CairoType, TypeConversionContext } from '../../utils/cairoTypeSystem'; -import { NotSupportedYetError } from '../../utils/errors'; -import { createCairoGeneratedFunction, createCallToFunction } from '../../utils/functionGeneration'; -import { createIdentifier } from '../../utils/nodeTemplates'; -import { mapRange, narrowBigIntSafe, typeNameFromTypeNode } from '../../utils/utils'; -import { - delegateBasedOnType, - GeneratedFunctionInfo, - locationIfComplexType, - StringIndexedFuncGen, -} from '../base'; -import { - checkableType, - getElementType, - isAddressType, - isDynamicArray, - safeGetNodeType, -} from '../../utils/nodeTypeProcessing'; -import { cloneASTNode } from '../../utils/cloning'; -import { IS_LE_FELT, NARROW_SAFE, WARPLIB_MATHS } from '../../utils/importPaths'; - -const IMPLICITS = '{range_check_ptr : felt}'; - -export class InputCheckGen extends StringIndexedFuncGen { - public gen(nodeInput: VariableDeclaration | Expression, typeToCheck: TypeNode): FunctionCall { - let functionInput; - let isUint256 = false; - if (nodeInput instanceof VariableDeclaration) { - functionInput = createIdentifier(nodeInput, this.ast); - } else { - functionInput = cloneASTNode(nodeInput, this.ast); - const inputType = safeGetNodeType(nodeInput, this.ast.inference); - this.ast.setContextRecursive(functionInput); - isUint256 = inputType instanceof IntType && inputType.nBits === 256; - } - - const funcDef = this.getOrCreateFuncDef(typeToCheck, isUint256); - return createCallToFunction(funcDef, [functionInput], this.ast); - } - - private getOrCreateFuncDef(type: TypeNode, takesUint256: boolean): CairoFunctionDefinition { - const key = type.pp(); - const value = this.generatedFunctionsDef.get(key); - if (value !== undefined) { - return value; - } - - if (type instanceof FixedBytesType) - return this.requireImport( - [...WARPLIB_MATHS, 'external_input_check_ints'], - `warp_external_input_check_int${type.size * 8}`, - ); - if (type instanceof IntType) - return this.requireImport( - [...WARPLIB_MATHS, 'external_input_check_ints'], - `warp_external_input_check_int${type.nBits}`, - ); - if (isAddressType(type)) - return this.requireImport( - [...WARPLIB_MATHS, 'external_input_check_address'], - `warp_external_input_check_address`, - ); - if (type instanceof BoolType) - return this.requireImport( - [...WARPLIB_MATHS, 'external_input_check_bool'], - `warp_external_input_check_bool`, - ); - - const funcInfo = this.getOrCreate(type, takesUint256); - const funcDef = createCairoGeneratedFunction( - funcInfo, - [ - [ - 'ref_var', - typeNameFromTypeNode(type, this.ast), - locationIfComplexType(type, DataLocation.CallData), - ], - ], - [], - this.ast, - this.sourceUnit, - { - mutability: FunctionStateMutability.Pure, - stubKind: FunctionStubKind.FunctionDefStub, - acceptsRawDArray: isDynamicArray(type), - }, - ); - this.generatedFunctionsDef.set(key, funcDef); - return funcDef; - } - - private getOrCreate(type: TypeNode, takesUint: boolean): GeneratedFunctionInfo { - const unexpectedTypeFunc = () => { - throw new NotSupportedYetError(`Input check for ${printTypeNode(type)} not defined yet.`); - }; - - return delegateBasedOnType( - type, - (type) => this.createDynArrayInputCheck(type), - (type) => this.createStaticArrayInputCheck(type), - (type, def) => this.createStructInputCheck(type, def), - unexpectedTypeFunc, - (type) => { - if (type instanceof UserDefinedType && type.definition instanceof EnumDefinition) - return this.createEnumInputCheck(type, takesUint); - return unexpectedTypeFunc(); - }, - ); - } - - private createStructInputCheck( - type: UserDefinedType, - structDef: StructDefinition, - ): GeneratedFunctionInfo { - const cairoType = CairoType.fromSol(type, this.ast, TypeConversionContext.CallDataRef); - - const [inputCheckCode, funcCalls] = structDef.vMembers.reduce( - ([inputCheckCode, funcCalls], decl) => { - const memberType = safeGetNodeType(decl, this.ast.inference); - if (checkableType(memberType)) { - const memberCheckFunc = this.getOrCreateFuncDef(memberType, false); - return [ - [...inputCheckCode, `${memberCheckFunc.name}(arg.${decl.name});`], - [...funcCalls, memberCheckFunc], - ]; - } - return [inputCheckCode, funcCalls]; - }, - [new Array(), new Array()], - ); - - const funcName = `external_input_check_struct_${structDef.name}`; - const funcInfo: GeneratedFunctionInfo = { - name: funcName, - code: [ - `func ${funcName}${IMPLICITS}(arg : ${cairoType.toString()}) -> (){`, - `alloc_locals;`, - ...inputCheckCode, - `return ();`, - `}`, - ].join('\n'), - functionsCalled: funcCalls, - }; - return funcInfo; - } - - // Todo: This function can probably be made recursive for big size static arrays - private createStaticArrayInputCheck(type: ArrayType): GeneratedFunctionInfo { - assert(type.size !== undefined); - const length = narrowBigIntSafe(type.size); - - const cairoType = CairoType.fromSol(type, this.ast, TypeConversionContext.CallDataRef); - const elementType = generalizeType(type.elementT)[0]; - - const auxFunc = this.getOrCreateFuncDef(elementType, false); - - const funcName = `external_input_check_static_array${this.generatedFunctionsDef.size}`; - const funcInfo: GeneratedFunctionInfo = { - name: funcName, - code: [ - `func ${funcName}${IMPLICITS}(arg : ${cairoType.toString()}) -> (){`, - `alloc_locals;`, - ...mapRange(length, (index) => { - return [`${auxFunc.name}(arg[${index}]);`]; - }), - `return ();`, - `}`, - ].join('\n'), - functionsCalled: [auxFunc], - }; - return funcInfo; - } - - // TODO: this function and EnumInputCheck single file do the same??? - // TODO: When does takesUint == true? - private createEnumInputCheck(type: UserDefinedType, takesUint = false): GeneratedFunctionInfo { - const enumDef = type.definition; - assert(enumDef instanceof EnumDefinition); - - // TODO: enum names are unique right? - const funcName = `external_input_check_enum_${enumDef.name}`; - - const importFuncs = [this.requireImport(...IS_LE_FELT)]; - if (takesUint) { - importFuncs.push(this.requireImport(...NARROW_SAFE)); - } - - const nMembers = enumDef.vMembers.length; - const funcInfo: GeneratedFunctionInfo = { - name: funcName, - code: [ - `func ${funcName}${IMPLICITS}(arg : ${takesUint ? 'Uint256' : 'felt'}) -> (){`, - takesUint - ? [ - ' let (arg_0) = narrow_safe(arg);', - ` let inRange: felt = is_le_felt(arg_0, ${nMembers - 1});`, - ].join('\n') - : ` let inRange : felt = is_le_felt(arg, ${nMembers - 1});`, - ` with_attr error_message("Error: value out-of-bounds. Values passed to must be in enum range (0, ${ - nMembers - 1 - }]."){`, - ` assert 1 = inRange;`, - ` }`, - ` return ();`, - `}`, - ].join('\n'), - functionsCalled: importFuncs, - }; - return funcInfo; - } - - private createDynArrayInputCheck( - type: ArrayType | BytesType | StringType, - ): GeneratedFunctionInfo { - const cairoType = CairoType.fromSol(type, this.ast, TypeConversionContext.CallDataRef); - assert(cairoType instanceof CairoDynArray); - - const ptrType = cairoType.vPtr; - const elementType = generalizeType(getElementType(type))[0]; - - const calledFunction = this.getOrCreateFuncDef(elementType, false); - - const funcName = `external_input_check_dynamic_array${this.generatedFunctionsDef.size}`; - const funcInfo: GeneratedFunctionInfo = { - name: funcName, - code: [ - `func ${funcName}${IMPLICITS}(len: felt, ptr : ${ptrType.toString()}) -> (){`, - ` alloc_locals;`, - ` if (len == 0){`, - ` return ();`, - ` }`, - ` ${calledFunction.name}(ptr[0]);`, - ` ${funcName}(len = len - 1, ptr = ptr + ${ptrType.to.width});`, - ` return ();`, - `}`, - ].join('\n'), - functionsCalled: [calledFunction], - }; - return funcInfo; - } -} diff --git a/src/passes/argBoundChecker.ts b/src/passes/argBoundChecker.ts index 4dd3331b9..f6fd8cba6 100644 --- a/src/passes/argBoundChecker.ts +++ b/src/passes/argBoundChecker.ts @@ -1,18 +1,6 @@ import { AST } from '../ast/ast'; -import { - ContractDefinition, - ContractKind, - FunctionDefinition, - FunctionCall, - FunctionCallKind, - EnumDefinition, - IntType, -} from 'solc-typed-ast'; +import { ContractDefinition, ContractKind, FunctionDefinition, FunctionCall } from 'solc-typed-ast'; import { ASTMapper } from '../ast/mapper'; -import { isExternallyVisible } from '../utils/utils'; -import assert from 'assert'; -import { createExpressionStatement } from '../utils/nodeTemplates'; -import { checkableType, safeGetNodeType } from '../utils/nodeTypeProcessing'; export class ArgBoundChecker extends ASTMapper { // Function to add passes that should have been run before this pass addInitialPassPrerequisites(): void { @@ -28,40 +16,10 @@ export class ArgBoundChecker extends ASTMapper { } visitFunctionDefinition(node: FunctionDefinition, ast: AST): void { - if (isExternallyVisible(node) && node.vBody !== undefined) { - node.vParameters.vParameters.forEach((decl) => { - const type = safeGetNodeType(decl, ast.inference); - if (checkableType(type) && !(type instanceof IntType)) { - const functionCall = ast.getUtilFuncGen(node).boundChecks.inputCheck.gen(decl, type); - this.insertFunctionCall(node, functionCall, ast); - } - }); - } - this.commonVisit(node, ast); } - private insertFunctionCall(node: FunctionDefinition, funcCall: FunctionCall, ast: AST): void { - const body = node.vBody; - assert(body !== undefined && funcCall.vArguments !== undefined); - const expressionStatement = createExpressionStatement(ast, funcCall); - body.insertAtBeginning(expressionStatement); - ast.setContextRecursive(expressionStatement); - } - visitFunctionCall(node: FunctionCall, ast: AST): void { - if ( - node.kind === FunctionCallKind.TypeConversion && - node.vReferencedDeclaration instanceof EnumDefinition && - safeGetNodeType(node.vArguments[0], ast.inference) instanceof IntType - ) { - const enumDef = node.vReferencedDeclaration; - const enumCheckFuncCall = ast - .getUtilFuncGen(node) - .boundChecks.enums.gen(node, node.vArguments[0], enumDef, node); - const parent = node.parent; - ast.replaceNode(node, enumCheckFuncCall, parent); - } this.commonVisit(node, ast); } } diff --git a/src/passes/references/externalReturnReceiver.ts b/src/passes/references/externalReturnReceiver.ts index a313466d1..25025d14b 100644 --- a/src/passes/references/externalReturnReceiver.ts +++ b/src/passes/references/externalReturnReceiver.ts @@ -52,19 +52,9 @@ export class ExternalReturnReceiver extends ASTMapper { ast.insertStatementAfter(node, statement); node.assignments = node.assignments.map((value) => (value === decl.id ? newId : value)); }); - - node.vDeclarations.forEach((decl) => addOutputValidation(decl, ast)); } } -function addOutputValidation(decl: VariableDeclaration, ast: AST) { - const type = safeGetNodeType(decl, ast.inference); - if (!checkableType(type)) return; - const validationFunctionCall = ast.getUtilFuncGen(decl).boundChecks.inputCheck.gen(decl, type); - const validationStatement = createExpressionStatement(ast, validationFunctionCall); - ast.insertStatementAfter(decl, validationStatement); -} - function generateCopyStatement( decl: VariableDeclaration, ast: AST, diff --git a/src/warplib/generateWarplib.ts b/src/warplib/generateWarplib.ts index 5e72a3e4a..8eaabce43 100644 --- a/src/warplib/generateWarplib.ts +++ b/src/warplib/generateWarplib.ts @@ -1,18 +1,5 @@ import { generateFile, WarplibFunctionInfo } from './utils'; import { int_conversions } from './implementations/conversions/int'; -import { add, add_unsafe, add_signed, add_signed_unsafe } from './implementations/maths/add'; -import { div_signed, div_signed_unsafe } from './implementations/maths/div'; -import { exp, exp_signed, exp_signed_unsafe, exp_unsafe } from './implementations/maths/exp'; -import { ge_signed } from './implementations/maths/ge'; -import { gt_signed } from './implementations/maths/gt'; -import { le_signed } from './implementations/maths/le'; -import { lt_signed } from './implementations/maths/lt'; -import { mod_signed } from './implementations/maths/mod'; -import { mul, mul_unsafe, mul_signed, mul_signed_unsafe } from './implementations/maths/mul'; -import { negate } from './implementations/maths/negate'; -import { shl } from './implementations/maths/shl'; -import { shr, shr_signed } from './implementations/maths/shr'; -import { sub_unsafe, sub_signed, sub_signed_unsafe } from './implementations/maths/sub'; import { bitwise_not } from './implementations/maths/bitwiseNot'; import { external_input_check_ints } from './implementations/external_input_checks/externalInputChecksInts'; import path from 'path'; @@ -22,40 +9,14 @@ import { glob } from 'glob'; import { parseMultipleRawCairoFunctions } from '../utils/cairoParsing'; const warplibFunctions: WarplibFunctionInfo[] = [ - add(), - add_unsafe(), - add_signed(), - add_signed_unsafe(), // sub - handwritten - sub_unsafe(), - sub_signed(), - sub_signed_unsafe(), - mul(), - mul_unsafe(), - mul_signed(), - mul_signed_unsafe(), // div - handwritten // div_unsafe - handwritten - div_signed(), - div_signed_unsafe(), // mod - handwritten - mod_signed(), - exp(), - exp_signed(), - exp_unsafe(), - exp_signed_unsafe(), - negate(), - shl(), - shr(), - shr_signed(), // ge - handwritten - ge_signed(), // gt - handwritten - gt_signed(), // le - handwritten - le_signed(), // lt - handwritten - lt_signed(), // and - handwritten // xor - handwritten // bitwise_and - handwritten diff --git a/src/warplib/implementations/maths/add.ts b/src/warplib/implementations/maths/add.ts deleted file mode 100644 index f25fdbfe4..000000000 --- a/src/warplib/implementations/maths/add.ts +++ /dev/null @@ -1,137 +0,0 @@ -import { BinaryOperation } from 'solc-typed-ast'; -import { AST } from '../../../ast/ast'; -import { - forAllWidths, - IntxIntFunction, - mask, - msb, - msbAndNext, - WarplibFunctionInfo, -} from '../../utils'; - -export function add(): WarplibFunctionInfo { - const fileName = 'add'; - const imports = ['use warplib::maths::le::warp_le;']; - const functions = forAllWidths((width) => { - if (width === 256) { - return [`fn warp_add256(lhs: u256, rhs: u256) -> u256{`, ` return lhs + rhs;`, `}`].join( - '\n', - ); - } else { - return [ - `fn warp_add${width}(lhs: felt252, rhs: felt252) -> felt252{`, - ` let res = lhs + rhs;`, - ` let max: felt252 = ${mask(width)};`, - ` assert(warp_le(res, max), 'Value out of bounds');`, - ` return res;`, - `}`, - ].join('\n'); - } - }); - - return { fileName, imports, functions }; -} - -export function add_unsafe(): WarplibFunctionInfo { - return { - fileName: 'add_unsafe', - imports: ['use integer::u256_overflowing_add;'], - functions: forAllWidths((width) => { - if (width === 256) { - return [ - `fn warp_add_unsafe256(lhs : u256, rhs : u256) -> u256 {`, - ` let (value, _) = u256_overflowing_add(lhs, rhs);`, - ` return value;`, - `}`, - ].join('\n'); - } else { - // TODO: Use bitwise '&' to take just the width-bits - return [ - `fn warp_add_unsafe${width}(lhs : felt252, rhs : felt252) -> felt252 {`, - ` let res = lhs + rhs;`, - ` return res;`, - `}`, - ].join('\n'); - } - }), - }; -} - -export function add_signed(): WarplibFunctionInfo { - return { - fileName: 'add_signed', - imports: [ - 'from starkware.cairo.common.bitwise import bitwise_and', - 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', - 'from starkware.cairo.common.math_cmp import is_le_felt', - 'from starkware.cairo.common.uint256 import Uint256, uint256_add', - ], - functions: forAllWidths((width) => { - if (width === 256) { - return [ - `func warp_add_signed256{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(`, - ` lhs : Uint256, rhs : Uint256) -> (res : Uint256){`, - ` let (lhs_extend) = bitwise_and(lhs.high, ${msb(128)});`, - ` let (rhs_extend) = bitwise_and(rhs.high, ${msb(128)});`, - ` let (res : Uint256, carry : felt252) = uint256_add(lhs, rhs);`, - ` let carry_extend = lhs_extend + rhs_extend + carry*${msb(128)};`, - ` let (msb) = bitwise_and(res.high, ${msb(128)});`, - ` let (carry_lsb) = bitwise_and(carry_extend, ${msb(128)});`, - ` assert msb = carry_lsb;`, - ` return (res,);`, - `}`, - ].join('\n'); - } else { - return [ - `func warp_add_signed${width}{bitwise_ptr : BitwiseBuiltin*}(lhs : felt252, rhs : felt252) -> (`, - ` res : felt252){`, - `// Do the addition sign extended`, - ` let (lmsb) = bitwise_and(lhs, ${msb(width)});`, - ` let (rmsb) = bitwise_and(rhs, ${msb(width)});`, - ` let big_res = lhs + rhs + 2*(lmsb+rmsb);`, - `// Check the result is valid`, - ` let (overflowBits) = bitwise_and(big_res, ${msbAndNext(width)});`, - ` assert overflowBits * (overflowBits - ${msbAndNext(width)}) = 0;`, - `// Truncate and return`, - ` let (res) = bitwise_and(big_res, ${mask(width)});`, - ` return (res,);`, - `}`, - ].join('\n'); - } - }), - }; -} - -export function add_signed_unsafe(): WarplibFunctionInfo { - return { - fileName: 'add_signed_unsafe', - imports: [ - 'from starkware.cairo.common.bitwise import bitwise_and', - 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', - 'from starkware.cairo.common.math_cmp import is_le_felt', - 'from starkware.cairo.common.uint256 import Uint256, uint256_add', - ], - functions: forAllWidths((width) => { - if (width === 256) { - return [ - `func warp_add_signed_unsafe256{range_check_ptr}(lhs : Uint256, rhs : Uint256) -> (res : Uint256){`, - ` let (res : Uint256, _) = uint256_add(lhs, rhs);`, - ` return (res,);`, - `}`, - ].join('\n'); - } else { - return [ - `func warp_add_signed_unsafe${width}{bitwise_ptr : BitwiseBuiltin*}(`, - ` lhs : felt252, rhs : felt252) -> (res : felt252){`, - ` let (res) = bitwise_and(lhs + rhs, ${mask(width)});`, - ` return (res,);`, - `}`, - ].join('\n'); - } - }), - }; -} - -export function functionaliseAdd(node: BinaryOperation, unsafe: boolean, ast: AST): void { - IntxIntFunction(node, 'add', 'always', true, unsafe, ast); -} diff --git a/src/warplib/implementations/maths/and.ts b/src/warplib/implementations/maths/and.ts deleted file mode 100644 index 97e21704a..000000000 --- a/src/warplib/implementations/maths/and.ts +++ /dev/null @@ -1,7 +0,0 @@ -import { BinaryOperation } from 'solc-typed-ast'; -import { AST } from '../../../ast/ast'; -import { BoolxBoolFunction } from '../../utils'; - -export function functionaliseAnd(node: BinaryOperation, ast: AST): void { - BoolxBoolFunction(node, 'and_', ast); -} diff --git a/src/warplib/implementations/maths/bitwiseAnd.ts b/src/warplib/implementations/maths/bitwiseAnd.ts deleted file mode 100644 index 10b994053..000000000 --- a/src/warplib/implementations/maths/bitwiseAnd.ts +++ /dev/null @@ -1,7 +0,0 @@ -import { BinaryOperation } from 'solc-typed-ast'; -import { AST } from '../../../ast/ast'; -import { IntxIntFunction } from '../../utils'; - -export function functionaliseBitwiseAnd(node: BinaryOperation, ast: AST): void { - IntxIntFunction(node, 'bitwise_and', 'only256', false, false, ast); -} diff --git a/src/warplib/implementations/maths/bitwiseOr.ts b/src/warplib/implementations/maths/bitwiseOr.ts deleted file mode 100644 index a1acdf678..000000000 --- a/src/warplib/implementations/maths/bitwiseOr.ts +++ /dev/null @@ -1,7 +0,0 @@ -import { BinaryOperation } from 'solc-typed-ast'; -import { AST } from '../../../ast/ast'; -import { IntxIntFunction } from '../../utils'; - -export function functionaliseBitwiseOr(node: BinaryOperation, ast: AST): void { - IntxIntFunction(node, 'bitwise_or', 'only256', false, false, ast); -} diff --git a/src/warplib/implementations/maths/div.ts b/src/warplib/implementations/maths/div.ts deleted file mode 100644 index ec292617c..000000000 --- a/src/warplib/implementations/maths/div.ts +++ /dev/null @@ -1,125 +0,0 @@ -import { BinaryOperation } from 'solc-typed-ast'; -import { AST } from '../../../ast/ast'; -import { mapRange } from '../../../utils/utils'; -import { forAllWidths, IntxIntFunction, mask, WarplibFunctionInfo } from '../../utils'; - -export function div_signed(): WarplibFunctionInfo { - return { - fileName: 'div_signed', - imports: [ - 'from starkware.cairo.common.bitwise import bitwise_and', - 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', - 'from starkware.cairo.common.uint256 import Uint256, uint256_signed_div_rem, uint256_eq', - 'from warplib.maths.utils import felt_to_uint256', - `from warplib.maths.int_conversions import ${mapRange( - 31, - (n) => `warp_int${8 * n + 8}_to_int256`, - ).join(', ')}, ${mapRange(31, (n) => `warp_int256_to_int${8 * n + 8}`).join(', ')}`, - `from warplib.maths.mul_signed import ${mapRange(32, (n) => `warp_mul_signed${8 * n + 8}`)}`, - ], - functions: forAllWidths((width) => { - if (width === 256) { - return [ - 'func warp_div_signed256{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(lhs : Uint256, rhs : Uint256) -> (res : Uint256){', - ` if (rhs.high == 0 and rhs.low == 0){`, - ` with_attr error_message("Division by zero error"){`, - ` assert 1 = 0;`, - ` }`, - ` }`, - ` let (is_minus_one) = uint256_eq(rhs, Uint256(${mask(128)}, ${mask(128)}));`, - ` if (is_minus_one == 1){`, - ' let (res : Uint256) = warp_mul_signed256(lhs, rhs);', - ' return (res,);', - ' }', - ' let (res : Uint256, _) = uint256_signed_div_rem(lhs, rhs);', - ' return (res,);', - '}', - ].join('\n'); - } else { - return [ - `func warp_div_signed${width}{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(lhs : felt, rhs : felt) -> (res : felt){`, - ` alloc_locals;`, - ` if (rhs == 0){`, - ` with_attr error_message("Division by zero error"){`, - ` assert 1 = 0;`, - ` }`, - ` }`, - ` if (rhs == ${mask(width)}){`, - ` let (res : felt) = warp_mul_signed${width}(lhs, rhs);`, - ` return (res,);`, - ' }', - ` let (local lhs_256) = warp_int${width}_to_int256(lhs);`, - ` let (rhs_256) = warp_int${width}_to_int256(rhs);`, - ' let (res256, _) = uint256_signed_div_rem(lhs_256, rhs_256);', - ` let (truncated) = warp_int256_to_int${width}(res256);`, - ` return (truncated,);`, - '}', - ].join('\n'); - } - }), - }; -} - -export function div_signed_unsafe(): WarplibFunctionInfo { - return { - fileName: 'div_signed_unsafe', - imports: [ - 'from starkware.cairo.common.bitwise import bitwise_and', - 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', - 'from starkware.cairo.common.uint256 import Uint256, uint256_signed_div_rem, uint256_eq', - 'from warplib.maths.utils import felt_to_uint256', - `from warplib.maths.int_conversions import ${mapRange( - 31, - (n) => `warp_int${8 * n + 8}_to_int256`, - ).join(', ')}, ${mapRange(31, (n) => `warp_int256_to_int${8 * n + 8}`).join(', ')}`, - `from warplib.maths.mul_signed_unsafe import ${mapRange( - 32, - (n) => `warp_mul_signed_unsafe${8 * n + 8}`, - )}`, - ], - functions: forAllWidths((width) => { - if (width === 256) { - return [ - 'func warp_div_signed_unsafe256{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(lhs : Uint256, rhs : Uint256) -> (res : Uint256){', - ` if (rhs.high == 0 and rhs.low == 0){`, - ` with_attr error_message("Division by zero error"){`, - ` assert 1 = 0;`, - ` }`, - ` }`, - ` let (is_minus_one) = uint256_eq(rhs, Uint256(${mask(128)}, ${mask(128)}));`, - ` if (is_minus_one == 1){`, - ' let (res : Uint256) = warp_mul_signed_unsafe256(lhs, rhs);', - ' return (res,);', - ' }', - ' let (res : Uint256, _) = uint256_signed_div_rem(lhs, rhs);', - ' return (res,);', - '}', - ].join('\n'); - } else { - return [ - `func warp_div_signed_unsafe${width}{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(lhs : felt, rhs : felt) -> (res : felt){`, - ` alloc_locals;`, - ` if (rhs == 0){`, - ` with_attr error_message("Division by zero error"){`, - ` assert 1 = 0;`, - ` }`, - ` }`, - ` if (rhs == ${mask(width)}){`, - ` let (res : felt) = warp_mul_signed_unsafe${width}(lhs, rhs);`, - ` return (res,);`, - ' }', - ` let (local lhs_256) = warp_int${width}_to_int256(lhs);`, - ` let (rhs_256) = warp_int${width}_to_int256(rhs);`, - ' let (res256, _) = uint256_signed_div_rem(lhs_256, rhs_256);', - ` let (truncated) = warp_int256_to_int${width}(res256);`, - ` return (truncated,);`, - '}', - ].join('\n'); - } - }), - }; -} - -export function functionaliseDiv(node: BinaryOperation, unsafe: boolean, ast: AST): void { - IntxIntFunction(node, 'div', 'signedOrWide', true, unsafe, ast); -} diff --git a/src/warplib/implementations/maths/eq.ts b/src/warplib/implementations/maths/eq.ts deleted file mode 100644 index 520e6c1d1..000000000 --- a/src/warplib/implementations/maths/eq.ts +++ /dev/null @@ -1,7 +0,0 @@ -import { BinaryOperation } from 'solc-typed-ast'; -import { AST } from '../../../ast/ast'; -import { Comparison } from '../../utils'; - -export function functionaliseEq(node: BinaryOperation, ast: AST): void { - Comparison(node, 'eq', 'only256', false, ast); -} diff --git a/src/warplib/implementations/maths/exp.ts b/src/warplib/implementations/maths/exp.ts deleted file mode 100644 index f3bfbb565..000000000 --- a/src/warplib/implementations/maths/exp.ts +++ /dev/null @@ -1,220 +0,0 @@ -import assert from 'assert'; -import { - BinaryOperation, - FunctionCall, - FunctionCallKind, - Identifier, - IntType, -} from 'solc-typed-ast'; -import { AST } from '../../../ast/ast'; -import { printNode, printTypeNode } from '../../../utils/astPrinter'; -import { WARPLIB_MATHS } from '../../../utils/importPaths'; -import { safeGetNodeType } from '../../../utils/nodeTypeProcessing'; -import { mapRange, typeNameFromTypeNode } from '../../../utils/utils'; -import { forAllWidths, getIntOrFixedByteBitWidth, mask, WarplibFunctionInfo } from '../../utils'; - -export function exp() { - return createExp(false, false); -} - -export function exp_signed() { - return createExp(true, false); -} - -export function exp_unsafe() { - return createExp(false, true); -} - -export function exp_signed_unsafe() { - return createExp(true, true); -} - -function createExp(signed: boolean, unsafe: boolean): WarplibFunctionInfo { - const suffix = `${signed ? '_signed' : ''}${unsafe ? '_unsafe' : ''}`; - return { - fileName: `exp${suffix}`, - imports: [ - 'from starkware.cairo.common.bitwise import bitwise_and', - 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', - 'from starkware.cairo.common.uint256 import Uint256, uint256_sub', - `from warplib.maths.mul${suffix} import ${mapRange( - 32, - (n) => `warp_mul${suffix}${8 * n + 8}`, - ).join(', ')}`, - ], - functions: forAllWidths((width) => { - if (width === 256) { - return [ - `func _repeated_multiplication${width}{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(op : Uint256, count : felt) -> (res : Uint256){`, - ` if (count == 0){`, - ` return (Uint256(1, 0),);`, - ` }`, - ` let (x) = _repeated_multiplication${width}(op, count - 1);`, - ` let (res) = warp_mul${suffix}${width}(op, x);`, - ` return (res,);`, - `}`, - `func warp_exp${suffix}${width}{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(lhs : Uint256, rhs : felt) -> (res : Uint256){`, - ` if (rhs == 0){`, - ` return (Uint256(1, 0),);`, - ' }', - ' if (lhs.high == 0){', - ` if (lhs.low * (lhs.low - 1) == 0){`, - ' return (lhs,);', - ` }`, - ` }`, - ...getNegativeOneShortcutCode(signed, width, false), - ` let (res) = _repeated_multiplication${width}(lhs, rhs);`, - ` return (res,);`, - `}`, - `func _repeated_multiplication_256_${width}{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(op : Uint256, count : Uint256) -> (res : Uint256){`, - ` if (count.low == 0 and count.high == 0){`, - ` return (Uint256(1, 0),);`, - ` }`, - ` let (decr) = uint256_sub(count, Uint256(1, 0));`, - ` let (x) = _repeated_multiplication_256_${width}(op, decr);`, - ` let (res) = warp_mul${suffix}${width}(op, x);`, - ` return (res,);`, - `}`, - `func warp_exp_wide${suffix}${width}{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(lhs : Uint256, rhs : Uint256) -> (res : Uint256){`, - ` if (rhs.high == 0 and rhs.low == 0){`, - ` return (Uint256(1, 0),);`, - ' }', - ' if (lhs.high == 0 and lhs.low * (lhs.low - 1) == 0){', - ' return (lhs,);', - ` }`, - ...getNegativeOneShortcutCode(signed, width, true), - ` let (res) = _repeated_multiplication_256_${width}(lhs, rhs);`, - ` return (res,);`, - `}`, - ].join('\n'); - } else { - return [ - `func _repeated_multiplication${width}{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(op : felt, count : felt) -> (res : felt){`, - ` alloc_locals;`, - ` if (count == 0){`, - ` return (1,);`, - ` }else{`, - ` let (x) = _repeated_multiplication${width}(op, count - 1);`, - ` local bitwise_ptr : BitwiseBuiltin* = bitwise_ptr;`, - ` let (res) = warp_mul${suffix}${width}(op, x);`, - ` return (res,);`, - ` }`, - `}`, - `func warp_exp${suffix}${width}{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(lhs : felt, rhs : felt) -> (res : felt){`, - ' if (rhs == 0){', - ' return (1,);', - ` }`, - ' if (lhs * (lhs-1) * (rhs-1) == 0){', - ' return (lhs,);', - ' }', - ...getNegativeOneShortcutCode(signed, width, false), - ` let (res) = _repeated_multiplication${width}(lhs, rhs);`, - ` return (res,);`, - '}', - `func _repeated_multiplication_256_${width}{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(op : felt, count : Uint256) -> (res : felt){`, - ` alloc_locals;`, - ` if (count.low == 0 and count.high == 0){`, - ` return (1,);`, - ` }`, - ` let (decr) = uint256_sub(count, Uint256(1, 0));`, - ` let (x) = _repeated_multiplication_256_${width}(op, decr);`, - ` local bitwise_ptr : BitwiseBuiltin* = bitwise_ptr;`, - ` let (res) = warp_mul${suffix}${width}(op, x);`, - ` return (res,);`, - `}`, - `func warp_exp_wide${suffix}${width}{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(lhs : felt, rhs : Uint256) -> (res : felt){`, - ' if (rhs.low == 0){', - ' if (rhs.high == 0){', - ' return (1,);', - ' }', - ` }`, - ' if (lhs * (lhs-1) == 0){', - ' return (lhs,);', - ' }', - ' if (rhs.low == 1 and rhs.high == 0){', - ' return (lhs,);', - ' }', - ...getNegativeOneShortcutCode(signed, width, true), - ` let (res) = _repeated_multiplication_256_${width}(lhs, rhs);`, - ` return (res,);`, - '}', - ].join('\n'); - } - }), - }; -} - -function getNegativeOneShortcutCode(signed: boolean, lhsWidth: number, rhsWide: boolean): string[] { - if (!signed) return []; - - if (lhsWidth < 256) { - return [ - `if ((lhs - ${mask(lhsWidth)}) == 0){`, - ` let (is_odd) = bitwise_and(${rhsWide ? 'rhs.low' : 'rhs'}, 1);`, - ` return (1 + is_odd * 0x${'f'.repeat(lhsWidth / 8 - 1)}e,);`, - `}`, - ]; - } else { - return [ - `if ((lhs.low - ${mask(128)}) == 0 and (lhs.high - ${mask(128)}) == 0){`, - ` let (is_odd) = bitwise_and(${rhsWide ? 'rhs.low' : 'rhs'}, 1);`, - ` return (Uint256(1 + is_odd * 0x${'f'.repeat(31)}e, is_odd * ${mask(128)}),);`, - `}`, - ]; - } -} - -export function functionaliseExp(node: BinaryOperation, unsafe: boolean, ast: AST) { - const lhsType = safeGetNodeType(node.vLeftExpression, ast.inference); - const rhsType = safeGetNodeType(node.vRightExpression, ast.inference); - const retType = safeGetNodeType(node, ast.inference); - assert( - retType instanceof IntType, - `${printNode(node)} has type ${printTypeNode(retType)}, which is not compatible with **`, - ); - assert( - rhsType instanceof IntType, - `${printNode(node)} has rhs-type ${rhsType.pp()}, which is not compatible with **`, - ); - const fullName = [ - 'warp_', - 'exp', - rhsType.nBits === 256 ? '_wide' : '', - retType.signed ? '_signed' : '', - unsafe ? '_unsafe' : '', - `${getIntOrFixedByteBitWidth(retType)}`, - ].join(''); - - const importName = [ - ...WARPLIB_MATHS, - `exp${retType.signed ? '_signed' : ''}${unsafe ? '_unsafe' : ''}`, - ]; - - const importedFunc = ast.registerImport( - node, - importName, - fullName, - [ - ['lhs', typeNameFromTypeNode(lhsType, ast)], - ['rhs', typeNameFromTypeNode(rhsType, ast)], - ], - [['res', typeNameFromTypeNode(retType, ast)]], - ); - - const call = new FunctionCall( - ast.reserveId(), - node.src, - node.typeString, - FunctionCallKind.FunctionCall, - new Identifier( - ast.reserveId(), - '', - `function (${node.typeString}, ${node.typeString}) returns (${node.typeString})`, - fullName, - importedFunc.id, - ), - [node.vLeftExpression, node.vRightExpression], - ); - - ast.replaceNode(node, call); -} diff --git a/src/warplib/implementations/maths/ge.ts b/src/warplib/implementations/maths/ge.ts deleted file mode 100644 index 426e9df1d..000000000 --- a/src/warplib/implementations/maths/ge.ts +++ /dev/null @@ -1,40 +0,0 @@ -import { BinaryOperation } from 'solc-typed-ast'; -import { AST } from '../../../ast/ast'; -import { mapRange } from '../../../utils/utils'; -import { forAllWidths, Comparison, WarplibFunctionInfo } from '../../utils'; - -export function ge_signed(): WarplibFunctionInfo { - return { - fileName: 'ge_signed', - imports: [ - 'from starkware.cairo.common.bitwise import bitwise_and', - 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', - 'from starkware.cairo.common.uint256 import Uint256, uint256_signed_le', - `from warplib.maths.lt_signed import ${mapRange(31, (n) => `warp_le_signed${8 * n + 8}`).join( - ', ', - )}`, - ], - functions: forAllWidths((width) => { - if (width === 256) { - return [ - 'func warp_ge_signed256{range_check_ptr}(lhs : Uint256, rhs : Uint256) -> (res : felt){', - ' let (res) = uint256_signed_le(rhs, lhs);', - ' return (res,);', - '}', - ].join('\n'); - } else { - return [ - `func warp_ge_signed${width}{bitwise_ptr : BitwiseBuiltin*, range_check_ptr}(`, - ' lhs : felt, rhs : felt) -> (res : felt){', - ` let (res) = warp_le_signed${width}(rhs, lhs);`, - ` return (res,);`, - '}', - ].join('\n'); - } - }), - }; -} - -export function functionaliseGe(node: BinaryOperation, ast: AST): void { - Comparison(node, 'ge', 'signedOrWide', true, ast); -} diff --git a/src/warplib/implementations/maths/gt.ts b/src/warplib/implementations/maths/gt.ts deleted file mode 100644 index 7764c5aca..000000000 --- a/src/warplib/implementations/maths/gt.ts +++ /dev/null @@ -1,40 +0,0 @@ -import { BinaryOperation } from 'solc-typed-ast'; -import { AST } from '../../../ast/ast'; -import { mapRange } from '../../../utils/utils'; -import { forAllWidths, Comparison, WarplibFunctionInfo } from '../../utils'; - -export function gt_signed(): WarplibFunctionInfo { - return { - fileName: 'gt_signed', - imports: [ - 'from starkware.cairo.common.bitwise import bitwise_and', - 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', - 'from starkware.cairo.common.uint256 import Uint256, uint256_signed_lt', - `from warplib.maths.lt_signed import ${mapRange(31, (n) => `warp_lt_signed${8 * n + 8}`).join( - ', ', - )}`, - ], - functions: forAllWidths((width) => { - if (width === 256) { - return [ - 'func warp_gt_signed256{range_check_ptr}(lhs : Uint256, rhs : Uint256) -> (res : felt){', - ' let (res) = uint256_signed_lt(rhs, lhs);', - ' return (res,);', - '}', - ].join('\n'); - } else { - return [ - `func warp_gt_signed${width}{bitwise_ptr : BitwiseBuiltin*, range_check_ptr}(`, - ' lhs : felt, rhs : felt) -> (res : felt){', - ` let (res) = warp_lt_signed${width}(rhs, lhs);`, - ` return (res,);`, - '}', - ].join('\n'); - } - }), - }; -} - -export function functionaliseGt(node: BinaryOperation, ast: AST): void { - Comparison(node, 'gt', 'signedOrWide', true, ast); -} diff --git a/src/warplib/implementations/maths/le.ts b/src/warplib/implementations/maths/le.ts deleted file mode 100644 index 9e18a1f5a..000000000 --- a/src/warplib/implementations/maths/le.ts +++ /dev/null @@ -1,61 +0,0 @@ -import { BinaryOperation } from 'solc-typed-ast'; -import { AST } from '../../../ast/ast'; -import { forAllWidths, msb, Comparison, WarplibFunctionInfo } from '../../utils'; - -export function le_signed(): WarplibFunctionInfo { - return { - fileName: 'le_signed', - imports: [ - 'from starkware.cairo.common.bitwise import bitwise_and', - 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', - 'from starkware.cairo.common.math_cmp import is_le_felt', - 'from starkware.cairo.common.uint256 import Uint256, uint256_signed_le', - ], - functions: forAllWidths((width) => { - if (width === 256) { - return [ - `func warp_le_signed${width}{range_check_ptr}(lhs : Uint256, rhs : Uint256) -> (res : felt){`, - ' let (res) = uint256_signed_le(lhs, rhs);', - ' return (res,);', - '}', - ].join('\n'); - } else { - return [ - `func warp_le_signed${width}{bitwise_ptr : BitwiseBuiltin*, range_check_ptr}(`, - ` lhs : felt, rhs : felt) -> (res : felt){`, - ` alloc_locals;`, - ` let (lhs_msb : felt) = bitwise_and(lhs, ${msb(width)});`, - ` let (rhs_msb : felt) = bitwise_and(rhs, ${msb(width)});`, - ` local bitwise_ptr : BitwiseBuiltin* = bitwise_ptr;`, - ` if (lhs_msb == 0){`, - ` // lhs >= 0`, - ` if (rhs_msb == 0){`, - ` // rhs >= 0`, - ` let result = is_le_felt(lhs, rhs);`, - ` return (result,);`, - ` }else{`, - ` // rhs < 0`, - ` return (0,);`, - ` }`, - ` }else{`, - ` // lhs < 0`, - ` if (rhs_msb == 0){`, - ` // rhs >= 0`, - ` return (1,);`, - ` }else{`, - ` // rhs < 0`, - ` // (signed) lhs <= rhs <=> (unsigned) lhs >= rhs`, - ` let result = is_le_felt(lhs, rhs);`, - ` return (result,);`, - ` }`, - ` }`, - `}`, - ].join('\n'); - } - }), - }; -} - -export function functionaliseLe(node: BinaryOperation, ast: AST): void { - Comparison(node, 'le', 'signedOrWide', true, ast); -} diff --git a/src/warplib/implementations/maths/lt.ts b/src/warplib/implementations/maths/lt.ts deleted file mode 100644 index b58212b8c..000000000 --- a/src/warplib/implementations/maths/lt.ts +++ /dev/null @@ -1,43 +0,0 @@ -import { BinaryOperation } from 'solc-typed-ast'; -import { AST } from '../../../ast/ast'; -import { mapRange } from '../../../utils/utils'; -import { forAllWidths, Comparison, WarplibFunctionInfo } from '../../utils'; - -export function lt_signed(): WarplibFunctionInfo { - return { - fileName: 'lt_signed', - imports: [ - 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', - 'from starkware.cairo.common.uint256 import Uint256, uint256_signed_lt', - 'from warplib.maths.utils import felt_to_uint256', - `from warplib.maths.le_signed import ${mapRange(31, (n) => `warp_le_signed${8 * n + 8}`).join( - ', ', - )}`, - ], - functions: forAllWidths((width) => { - if (width === 256) { - return [ - 'func warp_lt_signed256{range_check_ptr}(lhs : Uint256, rhs : Uint256) -> (res : felt){', - ' let (res) = uint256_signed_lt(lhs, rhs);', - ' return (res,);', - '}', - ].join('\n'); - } else { - return [ - `func warp_lt_signed${width}{bitwise_ptr : BitwiseBuiltin*, range_check_ptr}(`, - ' lhs : felt, rhs : felt) -> (res : felt){', - ' if (lhs == rhs){', - ' return (0,);', - ' }', - ` let (res) = warp_le_signed${width}(lhs, rhs);`, - ` return (res,);`, - '}', - ].join('\n'); - } - }), - }; -} - -export function functionaliseLt(node: BinaryOperation, ast: AST): void { - Comparison(node, 'lt', 'signedOrWide', true, ast); -} diff --git a/src/warplib/implementations/maths/mod.ts b/src/warplib/implementations/maths/mod.ts deleted file mode 100644 index 6cae243d4..000000000 --- a/src/warplib/implementations/maths/mod.ts +++ /dev/null @@ -1,55 +0,0 @@ -import { BinaryOperation } from 'solc-typed-ast'; -import { AST } from '../../../ast/ast'; -import { mapRange } from '../../../utils/utils'; -import { forAllWidths, IntxIntFunction, WarplibFunctionInfo } from '../../utils'; - -export function mod_signed(): WarplibFunctionInfo { - return { - fileName: 'mod_signed', - imports: [ - 'from starkware.cairo.common.bitwise import bitwise_and', - 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', - 'from starkware.cairo.common.uint256 import Uint256, uint256_signed_div_rem', - 'from warplib.maths.utils import felt_to_uint256', - `from warplib.maths.int_conversions import ${mapRange( - 31, - (n) => `warp_int${8 * n + 8}_to_int256`, - ).join(', ')}, ${mapRange(31, (n) => `warp_int256_to_int${8 * n + 8}`).join(', ')}`, - ], - functions: forAllWidths((width) => { - if (width === 256) { - return [ - 'func warp_mod_signed256{range_check_ptr}(lhs : Uint256, rhs : Uint256) -> (res : Uint256){', - ` if (rhs.high == 0 and rhs.low == 0){`, - ` with_attr error_message("Modulo by zero error"){`, - ` assert 1 = 0;`, - ` }`, - ` }`, - ' let (_, res : Uint256) = uint256_signed_div_rem(lhs, rhs);', - ' return (res,);', - '}', - ].join('\n'); - } else { - return [ - `func warp_mod_signed${width}{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(lhs : felt, rhs : felt) -> (res : felt){`, - ` alloc_locals;`, - ` if (rhs == 0){`, - ` with_attr error_message("Modulo by zero error"){`, - ` assert 1 = 0;`, - ` }`, - ` }`, - ` let (local lhs_256) = warp_int${width}_to_int256(lhs);`, - ` let (rhs_256) = warp_int${width}_to_int256(rhs);`, - ' let (_, res256) = uint256_signed_div_rem(lhs_256, rhs_256);', - ` let (truncated) = warp_int256_to_int${width}(res256);`, - ` return (truncated,);`, - '}', - ].join('\n'); - } - }), - }; -} - -export function functionaliseMod(node: BinaryOperation, ast: AST): void { - IntxIntFunction(node, 'mod', 'signedOrWide', true, false, ast); -} diff --git a/src/warplib/implementations/maths/mul.ts b/src/warplib/implementations/maths/mul.ts deleted file mode 100644 index 738c5ba2b..000000000 --- a/src/warplib/implementations/maths/mul.ts +++ /dev/null @@ -1,247 +0,0 @@ -import { BinaryOperation } from 'solc-typed-ast'; -import { AST } from '../../../ast/ast'; -import { mapRange } from '../../../utils/utils'; -import { - forAllWidths, - uint256, - pow2, - bound, - mask, - msb, - IntxIntFunction, - WarplibFunctionInfo, -} from '../../utils'; - -export function mul(): WarplibFunctionInfo { - return { - fileName: 'mul', - imports: [ - 'from starkware.cairo.common.uint256 import Uint256, uint256_mul', - 'from starkware.cairo.common.math_cmp import is_le_felt', - 'from warplib.maths.ge import warp_ge256', - 'from warplib.maths.utils import felt_to_uint256', - ], - functions: forAllWidths((width) => { - if (width === 256) { - return [ - 'func warp_mul256{range_check_ptr}(lhs : Uint256, rhs : Uint256) -> (res : Uint256){', - ' let (result : Uint256, overflow : Uint256) = uint256_mul(lhs, rhs);', - ' assert overflow.low = 0;', - ' assert overflow.high = 0;', - ' return (result,);', - '}', - ].join('\n'); - } else if (width >= 128) { - return [ - `func warp_mul${width}{range_check_ptr}(lhs : felt, rhs : felt) -> (res : felt){`, - ' alloc_locals;', - ' let (l256 : Uint256) = felt_to_uint256(lhs);', - ' let (r256 : Uint256) = felt_to_uint256(rhs);', - ' let (local res : Uint256) = warp_mul256(l256, r256);', - ` let (outOfRange : felt) = warp_ge256(res, ${uint256(pow2(width))});`, - ' assert outOfRange = 0;', - ` return (res.low + ${bound(128)} * res.high,);`, - '}', - ].join('\n'); - } else { - return [ - `func warp_mul${width}{range_check_ptr}(lhs : felt, rhs : felt) -> (res : felt){`, - ' let res = lhs * rhs;', - ` let inRange : felt = is_le_felt(res, ${mask(width)});`, - ' assert inRange = 1;', - ' return (res,);', - '}', - ].join('\n'); - } - }), - }; -} - -export function mul_unsafe(): WarplibFunctionInfo { - return { - fileName: 'mul_unsafe', - imports: [ - 'from starkware.cairo.common.bitwise import bitwise_and', - 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', - 'from starkware.cairo.common.uint256 import Uint256, uint256_mul', - 'from warplib.maths.utils import felt_to_uint256', - ], - functions: forAllWidths((width) => { - if (width === 256) { - return [ - `func warp_mul_unsafe256{range_check_ptr}(lhs : Uint256, rhs : Uint256) -> (res : Uint256){`, - ` let (res : Uint256, _) = uint256_mul(lhs, rhs);`, - ` return (res,);`, - `}`, - ].join('\n'); - } else if (width >= 128) { - return [ - `func warp_mul_unsafe${width}{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(lhs : felt, rhs : felt) -> (res : felt){`, - ` alloc_locals;`, - ` let (l256 : Uint256) = felt_to_uint256(lhs);`, - ` let (r256 : Uint256) = felt_to_uint256(rhs);`, - ` let (local res : Uint256, _) = uint256_mul(l256, r256);`, - ` let (high) = bitwise_and(res.high, ${mask(width - 128)});`, - ` return (res.low + ${bound(128)} * high,);`, - `}`, - ].join('\n'); - } else { - return [ - `func warp_mul_unsafe${width}{bitwise_ptr : BitwiseBuiltin*}(lhs : felt, rhs : felt) -> (res : felt){`, - ` let (res) = bitwise_and(lhs * rhs, ${mask(width)});`, - ` return (res,);`, - '}', - ].join('\n'); - } - }), - }; -} - -export function mul_signed(): WarplibFunctionInfo { - return { - fileName: 'mul_signed', - imports: [ - 'from starkware.cairo.common.bitwise import bitwise_and', - 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', - 'from starkware.cairo.common.math_cmp import is_le_felt', - 'from starkware.cairo.common.uint256 import Uint256, uint256_mul, uint256_cond_neg, uint256_signed_nn, uint256_neg, uint256_le', - 'from warplib.maths.utils import felt_to_uint256', - `from warplib.maths.le import warp_le`, - `from warplib.maths.mul import ${mapRange(31, (n) => `warp_mul${8 * n + 8}`).join(', ')}`, - `from warplib.maths.negate import ${mapRange(31, (n) => `warp_negate${8 * n + 8}`).join( - ', ', - )}`, - ], - functions: forAllWidths((width) => { - if (width === 256) { - return [ - `func warp_mul_signed256{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(`, - ` lhs : Uint256, rhs : Uint256) -> (result : Uint256){`, - ` alloc_locals;`, - ` // 1 => lhs >= 0, 0 => lhs < 0`, - ` let (lhs_nn) = uint256_signed_nn(lhs);`, - ` // 1 => rhs >= 0, 0 => rhs < 0`, - ` let (local rhs_nn) = uint256_signed_nn(rhs);`, - ` // negates if arg is 1, which is if lhs_nn is 0, which is if lhs < 0`, - ` let (lhs_abs) = uint256_cond_neg(lhs, 1 - lhs_nn);`, - ` // negates if arg is 1`, - ` let (rhs_abs) = uint256_cond_neg(rhs, 1 - rhs_nn);`, - ` let (res_abs, overflow) = uint256_mul(lhs_abs, rhs_abs);`, - ` assert overflow.low = 0;`, - ` assert overflow.high = 0;`, - ` let res_should_be_neg = lhs_nn + rhs_nn;`, - ` if (res_should_be_neg == 1){`, - ` let (in_range) = uint256_le(res_abs, Uint256(0,${msb(128)}));`, - ` assert in_range = 1;`, - ` let (negated) = uint256_neg(res_abs);`, - ` return (negated,);`, - ` }else{`, - ` let (msb) = bitwise_and(res_abs.high, ${msb(128)});`, - ` assert msb = 0;`, - ` return (res_abs,);`, - ` }`, - `}`, - ].join('\n'); - } else { - return [ - `func warp_mul_signed${width}{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(`, - ` lhs : felt, rhs : felt) -> (res : felt){`, - ` alloc_locals;`, - ` let (left_msb) = bitwise_and(lhs, ${msb(width)});`, - ` if (left_msb == 0){`, - ` let (right_msb) = bitwise_and(rhs, ${msb(width)});`, - ` if (right_msb == 0){`, - ` let (res) = warp_mul${width}(lhs, rhs);`, - ` let (res_msb) = bitwise_and(res, ${msb(width)});`, - ` assert res_msb = 0;`, - ` return (res,);`, - ` }else{`, - ` let (rhs_abs) = warp_negate${width}(rhs);`, - ` let (res_abs) = warp_mul${width}(lhs, rhs_abs);`, - ` let (in_range) = warp_le(res_abs, ${msb(width)});`, - ` assert in_range = 1;`, - ` let (res) = warp_negate${width}(res_abs);`, - ` return (res,);`, - ` }`, - ` }else{`, - ` let (right_msb) = bitwise_and(rhs, ${msb(width)});`, - ` if (right_msb == 0){`, - ` let (lhs_abs) = warp_negate${width}(lhs);`, - ` let (res_abs) = warp_mul${width}(lhs_abs, rhs);`, - ` let (in_range) = warp_le(res_abs, ${msb(width)});`, - ` assert in_range = 1;`, - ` let (res) = warp_negate${width}(res_abs);`, - ` return (res,);`, - ` }else{`, - ` let (lhs_abs) = warp_negate${width}(lhs);`, - ` let (rhs_abs) = warp_negate${width}(rhs);`, - ` let (res) = warp_mul${width}(lhs_abs, rhs_abs);`, - ` let (res_msb) = bitwise_and(res, ${msb(width)});`, - ` assert res_msb = 0;`, - ` return (res,);`, - ` }`, - ` }`, - `}`, - ].join('\n'); - } - }), - }; -} - -export function mul_signed_unsafe(): WarplibFunctionInfo { - return { - fileName: 'mul_signed_unsafe', - imports: [ - 'from starkware.cairo.common.bitwise import bitwise_and', - 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', - 'from starkware.cairo.common.math_cmp import is_le_felt', - 'from starkware.cairo.common.uint256 import Uint256, uint256_mul, uint256_cond_neg, uint256_signed_nn', - `from warplib.maths.mul_unsafe import ${mapRange( - 31, - (n) => `warp_mul_unsafe${8 * n + 8}`, - ).join(', ')}`, - `from warplib.maths.negate import ${mapRange(31, (n) => `warp_negate${8 * n + 8}`).join( - ', ', - )}`, - 'from warplib.maths.utils import felt_to_uint256', - ], - functions: forAllWidths((width) => { - if (width === 256) { - return [ - `func warp_mul_signed_unsafe256{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(`, - ` lhs : Uint256, rhs : Uint256) -> (result : Uint256){`, - ` alloc_locals;`, - ` let (lhs_nn) = uint256_signed_nn(lhs);`, - ` let (local rhs_nn) = uint256_signed_nn(rhs);`, - ` let (lhs_abs) = uint256_cond_neg(lhs, lhs_nn);`, - ` let (rhs_abs) = uint256_cond_neg(rhs, rhs_nn);`, - ` let (res_abs, _) = uint256_mul(lhs_abs, rhs_abs);`, - ` let (res) = uint256_cond_neg(res_abs, (lhs_nn + rhs_nn) * (2 - lhs_nn - rhs_nn));`, - ` return (res,);`, - `}`, - ].join('\n'); - } else { - return [ - `func warp_mul_signed_unsafe${width}{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(`, - ` lhs : felt, rhs : felt) -> (res : felt){`, - ` alloc_locals;`, - ` let (local left_msb) = bitwise_and(lhs, ${msb(width)});`, - ` let (local right_msb) = bitwise_and(rhs, ${msb(width)});`, - ` let (res) = warp_mul_unsafe${width}(lhs, rhs);`, - ` let not_neg = (left_msb + right_msb) * (${bound(width)} - left_msb - right_msb);`, - ` if (not_neg == ${msb(width)}){`, - ` let (res) = warp_negate${width}(res);`, - ` return (res,);`, - ` }else{`, - ` return (res,);`, - ` }`, - `}`, - ].join('\n'); - } - }), - }; -} - -export function functionaliseMul(node: BinaryOperation, unsafe: boolean, ast: AST): void { - IntxIntFunction(node, 'mul', 'always', true, unsafe, ast); -} diff --git a/src/warplib/implementations/maths/neq.ts b/src/warplib/implementations/maths/neq.ts deleted file mode 100644 index 0d62da748..000000000 --- a/src/warplib/implementations/maths/neq.ts +++ /dev/null @@ -1,7 +0,0 @@ -import { BinaryOperation } from 'solc-typed-ast'; -import { AST } from '../../../ast/ast'; -import { Comparison } from '../../utils'; - -export function functionaliseNeq(node: BinaryOperation, ast: AST): void { - Comparison(node, 'neq', 'only256', false, ast); -} diff --git a/src/warplib/implementations/maths/or.ts b/src/warplib/implementations/maths/or.ts deleted file mode 100644 index b8e92e419..000000000 --- a/src/warplib/implementations/maths/or.ts +++ /dev/null @@ -1,7 +0,0 @@ -import { BinaryOperation } from 'solc-typed-ast'; -import { AST } from '../../../ast/ast'; -import { BoolxBoolFunction } from '../../utils'; - -export function functionaliseOr(node: BinaryOperation, ast: AST): void { - BoolxBoolFunction(node, 'or', ast); -} diff --git a/src/warplib/implementations/maths/shl.ts b/src/warplib/implementations/maths/shl.ts deleted file mode 100644 index 7ce99e90a..000000000 --- a/src/warplib/implementations/maths/shl.ts +++ /dev/null @@ -1,118 +0,0 @@ -import assert from 'assert'; -import { - BinaryOperation, - FixedBytesType, - FunctionCall, - FunctionCallKind, - Identifier, - IntType, -} from 'solc-typed-ast'; -import { AST } from '../../../ast/ast'; -import { printNode, printTypeNode } from '../../../utils/astPrinter'; -import { WARPLIB_MATHS } from '../../../utils/importPaths'; -import { safeGetNodeType } from '../../../utils/nodeTypeProcessing'; -import { typeNameFromTypeNode } from '../../../utils/utils'; -import { forAllWidths, getIntOrFixedByteBitWidth, WarplibFunctionInfo } from '../../utils'; - -// rhs is always unsigned, and signed and unsigned shl are the same -export function shl(): WarplibFunctionInfo { - //Need to provide an implementation with 256bit rhs and <256bit lhs - return { - fileName: 'shl', - imports: [ - 'from starkware.cairo.common.bitwise import bitwise_and', - 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', - 'from starkware.cairo.common.math import split_felt', - 'from starkware.cairo.common.math_cmp import is_le_felt', - 'from starkware.cairo.common.uint256 import Uint256, uint256_shl', - 'from warplib.maths.pow2 import pow2', - ], - functions: forAllWidths((width) => { - if (width === 256) { - return [ - 'func warp_shl256{range_check_ptr}(lhs : Uint256, rhs : felt) -> (result : Uint256){', - ' let (high, low) = split_felt(rhs);', - ' let (res) = uint256_shl(lhs, Uint256(low, high));', - ' return (res,);', - '}', - 'func warp_shl256_256{range_check_ptr}(lhs : Uint256, rhs : Uint256) -> (result : Uint256){', - ' let (res) = uint256_shl(lhs, rhs);', - ' return (res,);', - '}', - ].join('\n'); - } else { - return [ - `func warp_shl${width}{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(`, - ` lhs : felt, rhs : felt) -> (res : felt){`, - ` // width <= rhs (shift amount) means result will be 0`, - ` let large_shift = is_le_felt(${width}, rhs);`, - ` if (large_shift == 1){`, - ` return (0,);`, - ` }else{`, - ` let preserved_width = ${width} - rhs;`, - ` let (preserved_bound) = pow2(preserved_width);`, - ` let (lhs_truncated) = bitwise_and(lhs, preserved_bound - 1);`, - ` let (multiplier) = pow2(rhs);`, - ` return (lhs_truncated * multiplier,);`, - ` }`, - `}`, - `func warp_shl${width}_256{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(`, - ` lhs : felt, rhs : Uint256) -> (res : felt){`, - ` if (rhs.high == 0){`, - ` let (res) = warp_shl${width}(lhs, rhs.low);`, - ` return (res,);`, - ` }else{`, - ` return (0,);`, - ` }`, - `}`, - ].join('\n'); - } - }), - }; -} - -export function functionaliseShl(node: BinaryOperation, ast: AST): void { - const lhsType = safeGetNodeType(node.vLeftExpression, ast.inference); - const rhsType = safeGetNodeType(node.vRightExpression, ast.inference); - const retType = safeGetNodeType(node, ast.inference); - - assert( - lhsType instanceof IntType || lhsType instanceof FixedBytesType, - `lhs of << ${printNode(node)} non-int type ${printTypeNode(lhsType)}`, - ); - assert( - rhsType instanceof IntType, - `rhs of << ${printNode(node)} non-int type ${printTypeNode(rhsType)}`, - ); - - const lhsWidth = getIntOrFixedByteBitWidth(lhsType); - - const fullName = `warp_shl${lhsWidth}${rhsType.nBits === 256 ? '_256' : ''}`; - - const importedFunc = ast.registerImport( - node, - [...WARPLIB_MATHS, 'shl'], - fullName, - [ - ['lhs', typeNameFromTypeNode(lhsType, ast)], - ['rhs', typeNameFromTypeNode(rhsType, ast)], - ], - [['res', typeNameFromTypeNode(retType, ast)]], - ); - const call = new FunctionCall( - ast.reserveId(), - node.src, - node.typeString, - FunctionCallKind.FunctionCall, - new Identifier( - ast.reserveId(), - '', - `function (${node.vLeftExpression.typeString}, ${node.vRightExpression.typeString}) returns (${node.typeString})`, - fullName, - importedFunc.id, - ), - [node.vLeftExpression, node.vRightExpression], - ); - - ast.replaceNode(node, call); -} diff --git a/src/warplib/implementations/maths/shr.ts b/src/warplib/implementations/maths/shr.ts deleted file mode 100644 index ea99b2503..000000000 --- a/src/warplib/implementations/maths/shr.ts +++ /dev/null @@ -1,269 +0,0 @@ -import assert from 'assert'; -import { - BinaryOperation, - FixedBytesType, - FunctionCall, - FunctionCallKind, - Identifier, - IntType, -} from 'solc-typed-ast'; -import { AST } from '../../../ast/ast'; -import { printNode, printTypeNode } from '../../../utils/astPrinter'; -import { WARPLIB_MATHS } from '../../../utils/importPaths'; -import { safeGetNodeType } from '../../../utils/nodeTypeProcessing'; -import { mapRange, typeNameFromTypeNode } from '../../../utils/utils'; -import { - forAllWidths, - bound, - msb, - mask, - getIntOrFixedByteBitWidth, - WarplibFunctionInfo, -} from '../../utils'; - -export function shr(): WarplibFunctionInfo { - return { - fileName: 'shr', - imports: [ - 'from starkware.cairo.common.bitwise import bitwise_and, bitwise_not', - 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', - 'from starkware.cairo.common.math_cmp import is_le, is_le_felt', - 'from starkware.cairo.common.uint256 import Uint256, uint256_and', - 'from warplib.maths.pow2 import pow2', - ], - functions: forAllWidths((width) => { - if (width === 256) { - return [ - `func warp_shr256{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(lhs : Uint256, rhs : felt) -> (`, - ` result : Uint256){`, - ` let le_127 = is_le(rhs, 127);`, - ` if (le_127 == 1){`, - ` // (h', l') := (h, l) >> rhs`, - ` // p := 2^rhs`, - ` // l' = ((h & (p-1)) << (128 - rhs)) + ((l&~(p-1)) >> rhs)`, - ` // = ((h & (p-1)) << 128 >> rhs) + ((l&~(p-1)) >> rhs)`, - ` // = (h & (p-1)) * 2^128 / p + (l&~(p-1)) / p`, - ` // = (h & (p-1) * 2^128 + l&~(p-1)) / p`, - ` // h' = h >> rhs = (h - h&(p-1)) / p`, - ` let (p) = pow2(rhs);`, - ` let (low_mask) = bitwise_not(p - 1);`, - ` let (low_part) = bitwise_and(lhs.low, low_mask);`, - ` let (high_part) = bitwise_and(lhs.high, p - 1);`, - ` return (`, - ` Uint256(low=(low_part + ${bound( - 128, - )} * high_part) / p, high=(lhs.high - high_part) / p),);`, - ` }`, - ` let le_255 = is_le(rhs, 255);`, - ` if (le_255 == 1){`, - ` let (p) = pow2(rhs - 128);`, - ` let (mask) = bitwise_not(p - 1);`, - ` let (res) = bitwise_and(lhs.high, mask);`, - ` return (Uint256(res / p, 0),);`, - ` }`, - ` return (Uint256(0, 0),);`, - `}`, - `func warp_shr256_256{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(lhs : Uint256, rhs : Uint256) -> (`, - ` result : Uint256){`, - ` if (rhs.high != 0){`, - ` return (Uint256(0, 0),);`, - ` }`, - ` let le_127 = is_le(rhs.low, 127);`, - ` if (le_127 == 1){`, - ` // (h', l') := (h, l) >> rhs`, - ` // p := 2^rhs`, - ` // l' = ((h & (p-1)) << (128 - rhs)) + ((l&~(p-1)) >> rhs)`, - ` // = ((h & (p-1)) << 128 >> rhs) + ((l&~(p-1)) >> rhs)`, - ` // = (h & (p-1)) * 2^128 / p + (l&~(p-1)) / p`, - ` // = (h & (p-1) * 2^128 + l&~(p-1)) / p`, - ` // h' = h >> rhs = (h - h&(p-1)) / p`, - ` let (p) = pow2(rhs.low);`, - ` let (low_mask) = bitwise_not(p - 1);`, - ` let (low_part) = bitwise_and(lhs.low, low_mask);`, - ` let (high_part) = bitwise_and(lhs.high, p - 1);`, - ` return (`, - ` Uint256(low=(low_part + ${bound( - 128, - )} * high_part) / p, high=(lhs.high - high_part) / p),);`, - ` }`, - ` let le_255 = is_le(rhs.low, 255);`, - ` if (le_255 == 1){`, - ` let (p) = pow2(rhs.low - 128);`, - ` let (mask) = bitwise_not(p - 1);`, - ` let (res) = bitwise_and(lhs.high, mask);`, - ` return (Uint256(res / p, 0),);`, - ` }`, - ` return (Uint256(0, 0),);`, - `}`, - ].join('\n'); - } else { - return [ - `func warp_shr${width}{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(`, - ` lhs : felt, rhs : felt) -> (res : felt){`, - ` let large_shift = is_le_felt(${width}, rhs);`, - ` if (large_shift == 1){`, - ` return (0,);`, - ` }else{`, - ` let preserved_width = ${width} - rhs;`, - ` let (preserved_bound) = pow2(preserved_width);`, - ` let mask = preserved_bound - 1;`, - ` let (divisor) = pow2(rhs);`, - ` let shifted_mask = mask * divisor;`, - ` let (lhs_truncated) = bitwise_and(lhs, shifted_mask);`, - ` return (lhs_truncated / divisor , );`, - ` }`, - `}`, - `func warp_shr${width}_256{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(`, - ` lhs : felt, rhs : Uint256) -> (res : felt){`, - ` if (rhs.high == 0){`, - ` let (res,) = warp_shr${width}(lhs, rhs.low);`, - ` return (res,);`, - ` }else{`, - ` return (0,);`, - ` }`, - `}`, - ].join('\n'); - } - }), - }; -} - -export function shr_signed(): WarplibFunctionInfo { - return { - fileName: 'shr_signed', - imports: [ - 'from starkware.cairo.common.bitwise import bitwise_and', - 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', - 'from starkware.cairo.common.math_cmp import is_le, is_le_felt', - 'from starkware.cairo.common.uint256 import Uint256, uint256_and', - 'from warplib.maths.pow2 import pow2', - `from warplib.maths.shr import ${mapRange(32, (n) => `warp_shr${8 * n + 8}`).join(', ')}`, - ], - functions: forAllWidths((width) => { - if (width === 256) { - return [ - `func warp_shr_signed256{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(lhs : Uint256, rhs : felt) -> (res : Uint256){`, - ` alloc_locals;`, - ` let (local lhs_msb) = bitwise_and(lhs.high, ${msb(128)});`, - ` let (logical_shift) = warp_shr256(lhs, rhs);`, - ` if (lhs_msb == 0){`, - ` return (logical_shift,);`, - ` }else{`, - ` let large_shift = is_le(${width}, rhs);`, - ` if (large_shift == 1){`, - ` return (Uint256(${mask(128)}, ${mask(128)}),);`, - ` }else{`, - ` let crosses_boundary = is_le(128, rhs);`, - ` if (crosses_boundary == 1){`, - ` let (bound) = pow2(rhs-128);`, - ` let ones = bound - 1;`, - ` let (shift) = pow2(256-rhs);`, - ` return (Uint256(logical_shift.low+ones*shift, ${mask(128)}),);`, - ` }else{`, - ` let (bound) = pow2(rhs);`, - ` let ones = bound - 1;`, - ` let (shift) = pow2(128-rhs);`, - ` return (Uint256(logical_shift.low, logical_shift.high+ones*shift),);`, - ` }`, - ` }`, - ` }`, - `}`, - `func warp_shr_signed256_256{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(lhs : Uint256, rhs : Uint256) -> (res : Uint256){`, - ` if (rhs.high == 0){`, - ` let (res) = warp_shr_signed256(lhs, rhs.low);`, - ` return (res,);`, - ` }else{`, - ` let (res) = warp_shr_signed256(lhs, 256);`, - ` return (res,);`, - ` }`, - `}`, - ].join('\n'); - } else { - return [ - `func warp_shr_signed${width}{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(`, - ` lhs : felt, rhs : felt) -> (res : felt){`, - ` alloc_locals;`, - ` let (local lhs_msb) = bitwise_and(lhs, ${msb(width)});`, - ` local bitwise_ptr : BitwiseBuiltin* = bitwise_ptr;`, - ` if (lhs_msb == 0){`, - ` let (res) = warp_shr${width}(lhs, rhs);`, - ` return (res,);`, - ` }else{`, - ` let large_shift = is_le_felt(${width}, rhs);`, - ` if (large_shift == 1){`, - ` return (${mask(width)},);`, - ` }else{`, - ` let (shifted) = warp_shr${width}(lhs, rhs);`, - ` let (sign_extend_bound) = pow2(rhs);`, - ` let sign_extend_value = sign_extend_bound - 1;`, - ` let (sign_extend_multiplier) = pow2(${width} - rhs);`, - ` return (shifted + sign_extend_value * sign_extend_multiplier,);`, - ` }`, - ` }`, - `}`, - `func warp_shr_signed${width}_256{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(`, - ` lhs : felt, rhs : Uint256) -> (res : felt){`, - ` if (rhs.high == 0){`, - ` let (res) = warp_shr${width}(lhs, rhs.low);`, - ` return (res,);`, - ` }else{`, - ` let (res) = warp_shr${width}(lhs, ${width});`, - ` return (res,);`, - ` }`, - `}`, - ].join('\n'); - } - }), - }; -} - -export function functionaliseShr(node: BinaryOperation, ast: AST): void { - const lhsType = safeGetNodeType(node.vLeftExpression, ast.inference); - const rhsType = safeGetNodeType(node.vRightExpression, ast.inference); - const retType = safeGetNodeType(node, ast.inference); - - assert( - lhsType instanceof IntType || lhsType instanceof FixedBytesType, - `lhs of >> ${printNode(node)} non-int type ${printTypeNode(lhsType)}`, - ); - assert( - rhsType instanceof IntType, - `rhs of >> ${printNode(node)} non-int type ${printTypeNode(rhsType)}`, - ); - - const lhsWidth = getIntOrFixedByteBitWidth(lhsType); - const signed = lhsType instanceof IntType && lhsType.signed; - - const fullName = `warp_shr${signed ? '_signed' : ''}${lhsWidth}${ - rhsType.nBits === 256 ? '_256' : '' - }`; - - const importName = [...WARPLIB_MATHS, `shr${signed ? '_signed' : ''}`]; - - const importedFunc = ast.registerImport( - node, - importName, - fullName, - [ - ['lhs', typeNameFromTypeNode(lhsType, ast)], - ['rhs', typeNameFromTypeNode(rhsType, ast)], - ], - [['res', typeNameFromTypeNode(retType, ast)]], - ); - const call = new FunctionCall( - ast.reserveId(), - node.src, - node.typeString, - FunctionCallKind.FunctionCall, - new Identifier( - ast.reserveId(), - '', - `function (${node.vLeftExpression.typeString}, ${node.vRightExpression.typeString}) returns (${node.typeString})`, - fullName, - importedFunc.id, - ), - [node.vLeftExpression, node.vRightExpression], - ); - - ast.replaceNode(node, call); -} diff --git a/src/warplib/implementations/maths/sub.ts b/src/warplib/implementations/maths/sub.ts deleted file mode 100644 index b6174c393..000000000 --- a/src/warplib/implementations/maths/sub.ts +++ /dev/null @@ -1,157 +0,0 @@ -import assert from 'assert'; -import { BinaryOperation, IntType } from 'solc-typed-ast'; -import { AST } from '../../../ast/ast'; -import { printTypeNode } from '../../../utils/astPrinter'; -import { safeGetNodeType } from '../../../utils/nodeTypeProcessing'; -import { - forAllWidths, - bound, - mask, - msb, - msbAndNext, - IntxIntFunction, - WarplibFunctionInfo, -} from '../../utils'; - -export function sub_unsafe(): WarplibFunctionInfo { - return { - fileName: 'sub_unsafe', - imports: ['use integer::u256_overflow_sub;'], - functions: forAllWidths((width) => { - if (width === 256) { - return [ - `fn warp_sub_unsafe256(lhs : u256, rhs : u256) -> u256 {`, - ` let (value, _) = u256_overflow_sub(lhs, rhs);`, - ` return value;`, - `}`, - ].join('\n'); - } else { - return [ - // TODO: Use bitwise '&' to take just the width-bits - `fn warp_sub_unsafe${width}(lhs : felt252, rhs : felt252) -> felt252 {`, - ` return lhs - rhs;`, - `}`, - ].join('\n'); - } - }), - }; -} - -export function sub_signed(): WarplibFunctionInfo { - return { - fileName: 'sub_signed', - imports: [ - 'from starkware.cairo.common.bitwise import bitwise_and', - 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', - 'from starkware.cairo.common.uint256 import Uint256, uint256_add, uint256_signed_le, uint256_sub, uint256_not', - ], - functions: forAllWidths((width) => { - if (width === 256) { - return [ - `func warp_sub_signed${width}{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(lhs : Uint256, rhs : Uint256) -> (`, - ` res : Uint256){`, - ` // First sign extend both operands`, - ` let (left_msb : felt252) = bitwise_and(lhs.high, ${msb(128)});`, - ` let (right_msb : felt252) = bitwise_and(rhs.high, ${msb(128)});`, - ` let left_overflow : felt252 = left_msb / ${msb(128)};`, - ` let right_overflow : felt252 = right_msb / ${msb(128)};`, - ``, - ` // Now safely negate the rhs and add (l - r = l + (-r))`, - ` let (right_flipped : Uint256) = uint256_not(rhs);`, - ` let (right_neg, overflow) = uint256_add(right_flipped, Uint256(1,0));`, - ` let right_overflow_neg = overflow + 1 - right_overflow;`, - ` let (res, res_base_overflow) = uint256_add(lhs, right_neg);`, - ` let res_overflow = res_base_overflow + left_overflow + right_overflow_neg;`, - ``, - ` // Check if the result fits in the correct width`, - ` let (res_msb : felt252) = bitwise_and(res.high, ${msb(128)});`, - ` let (res_overflow_lsb : felt252) = bitwise_and(res_overflow, 1);`, - ` assert res_overflow_lsb * ${msb(128)} = res_msb;`, - ``, - ` // Narrow and return`, - ` return (res,);`, - `}`, - ].join('\n'); - } else { - return [ - `func warp_sub_signed${width}{bitwise_ptr : BitwiseBuiltin*}(lhs : felt252, rhs : felt252) -> (`, - ` res : felt252){`, - ` // First sign extend both operands`, - ` let (left_msb : felt252) = bitwise_and(lhs, ${msb(width)});`, - ` let (right_msb : felt252) = bitwise_and(rhs, ${msb(width)});`, - ` let left_safe : felt252 = lhs + 2 * left_msb;`, - ` let right_safe : felt252 = rhs + 2 * right_msb;`, - ``, - ` // Now safely negate the rhs and add (l - r = l + (-r))`, - ` let right_neg : felt252 = ${bound(width + 1)} - right_safe;`, - ` let extended_res : felt252 = left_safe + right_neg;`, - ``, - ` // Check if the result fits in the correct width`, - ` let (overflowBits) = bitwise_and(extended_res, ${msbAndNext(width)});`, - ` assert overflowBits * (overflowBits - ${msbAndNext(width)}) = 0;`, - ``, - ` // Narrow and return`, - ` let (res) = bitwise_and(extended_res, ${mask(width)});`, - ` return (res,);`, - `}`, - ].join('\n'); - } - }), - }; -} - -export function sub_signed_unsafe(): WarplibFunctionInfo { - return { - fileName: 'sub_signed_unsafe', - imports: [ - 'from starkware.cairo.common.bitwise import bitwise_and', - 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', - 'from starkware.cairo.common.uint256 import Uint256, uint256_sub', - ], - functions: forAllWidths((width) => { - if (width === 256) { - return [ - 'func warp_sub_signed_unsafe256{range_check_ptr}(lhs : Uint256, rhs : Uint256) -> (res : Uint256){', - ' let (res) = uint256_sub(lhs, rhs);', - ' return (res,);', - '}', - ].join('\n'); - } else { - return [ - `func warp_sub_signed_unsafe${width}{bitwise_ptr : BitwiseBuiltin*}(`, - ` lhs : felt252, rhs : felt252) -> (res : felt252){`, - ` // First sign extend both operands`, - ` let (left_msb : felt252) = bitwise_and(lhs, ${msb(width)});`, - ` let (right_msb : felt252) = bitwise_and(rhs, ${msb(width)});`, - ` let left_safe : felt252 = lhs + 2 * left_msb;`, - ` let right_safe : felt252 = rhs + 2 * right_msb;`, - ``, - ` // Now safely negate the rhs and add (l - r = l + (-r))`, - ` let right_neg : felt252 = ${bound(width + 1)} - right_safe;`, - ` let extended_res : felt252 = left_safe + right_neg;`, - ``, - ` // Narrow and return`, - ` let (res) = bitwise_and(extended_res, ${mask(width)});`, - ` return (res,);`, - `}`, - ].join('\n'); - } - }), - }; -} - -//func warp_sub{range_check_ptr}(lhs : felt252, rhs : felt252) -> (res : felt252): -//func warp_sub256{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(lhs : Uint256, rhs : Uint256) -> (res : Uint256): - -export function functionaliseSub(node: BinaryOperation, unsafe: boolean, ast: AST): void { - const typeNode = safeGetNodeType(node, ast.inference); - assert( - typeNode instanceof IntType, - `Expected IntType for subtraction, got ${printTypeNode(typeNode)}`, - ); - if (unsafe) { - IntxIntFunction(node, 'sub', 'always', true, unsafe, ast); - } else { - IntxIntFunction(node, 'sub', 'signedOrWide', true, unsafe, ast); - } -} diff --git a/src/warplib/implementations/maths/xor.ts b/src/warplib/implementations/maths/xor.ts deleted file mode 100644 index 5097ad341..000000000 --- a/src/warplib/implementations/maths/xor.ts +++ /dev/null @@ -1,7 +0,0 @@ -import { BinaryOperation } from 'solc-typed-ast'; -import { AST } from '../../../ast/ast'; -import { IntxIntFunction } from '../../utils'; - -export function functionaliseXor(node: BinaryOperation, ast: AST): void { - IntxIntFunction(node, 'xor', 'only256', false, false, ast); -} From c9cbfe4d6fea86cfc1fc25e43cfd078bd749a7b7 Mon Sep 17 00:00:00 2001 From: esdras-santos Date: Sun, 9 Apr 2023 13:37:05 -0300 Subject: [PATCH 5/6] orderNestedStruct.ts removed --- src/cairoUtilFuncGen/export.ts | 1 + src/cairoUtilFuncGen/index.ts | 9 + src/cairoUtilFuncGen/inputArgCheck/export.ts | 1 + .../inputArgCheck/inputCheck.ts | 262 +++++++++++++++++ src/freeStructWritter.ts | 3 +- src/passes/argBoundChecker.ts | 44 ++- .../mathsOperationToFunction.ts | 47 +++ src/passes/export.ts | 1 - src/passes/index.ts | 1 - src/passes/orderNestedStructs.ts | 126 -------- .../references/externalReturnReceiver.ts | 10 + src/transpiler.ts | 2 - src/warplib/generateWarplib.ts | 39 +++ src/warplib/implementations/maths/add.ts | 137 +++++++++ src/warplib/implementations/maths/and.ts | 7 + .../implementations/maths/bitwiseAnd.ts | 7 + .../implementations/maths/bitwiseOr.ts | 7 + src/warplib/implementations/maths/div.ts | 125 ++++++++ src/warplib/implementations/maths/eq.ts | 7 + src/warplib/implementations/maths/exp.ts | 220 ++++++++++++++ src/warplib/implementations/maths/ge.ts | 40 +++ src/warplib/implementations/maths/gt.ts | 40 +++ src/warplib/implementations/maths/le.ts | 61 ++++ src/warplib/implementations/maths/lt.ts | 43 +++ src/warplib/implementations/maths/mod.ts | 55 ++++ src/warplib/implementations/maths/mul.ts | 247 ++++++++++++++++ src/warplib/implementations/maths/neq.ts | 7 + src/warplib/implementations/maths/or.ts | 7 + src/warplib/implementations/maths/shl.ts | 118 ++++++++ src/warplib/implementations/maths/shr.ts | 269 ++++++++++++++++++ src/warplib/implementations/maths/sub.ts | 157 ++++++++++ src/warplib/implementations/maths/xor.ts | 7 + 32 files changed, 1974 insertions(+), 133 deletions(-) create mode 100644 src/cairoUtilFuncGen/inputArgCheck/export.ts create mode 100644 src/cairoUtilFuncGen/inputArgCheck/inputCheck.ts delete mode 100644 src/passes/orderNestedStructs.ts create mode 100644 src/warplib/implementations/maths/add.ts create mode 100644 src/warplib/implementations/maths/and.ts create mode 100644 src/warplib/implementations/maths/bitwiseAnd.ts create mode 100644 src/warplib/implementations/maths/bitwiseOr.ts create mode 100644 src/warplib/implementations/maths/div.ts create mode 100644 src/warplib/implementations/maths/eq.ts create mode 100644 src/warplib/implementations/maths/exp.ts create mode 100644 src/warplib/implementations/maths/ge.ts create mode 100644 src/warplib/implementations/maths/gt.ts create mode 100644 src/warplib/implementations/maths/le.ts create mode 100644 src/warplib/implementations/maths/lt.ts create mode 100644 src/warplib/implementations/maths/mod.ts create mode 100644 src/warplib/implementations/maths/mul.ts create mode 100644 src/warplib/implementations/maths/neq.ts create mode 100644 src/warplib/implementations/maths/or.ts create mode 100644 src/warplib/implementations/maths/shl.ts create mode 100644 src/warplib/implementations/maths/shr.ts create mode 100644 src/warplib/implementations/maths/sub.ts create mode 100644 src/warplib/implementations/maths/xor.ts diff --git a/src/cairoUtilFuncGen/export.ts b/src/cairoUtilFuncGen/export.ts index 773c84489..2f0205905 100644 --- a/src/cairoUtilFuncGen/export.ts +++ b/src/cairoUtilFuncGen/export.ts @@ -1,6 +1,7 @@ export * from './serialisation'; export * from './base'; export * from './memory/export'; +export * from './inputArgCheck/export'; export * from './calldata/export'; export * from './utils/export'; export * from './storage/export'; diff --git a/src/cairoUtilFuncGen/index.ts b/src/cairoUtilFuncGen/index.ts index 74bc2eea3..f6848e964 100644 --- a/src/cairoUtilFuncGen/index.ts +++ b/src/cairoUtilFuncGen/index.ts @@ -1,4 +1,5 @@ import { AST } from '../ast/ast'; +import { InputCheckGen } from './inputArgCheck/inputCheck'; import { MemoryArrayLiteralGen } from './memory/arrayLiteral'; import { MemoryDynArrayLengthGen } from './memory/memoryDynArrayLength'; import { MemoryMemberAccessGen } from './memory/memoryMemberAccess'; @@ -84,6 +85,10 @@ export class CairoUtilFuncGen { toStorage: StorageToStorageGen; write: StorageWriteGen; }; + boundChecks: { + inputCheck: InputCheckGen; + enums: EnumInputCheck; + }; events: { index: IndexEncode; event: EventFunction; @@ -172,6 +177,10 @@ export class CairoUtilFuncGen { toStorage: storageToStorage, write: storageWrite, }; + this.boundChecks = { + inputCheck: new InputCheckGen(ast, sourceUnit), + enums: new EnumInputCheck(ast, sourceUnit), + }; this.calldata = { dynArrayStructConstructor: externalDynArrayStructConstructor, toMemory: new CallDataToMemoryGen(ast, sourceUnit), diff --git a/src/cairoUtilFuncGen/inputArgCheck/export.ts b/src/cairoUtilFuncGen/inputArgCheck/export.ts new file mode 100644 index 000000000..5bebac706 --- /dev/null +++ b/src/cairoUtilFuncGen/inputArgCheck/export.ts @@ -0,0 +1 @@ +export * from './inputCheck'; diff --git a/src/cairoUtilFuncGen/inputArgCheck/inputCheck.ts b/src/cairoUtilFuncGen/inputArgCheck/inputCheck.ts new file mode 100644 index 000000000..05e6747f5 --- /dev/null +++ b/src/cairoUtilFuncGen/inputArgCheck/inputCheck.ts @@ -0,0 +1,262 @@ +import assert from 'assert'; +import { + ArrayType, + BoolType, + BytesType, + DataLocation, + EnumDefinition, + Expression, + FixedBytesType, + FunctionCall, + FunctionStateMutability, + generalizeType, + IntType, + StringType, + StructDefinition, + TypeNode, + UserDefinedType, + VariableDeclaration, +} from 'solc-typed-ast'; +import { CairoFunctionDefinition, FunctionStubKind } from '../../ast/cairoNodes'; +import { printTypeNode } from '../../utils/astPrinter'; +import { CairoDynArray, CairoType, TypeConversionContext } from '../../utils/cairoTypeSystem'; +import { NotSupportedYetError } from '../../utils/errors'; +import { createCairoGeneratedFunction, createCallToFunction } from '../../utils/functionGeneration'; +import { createIdentifier } from '../../utils/nodeTemplates'; +import { mapRange, narrowBigIntSafe, typeNameFromTypeNode } from '../../utils/utils'; +import { + delegateBasedOnType, + GeneratedFunctionInfo, + locationIfComplexType, + StringIndexedFuncGen, +} from '../base'; +import { + checkableType, + getElementType, + isAddressType, + isDynamicArray, + safeGetNodeType, +} from '../../utils/nodeTypeProcessing'; +import { cloneASTNode } from '../../utils/cloning'; +import { IS_LE_FELT, NARROW_SAFE, WARPLIB_MATHS } from '../../utils/importPaths'; + +const IMPLICITS = '{range_check_ptr : felt}'; + +export class InputCheckGen extends StringIndexedFuncGen { + public gen(nodeInput: VariableDeclaration | Expression, typeToCheck: TypeNode): FunctionCall { + let functionInput; + let isUint256 = false; + if (nodeInput instanceof VariableDeclaration) { + functionInput = createIdentifier(nodeInput, this.ast); + } else { + functionInput = cloneASTNode(nodeInput, this.ast); + const inputType = safeGetNodeType(nodeInput, this.ast.inference); + this.ast.setContextRecursive(functionInput); + isUint256 = inputType instanceof IntType && inputType.nBits === 256; + } + + const funcDef = this.getOrCreateFuncDef(typeToCheck, isUint256); + return createCallToFunction(funcDef, [functionInput], this.ast); + } + + private getOrCreateFuncDef(type: TypeNode, takesUint256: boolean): CairoFunctionDefinition { + const key = type.pp(); + const value = this.generatedFunctionsDef.get(key); + if (value !== undefined) { + return value; + } + + if (type instanceof FixedBytesType) + return this.requireImport( + [...WARPLIB_MATHS, 'external_input_check_ints'], + `warp_external_input_check_int${type.size * 8}`, + ); + if (type instanceof IntType) + return this.requireImport( + [...WARPLIB_MATHS, 'external_input_check_ints'], + `warp_external_input_check_int${type.nBits}`, + ); + if (isAddressType(type)) + return this.requireImport( + [...WARPLIB_MATHS, 'external_input_check_address'], + `warp_external_input_check_address`, + ); + if (type instanceof BoolType) + return this.requireImport( + [...WARPLIB_MATHS, 'external_input_check_bool'], + `warp_external_input_check_bool`, + ); + + const funcInfo = this.getOrCreate(type, takesUint256); + const funcDef = createCairoGeneratedFunction( + funcInfo, + [ + [ + 'ref_var', + typeNameFromTypeNode(type, this.ast), + locationIfComplexType(type, DataLocation.CallData), + ], + ], + [], + this.ast, + this.sourceUnit, + { + mutability: FunctionStateMutability.Pure, + stubKind: FunctionStubKind.FunctionDefStub, + acceptsRawDArray: isDynamicArray(type), + }, + ); + this.generatedFunctionsDef.set(key, funcDef); + return funcDef; + } + + private getOrCreate(type: TypeNode, takesUint: boolean): GeneratedFunctionInfo { + const unexpectedTypeFunc = () => { + throw new NotSupportedYetError(`Input check for ${printTypeNode(type)} not defined yet.`); + }; + + return delegateBasedOnType( + type, + (type) => this.createDynArrayInputCheck(type), + (type) => this.createStaticArrayInputCheck(type), + (type, def) => this.createStructInputCheck(type, def), + unexpectedTypeFunc, + (type) => { + if (type instanceof UserDefinedType && type.definition instanceof EnumDefinition) + return this.createEnumInputCheck(type, takesUint); + return unexpectedTypeFunc(); + }, + ); + } + + private createStructInputCheck( + type: UserDefinedType, + structDef: StructDefinition, + ): GeneratedFunctionInfo { + const cairoType = CairoType.fromSol(type, this.ast, TypeConversionContext.CallDataRef); + + const [inputCheckCode, funcCalls] = structDef.vMembers.reduce( + ([inputCheckCode, funcCalls], decl) => { + const memberType = safeGetNodeType(decl, this.ast.inference); + if (checkableType(memberType)) { + const memberCheckFunc = this.getOrCreateFuncDef(memberType, false); + return [ + [...inputCheckCode, `${memberCheckFunc.name}(arg.${decl.name});`], + [...funcCalls, memberCheckFunc], + ]; + } + return [inputCheckCode, funcCalls]; + }, + [new Array(), new Array()], + ); + + const funcName = `external_input_check_struct_${structDef.name}`; + const funcInfo: GeneratedFunctionInfo = { + name: funcName, + code: [ + `func ${funcName}${IMPLICITS}(arg : ${cairoType.toString()}) -> (){`, + `alloc_locals;`, + ...inputCheckCode, + `return ();`, + `}`, + ].join('\n'), + functionsCalled: funcCalls, + }; + return funcInfo; + } + + // Todo: This function can probably be made recursive for big size static arrays + private createStaticArrayInputCheck(type: ArrayType): GeneratedFunctionInfo { + assert(type.size !== undefined); + const length = narrowBigIntSafe(type.size); + + const cairoType = CairoType.fromSol(type, this.ast, TypeConversionContext.CallDataRef); + const elementType = generalizeType(type.elementT)[0]; + + const auxFunc = this.getOrCreateFuncDef(elementType, false); + + const funcName = `external_input_check_static_array${this.generatedFunctionsDef.size}`; + const funcInfo: GeneratedFunctionInfo = { + name: funcName, + code: [ + `func ${funcName}${IMPLICITS}(arg : ${cairoType.toString()}) -> (){`, + `alloc_locals;`, + ...mapRange(length, (index) => { + return [`${auxFunc.name}(arg[${index}]);`]; + }), + `return ();`, + `}`, + ].join('\n'), + functionsCalled: [auxFunc], + }; + return funcInfo; + } + + // TODO: this function and EnumInputCheck single file do the same??? + // TODO: When does takesUint == true? + private createEnumInputCheck(type: UserDefinedType, takesUint = false): GeneratedFunctionInfo { + const enumDef = type.definition; + assert(enumDef instanceof EnumDefinition); + + // TODO: enum names are unique right? + const funcName = `external_input_check_enum_${enumDef.name}`; + + const importFuncs = [this.requireImport(...IS_LE_FELT)]; + if (takesUint) { + importFuncs.push(this.requireImport(...NARROW_SAFE)); + } + + const nMembers = enumDef.vMembers.length; + const funcInfo: GeneratedFunctionInfo = { + name: funcName, + code: [ + `func ${funcName}${IMPLICITS}(arg : ${takesUint ? 'Uint256' : 'felt'}) -> (){`, + takesUint + ? [ + ' let (arg_0) = narrow_safe(arg);', + ` let inRange: felt = is_le_felt(arg_0, ${nMembers - 1});`, + ].join('\n') + : ` let inRange : felt = is_le_felt(arg, ${nMembers - 1});`, + ` with_attr error_message("Error: value out-of-bounds. Values passed to must be in enum range (0, ${ + nMembers - 1 + }]."){`, + ` assert 1 = inRange;`, + ` }`, + ` return ();`, + `}`, + ].join('\n'), + functionsCalled: importFuncs, + }; + return funcInfo; + } + + private createDynArrayInputCheck( + type: ArrayType | BytesType | StringType, + ): GeneratedFunctionInfo { + const cairoType = CairoType.fromSol(type, this.ast, TypeConversionContext.CallDataRef); + assert(cairoType instanceof CairoDynArray); + + const ptrType = cairoType.vPtr; + const elementType = generalizeType(getElementType(type))[0]; + + const calledFunction = this.getOrCreateFuncDef(elementType, false); + + const funcName = `external_input_check_dynamic_array${this.generatedFunctionsDef.size}`; + const funcInfo: GeneratedFunctionInfo = { + name: funcName, + code: [ + `func ${funcName}${IMPLICITS}(len: felt, ptr : ${ptrType.toString()}) -> (){`, + ` alloc_locals;`, + ` if (len == 0){`, + ` return ();`, + ` }`, + ` ${calledFunction.name}(ptr[0]);`, + ` ${funcName}(len = len - 1, ptr = ptr + ${ptrType.to.width});`, + ` return ();`, + `}`, + ].join('\n'), + functionsCalled: [calledFunction], + }; + return funcInfo; + } +} diff --git a/src/freeStructWritter.ts b/src/freeStructWritter.ts index 8569f3bda..2e394db74 100644 --- a/src/freeStructWritter.ts +++ b/src/freeStructWritter.ts @@ -6,7 +6,6 @@ import { UserDefinedTypeName, } from 'solc-typed-ast'; import { AST } from './ast/ast'; -import { makeStructTree, reorderStructs } from './passes/orderNestedStructs'; /* Library calls in solidity are delegate calls @@ -20,7 +19,7 @@ import { makeStructTree, reorderStructs } from './passes/orderNestedStructs'; export function getStructs(node: SourceUnit, ast: AST): StructDefinition[] { const externalStructs = getDefinitionsToInline(node, node, new Set()); - return reorderStructs(...makeStructTree(externalStructs, ast)); + return Array.from(externalStructs.values()); } // DFS a node for definitions in a free context. diff --git a/src/passes/argBoundChecker.ts b/src/passes/argBoundChecker.ts index f6fd8cba6..a9836fe76 100644 --- a/src/passes/argBoundChecker.ts +++ b/src/passes/argBoundChecker.ts @@ -1,6 +1,18 @@ import { AST } from '../ast/ast'; -import { ContractDefinition, ContractKind, FunctionDefinition, FunctionCall } from 'solc-typed-ast'; +import { + ContractDefinition, + ContractKind, + FunctionDefinition, + FunctionCall, + FunctionCallKind, + EnumDefinition, + IntType, +} from 'solc-typed-ast'; import { ASTMapper } from '../ast/mapper'; +import { isExternallyVisible } from '../utils/utils'; +import assert from 'assert'; +import { createExpressionStatement } from '../utils/nodeTemplates'; +import { checkableType, safeGetNodeType } from '../utils/nodeTypeProcessing'; export class ArgBoundChecker extends ASTMapper { // Function to add passes that should have been run before this pass addInitialPassPrerequisites(): void { @@ -16,10 +28,40 @@ export class ArgBoundChecker extends ASTMapper { } visitFunctionDefinition(node: FunctionDefinition, ast: AST): void { + if (isExternallyVisible(node) && node.vBody !== undefined) { + node.vParameters.vParameters.forEach((decl) => { + const type = safeGetNodeType(decl, ast.inference); + if (checkableType(type)) { + const functionCall = ast.getUtilFuncGen(node).boundChecks.inputCheck.gen(decl, type); + this.insertFunctionCall(node, functionCall, ast); + } + }); + } + this.commonVisit(node, ast); } + private insertFunctionCall(node: FunctionDefinition, funcCall: FunctionCall, ast: AST): void { + const body = node.vBody; + assert(body !== undefined && funcCall.vArguments !== undefined); + const expressionStatement = createExpressionStatement(ast, funcCall); + body.insertAtBeginning(expressionStatement); + ast.setContextRecursive(expressionStatement); + } + visitFunctionCall(node: FunctionCall, ast: AST): void { + if ( + node.kind === FunctionCallKind.TypeConversion && + node.vReferencedDeclaration instanceof EnumDefinition && + safeGetNodeType(node.vArguments[0], ast.inference) instanceof IntType + ) { + const enumDef = node.vReferencedDeclaration; + const enumCheckFuncCall = ast + .getUtilFuncGen(node) + .boundChecks.enums.gen(node, node.vArguments[0], enumDef, node); + const parent = node.parent; + ast.replaceNode(node, enumCheckFuncCall, parent); + } this.commonVisit(node, ast); } } diff --git a/src/passes/builtinHandler/mathsOperationToFunction.ts b/src/passes/builtinHandler/mathsOperationToFunction.ts index b921ea312..457c243e5 100644 --- a/src/passes/builtinHandler/mathsOperationToFunction.ts +++ b/src/passes/builtinHandler/mathsOperationToFunction.ts @@ -11,8 +11,27 @@ import { NotSupportedYetError } from '../../utils/errors'; import { createCallToFunction } from '../../utils/functionGeneration'; import { WARPLIB_MATHS } from '../../utils/importPaths'; import { createNumberLiteral, createUint256TypeName } from '../../utils/nodeTemplates'; +import { functionaliseAdd } from '../../warplib/implementations/maths/add'; +import { functionaliseAnd } from '../../warplib/implementations/maths/and'; +import { functionaliseBitwiseAnd } from '../../warplib/implementations/maths/bitwiseAnd'; import { functionaliseBitwiseNot } from '../../warplib/implementations/maths/bitwiseNot'; +import { functionaliseBitwiseOr } from '../../warplib/implementations/maths/bitwiseOr'; +import { functionaliseDiv } from '../../warplib/implementations/maths/div'; +import { functionaliseEq } from '../../warplib/implementations/maths/eq'; +import { functionaliseExp } from '../../warplib/implementations/maths/exp'; +import { functionaliseGe } from '../../warplib/implementations/maths/ge'; +import { functionaliseGt } from '../../warplib/implementations/maths/gt'; +import { functionaliseLe } from '../../warplib/implementations/maths/le'; +import { functionaliseLt } from '../../warplib/implementations/maths/lt'; +import { functionaliseMod } from '../../warplib/implementations/maths/mod'; +import { functionaliseMul } from '../../warplib/implementations/maths/mul'; import { functionaliseNegate } from '../../warplib/implementations/maths/negate'; +import { functionaliseNeq } from '../../warplib/implementations/maths/neq'; +import { functionaliseOr } from '../../warplib/implementations/maths/or'; +import { functionaliseShl } from '../../warplib/implementations/maths/shl'; +import { functionaliseShr } from '../../warplib/implementations/maths/shr'; +import { functionaliseSub } from '../../warplib/implementations/maths/sub'; +import { functionaliseXor } from '../../warplib/implementations/maths/xor'; /* Note we also include mulmod and add mod here */ export class MathsOperationToFunction extends ASTMapper { @@ -26,6 +45,34 @@ export class MathsOperationToFunction extends ASTMapper { visitBinaryOperation(node: BinaryOperation, ast: AST): void { this.commonVisit(node, ast); + const operatorMap: Map void> = new Map([ + ['+', () => functionaliseAdd(node, this.inUncheckedBlock, ast)], + ['-', () => functionaliseSub(node, this.inUncheckedBlock, ast)], + ['*', () => functionaliseMul(node, this.inUncheckedBlock, ast)], + ['/', () => functionaliseDiv(node, this.inUncheckedBlock, ast)], + ['%', () => functionaliseMod(node, ast)], + ['**', () => functionaliseExp(node, this.inUncheckedBlock, ast)], + ['==', () => functionaliseEq(node, ast)], + ['!=', () => functionaliseNeq(node, ast)], + ['>=', () => functionaliseGe(node, ast)], + ['>', () => functionaliseGt(node, ast)], + ['<=', () => functionaliseLe(node, ast)], + ['<', () => functionaliseLt(node, ast)], + ['&', () => functionaliseBitwiseAnd(node, ast)], + ['|', () => functionaliseBitwiseOr(node, ast)], + ['^', () => functionaliseXor(node, ast)], + ['<<', () => functionaliseShl(node, ast)], + ['>>', () => functionaliseShr(node, ast)], + ['&&', () => functionaliseAnd(node, ast)], + ['||', () => functionaliseOr(node, ast)], + ]); + + const thunk = operatorMap.get(node.operator); + if (thunk === undefined) { + throw new NotSupportedYetError(`${node.operator} not supported yet`); + } + + thunk(); } visitUnaryOperation(node: UnaryOperation, ast: AST): void { diff --git a/src/passes/export.ts b/src/passes/export.ts index f62848442..6bef5ded1 100644 --- a/src/passes/export.ts +++ b/src/passes/export.ts @@ -33,7 +33,6 @@ export * from './literalExpressionEvaluator/export'; export * from './loopFunctionaliser/export'; export * from './namedArgsRemover'; export * from './newToDeploy'; -export * from './orderNestedStructs'; export * from './publicFunctionSplitter/export'; export * from './referencedLibraries'; export * from './references/export'; diff --git a/src/passes/index.ts b/src/passes/index.ts index a78768dc7..27b588ee0 100644 --- a/src/passes/index.ts +++ b/src/passes/index.ts @@ -31,7 +31,6 @@ export * from './literalExpressionEvaluator/literalExpressionEvaluator'; export * from './loopFunctionaliser'; export * from './namedArgsRemover'; export * from './newToDeploy'; -export * from './orderNestedStructs'; export * from './publicFunctionSplitter'; export * from './referencedLibraries'; export * from './references'; diff --git a/src/passes/orderNestedStructs.ts b/src/passes/orderNestedStructs.ts deleted file mode 100644 index 279fc6fbd..000000000 --- a/src/passes/orderNestedStructs.ts +++ /dev/null @@ -1,126 +0,0 @@ -import { - ArrayType, - ContractDefinition, - SourceUnit, - StructDefinition, - TypeNode, - UserDefinedType, -} from 'solc-typed-ast'; -import { AST } from '../ast/ast'; -import { ASTMapper } from '../ast/mapper'; -import { safeGetNodeType } from '../utils/nodeTypeProcessing'; - -export class OrderNestedStructs extends ASTMapper { - // Cairo does not permit to use struct definitions which are yet to be defined. - // For example: - // contract Warp { - // struct Top { - // Nested n; - // } - // struct Nested { ... } - // } - // When transpiled to Cairo, struct definitions must be reordered so that - // nested structs are defined first: - // struct Nested { ... } - // struct Top { - // member n : Nested; - // } - - // Function to add passes that should have been run before this pass - addInitialPassPrerequisites(): void { - const passKeys: Set = new Set([]); - passKeys.forEach((key) => this.addPassPrerequisite(key)); - } - - visitSourceUnit(node: SourceUnit, ast: AST): void { - this.reorderNestedStructs(node, ast); - this.commonVisit(node, ast); - } - - visitContractDefinition(node: ContractDefinition, ast: AST): void { - this.reorderNestedStructs(node, ast); - } - - private reorderNestedStructs(node: SourceUnit | ContractDefinition, ast: AST) { - const structs = node.vStructs; - const [roots, tree] = makeStructTree(new Set(structs), ast); - - // there are no nested structs - if (roots.size === structs.length) return; - - const newStructOrder = reorderStructs(roots, tree); - - // remove old struct definition - structs.forEach((child) => { - if (child instanceof StructDefinition) { - node.removeChild(child); - } - }); - - // insert back in new order - newStructOrder.reverse().forEach((struct) => node.insertAtBeginning(struct)); - } -} - -export function reorderStructs( - roots: Set, - tree: Map, -) { - const newOrder: StructDefinition[] = []; - const visited = new Set(); - - roots.forEach((root) => visitTree(root, tree, visited, newOrder)); - - return newOrder; -} - -// dfs through the tree -// root is always added to orderedStructs after all it's children -function visitTree( - root: StructDefinition, - tree: Map, - visited: Set, - orderedStructs: StructDefinition[], -) { - if (visited.has(root)) { - return; - } - visited.add(root); - - tree.get(root)?.forEach((nested) => visitTree(nested, tree, visited, orderedStructs)); - - orderedStructs.push(root); -} - -export function makeStructTree( - structs: Set, - ast: AST, -): [Set, Map] { - const roots = new Set(structs); - const tree = new Map(); - - structs.forEach((struct) => { - struct.vMembers.forEach((varDecl) => { - const nestedStruct = findStruct(safeGetNodeType(varDecl, ast.inference)); - // second check to avoid adding imported structs to contract definition - if (nestedStruct !== null && structs.has(nestedStruct)) { - roots.delete(nestedStruct); - tree.has(struct) ? tree.get(struct)?.push(nestedStruct) : tree.set(struct, [nestedStruct]); - } - }); - }); - - // roots are struct definition from which none other struct definition - // depends on - return [roots, tree]; -} - -function findStruct(varType: TypeNode): StructDefinition | null { - if (varType instanceof UserDefinedType && varType.definition instanceof StructDefinition) - return varType.definition; - - if (varType instanceof ArrayType && varType.size !== undefined) - return findStruct(varType.elementT); - - return null; -} diff --git a/src/passes/references/externalReturnReceiver.ts b/src/passes/references/externalReturnReceiver.ts index 25025d14b..a313466d1 100644 --- a/src/passes/references/externalReturnReceiver.ts +++ b/src/passes/references/externalReturnReceiver.ts @@ -52,9 +52,19 @@ export class ExternalReturnReceiver extends ASTMapper { ast.insertStatementAfter(node, statement); node.assignments = node.assignments.map((value) => (value === decl.id ? newId : value)); }); + + node.vDeclarations.forEach((decl) => addOutputValidation(decl, ast)); } } +function addOutputValidation(decl: VariableDeclaration, ast: AST) { + const type = safeGetNodeType(decl, ast.inference); + if (!checkableType(type)) return; + const validationFunctionCall = ast.getUtilFuncGen(decl).boundChecks.inputCheck.gen(decl, type); + const validationStatement = createExpressionStatement(ast, validationFunctionCall); + ast.insertStatementAfter(decl, validationStatement); +} + function generateCopyStatement( decl: VariableDeclaration, ast: AST, diff --git a/src/transpiler.ts b/src/transpiler.ts index 4dee78a44..39e9e161e 100644 --- a/src/transpiler.ts +++ b/src/transpiler.ts @@ -35,7 +35,6 @@ import { ModifierHandler, NamedArgsRemover, NewToDeploy, - OrderNestedStructs, PublicFunctionSplitter, PublicStateVarsGetterGenerator, ReferencedLibraries, @@ -120,7 +119,6 @@ function applyPasses( ['Req', Require], ['Ffi', FreeFunctionInliner], ['Rl', ReferencedLibraries], - ['Ons', OrderNestedStructs], ['Sa', StorageAllocator], ['Ii', InheritanceInliner], ['Ech', ExternalContractHandler], diff --git a/src/warplib/generateWarplib.ts b/src/warplib/generateWarplib.ts index 8eaabce43..5e72a3e4a 100644 --- a/src/warplib/generateWarplib.ts +++ b/src/warplib/generateWarplib.ts @@ -1,5 +1,18 @@ import { generateFile, WarplibFunctionInfo } from './utils'; import { int_conversions } from './implementations/conversions/int'; +import { add, add_unsafe, add_signed, add_signed_unsafe } from './implementations/maths/add'; +import { div_signed, div_signed_unsafe } from './implementations/maths/div'; +import { exp, exp_signed, exp_signed_unsafe, exp_unsafe } from './implementations/maths/exp'; +import { ge_signed } from './implementations/maths/ge'; +import { gt_signed } from './implementations/maths/gt'; +import { le_signed } from './implementations/maths/le'; +import { lt_signed } from './implementations/maths/lt'; +import { mod_signed } from './implementations/maths/mod'; +import { mul, mul_unsafe, mul_signed, mul_signed_unsafe } from './implementations/maths/mul'; +import { negate } from './implementations/maths/negate'; +import { shl } from './implementations/maths/shl'; +import { shr, shr_signed } from './implementations/maths/shr'; +import { sub_unsafe, sub_signed, sub_signed_unsafe } from './implementations/maths/sub'; import { bitwise_not } from './implementations/maths/bitwiseNot'; import { external_input_check_ints } from './implementations/external_input_checks/externalInputChecksInts'; import path from 'path'; @@ -9,14 +22,40 @@ import { glob } from 'glob'; import { parseMultipleRawCairoFunctions } from '../utils/cairoParsing'; const warplibFunctions: WarplibFunctionInfo[] = [ + add(), + add_unsafe(), + add_signed(), + add_signed_unsafe(), // sub - handwritten + sub_unsafe(), + sub_signed(), + sub_signed_unsafe(), + mul(), + mul_unsafe(), + mul_signed(), + mul_signed_unsafe(), // div - handwritten // div_unsafe - handwritten + div_signed(), + div_signed_unsafe(), // mod - handwritten + mod_signed(), + exp(), + exp_signed(), + exp_unsafe(), + exp_signed_unsafe(), + negate(), + shl(), + shr(), + shr_signed(), // ge - handwritten + ge_signed(), // gt - handwritten + gt_signed(), // le - handwritten + le_signed(), // lt - handwritten + lt_signed(), // and - handwritten // xor - handwritten // bitwise_and - handwritten diff --git a/src/warplib/implementations/maths/add.ts b/src/warplib/implementations/maths/add.ts new file mode 100644 index 000000000..f25fdbfe4 --- /dev/null +++ b/src/warplib/implementations/maths/add.ts @@ -0,0 +1,137 @@ +import { BinaryOperation } from 'solc-typed-ast'; +import { AST } from '../../../ast/ast'; +import { + forAllWidths, + IntxIntFunction, + mask, + msb, + msbAndNext, + WarplibFunctionInfo, +} from '../../utils'; + +export function add(): WarplibFunctionInfo { + const fileName = 'add'; + const imports = ['use warplib::maths::le::warp_le;']; + const functions = forAllWidths((width) => { + if (width === 256) { + return [`fn warp_add256(lhs: u256, rhs: u256) -> u256{`, ` return lhs + rhs;`, `}`].join( + '\n', + ); + } else { + return [ + `fn warp_add${width}(lhs: felt252, rhs: felt252) -> felt252{`, + ` let res = lhs + rhs;`, + ` let max: felt252 = ${mask(width)};`, + ` assert(warp_le(res, max), 'Value out of bounds');`, + ` return res;`, + `}`, + ].join('\n'); + } + }); + + return { fileName, imports, functions }; +} + +export function add_unsafe(): WarplibFunctionInfo { + return { + fileName: 'add_unsafe', + imports: ['use integer::u256_overflowing_add;'], + functions: forAllWidths((width) => { + if (width === 256) { + return [ + `fn warp_add_unsafe256(lhs : u256, rhs : u256) -> u256 {`, + ` let (value, _) = u256_overflowing_add(lhs, rhs);`, + ` return value;`, + `}`, + ].join('\n'); + } else { + // TODO: Use bitwise '&' to take just the width-bits + return [ + `fn warp_add_unsafe${width}(lhs : felt252, rhs : felt252) -> felt252 {`, + ` let res = lhs + rhs;`, + ` return res;`, + `}`, + ].join('\n'); + } + }), + }; +} + +export function add_signed(): WarplibFunctionInfo { + return { + fileName: 'add_signed', + imports: [ + 'from starkware.cairo.common.bitwise import bitwise_and', + 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', + 'from starkware.cairo.common.math_cmp import is_le_felt', + 'from starkware.cairo.common.uint256 import Uint256, uint256_add', + ], + functions: forAllWidths((width) => { + if (width === 256) { + return [ + `func warp_add_signed256{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(`, + ` lhs : Uint256, rhs : Uint256) -> (res : Uint256){`, + ` let (lhs_extend) = bitwise_and(lhs.high, ${msb(128)});`, + ` let (rhs_extend) = bitwise_and(rhs.high, ${msb(128)});`, + ` let (res : Uint256, carry : felt252) = uint256_add(lhs, rhs);`, + ` let carry_extend = lhs_extend + rhs_extend + carry*${msb(128)};`, + ` let (msb) = bitwise_and(res.high, ${msb(128)});`, + ` let (carry_lsb) = bitwise_and(carry_extend, ${msb(128)});`, + ` assert msb = carry_lsb;`, + ` return (res,);`, + `}`, + ].join('\n'); + } else { + return [ + `func warp_add_signed${width}{bitwise_ptr : BitwiseBuiltin*}(lhs : felt252, rhs : felt252) -> (`, + ` res : felt252){`, + `// Do the addition sign extended`, + ` let (lmsb) = bitwise_and(lhs, ${msb(width)});`, + ` let (rmsb) = bitwise_and(rhs, ${msb(width)});`, + ` let big_res = lhs + rhs + 2*(lmsb+rmsb);`, + `// Check the result is valid`, + ` let (overflowBits) = bitwise_and(big_res, ${msbAndNext(width)});`, + ` assert overflowBits * (overflowBits - ${msbAndNext(width)}) = 0;`, + `// Truncate and return`, + ` let (res) = bitwise_and(big_res, ${mask(width)});`, + ` return (res,);`, + `}`, + ].join('\n'); + } + }), + }; +} + +export function add_signed_unsafe(): WarplibFunctionInfo { + return { + fileName: 'add_signed_unsafe', + imports: [ + 'from starkware.cairo.common.bitwise import bitwise_and', + 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', + 'from starkware.cairo.common.math_cmp import is_le_felt', + 'from starkware.cairo.common.uint256 import Uint256, uint256_add', + ], + functions: forAllWidths((width) => { + if (width === 256) { + return [ + `func warp_add_signed_unsafe256{range_check_ptr}(lhs : Uint256, rhs : Uint256) -> (res : Uint256){`, + ` let (res : Uint256, _) = uint256_add(lhs, rhs);`, + ` return (res,);`, + `}`, + ].join('\n'); + } else { + return [ + `func warp_add_signed_unsafe${width}{bitwise_ptr : BitwiseBuiltin*}(`, + ` lhs : felt252, rhs : felt252) -> (res : felt252){`, + ` let (res) = bitwise_and(lhs + rhs, ${mask(width)});`, + ` return (res,);`, + `}`, + ].join('\n'); + } + }), + }; +} + +export function functionaliseAdd(node: BinaryOperation, unsafe: boolean, ast: AST): void { + IntxIntFunction(node, 'add', 'always', true, unsafe, ast); +} diff --git a/src/warplib/implementations/maths/and.ts b/src/warplib/implementations/maths/and.ts new file mode 100644 index 000000000..97e21704a --- /dev/null +++ b/src/warplib/implementations/maths/and.ts @@ -0,0 +1,7 @@ +import { BinaryOperation } from 'solc-typed-ast'; +import { AST } from '../../../ast/ast'; +import { BoolxBoolFunction } from '../../utils'; + +export function functionaliseAnd(node: BinaryOperation, ast: AST): void { + BoolxBoolFunction(node, 'and_', ast); +} diff --git a/src/warplib/implementations/maths/bitwiseAnd.ts b/src/warplib/implementations/maths/bitwiseAnd.ts new file mode 100644 index 000000000..10b994053 --- /dev/null +++ b/src/warplib/implementations/maths/bitwiseAnd.ts @@ -0,0 +1,7 @@ +import { BinaryOperation } from 'solc-typed-ast'; +import { AST } from '../../../ast/ast'; +import { IntxIntFunction } from '../../utils'; + +export function functionaliseBitwiseAnd(node: BinaryOperation, ast: AST): void { + IntxIntFunction(node, 'bitwise_and', 'only256', false, false, ast); +} diff --git a/src/warplib/implementations/maths/bitwiseOr.ts b/src/warplib/implementations/maths/bitwiseOr.ts new file mode 100644 index 000000000..a1acdf678 --- /dev/null +++ b/src/warplib/implementations/maths/bitwiseOr.ts @@ -0,0 +1,7 @@ +import { BinaryOperation } from 'solc-typed-ast'; +import { AST } from '../../../ast/ast'; +import { IntxIntFunction } from '../../utils'; + +export function functionaliseBitwiseOr(node: BinaryOperation, ast: AST): void { + IntxIntFunction(node, 'bitwise_or', 'only256', false, false, ast); +} diff --git a/src/warplib/implementations/maths/div.ts b/src/warplib/implementations/maths/div.ts new file mode 100644 index 000000000..ec292617c --- /dev/null +++ b/src/warplib/implementations/maths/div.ts @@ -0,0 +1,125 @@ +import { BinaryOperation } from 'solc-typed-ast'; +import { AST } from '../../../ast/ast'; +import { mapRange } from '../../../utils/utils'; +import { forAllWidths, IntxIntFunction, mask, WarplibFunctionInfo } from '../../utils'; + +export function div_signed(): WarplibFunctionInfo { + return { + fileName: 'div_signed', + imports: [ + 'from starkware.cairo.common.bitwise import bitwise_and', + 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', + 'from starkware.cairo.common.uint256 import Uint256, uint256_signed_div_rem, uint256_eq', + 'from warplib.maths.utils import felt_to_uint256', + `from warplib.maths.int_conversions import ${mapRange( + 31, + (n) => `warp_int${8 * n + 8}_to_int256`, + ).join(', ')}, ${mapRange(31, (n) => `warp_int256_to_int${8 * n + 8}`).join(', ')}`, + `from warplib.maths.mul_signed import ${mapRange(32, (n) => `warp_mul_signed${8 * n + 8}`)}`, + ], + functions: forAllWidths((width) => { + if (width === 256) { + return [ + 'func warp_div_signed256{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(lhs : Uint256, rhs : Uint256) -> (res : Uint256){', + ` if (rhs.high == 0 and rhs.low == 0){`, + ` with_attr error_message("Division by zero error"){`, + ` assert 1 = 0;`, + ` }`, + ` }`, + ` let (is_minus_one) = uint256_eq(rhs, Uint256(${mask(128)}, ${mask(128)}));`, + ` if (is_minus_one == 1){`, + ' let (res : Uint256) = warp_mul_signed256(lhs, rhs);', + ' return (res,);', + ' }', + ' let (res : Uint256, _) = uint256_signed_div_rem(lhs, rhs);', + ' return (res,);', + '}', + ].join('\n'); + } else { + return [ + `func warp_div_signed${width}{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(lhs : felt, rhs : felt) -> (res : felt){`, + ` alloc_locals;`, + ` if (rhs == 0){`, + ` with_attr error_message("Division by zero error"){`, + ` assert 1 = 0;`, + ` }`, + ` }`, + ` if (rhs == ${mask(width)}){`, + ` let (res : felt) = warp_mul_signed${width}(lhs, rhs);`, + ` return (res,);`, + ' }', + ` let (local lhs_256) = warp_int${width}_to_int256(lhs);`, + ` let (rhs_256) = warp_int${width}_to_int256(rhs);`, + ' let (res256, _) = uint256_signed_div_rem(lhs_256, rhs_256);', + ` let (truncated) = warp_int256_to_int${width}(res256);`, + ` return (truncated,);`, + '}', + ].join('\n'); + } + }), + }; +} + +export function div_signed_unsafe(): WarplibFunctionInfo { + return { + fileName: 'div_signed_unsafe', + imports: [ + 'from starkware.cairo.common.bitwise import bitwise_and', + 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', + 'from starkware.cairo.common.uint256 import Uint256, uint256_signed_div_rem, uint256_eq', + 'from warplib.maths.utils import felt_to_uint256', + `from warplib.maths.int_conversions import ${mapRange( + 31, + (n) => `warp_int${8 * n + 8}_to_int256`, + ).join(', ')}, ${mapRange(31, (n) => `warp_int256_to_int${8 * n + 8}`).join(', ')}`, + `from warplib.maths.mul_signed_unsafe import ${mapRange( + 32, + (n) => `warp_mul_signed_unsafe${8 * n + 8}`, + )}`, + ], + functions: forAllWidths((width) => { + if (width === 256) { + return [ + 'func warp_div_signed_unsafe256{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(lhs : Uint256, rhs : Uint256) -> (res : Uint256){', + ` if (rhs.high == 0 and rhs.low == 0){`, + ` with_attr error_message("Division by zero error"){`, + ` assert 1 = 0;`, + ` }`, + ` }`, + ` let (is_minus_one) = uint256_eq(rhs, Uint256(${mask(128)}, ${mask(128)}));`, + ` if (is_minus_one == 1){`, + ' let (res : Uint256) = warp_mul_signed_unsafe256(lhs, rhs);', + ' return (res,);', + ' }', + ' let (res : Uint256, _) = uint256_signed_div_rem(lhs, rhs);', + ' return (res,);', + '}', + ].join('\n'); + } else { + return [ + `func warp_div_signed_unsafe${width}{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(lhs : felt, rhs : felt) -> (res : felt){`, + ` alloc_locals;`, + ` if (rhs == 0){`, + ` with_attr error_message("Division by zero error"){`, + ` assert 1 = 0;`, + ` }`, + ` }`, + ` if (rhs == ${mask(width)}){`, + ` let (res : felt) = warp_mul_signed_unsafe${width}(lhs, rhs);`, + ` return (res,);`, + ' }', + ` let (local lhs_256) = warp_int${width}_to_int256(lhs);`, + ` let (rhs_256) = warp_int${width}_to_int256(rhs);`, + ' let (res256, _) = uint256_signed_div_rem(lhs_256, rhs_256);', + ` let (truncated) = warp_int256_to_int${width}(res256);`, + ` return (truncated,);`, + '}', + ].join('\n'); + } + }), + }; +} + +export function functionaliseDiv(node: BinaryOperation, unsafe: boolean, ast: AST): void { + IntxIntFunction(node, 'div', 'signedOrWide', true, unsafe, ast); +} diff --git a/src/warplib/implementations/maths/eq.ts b/src/warplib/implementations/maths/eq.ts new file mode 100644 index 000000000..520e6c1d1 --- /dev/null +++ b/src/warplib/implementations/maths/eq.ts @@ -0,0 +1,7 @@ +import { BinaryOperation } from 'solc-typed-ast'; +import { AST } from '../../../ast/ast'; +import { Comparison } from '../../utils'; + +export function functionaliseEq(node: BinaryOperation, ast: AST): void { + Comparison(node, 'eq', 'only256', false, ast); +} diff --git a/src/warplib/implementations/maths/exp.ts b/src/warplib/implementations/maths/exp.ts new file mode 100644 index 000000000..f3bfbb565 --- /dev/null +++ b/src/warplib/implementations/maths/exp.ts @@ -0,0 +1,220 @@ +import assert from 'assert'; +import { + BinaryOperation, + FunctionCall, + FunctionCallKind, + Identifier, + IntType, +} from 'solc-typed-ast'; +import { AST } from '../../../ast/ast'; +import { printNode, printTypeNode } from '../../../utils/astPrinter'; +import { WARPLIB_MATHS } from '../../../utils/importPaths'; +import { safeGetNodeType } from '../../../utils/nodeTypeProcessing'; +import { mapRange, typeNameFromTypeNode } from '../../../utils/utils'; +import { forAllWidths, getIntOrFixedByteBitWidth, mask, WarplibFunctionInfo } from '../../utils'; + +export function exp() { + return createExp(false, false); +} + +export function exp_signed() { + return createExp(true, false); +} + +export function exp_unsafe() { + return createExp(false, true); +} + +export function exp_signed_unsafe() { + return createExp(true, true); +} + +function createExp(signed: boolean, unsafe: boolean): WarplibFunctionInfo { + const suffix = `${signed ? '_signed' : ''}${unsafe ? '_unsafe' : ''}`; + return { + fileName: `exp${suffix}`, + imports: [ + 'from starkware.cairo.common.bitwise import bitwise_and', + 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', + 'from starkware.cairo.common.uint256 import Uint256, uint256_sub', + `from warplib.maths.mul${suffix} import ${mapRange( + 32, + (n) => `warp_mul${suffix}${8 * n + 8}`, + ).join(', ')}`, + ], + functions: forAllWidths((width) => { + if (width === 256) { + return [ + `func _repeated_multiplication${width}{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(op : Uint256, count : felt) -> (res : Uint256){`, + ` if (count == 0){`, + ` return (Uint256(1, 0),);`, + ` }`, + ` let (x) = _repeated_multiplication${width}(op, count - 1);`, + ` let (res) = warp_mul${suffix}${width}(op, x);`, + ` return (res,);`, + `}`, + `func warp_exp${suffix}${width}{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(lhs : Uint256, rhs : felt) -> (res : Uint256){`, + ` if (rhs == 0){`, + ` return (Uint256(1, 0),);`, + ' }', + ' if (lhs.high == 0){', + ` if (lhs.low * (lhs.low - 1) == 0){`, + ' return (lhs,);', + ` }`, + ` }`, + ...getNegativeOneShortcutCode(signed, width, false), + ` let (res) = _repeated_multiplication${width}(lhs, rhs);`, + ` return (res,);`, + `}`, + `func _repeated_multiplication_256_${width}{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(op : Uint256, count : Uint256) -> (res : Uint256){`, + ` if (count.low == 0 and count.high == 0){`, + ` return (Uint256(1, 0),);`, + ` }`, + ` let (decr) = uint256_sub(count, Uint256(1, 0));`, + ` let (x) = _repeated_multiplication_256_${width}(op, decr);`, + ` let (res) = warp_mul${suffix}${width}(op, x);`, + ` return (res,);`, + `}`, + `func warp_exp_wide${suffix}${width}{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(lhs : Uint256, rhs : Uint256) -> (res : Uint256){`, + ` if (rhs.high == 0 and rhs.low == 0){`, + ` return (Uint256(1, 0),);`, + ' }', + ' if (lhs.high == 0 and lhs.low * (lhs.low - 1) == 0){', + ' return (lhs,);', + ` }`, + ...getNegativeOneShortcutCode(signed, width, true), + ` let (res) = _repeated_multiplication_256_${width}(lhs, rhs);`, + ` return (res,);`, + `}`, + ].join('\n'); + } else { + return [ + `func _repeated_multiplication${width}{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(op : felt, count : felt) -> (res : felt){`, + ` alloc_locals;`, + ` if (count == 0){`, + ` return (1,);`, + ` }else{`, + ` let (x) = _repeated_multiplication${width}(op, count - 1);`, + ` local bitwise_ptr : BitwiseBuiltin* = bitwise_ptr;`, + ` let (res) = warp_mul${suffix}${width}(op, x);`, + ` return (res,);`, + ` }`, + `}`, + `func warp_exp${suffix}${width}{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(lhs : felt, rhs : felt) -> (res : felt){`, + ' if (rhs == 0){', + ' return (1,);', + ` }`, + ' if (lhs * (lhs-1) * (rhs-1) == 0){', + ' return (lhs,);', + ' }', + ...getNegativeOneShortcutCode(signed, width, false), + ` let (res) = _repeated_multiplication${width}(lhs, rhs);`, + ` return (res,);`, + '}', + `func _repeated_multiplication_256_${width}{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(op : felt, count : Uint256) -> (res : felt){`, + ` alloc_locals;`, + ` if (count.low == 0 and count.high == 0){`, + ` return (1,);`, + ` }`, + ` let (decr) = uint256_sub(count, Uint256(1, 0));`, + ` let (x) = _repeated_multiplication_256_${width}(op, decr);`, + ` local bitwise_ptr : BitwiseBuiltin* = bitwise_ptr;`, + ` let (res) = warp_mul${suffix}${width}(op, x);`, + ` return (res,);`, + `}`, + `func warp_exp_wide${suffix}${width}{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(lhs : felt, rhs : Uint256) -> (res : felt){`, + ' if (rhs.low == 0){', + ' if (rhs.high == 0){', + ' return (1,);', + ' }', + ` }`, + ' if (lhs * (lhs-1) == 0){', + ' return (lhs,);', + ' }', + ' if (rhs.low == 1 and rhs.high == 0){', + ' return (lhs,);', + ' }', + ...getNegativeOneShortcutCode(signed, width, true), + ` let (res) = _repeated_multiplication_256_${width}(lhs, rhs);`, + ` return (res,);`, + '}', + ].join('\n'); + } + }), + }; +} + +function getNegativeOneShortcutCode(signed: boolean, lhsWidth: number, rhsWide: boolean): string[] { + if (!signed) return []; + + if (lhsWidth < 256) { + return [ + `if ((lhs - ${mask(lhsWidth)}) == 0){`, + ` let (is_odd) = bitwise_and(${rhsWide ? 'rhs.low' : 'rhs'}, 1);`, + ` return (1 + is_odd * 0x${'f'.repeat(lhsWidth / 8 - 1)}e,);`, + `}`, + ]; + } else { + return [ + `if ((lhs.low - ${mask(128)}) == 0 and (lhs.high - ${mask(128)}) == 0){`, + ` let (is_odd) = bitwise_and(${rhsWide ? 'rhs.low' : 'rhs'}, 1);`, + ` return (Uint256(1 + is_odd * 0x${'f'.repeat(31)}e, is_odd * ${mask(128)}),);`, + `}`, + ]; + } +} + +export function functionaliseExp(node: BinaryOperation, unsafe: boolean, ast: AST) { + const lhsType = safeGetNodeType(node.vLeftExpression, ast.inference); + const rhsType = safeGetNodeType(node.vRightExpression, ast.inference); + const retType = safeGetNodeType(node, ast.inference); + assert( + retType instanceof IntType, + `${printNode(node)} has type ${printTypeNode(retType)}, which is not compatible with **`, + ); + assert( + rhsType instanceof IntType, + `${printNode(node)} has rhs-type ${rhsType.pp()}, which is not compatible with **`, + ); + const fullName = [ + 'warp_', + 'exp', + rhsType.nBits === 256 ? '_wide' : '', + retType.signed ? '_signed' : '', + unsafe ? '_unsafe' : '', + `${getIntOrFixedByteBitWidth(retType)}`, + ].join(''); + + const importName = [ + ...WARPLIB_MATHS, + `exp${retType.signed ? '_signed' : ''}${unsafe ? '_unsafe' : ''}`, + ]; + + const importedFunc = ast.registerImport( + node, + importName, + fullName, + [ + ['lhs', typeNameFromTypeNode(lhsType, ast)], + ['rhs', typeNameFromTypeNode(rhsType, ast)], + ], + [['res', typeNameFromTypeNode(retType, ast)]], + ); + + const call = new FunctionCall( + ast.reserveId(), + node.src, + node.typeString, + FunctionCallKind.FunctionCall, + new Identifier( + ast.reserveId(), + '', + `function (${node.typeString}, ${node.typeString}) returns (${node.typeString})`, + fullName, + importedFunc.id, + ), + [node.vLeftExpression, node.vRightExpression], + ); + + ast.replaceNode(node, call); +} diff --git a/src/warplib/implementations/maths/ge.ts b/src/warplib/implementations/maths/ge.ts new file mode 100644 index 000000000..426e9df1d --- /dev/null +++ b/src/warplib/implementations/maths/ge.ts @@ -0,0 +1,40 @@ +import { BinaryOperation } from 'solc-typed-ast'; +import { AST } from '../../../ast/ast'; +import { mapRange } from '../../../utils/utils'; +import { forAllWidths, Comparison, WarplibFunctionInfo } from '../../utils'; + +export function ge_signed(): WarplibFunctionInfo { + return { + fileName: 'ge_signed', + imports: [ + 'from starkware.cairo.common.bitwise import bitwise_and', + 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', + 'from starkware.cairo.common.uint256 import Uint256, uint256_signed_le', + `from warplib.maths.lt_signed import ${mapRange(31, (n) => `warp_le_signed${8 * n + 8}`).join( + ', ', + )}`, + ], + functions: forAllWidths((width) => { + if (width === 256) { + return [ + 'func warp_ge_signed256{range_check_ptr}(lhs : Uint256, rhs : Uint256) -> (res : felt){', + ' let (res) = uint256_signed_le(rhs, lhs);', + ' return (res,);', + '}', + ].join('\n'); + } else { + return [ + `func warp_ge_signed${width}{bitwise_ptr : BitwiseBuiltin*, range_check_ptr}(`, + ' lhs : felt, rhs : felt) -> (res : felt){', + ` let (res) = warp_le_signed${width}(rhs, lhs);`, + ` return (res,);`, + '}', + ].join('\n'); + } + }), + }; +} + +export function functionaliseGe(node: BinaryOperation, ast: AST): void { + Comparison(node, 'ge', 'signedOrWide', true, ast); +} diff --git a/src/warplib/implementations/maths/gt.ts b/src/warplib/implementations/maths/gt.ts new file mode 100644 index 000000000..7764c5aca --- /dev/null +++ b/src/warplib/implementations/maths/gt.ts @@ -0,0 +1,40 @@ +import { BinaryOperation } from 'solc-typed-ast'; +import { AST } from '../../../ast/ast'; +import { mapRange } from '../../../utils/utils'; +import { forAllWidths, Comparison, WarplibFunctionInfo } from '../../utils'; + +export function gt_signed(): WarplibFunctionInfo { + return { + fileName: 'gt_signed', + imports: [ + 'from starkware.cairo.common.bitwise import bitwise_and', + 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', + 'from starkware.cairo.common.uint256 import Uint256, uint256_signed_lt', + `from warplib.maths.lt_signed import ${mapRange(31, (n) => `warp_lt_signed${8 * n + 8}`).join( + ', ', + )}`, + ], + functions: forAllWidths((width) => { + if (width === 256) { + return [ + 'func warp_gt_signed256{range_check_ptr}(lhs : Uint256, rhs : Uint256) -> (res : felt){', + ' let (res) = uint256_signed_lt(rhs, lhs);', + ' return (res,);', + '}', + ].join('\n'); + } else { + return [ + `func warp_gt_signed${width}{bitwise_ptr : BitwiseBuiltin*, range_check_ptr}(`, + ' lhs : felt, rhs : felt) -> (res : felt){', + ` let (res) = warp_lt_signed${width}(rhs, lhs);`, + ` return (res,);`, + '}', + ].join('\n'); + } + }), + }; +} + +export function functionaliseGt(node: BinaryOperation, ast: AST): void { + Comparison(node, 'gt', 'signedOrWide', true, ast); +} diff --git a/src/warplib/implementations/maths/le.ts b/src/warplib/implementations/maths/le.ts new file mode 100644 index 000000000..9e18a1f5a --- /dev/null +++ b/src/warplib/implementations/maths/le.ts @@ -0,0 +1,61 @@ +import { BinaryOperation } from 'solc-typed-ast'; +import { AST } from '../../../ast/ast'; +import { forAllWidths, msb, Comparison, WarplibFunctionInfo } from '../../utils'; + +export function le_signed(): WarplibFunctionInfo { + return { + fileName: 'le_signed', + imports: [ + 'from starkware.cairo.common.bitwise import bitwise_and', + 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', + 'from starkware.cairo.common.math_cmp import is_le_felt', + 'from starkware.cairo.common.uint256 import Uint256, uint256_signed_le', + ], + functions: forAllWidths((width) => { + if (width === 256) { + return [ + `func warp_le_signed${width}{range_check_ptr}(lhs : Uint256, rhs : Uint256) -> (res : felt){`, + ' let (res) = uint256_signed_le(lhs, rhs);', + ' return (res,);', + '}', + ].join('\n'); + } else { + return [ + `func warp_le_signed${width}{bitwise_ptr : BitwiseBuiltin*, range_check_ptr}(`, + ` lhs : felt, rhs : felt) -> (res : felt){`, + ` alloc_locals;`, + ` let (lhs_msb : felt) = bitwise_and(lhs, ${msb(width)});`, + ` let (rhs_msb : felt) = bitwise_and(rhs, ${msb(width)});`, + ` local bitwise_ptr : BitwiseBuiltin* = bitwise_ptr;`, + ` if (lhs_msb == 0){`, + ` // lhs >= 0`, + ` if (rhs_msb == 0){`, + ` // rhs >= 0`, + ` let result = is_le_felt(lhs, rhs);`, + ` return (result,);`, + ` }else{`, + ` // rhs < 0`, + ` return (0,);`, + ` }`, + ` }else{`, + ` // lhs < 0`, + ` if (rhs_msb == 0){`, + ` // rhs >= 0`, + ` return (1,);`, + ` }else{`, + ` // rhs < 0`, + ` // (signed) lhs <= rhs <=> (unsigned) lhs >= rhs`, + ` let result = is_le_felt(lhs, rhs);`, + ` return (result,);`, + ` }`, + ` }`, + `}`, + ].join('\n'); + } + }), + }; +} + +export function functionaliseLe(node: BinaryOperation, ast: AST): void { + Comparison(node, 'le', 'signedOrWide', true, ast); +} diff --git a/src/warplib/implementations/maths/lt.ts b/src/warplib/implementations/maths/lt.ts new file mode 100644 index 000000000..b58212b8c --- /dev/null +++ b/src/warplib/implementations/maths/lt.ts @@ -0,0 +1,43 @@ +import { BinaryOperation } from 'solc-typed-ast'; +import { AST } from '../../../ast/ast'; +import { mapRange } from '../../../utils/utils'; +import { forAllWidths, Comparison, WarplibFunctionInfo } from '../../utils'; + +export function lt_signed(): WarplibFunctionInfo { + return { + fileName: 'lt_signed', + imports: [ + 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', + 'from starkware.cairo.common.uint256 import Uint256, uint256_signed_lt', + 'from warplib.maths.utils import felt_to_uint256', + `from warplib.maths.le_signed import ${mapRange(31, (n) => `warp_le_signed${8 * n + 8}`).join( + ', ', + )}`, + ], + functions: forAllWidths((width) => { + if (width === 256) { + return [ + 'func warp_lt_signed256{range_check_ptr}(lhs : Uint256, rhs : Uint256) -> (res : felt){', + ' let (res) = uint256_signed_lt(lhs, rhs);', + ' return (res,);', + '}', + ].join('\n'); + } else { + return [ + `func warp_lt_signed${width}{bitwise_ptr : BitwiseBuiltin*, range_check_ptr}(`, + ' lhs : felt, rhs : felt) -> (res : felt){', + ' if (lhs == rhs){', + ' return (0,);', + ' }', + ` let (res) = warp_le_signed${width}(lhs, rhs);`, + ` return (res,);`, + '}', + ].join('\n'); + } + }), + }; +} + +export function functionaliseLt(node: BinaryOperation, ast: AST): void { + Comparison(node, 'lt', 'signedOrWide', true, ast); +} diff --git a/src/warplib/implementations/maths/mod.ts b/src/warplib/implementations/maths/mod.ts new file mode 100644 index 000000000..6cae243d4 --- /dev/null +++ b/src/warplib/implementations/maths/mod.ts @@ -0,0 +1,55 @@ +import { BinaryOperation } from 'solc-typed-ast'; +import { AST } from '../../../ast/ast'; +import { mapRange } from '../../../utils/utils'; +import { forAllWidths, IntxIntFunction, WarplibFunctionInfo } from '../../utils'; + +export function mod_signed(): WarplibFunctionInfo { + return { + fileName: 'mod_signed', + imports: [ + 'from starkware.cairo.common.bitwise import bitwise_and', + 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', + 'from starkware.cairo.common.uint256 import Uint256, uint256_signed_div_rem', + 'from warplib.maths.utils import felt_to_uint256', + `from warplib.maths.int_conversions import ${mapRange( + 31, + (n) => `warp_int${8 * n + 8}_to_int256`, + ).join(', ')}, ${mapRange(31, (n) => `warp_int256_to_int${8 * n + 8}`).join(', ')}`, + ], + functions: forAllWidths((width) => { + if (width === 256) { + return [ + 'func warp_mod_signed256{range_check_ptr}(lhs : Uint256, rhs : Uint256) -> (res : Uint256){', + ` if (rhs.high == 0 and rhs.low == 0){`, + ` with_attr error_message("Modulo by zero error"){`, + ` assert 1 = 0;`, + ` }`, + ` }`, + ' let (_, res : Uint256) = uint256_signed_div_rem(lhs, rhs);', + ' return (res,);', + '}', + ].join('\n'); + } else { + return [ + `func warp_mod_signed${width}{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(lhs : felt, rhs : felt) -> (res : felt){`, + ` alloc_locals;`, + ` if (rhs == 0){`, + ` with_attr error_message("Modulo by zero error"){`, + ` assert 1 = 0;`, + ` }`, + ` }`, + ` let (local lhs_256) = warp_int${width}_to_int256(lhs);`, + ` let (rhs_256) = warp_int${width}_to_int256(rhs);`, + ' let (_, res256) = uint256_signed_div_rem(lhs_256, rhs_256);', + ` let (truncated) = warp_int256_to_int${width}(res256);`, + ` return (truncated,);`, + '}', + ].join('\n'); + } + }), + }; +} + +export function functionaliseMod(node: BinaryOperation, ast: AST): void { + IntxIntFunction(node, 'mod', 'signedOrWide', true, false, ast); +} diff --git a/src/warplib/implementations/maths/mul.ts b/src/warplib/implementations/maths/mul.ts new file mode 100644 index 000000000..738c5ba2b --- /dev/null +++ b/src/warplib/implementations/maths/mul.ts @@ -0,0 +1,247 @@ +import { BinaryOperation } from 'solc-typed-ast'; +import { AST } from '../../../ast/ast'; +import { mapRange } from '../../../utils/utils'; +import { + forAllWidths, + uint256, + pow2, + bound, + mask, + msb, + IntxIntFunction, + WarplibFunctionInfo, +} from '../../utils'; + +export function mul(): WarplibFunctionInfo { + return { + fileName: 'mul', + imports: [ + 'from starkware.cairo.common.uint256 import Uint256, uint256_mul', + 'from starkware.cairo.common.math_cmp import is_le_felt', + 'from warplib.maths.ge import warp_ge256', + 'from warplib.maths.utils import felt_to_uint256', + ], + functions: forAllWidths((width) => { + if (width === 256) { + return [ + 'func warp_mul256{range_check_ptr}(lhs : Uint256, rhs : Uint256) -> (res : Uint256){', + ' let (result : Uint256, overflow : Uint256) = uint256_mul(lhs, rhs);', + ' assert overflow.low = 0;', + ' assert overflow.high = 0;', + ' return (result,);', + '}', + ].join('\n'); + } else if (width >= 128) { + return [ + `func warp_mul${width}{range_check_ptr}(lhs : felt, rhs : felt) -> (res : felt){`, + ' alloc_locals;', + ' let (l256 : Uint256) = felt_to_uint256(lhs);', + ' let (r256 : Uint256) = felt_to_uint256(rhs);', + ' let (local res : Uint256) = warp_mul256(l256, r256);', + ` let (outOfRange : felt) = warp_ge256(res, ${uint256(pow2(width))});`, + ' assert outOfRange = 0;', + ` return (res.low + ${bound(128)} * res.high,);`, + '}', + ].join('\n'); + } else { + return [ + `func warp_mul${width}{range_check_ptr}(lhs : felt, rhs : felt) -> (res : felt){`, + ' let res = lhs * rhs;', + ` let inRange : felt = is_le_felt(res, ${mask(width)});`, + ' assert inRange = 1;', + ' return (res,);', + '}', + ].join('\n'); + } + }), + }; +} + +export function mul_unsafe(): WarplibFunctionInfo { + return { + fileName: 'mul_unsafe', + imports: [ + 'from starkware.cairo.common.bitwise import bitwise_and', + 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', + 'from starkware.cairo.common.uint256 import Uint256, uint256_mul', + 'from warplib.maths.utils import felt_to_uint256', + ], + functions: forAllWidths((width) => { + if (width === 256) { + return [ + `func warp_mul_unsafe256{range_check_ptr}(lhs : Uint256, rhs : Uint256) -> (res : Uint256){`, + ` let (res : Uint256, _) = uint256_mul(lhs, rhs);`, + ` return (res,);`, + `}`, + ].join('\n'); + } else if (width >= 128) { + return [ + `func warp_mul_unsafe${width}{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(lhs : felt, rhs : felt) -> (res : felt){`, + ` alloc_locals;`, + ` let (l256 : Uint256) = felt_to_uint256(lhs);`, + ` let (r256 : Uint256) = felt_to_uint256(rhs);`, + ` let (local res : Uint256, _) = uint256_mul(l256, r256);`, + ` let (high) = bitwise_and(res.high, ${mask(width - 128)});`, + ` return (res.low + ${bound(128)} * high,);`, + `}`, + ].join('\n'); + } else { + return [ + `func warp_mul_unsafe${width}{bitwise_ptr : BitwiseBuiltin*}(lhs : felt, rhs : felt) -> (res : felt){`, + ` let (res) = bitwise_and(lhs * rhs, ${mask(width)});`, + ` return (res,);`, + '}', + ].join('\n'); + } + }), + }; +} + +export function mul_signed(): WarplibFunctionInfo { + return { + fileName: 'mul_signed', + imports: [ + 'from starkware.cairo.common.bitwise import bitwise_and', + 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', + 'from starkware.cairo.common.math_cmp import is_le_felt', + 'from starkware.cairo.common.uint256 import Uint256, uint256_mul, uint256_cond_neg, uint256_signed_nn, uint256_neg, uint256_le', + 'from warplib.maths.utils import felt_to_uint256', + `from warplib.maths.le import warp_le`, + `from warplib.maths.mul import ${mapRange(31, (n) => `warp_mul${8 * n + 8}`).join(', ')}`, + `from warplib.maths.negate import ${mapRange(31, (n) => `warp_negate${8 * n + 8}`).join( + ', ', + )}`, + ], + functions: forAllWidths((width) => { + if (width === 256) { + return [ + `func warp_mul_signed256{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(`, + ` lhs : Uint256, rhs : Uint256) -> (result : Uint256){`, + ` alloc_locals;`, + ` // 1 => lhs >= 0, 0 => lhs < 0`, + ` let (lhs_nn) = uint256_signed_nn(lhs);`, + ` // 1 => rhs >= 0, 0 => rhs < 0`, + ` let (local rhs_nn) = uint256_signed_nn(rhs);`, + ` // negates if arg is 1, which is if lhs_nn is 0, which is if lhs < 0`, + ` let (lhs_abs) = uint256_cond_neg(lhs, 1 - lhs_nn);`, + ` // negates if arg is 1`, + ` let (rhs_abs) = uint256_cond_neg(rhs, 1 - rhs_nn);`, + ` let (res_abs, overflow) = uint256_mul(lhs_abs, rhs_abs);`, + ` assert overflow.low = 0;`, + ` assert overflow.high = 0;`, + ` let res_should_be_neg = lhs_nn + rhs_nn;`, + ` if (res_should_be_neg == 1){`, + ` let (in_range) = uint256_le(res_abs, Uint256(0,${msb(128)}));`, + ` assert in_range = 1;`, + ` let (negated) = uint256_neg(res_abs);`, + ` return (negated,);`, + ` }else{`, + ` let (msb) = bitwise_and(res_abs.high, ${msb(128)});`, + ` assert msb = 0;`, + ` return (res_abs,);`, + ` }`, + `}`, + ].join('\n'); + } else { + return [ + `func warp_mul_signed${width}{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(`, + ` lhs : felt, rhs : felt) -> (res : felt){`, + ` alloc_locals;`, + ` let (left_msb) = bitwise_and(lhs, ${msb(width)});`, + ` if (left_msb == 0){`, + ` let (right_msb) = bitwise_and(rhs, ${msb(width)});`, + ` if (right_msb == 0){`, + ` let (res) = warp_mul${width}(lhs, rhs);`, + ` let (res_msb) = bitwise_and(res, ${msb(width)});`, + ` assert res_msb = 0;`, + ` return (res,);`, + ` }else{`, + ` let (rhs_abs) = warp_negate${width}(rhs);`, + ` let (res_abs) = warp_mul${width}(lhs, rhs_abs);`, + ` let (in_range) = warp_le(res_abs, ${msb(width)});`, + ` assert in_range = 1;`, + ` let (res) = warp_negate${width}(res_abs);`, + ` return (res,);`, + ` }`, + ` }else{`, + ` let (right_msb) = bitwise_and(rhs, ${msb(width)});`, + ` if (right_msb == 0){`, + ` let (lhs_abs) = warp_negate${width}(lhs);`, + ` let (res_abs) = warp_mul${width}(lhs_abs, rhs);`, + ` let (in_range) = warp_le(res_abs, ${msb(width)});`, + ` assert in_range = 1;`, + ` let (res) = warp_negate${width}(res_abs);`, + ` return (res,);`, + ` }else{`, + ` let (lhs_abs) = warp_negate${width}(lhs);`, + ` let (rhs_abs) = warp_negate${width}(rhs);`, + ` let (res) = warp_mul${width}(lhs_abs, rhs_abs);`, + ` let (res_msb) = bitwise_and(res, ${msb(width)});`, + ` assert res_msb = 0;`, + ` return (res,);`, + ` }`, + ` }`, + `}`, + ].join('\n'); + } + }), + }; +} + +export function mul_signed_unsafe(): WarplibFunctionInfo { + return { + fileName: 'mul_signed_unsafe', + imports: [ + 'from starkware.cairo.common.bitwise import bitwise_and', + 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', + 'from starkware.cairo.common.math_cmp import is_le_felt', + 'from starkware.cairo.common.uint256 import Uint256, uint256_mul, uint256_cond_neg, uint256_signed_nn', + `from warplib.maths.mul_unsafe import ${mapRange( + 31, + (n) => `warp_mul_unsafe${8 * n + 8}`, + ).join(', ')}`, + `from warplib.maths.negate import ${mapRange(31, (n) => `warp_negate${8 * n + 8}`).join( + ', ', + )}`, + 'from warplib.maths.utils import felt_to_uint256', + ], + functions: forAllWidths((width) => { + if (width === 256) { + return [ + `func warp_mul_signed_unsafe256{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(`, + ` lhs : Uint256, rhs : Uint256) -> (result : Uint256){`, + ` alloc_locals;`, + ` let (lhs_nn) = uint256_signed_nn(lhs);`, + ` let (local rhs_nn) = uint256_signed_nn(rhs);`, + ` let (lhs_abs) = uint256_cond_neg(lhs, lhs_nn);`, + ` let (rhs_abs) = uint256_cond_neg(rhs, rhs_nn);`, + ` let (res_abs, _) = uint256_mul(lhs_abs, rhs_abs);`, + ` let (res) = uint256_cond_neg(res_abs, (lhs_nn + rhs_nn) * (2 - lhs_nn - rhs_nn));`, + ` return (res,);`, + `}`, + ].join('\n'); + } else { + return [ + `func warp_mul_signed_unsafe${width}{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(`, + ` lhs : felt, rhs : felt) -> (res : felt){`, + ` alloc_locals;`, + ` let (local left_msb) = bitwise_and(lhs, ${msb(width)});`, + ` let (local right_msb) = bitwise_and(rhs, ${msb(width)});`, + ` let (res) = warp_mul_unsafe${width}(lhs, rhs);`, + ` let not_neg = (left_msb + right_msb) * (${bound(width)} - left_msb - right_msb);`, + ` if (not_neg == ${msb(width)}){`, + ` let (res) = warp_negate${width}(res);`, + ` return (res,);`, + ` }else{`, + ` return (res,);`, + ` }`, + `}`, + ].join('\n'); + } + }), + }; +} + +export function functionaliseMul(node: BinaryOperation, unsafe: boolean, ast: AST): void { + IntxIntFunction(node, 'mul', 'always', true, unsafe, ast); +} diff --git a/src/warplib/implementations/maths/neq.ts b/src/warplib/implementations/maths/neq.ts new file mode 100644 index 000000000..0d62da748 --- /dev/null +++ b/src/warplib/implementations/maths/neq.ts @@ -0,0 +1,7 @@ +import { BinaryOperation } from 'solc-typed-ast'; +import { AST } from '../../../ast/ast'; +import { Comparison } from '../../utils'; + +export function functionaliseNeq(node: BinaryOperation, ast: AST): void { + Comparison(node, 'neq', 'only256', false, ast); +} diff --git a/src/warplib/implementations/maths/or.ts b/src/warplib/implementations/maths/or.ts new file mode 100644 index 000000000..b8e92e419 --- /dev/null +++ b/src/warplib/implementations/maths/or.ts @@ -0,0 +1,7 @@ +import { BinaryOperation } from 'solc-typed-ast'; +import { AST } from '../../../ast/ast'; +import { BoolxBoolFunction } from '../../utils'; + +export function functionaliseOr(node: BinaryOperation, ast: AST): void { + BoolxBoolFunction(node, 'or', ast); +} diff --git a/src/warplib/implementations/maths/shl.ts b/src/warplib/implementations/maths/shl.ts new file mode 100644 index 000000000..7ce99e90a --- /dev/null +++ b/src/warplib/implementations/maths/shl.ts @@ -0,0 +1,118 @@ +import assert from 'assert'; +import { + BinaryOperation, + FixedBytesType, + FunctionCall, + FunctionCallKind, + Identifier, + IntType, +} from 'solc-typed-ast'; +import { AST } from '../../../ast/ast'; +import { printNode, printTypeNode } from '../../../utils/astPrinter'; +import { WARPLIB_MATHS } from '../../../utils/importPaths'; +import { safeGetNodeType } from '../../../utils/nodeTypeProcessing'; +import { typeNameFromTypeNode } from '../../../utils/utils'; +import { forAllWidths, getIntOrFixedByteBitWidth, WarplibFunctionInfo } from '../../utils'; + +// rhs is always unsigned, and signed and unsigned shl are the same +export function shl(): WarplibFunctionInfo { + //Need to provide an implementation with 256bit rhs and <256bit lhs + return { + fileName: 'shl', + imports: [ + 'from starkware.cairo.common.bitwise import bitwise_and', + 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', + 'from starkware.cairo.common.math import split_felt', + 'from starkware.cairo.common.math_cmp import is_le_felt', + 'from starkware.cairo.common.uint256 import Uint256, uint256_shl', + 'from warplib.maths.pow2 import pow2', + ], + functions: forAllWidths((width) => { + if (width === 256) { + return [ + 'func warp_shl256{range_check_ptr}(lhs : Uint256, rhs : felt) -> (result : Uint256){', + ' let (high, low) = split_felt(rhs);', + ' let (res) = uint256_shl(lhs, Uint256(low, high));', + ' return (res,);', + '}', + 'func warp_shl256_256{range_check_ptr}(lhs : Uint256, rhs : Uint256) -> (result : Uint256){', + ' let (res) = uint256_shl(lhs, rhs);', + ' return (res,);', + '}', + ].join('\n'); + } else { + return [ + `func warp_shl${width}{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(`, + ` lhs : felt, rhs : felt) -> (res : felt){`, + ` // width <= rhs (shift amount) means result will be 0`, + ` let large_shift = is_le_felt(${width}, rhs);`, + ` if (large_shift == 1){`, + ` return (0,);`, + ` }else{`, + ` let preserved_width = ${width} - rhs;`, + ` let (preserved_bound) = pow2(preserved_width);`, + ` let (lhs_truncated) = bitwise_and(lhs, preserved_bound - 1);`, + ` let (multiplier) = pow2(rhs);`, + ` return (lhs_truncated * multiplier,);`, + ` }`, + `}`, + `func warp_shl${width}_256{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(`, + ` lhs : felt, rhs : Uint256) -> (res : felt){`, + ` if (rhs.high == 0){`, + ` let (res) = warp_shl${width}(lhs, rhs.low);`, + ` return (res,);`, + ` }else{`, + ` return (0,);`, + ` }`, + `}`, + ].join('\n'); + } + }), + }; +} + +export function functionaliseShl(node: BinaryOperation, ast: AST): void { + const lhsType = safeGetNodeType(node.vLeftExpression, ast.inference); + const rhsType = safeGetNodeType(node.vRightExpression, ast.inference); + const retType = safeGetNodeType(node, ast.inference); + + assert( + lhsType instanceof IntType || lhsType instanceof FixedBytesType, + `lhs of << ${printNode(node)} non-int type ${printTypeNode(lhsType)}`, + ); + assert( + rhsType instanceof IntType, + `rhs of << ${printNode(node)} non-int type ${printTypeNode(rhsType)}`, + ); + + const lhsWidth = getIntOrFixedByteBitWidth(lhsType); + + const fullName = `warp_shl${lhsWidth}${rhsType.nBits === 256 ? '_256' : ''}`; + + const importedFunc = ast.registerImport( + node, + [...WARPLIB_MATHS, 'shl'], + fullName, + [ + ['lhs', typeNameFromTypeNode(lhsType, ast)], + ['rhs', typeNameFromTypeNode(rhsType, ast)], + ], + [['res', typeNameFromTypeNode(retType, ast)]], + ); + const call = new FunctionCall( + ast.reserveId(), + node.src, + node.typeString, + FunctionCallKind.FunctionCall, + new Identifier( + ast.reserveId(), + '', + `function (${node.vLeftExpression.typeString}, ${node.vRightExpression.typeString}) returns (${node.typeString})`, + fullName, + importedFunc.id, + ), + [node.vLeftExpression, node.vRightExpression], + ); + + ast.replaceNode(node, call); +} diff --git a/src/warplib/implementations/maths/shr.ts b/src/warplib/implementations/maths/shr.ts new file mode 100644 index 000000000..ea99b2503 --- /dev/null +++ b/src/warplib/implementations/maths/shr.ts @@ -0,0 +1,269 @@ +import assert from 'assert'; +import { + BinaryOperation, + FixedBytesType, + FunctionCall, + FunctionCallKind, + Identifier, + IntType, +} from 'solc-typed-ast'; +import { AST } from '../../../ast/ast'; +import { printNode, printTypeNode } from '../../../utils/astPrinter'; +import { WARPLIB_MATHS } from '../../../utils/importPaths'; +import { safeGetNodeType } from '../../../utils/nodeTypeProcessing'; +import { mapRange, typeNameFromTypeNode } from '../../../utils/utils'; +import { + forAllWidths, + bound, + msb, + mask, + getIntOrFixedByteBitWidth, + WarplibFunctionInfo, +} from '../../utils'; + +export function shr(): WarplibFunctionInfo { + return { + fileName: 'shr', + imports: [ + 'from starkware.cairo.common.bitwise import bitwise_and, bitwise_not', + 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', + 'from starkware.cairo.common.math_cmp import is_le, is_le_felt', + 'from starkware.cairo.common.uint256 import Uint256, uint256_and', + 'from warplib.maths.pow2 import pow2', + ], + functions: forAllWidths((width) => { + if (width === 256) { + return [ + `func warp_shr256{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(lhs : Uint256, rhs : felt) -> (`, + ` result : Uint256){`, + ` let le_127 = is_le(rhs, 127);`, + ` if (le_127 == 1){`, + ` // (h', l') := (h, l) >> rhs`, + ` // p := 2^rhs`, + ` // l' = ((h & (p-1)) << (128 - rhs)) + ((l&~(p-1)) >> rhs)`, + ` // = ((h & (p-1)) << 128 >> rhs) + ((l&~(p-1)) >> rhs)`, + ` // = (h & (p-1)) * 2^128 / p + (l&~(p-1)) / p`, + ` // = (h & (p-1) * 2^128 + l&~(p-1)) / p`, + ` // h' = h >> rhs = (h - h&(p-1)) / p`, + ` let (p) = pow2(rhs);`, + ` let (low_mask) = bitwise_not(p - 1);`, + ` let (low_part) = bitwise_and(lhs.low, low_mask);`, + ` let (high_part) = bitwise_and(lhs.high, p - 1);`, + ` return (`, + ` Uint256(low=(low_part + ${bound( + 128, + )} * high_part) / p, high=(lhs.high - high_part) / p),);`, + ` }`, + ` let le_255 = is_le(rhs, 255);`, + ` if (le_255 == 1){`, + ` let (p) = pow2(rhs - 128);`, + ` let (mask) = bitwise_not(p - 1);`, + ` let (res) = bitwise_and(lhs.high, mask);`, + ` return (Uint256(res / p, 0),);`, + ` }`, + ` return (Uint256(0, 0),);`, + `}`, + `func warp_shr256_256{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(lhs : Uint256, rhs : Uint256) -> (`, + ` result : Uint256){`, + ` if (rhs.high != 0){`, + ` return (Uint256(0, 0),);`, + ` }`, + ` let le_127 = is_le(rhs.low, 127);`, + ` if (le_127 == 1){`, + ` // (h', l') := (h, l) >> rhs`, + ` // p := 2^rhs`, + ` // l' = ((h & (p-1)) << (128 - rhs)) + ((l&~(p-1)) >> rhs)`, + ` // = ((h & (p-1)) << 128 >> rhs) + ((l&~(p-1)) >> rhs)`, + ` // = (h & (p-1)) * 2^128 / p + (l&~(p-1)) / p`, + ` // = (h & (p-1) * 2^128 + l&~(p-1)) / p`, + ` // h' = h >> rhs = (h - h&(p-1)) / p`, + ` let (p) = pow2(rhs.low);`, + ` let (low_mask) = bitwise_not(p - 1);`, + ` let (low_part) = bitwise_and(lhs.low, low_mask);`, + ` let (high_part) = bitwise_and(lhs.high, p - 1);`, + ` return (`, + ` Uint256(low=(low_part + ${bound( + 128, + )} * high_part) / p, high=(lhs.high - high_part) / p),);`, + ` }`, + ` let le_255 = is_le(rhs.low, 255);`, + ` if (le_255 == 1){`, + ` let (p) = pow2(rhs.low - 128);`, + ` let (mask) = bitwise_not(p - 1);`, + ` let (res) = bitwise_and(lhs.high, mask);`, + ` return (Uint256(res / p, 0),);`, + ` }`, + ` return (Uint256(0, 0),);`, + `}`, + ].join('\n'); + } else { + return [ + `func warp_shr${width}{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(`, + ` lhs : felt, rhs : felt) -> (res : felt){`, + ` let large_shift = is_le_felt(${width}, rhs);`, + ` if (large_shift == 1){`, + ` return (0,);`, + ` }else{`, + ` let preserved_width = ${width} - rhs;`, + ` let (preserved_bound) = pow2(preserved_width);`, + ` let mask = preserved_bound - 1;`, + ` let (divisor) = pow2(rhs);`, + ` let shifted_mask = mask * divisor;`, + ` let (lhs_truncated) = bitwise_and(lhs, shifted_mask);`, + ` return (lhs_truncated / divisor , );`, + ` }`, + `}`, + `func warp_shr${width}_256{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(`, + ` lhs : felt, rhs : Uint256) -> (res : felt){`, + ` if (rhs.high == 0){`, + ` let (res,) = warp_shr${width}(lhs, rhs.low);`, + ` return (res,);`, + ` }else{`, + ` return (0,);`, + ` }`, + `}`, + ].join('\n'); + } + }), + }; +} + +export function shr_signed(): WarplibFunctionInfo { + return { + fileName: 'shr_signed', + imports: [ + 'from starkware.cairo.common.bitwise import bitwise_and', + 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', + 'from starkware.cairo.common.math_cmp import is_le, is_le_felt', + 'from starkware.cairo.common.uint256 import Uint256, uint256_and', + 'from warplib.maths.pow2 import pow2', + `from warplib.maths.shr import ${mapRange(32, (n) => `warp_shr${8 * n + 8}`).join(', ')}`, + ], + functions: forAllWidths((width) => { + if (width === 256) { + return [ + `func warp_shr_signed256{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(lhs : Uint256, rhs : felt) -> (res : Uint256){`, + ` alloc_locals;`, + ` let (local lhs_msb) = bitwise_and(lhs.high, ${msb(128)});`, + ` let (logical_shift) = warp_shr256(lhs, rhs);`, + ` if (lhs_msb == 0){`, + ` return (logical_shift,);`, + ` }else{`, + ` let large_shift = is_le(${width}, rhs);`, + ` if (large_shift == 1){`, + ` return (Uint256(${mask(128)}, ${mask(128)}),);`, + ` }else{`, + ` let crosses_boundary = is_le(128, rhs);`, + ` if (crosses_boundary == 1){`, + ` let (bound) = pow2(rhs-128);`, + ` let ones = bound - 1;`, + ` let (shift) = pow2(256-rhs);`, + ` return (Uint256(logical_shift.low+ones*shift, ${mask(128)}),);`, + ` }else{`, + ` let (bound) = pow2(rhs);`, + ` let ones = bound - 1;`, + ` let (shift) = pow2(128-rhs);`, + ` return (Uint256(logical_shift.low, logical_shift.high+ones*shift),);`, + ` }`, + ` }`, + ` }`, + `}`, + `func warp_shr_signed256_256{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(lhs : Uint256, rhs : Uint256) -> (res : Uint256){`, + ` if (rhs.high == 0){`, + ` let (res) = warp_shr_signed256(lhs, rhs.low);`, + ` return (res,);`, + ` }else{`, + ` let (res) = warp_shr_signed256(lhs, 256);`, + ` return (res,);`, + ` }`, + `}`, + ].join('\n'); + } else { + return [ + `func warp_shr_signed${width}{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(`, + ` lhs : felt, rhs : felt) -> (res : felt){`, + ` alloc_locals;`, + ` let (local lhs_msb) = bitwise_and(lhs, ${msb(width)});`, + ` local bitwise_ptr : BitwiseBuiltin* = bitwise_ptr;`, + ` if (lhs_msb == 0){`, + ` let (res) = warp_shr${width}(lhs, rhs);`, + ` return (res,);`, + ` }else{`, + ` let large_shift = is_le_felt(${width}, rhs);`, + ` if (large_shift == 1){`, + ` return (${mask(width)},);`, + ` }else{`, + ` let (shifted) = warp_shr${width}(lhs, rhs);`, + ` let (sign_extend_bound) = pow2(rhs);`, + ` let sign_extend_value = sign_extend_bound - 1;`, + ` let (sign_extend_multiplier) = pow2(${width} - rhs);`, + ` return (shifted + sign_extend_value * sign_extend_multiplier,);`, + ` }`, + ` }`, + `}`, + `func warp_shr_signed${width}_256{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(`, + ` lhs : felt, rhs : Uint256) -> (res : felt){`, + ` if (rhs.high == 0){`, + ` let (res) = warp_shr${width}(lhs, rhs.low);`, + ` return (res,);`, + ` }else{`, + ` let (res) = warp_shr${width}(lhs, ${width});`, + ` return (res,);`, + ` }`, + `}`, + ].join('\n'); + } + }), + }; +} + +export function functionaliseShr(node: BinaryOperation, ast: AST): void { + const lhsType = safeGetNodeType(node.vLeftExpression, ast.inference); + const rhsType = safeGetNodeType(node.vRightExpression, ast.inference); + const retType = safeGetNodeType(node, ast.inference); + + assert( + lhsType instanceof IntType || lhsType instanceof FixedBytesType, + `lhs of >> ${printNode(node)} non-int type ${printTypeNode(lhsType)}`, + ); + assert( + rhsType instanceof IntType, + `rhs of >> ${printNode(node)} non-int type ${printTypeNode(rhsType)}`, + ); + + const lhsWidth = getIntOrFixedByteBitWidth(lhsType); + const signed = lhsType instanceof IntType && lhsType.signed; + + const fullName = `warp_shr${signed ? '_signed' : ''}${lhsWidth}${ + rhsType.nBits === 256 ? '_256' : '' + }`; + + const importName = [...WARPLIB_MATHS, `shr${signed ? '_signed' : ''}`]; + + const importedFunc = ast.registerImport( + node, + importName, + fullName, + [ + ['lhs', typeNameFromTypeNode(lhsType, ast)], + ['rhs', typeNameFromTypeNode(rhsType, ast)], + ], + [['res', typeNameFromTypeNode(retType, ast)]], + ); + const call = new FunctionCall( + ast.reserveId(), + node.src, + node.typeString, + FunctionCallKind.FunctionCall, + new Identifier( + ast.reserveId(), + '', + `function (${node.vLeftExpression.typeString}, ${node.vRightExpression.typeString}) returns (${node.typeString})`, + fullName, + importedFunc.id, + ), + [node.vLeftExpression, node.vRightExpression], + ); + + ast.replaceNode(node, call); +} diff --git a/src/warplib/implementations/maths/sub.ts b/src/warplib/implementations/maths/sub.ts new file mode 100644 index 000000000..b6174c393 --- /dev/null +++ b/src/warplib/implementations/maths/sub.ts @@ -0,0 +1,157 @@ +import assert from 'assert'; +import { BinaryOperation, IntType } from 'solc-typed-ast'; +import { AST } from '../../../ast/ast'; +import { printTypeNode } from '../../../utils/astPrinter'; +import { safeGetNodeType } from '../../../utils/nodeTypeProcessing'; +import { + forAllWidths, + bound, + mask, + msb, + msbAndNext, + IntxIntFunction, + WarplibFunctionInfo, +} from '../../utils'; + +export function sub_unsafe(): WarplibFunctionInfo { + return { + fileName: 'sub_unsafe', + imports: ['use integer::u256_overflow_sub;'], + functions: forAllWidths((width) => { + if (width === 256) { + return [ + `fn warp_sub_unsafe256(lhs : u256, rhs : u256) -> u256 {`, + ` let (value, _) = u256_overflow_sub(lhs, rhs);`, + ` return value;`, + `}`, + ].join('\n'); + } else { + return [ + // TODO: Use bitwise '&' to take just the width-bits + `fn warp_sub_unsafe${width}(lhs : felt252, rhs : felt252) -> felt252 {`, + ` return lhs - rhs;`, + `}`, + ].join('\n'); + } + }), + }; +} + +export function sub_signed(): WarplibFunctionInfo { + return { + fileName: 'sub_signed', + imports: [ + 'from starkware.cairo.common.bitwise import bitwise_and', + 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', + 'from starkware.cairo.common.uint256 import Uint256, uint256_add, uint256_signed_le, uint256_sub, uint256_not', + ], + functions: forAllWidths((width) => { + if (width === 256) { + return [ + `func warp_sub_signed${width}{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(lhs : Uint256, rhs : Uint256) -> (`, + ` res : Uint256){`, + ` // First sign extend both operands`, + ` let (left_msb : felt252) = bitwise_and(lhs.high, ${msb(128)});`, + ` let (right_msb : felt252) = bitwise_and(rhs.high, ${msb(128)});`, + ` let left_overflow : felt252 = left_msb / ${msb(128)};`, + ` let right_overflow : felt252 = right_msb / ${msb(128)};`, + ``, + ` // Now safely negate the rhs and add (l - r = l + (-r))`, + ` let (right_flipped : Uint256) = uint256_not(rhs);`, + ` let (right_neg, overflow) = uint256_add(right_flipped, Uint256(1,0));`, + ` let right_overflow_neg = overflow + 1 - right_overflow;`, + ` let (res, res_base_overflow) = uint256_add(lhs, right_neg);`, + ` let res_overflow = res_base_overflow + left_overflow + right_overflow_neg;`, + ``, + ` // Check if the result fits in the correct width`, + ` let (res_msb : felt252) = bitwise_and(res.high, ${msb(128)});`, + ` let (res_overflow_lsb : felt252) = bitwise_and(res_overflow, 1);`, + ` assert res_overflow_lsb * ${msb(128)} = res_msb;`, + ``, + ` // Narrow and return`, + ` return (res,);`, + `}`, + ].join('\n'); + } else { + return [ + `func warp_sub_signed${width}{bitwise_ptr : BitwiseBuiltin*}(lhs : felt252, rhs : felt252) -> (`, + ` res : felt252){`, + ` // First sign extend both operands`, + ` let (left_msb : felt252) = bitwise_and(lhs, ${msb(width)});`, + ` let (right_msb : felt252) = bitwise_and(rhs, ${msb(width)});`, + ` let left_safe : felt252 = lhs + 2 * left_msb;`, + ` let right_safe : felt252 = rhs + 2 * right_msb;`, + ``, + ` // Now safely negate the rhs and add (l - r = l + (-r))`, + ` let right_neg : felt252 = ${bound(width + 1)} - right_safe;`, + ` let extended_res : felt252 = left_safe + right_neg;`, + ``, + ` // Check if the result fits in the correct width`, + ` let (overflowBits) = bitwise_and(extended_res, ${msbAndNext(width)});`, + ` assert overflowBits * (overflowBits - ${msbAndNext(width)}) = 0;`, + ``, + ` // Narrow and return`, + ` let (res) = bitwise_and(extended_res, ${mask(width)});`, + ` return (res,);`, + `}`, + ].join('\n'); + } + }), + }; +} + +export function sub_signed_unsafe(): WarplibFunctionInfo { + return { + fileName: 'sub_signed_unsafe', + imports: [ + 'from starkware.cairo.common.bitwise import bitwise_and', + 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', + 'from starkware.cairo.common.uint256 import Uint256, uint256_sub', + ], + functions: forAllWidths((width) => { + if (width === 256) { + return [ + 'func warp_sub_signed_unsafe256{range_check_ptr}(lhs : Uint256, rhs : Uint256) -> (res : Uint256){', + ' let (res) = uint256_sub(lhs, rhs);', + ' return (res,);', + '}', + ].join('\n'); + } else { + return [ + `func warp_sub_signed_unsafe${width}{bitwise_ptr : BitwiseBuiltin*}(`, + ` lhs : felt252, rhs : felt252) -> (res : felt252){`, + ` // First sign extend both operands`, + ` let (left_msb : felt252) = bitwise_and(lhs, ${msb(width)});`, + ` let (right_msb : felt252) = bitwise_and(rhs, ${msb(width)});`, + ` let left_safe : felt252 = lhs + 2 * left_msb;`, + ` let right_safe : felt252 = rhs + 2 * right_msb;`, + ``, + ` // Now safely negate the rhs and add (l - r = l + (-r))`, + ` let right_neg : felt252 = ${bound(width + 1)} - right_safe;`, + ` let extended_res : felt252 = left_safe + right_neg;`, + ``, + ` // Narrow and return`, + ` let (res) = bitwise_and(extended_res, ${mask(width)});`, + ` return (res,);`, + `}`, + ].join('\n'); + } + }), + }; +} + +//func warp_sub{range_check_ptr}(lhs : felt252, rhs : felt252) -> (res : felt252): +//func warp_sub256{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(lhs : Uint256, rhs : Uint256) -> (res : Uint256): + +export function functionaliseSub(node: BinaryOperation, unsafe: boolean, ast: AST): void { + const typeNode = safeGetNodeType(node, ast.inference); + assert( + typeNode instanceof IntType, + `Expected IntType for subtraction, got ${printTypeNode(typeNode)}`, + ); + if (unsafe) { + IntxIntFunction(node, 'sub', 'always', true, unsafe, ast); + } else { + IntxIntFunction(node, 'sub', 'signedOrWide', true, unsafe, ast); + } +} diff --git a/src/warplib/implementations/maths/xor.ts b/src/warplib/implementations/maths/xor.ts new file mode 100644 index 000000000..5097ad341 --- /dev/null +++ b/src/warplib/implementations/maths/xor.ts @@ -0,0 +1,7 @@ +import { BinaryOperation } from 'solc-typed-ast'; +import { AST } from '../../../ast/ast'; +import { IntxIntFunction } from '../../utils'; + +export function functionaliseXor(node: BinaryOperation, ast: AST): void { + IntxIntFunction(node, 'xor', 'only256', false, false, ast); +} From fbdb8a848fef971c80a4aaaa3a2d0fb6015131cd Mon Sep 17 00:00:00 2001 From: esdras-santos Date: Mon, 10 Apr 2023 11:36:56 -0300 Subject: [PATCH 6/6] lint fixed --- src/cairoWriter/writers/cairoContractWriter.ts | 2 +- src/freeStructWritter.ts | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/cairoWriter/writers/cairoContractWriter.ts b/src/cairoWriter/writers/cairoContractWriter.ts index d6bc781c3..00d72d475 100644 --- a/src/cairoWriter/writers/cairoContractWriter.ts +++ b/src/cairoWriter/writers/cairoContractWriter.ts @@ -111,7 +111,7 @@ export class CairoContractWriter extends CairoASTNodeWriter { }) .join('\n'); - const freeStructs = getStructs(sourceUnit, this.ast); + const freeStructs = getStructs(sourceUnit); const structs = [...freeStructs, ...sourceUnit.vStructs, ...(node?.vStructs || [])] .map((v) => writer.write(v)) .join('\n\n'); diff --git a/src/freeStructWritter.ts b/src/freeStructWritter.ts index 2e394db74..8d48c7887 100644 --- a/src/freeStructWritter.ts +++ b/src/freeStructWritter.ts @@ -5,7 +5,6 @@ import { StructDefinition, UserDefinedTypeName, } from 'solc-typed-ast'; -import { AST } from './ast/ast'; /* Library calls in solidity are delegate calls @@ -17,7 +16,7 @@ import { AST } from './ast/ast'; function which do that. */ -export function getStructs(node: SourceUnit, ast: AST): StructDefinition[] { +export function getStructs(node: SourceUnit): StructDefinition[] { const externalStructs = getDefinitionsToInline(node, node, new Set()); return Array.from(externalStructs.values()); }