Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: generate nested functions for expression lambdas #1062

Merged
merged 3 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ import {
SdsClassMember,
SdsDeclaration,
SdsExpression,
SdsExpressionLambda,
SdsLambda,
SdsModule,
SdsParameter,
SdsParameterList,
Expand Down Expand Up @@ -100,7 +102,7 @@ import { NamedTupleType } from '../typing/model.js';
import { getOutermostContainerOfType } from '../helpers/astUtils.js';

export const CODEGEN_PREFIX = '__gen_';
const BLOCK_LAMBDA_PREFIX = `${CODEGEN_PREFIX}block_lambda_`;
const LAMBDA_PREFIX = `${CODEGEN_PREFIX}lambda_`;
const BLOCK_LAMBDA_RESULT_PREFIX = `${CODEGEN_PREFIX}block_lambda_result_`;
const YIELD_PREFIX = `${CODEGEN_PREFIX}yield_`;

Expand Down Expand Up @@ -593,23 +595,31 @@ export class SafeDsPythonGenerator {
frame: GenerationInfoFrame,
generateLambda: boolean,
): CompositeGeneratorNode {
const blockLambdaCode: CompositeGeneratorNode[] = [];
const result: CompositeGeneratorNode[] = [];
if (isSdsAssignment(statement)) {
if (statement.expression) {
for (const lambda of AstUtils.streamAllContents(statement.expression).filter(isSdsBlockLambda)) {
blockLambdaCode.push(this.generateBlockLambda(lambda, frame));
for (const node of AstUtils.streamAllContents(statement.expression)) {
if (isSdsBlockLambda(node)) {
result.push(this.generateBlockLambda(node, frame));
} else if (isSdsExpressionLambda(node)) {
result.push(this.generateExpressionLambda(node, frame));
}
}
}
blockLambdaCode.push(this.generateAssignment(statement, frame, generateLambda));
return joinTracedToNode(statement)(blockLambdaCode, (stmt) => stmt, {
result.push(this.generateAssignment(statement, frame, generateLambda));
return joinTracedToNode(statement)(result, (stmt) => stmt, {
separator: NL,
})!;
} else if (isSdsExpressionStatement(statement)) {
for (const lambda of AstUtils.streamAllContents(statement.expression).filter(isSdsBlockLambda)) {
blockLambdaCode.push(this.generateBlockLambda(lambda, frame));
for (const node of AstUtils.streamAllContents(statement.expression)) {
if (isSdsBlockLambda(node)) {
result.push(this.generateBlockLambda(node, frame));
} else if (isSdsExpressionLambda(node)) {
result.push(this.generateExpressionLambda(node, frame));
}
}
blockLambdaCode.push(this.generateExpression(statement.expression, frame));
return joinTracedToNode(statement)(blockLambdaCode, (stmt) => stmt, {
result.push(this.generateExpression(statement.expression, frame));
return joinTracedToNode(statement)(result, (stmt) => stmt, {
separator: NL,
})!;
}
Expand Down Expand Up @@ -701,7 +711,7 @@ export class SafeDsPythonGenerator {
)}`,
);
}
return expandTracedToNode(blockLambda)`def ${frame.getUniqueLambdaBlockName(
return expandTracedToNode(blockLambda)`def ${frame.getUniqueLambdaName(
blockLambda,
)}(${this.generateParameters(blockLambda.parameterList, frame)}):`
.appendNewLine()
Expand All @@ -711,6 +721,17 @@ export class SafeDsPythonGenerator {
});
}

private generateExpressionLambda(node: SdsExpressionLambda, frame: GenerationInfoFrame): CompositeGeneratorNode {
const name = frame.getUniqueLambdaName(node);
const parameters = this.generateParameters(node.parameterList, frame);
const result = this.generateExpression(node.result, frame);

return expandTracedToNode(node)`
def ${name}(${parameters}):
return ${result}
`;
}

private generateExpression(expression: SdsExpression, frame: GenerationInfoFrame): CompositeGeneratorNode {
if (isSdsTemplateStringPart(expression)) {
if (isSdsTemplateStringStart(expression)) {
Expand Down Expand Up @@ -765,7 +786,7 @@ export class SafeDsPythonGenerator {
{ separator: ', ' },
)}]`;
} else if (isSdsBlockLambda(expression)) {
return traceToNode(expression)(frame.getUniqueLambdaBlockName(expression));
return traceToNode(expression)(frame.getUniqueLambdaName(expression));
} else if (isSdsCall(expression)) {
const callable = this.nodeMapper.callToCallable(expression);
const receiver = this.generateExpression(expression.receiver, frame);
Expand Down Expand Up @@ -807,10 +828,7 @@ export class SafeDsPythonGenerator {
return call;
}
} else if (isSdsExpressionLambda(expression)) {
return expandTracedToNode(expression)`lambda ${this.generateParameters(
expression.parameterList,
frame,
)}: ${this.generateExpression(expression.result, frame)}`;
return traceToNode(expression)(frame.getUniqueLambdaName(expression));
} else if (isSdsInfixOperation(expression)) {
const leftOperand = this.generateExpression(expression.leftOperand, frame);
const rightOperand = this.generateExpression(expression.rightOperand, frame);
Expand Down Expand Up @@ -1260,7 +1278,7 @@ interface ImportData {
}

class GenerationInfoFrame {
private readonly blockLambdaManager: IdManager<SdsBlockLambda>;
private readonly lambdaManager: IdManager<SdsLambda>;
private readonly importSet: Map<String, ImportData>;
private readonly utilitySet: Set<UtilityFunction>;
private readonly typeVariableSet: Set<string>;
Expand All @@ -1276,7 +1294,7 @@ class GenerationInfoFrame {
targetPlaceholder: string | undefined = undefined,
disableRunnerIntegration: boolean = false,
) {
this.blockLambdaManager = new IdManager<SdsBlockLambda>();
this.lambdaManager = new IdManager();
this.importSet = importSet;
this.utilitySet = utilitySet;
this.typeVariableSet = typeVariableSet;
Expand Down Expand Up @@ -1312,8 +1330,8 @@ class GenerationInfoFrame {
}
}

getUniqueLambdaBlockName(lambda: SdsBlockLambda): string {
return `${BLOCK_LAMBDA_PREFIX}${this.blockLambdaManager.assignId(lambda)}`;
getUniqueLambdaName(lambda: SdsLambda): string {
return `${LAMBDA_PREFIX}${this.lambdaManager.assignId(lambda)}`;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
# Segments ---------------------------------------------------------------------

def test(param1, param_2, param_3=0):
f1(lambda a, b, c=0: 1)
def __gen_block_lambda_0(a, b, c=0):
def __gen_lambda_0(a, b, c=0):
return 1
f1(__gen_lambda_0)
def __gen_lambda_1(a, b, c=0):
pass
f2(__gen_block_lambda_0)
f2(__gen_lambda_1)

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ def f2(l):
# Pipelines --------------------------------------------------------------------

def test():
def __gen_block_lambda_0(a, b):
def __gen_lambda_0(a, b):
__gen_block_lambda_result_d = g()
return __gen_block_lambda_result_d
f1(__gen_block_lambda_0)
def __gen_block_lambda_1(a, b):
f1(__gen_lambda_0)
def __gen_lambda_1(a, b):
__gen_block_lambda_result_d = g()
__gen_block_lambda_result_e = g()
return __gen_block_lambda_result_d, __gen_block_lambda_result_e
f2(__gen_block_lambda_1)
f2(__gen_lambda_1)

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,23 @@
# Pipelines --------------------------------------------------------------------

def test():
def __gen_block_lambda_0(a, b=2):
def __gen_lambda_0(a, b=2):
__gen_block_lambda_result_d = g()
return __gen_block_lambda_result_d
f1(__gen_block_lambda_0)
def __gen_block_lambda_1(a, b):
f1(__gen_lambda_0)
def __gen_lambda_1(a, b):
__gen_block_lambda_result_d = g()
return __gen_block_lambda_result_d
f1(__gen_block_lambda_1)
def __gen_block_lambda_2():
f1(__gen_lambda_1)
def __gen_lambda_2():
pass
f2(__gen_block_lambda_2)
def __gen_block_lambda_3(a, b=2):
f2(__gen_lambda_2)
def __gen_lambda_3(a, b=2):
__gen_block_lambda_result_d = g()
return __gen_block_lambda_result_d
g2(f3(__gen_block_lambda_3))
def __gen_block_lambda_4(a, b=2):
g2(f3(__gen_lambda_3))
def __gen_lambda_4(a, b=2):
__gen_block_lambda_result_d = g()
return __gen_block_lambda_result_d
c = f3(__gen_block_lambda_4)
c = f3(__gen_lambda_4)
g2(c)

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,9 @@
# Pipelines --------------------------------------------------------------------

def test():
f(lambda a, b=2: 1)
f(lambda a, b: 1)
def __gen_lambda_0(a, b=2):
return 1
f(__gen_lambda_0)
def __gen_lambda_1(a, b):
return 1
f(__gen_lambda_1)

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ def testPipeline():
impureFileWrite2 = iFileWrite()
pureValueForImpure2 = noPartialEvalInt(2)
pureValueForImpure3 = 3
def __gen_block_lambda_0():
def __gen_lambda_0():
i1(1)
__gen_block_lambda_result_r = 1
return __gen_block_lambda_result_r
fp(__gen_block_lambda_0)
fp(__gen_lambda_0)
i1(1)
impureA1 = i1(pureValueForImpure2)
impureA2 = i1(noPartialEvalInt(3))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,16 @@ def testPipeline():
lStrMulti = 'multi\nline'
boolean1 = True
value1 = g(True, -1.0, 1, None, 'multi\nline')
def __gen_block_lambda_0():
def __gen_lambda_0():
i = 1
i2 = 3
j = 6
j2 = 4
__gen_block_lambda_result_z = 7
return __gen_block_lambda_result_z
o = (f(__gen_block_lambda_0)) + (f(lambda : 2))
def __gen_lambda_1():
return 2
o = (f(__gen_lambda_0)) + (f(__gen_lambda_1))
mapKey = 'key'
mapValue = 'value'
mapResult = g2({'key': 'value'})
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,30 @@ def test():
__gen_null_safe_call(j, lambda: 'abc'.j(123))
__gen_null_safe_call(k, lambda: k(456, 1.23))
f(safeds_runner.memoized_static_call("tests.generator.runnerIntegration.expressions.calls.main.readFile", readFile, [], [safeds_runner.file_mtime('a.txt')]))
f(l(lambda a: segment_a(a)))
f(l(lambda a: (3) * (segment_a(a))))
f(l(lambda a: safeds_runner.memoized_static_call("tests.generator.runnerIntegration.expressions.calls.main.m", m, [(3) * (segment_a(a))], [])))
f(l(lambda a: (3) * (safeds_runner.memoized_static_call("tests.generator.runnerIntegration.expressions.calls.main.m", m, [safeds_runner.memoized_static_call("tests.generator.runnerIntegration.expressions.calls.main.m", m, [(3) * (segment_a(a))], [])], []))))
def __gen_block_lambda_0(a):
def __gen_lambda_0(a):
return segment_a(a)
f(l(__gen_lambda_0))
def __gen_lambda_1(a):
return (3) * (segment_a(a))
f(l(__gen_lambda_1))
def __gen_lambda_2(a):
return safeds_runner.memoized_static_call("tests.generator.runnerIntegration.expressions.calls.main.m", m, [(3) * (segment_a(a))], [])
f(l(__gen_lambda_2))
def __gen_lambda_3(a):
return (3) * (safeds_runner.memoized_static_call("tests.generator.runnerIntegration.expressions.calls.main.m", m, [safeds_runner.memoized_static_call("tests.generator.runnerIntegration.expressions.calls.main.m", m, [(3) * (segment_a(a))], [])], []))
f(l(__gen_lambda_3))
def __gen_lambda_4(a):
__gen_block_lambda_result_result = segment_a(a)
return __gen_block_lambda_result_result
f(l(__gen_block_lambda_0))
def __gen_block_lambda_1(a):
f(l(__gen_lambda_4))
def __gen_lambda_5(a):
__gen_block_lambda_result_result = safeds_runner.memoized_static_call("tests.generator.runnerIntegration.expressions.calls.main.m", m, [segment_a(a)], [])
return __gen_block_lambda_result_result
f(l(__gen_block_lambda_1))
f(safeds_runner.memoized_static_call("tests.generator.runnerIntegration.expressions.calls.main.l", l, [lambda a: (3) * (a)], []))
def __gen_block_lambda_2(a):
f(l(__gen_lambda_5))
def __gen_lambda_6(a):
return (3) * (a)
f(safeds_runner.memoized_static_call("tests.generator.runnerIntegration.expressions.calls.main.l", l, [__gen_lambda_6], []))
def __gen_lambda_7(a):
__gen_block_lambda_result_result = (3) * (safeds_runner.memoized_static_call("tests.generator.runnerIntegration.expressions.calls.main.m", m, [a], []))
return __gen_block_lambda_result_result
f(safeds_runner.memoized_static_call("tests.generator.runnerIntegration.expressions.calls.main.l", l, [__gen_block_lambda_2], []))
f(safeds_runner.memoized_static_call("tests.generator.runnerIntegration.expressions.calls.main.l", l, [__gen_lambda_7], []))
Loading