Skip to content

Commit

Permalink
feat: generate nested functions for expression lambdas (#1062)
Browse files Browse the repository at this point in the history
### Summary of Changes

For expression lambdas, we now also generate nested functions instead of
Python lambdas. This is
1. more readable,
2. better for memoization, since we use the source code of a callable
for lookup and Python is not able to return *only* the source code of a
lambda.
  • Loading branch information
lars-reimann authored Apr 18, 2024
1 parent e45a4c9 commit f79fd61
Show file tree
Hide file tree
Showing 18 changed files with 104 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ import {
SdsClassMember,
SdsDeclaration,
SdsExpression,
SdsExpressionLambda,
SdsLambda,
SdsModule,
SdsParameter,
SdsParameterList,
Expand Down Expand Up @@ -100,7 +102,7 @@ import { NamedTupleType } from '../typing/model.js';
import { getOutermostContainerOfType } from '../helpers/astUtils.js';

export const CODEGEN_PREFIX = '__gen_';
const BLOCK_LAMBDA_PREFIX = `${CODEGEN_PREFIX}block_lambda_`;
const LAMBDA_PREFIX = `${CODEGEN_PREFIX}lambda_`;
const BLOCK_LAMBDA_RESULT_PREFIX = `${CODEGEN_PREFIX}block_lambda_result_`;
const YIELD_PREFIX = `${CODEGEN_PREFIX}yield_`;

Expand Down Expand Up @@ -593,23 +595,31 @@ export class SafeDsPythonGenerator {
frame: GenerationInfoFrame,
generateLambda: boolean,
): CompositeGeneratorNode {
const blockLambdaCode: CompositeGeneratorNode[] = [];
const result: CompositeGeneratorNode[] = [];
if (isSdsAssignment(statement)) {
if (statement.expression) {
for (const lambda of AstUtils.streamAllContents(statement.expression).filter(isSdsBlockLambda)) {
blockLambdaCode.push(this.generateBlockLambda(lambda, frame));
for (const node of AstUtils.streamAllContents(statement.expression)) {
if (isSdsBlockLambda(node)) {
result.push(this.generateBlockLambda(node, frame));
} else if (isSdsExpressionLambda(node)) {
result.push(this.generateExpressionLambda(node, frame));
}
}
}
blockLambdaCode.push(this.generateAssignment(statement, frame, generateLambda));
return joinTracedToNode(statement)(blockLambdaCode, (stmt) => stmt, {
result.push(this.generateAssignment(statement, frame, generateLambda));
return joinTracedToNode(statement)(result, (stmt) => stmt, {
separator: NL,
})!;
} else if (isSdsExpressionStatement(statement)) {
for (const lambda of AstUtils.streamAllContents(statement.expression).filter(isSdsBlockLambda)) {
blockLambdaCode.push(this.generateBlockLambda(lambda, frame));
for (const node of AstUtils.streamAllContents(statement.expression)) {
if (isSdsBlockLambda(node)) {
result.push(this.generateBlockLambda(node, frame));
} else if (isSdsExpressionLambda(node)) {
result.push(this.generateExpressionLambda(node, frame));
}
}
blockLambdaCode.push(this.generateExpression(statement.expression, frame));
return joinTracedToNode(statement)(blockLambdaCode, (stmt) => stmt, {
result.push(this.generateExpression(statement.expression, frame));
return joinTracedToNode(statement)(result, (stmt) => stmt, {
separator: NL,
})!;
}
Expand Down Expand Up @@ -701,7 +711,7 @@ export class SafeDsPythonGenerator {
)}`,
);
}
return expandTracedToNode(blockLambda)`def ${frame.getUniqueLambdaBlockName(
return expandTracedToNode(blockLambda)`def ${frame.getUniqueLambdaName(
blockLambda,
)}(${this.generateParameters(blockLambda.parameterList, frame)}):`
.appendNewLine()
Expand All @@ -711,6 +721,17 @@ export class SafeDsPythonGenerator {
});
}

private generateExpressionLambda(node: SdsExpressionLambda, frame: GenerationInfoFrame): CompositeGeneratorNode {
const name = frame.getUniqueLambdaName(node);
const parameters = this.generateParameters(node.parameterList, frame);
const result = this.generateExpression(node.result, frame);

return expandTracedToNode(node)`
def ${name}(${parameters}):
return ${result}
`;
}

private generateExpression(expression: SdsExpression, frame: GenerationInfoFrame): CompositeGeneratorNode {
if (isSdsTemplateStringPart(expression)) {
if (isSdsTemplateStringStart(expression)) {
Expand Down Expand Up @@ -765,7 +786,7 @@ export class SafeDsPythonGenerator {
{ separator: ', ' },
)}]`;
} else if (isSdsBlockLambda(expression)) {
return traceToNode(expression)(frame.getUniqueLambdaBlockName(expression));
return traceToNode(expression)(frame.getUniqueLambdaName(expression));
} else if (isSdsCall(expression)) {
const callable = this.nodeMapper.callToCallable(expression);
const receiver = this.generateExpression(expression.receiver, frame);
Expand Down Expand Up @@ -807,10 +828,7 @@ export class SafeDsPythonGenerator {
return call;
}
} else if (isSdsExpressionLambda(expression)) {
return expandTracedToNode(expression)`lambda ${this.generateParameters(
expression.parameterList,
frame,
)}: ${this.generateExpression(expression.result, frame)}`;
return traceToNode(expression)(frame.getUniqueLambdaName(expression));
} else if (isSdsInfixOperation(expression)) {
const leftOperand = this.generateExpression(expression.leftOperand, frame);
const rightOperand = this.generateExpression(expression.rightOperand, frame);
Expand Down Expand Up @@ -1260,7 +1278,7 @@ interface ImportData {
}

class GenerationInfoFrame {
private readonly blockLambdaManager: IdManager<SdsBlockLambda>;
private readonly lambdaManager: IdManager<SdsLambda>;
private readonly importSet: Map<String, ImportData>;
private readonly utilitySet: Set<UtilityFunction>;
private readonly typeVariableSet: Set<string>;
Expand All @@ -1276,7 +1294,7 @@ class GenerationInfoFrame {
targetPlaceholder: string | undefined = undefined,
disableRunnerIntegration: boolean = false,
) {
this.blockLambdaManager = new IdManager<SdsBlockLambda>();
this.lambdaManager = new IdManager();
this.importSet = importSet;
this.utilitySet = utilitySet;
this.typeVariableSet = typeVariableSet;
Expand Down Expand Up @@ -1312,8 +1330,8 @@ class GenerationInfoFrame {
}
}

getUniqueLambdaBlockName(lambda: SdsBlockLambda): string {
return `${BLOCK_LAMBDA_PREFIX}${this.blockLambdaManager.assignId(lambda)}`;
getUniqueLambdaName(lambda: SdsLambda): string {
return `${LAMBDA_PREFIX}${this.lambdaManager.assignId(lambda)}`;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
# Segments ---------------------------------------------------------------------

def test(param1, param_2, param_3=0):
f1(lambda a, b, c=0: 1)
def __gen_block_lambda_0(a, b, c=0):
def __gen_lambda_0(a, b, c=0):
return 1
f1(__gen_lambda_0)
def __gen_lambda_1(a, b, c=0):
pass
f2(__gen_block_lambda_0)
f2(__gen_lambda_1)

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ def f2(l):
# Pipelines --------------------------------------------------------------------

def test():
def __gen_block_lambda_0(a, b):
def __gen_lambda_0(a, b):
__gen_block_lambda_result_d = g()
return __gen_block_lambda_result_d
f1(__gen_block_lambda_0)
def __gen_block_lambda_1(a, b):
f1(__gen_lambda_0)
def __gen_lambda_1(a, b):
__gen_block_lambda_result_d = g()
__gen_block_lambda_result_e = g()
return __gen_block_lambda_result_d, __gen_block_lambda_result_e
f2(__gen_block_lambda_1)
f2(__gen_lambda_1)

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,23 @@
# Pipelines --------------------------------------------------------------------

def test():
def __gen_block_lambda_0(a, b=2):
def __gen_lambda_0(a, b=2):
__gen_block_lambda_result_d = g()
return __gen_block_lambda_result_d
f1(__gen_block_lambda_0)
def __gen_block_lambda_1(a, b):
f1(__gen_lambda_0)
def __gen_lambda_1(a, b):
__gen_block_lambda_result_d = g()
return __gen_block_lambda_result_d
f1(__gen_block_lambda_1)
def __gen_block_lambda_2():
f1(__gen_lambda_1)
def __gen_lambda_2():
pass
f2(__gen_block_lambda_2)
def __gen_block_lambda_3(a, b=2):
f2(__gen_lambda_2)
def __gen_lambda_3(a, b=2):
__gen_block_lambda_result_d = g()
return __gen_block_lambda_result_d
g2(f3(__gen_block_lambda_3))
def __gen_block_lambda_4(a, b=2):
g2(f3(__gen_lambda_3))
def __gen_lambda_4(a, b=2):
__gen_block_lambda_result_d = g()
return __gen_block_lambda_result_d
c = f3(__gen_block_lambda_4)
c = f3(__gen_lambda_4)
g2(c)

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,9 @@
# Pipelines --------------------------------------------------------------------

def test():
f(lambda a, b=2: 1)
f(lambda a, b: 1)
def __gen_lambda_0(a, b=2):
return 1
f(__gen_lambda_0)
def __gen_lambda_1(a, b):
return 1
f(__gen_lambda_1)

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ def testPipeline():
impureFileWrite2 = iFileWrite()
pureValueForImpure2 = noPartialEvalInt(2)
pureValueForImpure3 = 3
def __gen_block_lambda_0():
def __gen_lambda_0():
i1(1)
__gen_block_lambda_result_r = 1
return __gen_block_lambda_result_r
fp(__gen_block_lambda_0)
fp(__gen_lambda_0)
i1(1)
impureA1 = i1(pureValueForImpure2)
impureA2 = i1(noPartialEvalInt(3))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,16 @@ def testPipeline():
lStrMulti = 'multi\nline'
boolean1 = True
value1 = g(True, -1.0, 1, None, 'multi\nline')
def __gen_block_lambda_0():
def __gen_lambda_0():
i = 1
i2 = 3
j = 6
j2 = 4
__gen_block_lambda_result_z = 7
return __gen_block_lambda_result_z
o = (f(__gen_block_lambda_0)) + (f(lambda : 2))
def __gen_lambda_1():
return 2
o = (f(__gen_lambda_0)) + (f(__gen_lambda_1))
mapKey = 'key'
mapValue = 'value'
mapResult = g2({'key': 'value'})
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,30 @@ def test():
__gen_null_safe_call(j, lambda: 'abc'.j(123))
__gen_null_safe_call(k, lambda: k(456, 1.23))
f(safeds_runner.memoized_static_call("tests.generator.runnerIntegration.expressions.calls.main.readFile", readFile, [], [safeds_runner.file_mtime('a.txt')]))
f(l(lambda a: segment_a(a)))
f(l(lambda a: (3) * (segment_a(a))))
f(l(lambda a: safeds_runner.memoized_static_call("tests.generator.runnerIntegration.expressions.calls.main.m", m, [(3) * (segment_a(a))], [])))
f(l(lambda a: (3) * (safeds_runner.memoized_static_call("tests.generator.runnerIntegration.expressions.calls.main.m", m, [safeds_runner.memoized_static_call("tests.generator.runnerIntegration.expressions.calls.main.m", m, [(3) * (segment_a(a))], [])], []))))
def __gen_block_lambda_0(a):
def __gen_lambda_0(a):
return segment_a(a)
f(l(__gen_lambda_0))
def __gen_lambda_1(a):
return (3) * (segment_a(a))
f(l(__gen_lambda_1))
def __gen_lambda_2(a):
return safeds_runner.memoized_static_call("tests.generator.runnerIntegration.expressions.calls.main.m", m, [(3) * (segment_a(a))], [])
f(l(__gen_lambda_2))
def __gen_lambda_3(a):
return (3) * (safeds_runner.memoized_static_call("tests.generator.runnerIntegration.expressions.calls.main.m", m, [safeds_runner.memoized_static_call("tests.generator.runnerIntegration.expressions.calls.main.m", m, [(3) * (segment_a(a))], [])], []))
f(l(__gen_lambda_3))
def __gen_lambda_4(a):
__gen_block_lambda_result_result = segment_a(a)
return __gen_block_lambda_result_result
f(l(__gen_block_lambda_0))
def __gen_block_lambda_1(a):
f(l(__gen_lambda_4))
def __gen_lambda_5(a):
__gen_block_lambda_result_result = safeds_runner.memoized_static_call("tests.generator.runnerIntegration.expressions.calls.main.m", m, [segment_a(a)], [])
return __gen_block_lambda_result_result
f(l(__gen_block_lambda_1))
f(safeds_runner.memoized_static_call("tests.generator.runnerIntegration.expressions.calls.main.l", l, [lambda a: (3) * (a)], []))
def __gen_block_lambda_2(a):
f(l(__gen_lambda_5))
def __gen_lambda_6(a):
return (3) * (a)
f(safeds_runner.memoized_static_call("tests.generator.runnerIntegration.expressions.calls.main.l", l, [__gen_lambda_6], []))
def __gen_lambda_7(a):
__gen_block_lambda_result_result = (3) * (safeds_runner.memoized_static_call("tests.generator.runnerIntegration.expressions.calls.main.m", m, [a], []))
return __gen_block_lambda_result_result
f(safeds_runner.memoized_static_call("tests.generator.runnerIntegration.expressions.calls.main.l", l, [__gen_block_lambda_2], []))
f(safeds_runner.memoized_static_call("tests.generator.runnerIntegration.expressions.calls.main.l", l, [__gen_lambda_7], []))
Loading

0 comments on commit f79fd61

Please sign in to comment.