-
Notifications
You must be signed in to change notification settings - Fork 3.4k
/
kernel_prompt_template.py
151 lines (123 loc) · 7.05 KB
/
kernel_prompt_template.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
# Copyright (c) Microsoft. All rights reserved.
import logging
from html import escape
from typing import TYPE_CHECKING, Any
from pydantic import PrivateAttr, field_validator
from semantic_kernel.exceptions import TemplateRenderException
from semantic_kernel.functions.kernel_arguments import KernelArguments
from semantic_kernel.prompt_template.const import KERNEL_TEMPLATE_FORMAT_NAME
from semantic_kernel.prompt_template.input_variable import InputVariable
from semantic_kernel.prompt_template.prompt_template_base import PromptTemplateBase
from semantic_kernel.template_engine.blocks.block import Block
from semantic_kernel.template_engine.blocks.code_block import CodeBlock
from semantic_kernel.template_engine.blocks.named_arg_block import NamedArgBlock
from semantic_kernel.template_engine.blocks.var_block import VarBlock
from semantic_kernel.template_engine.template_tokenizer import TemplateTokenizer
if TYPE_CHECKING:
from semantic_kernel.kernel import Kernel
from semantic_kernel.prompt_template.prompt_template_config import PromptTemplateConfig
logger: logging.Logger = logging.getLogger(__name__)
class KernelPromptTemplate(PromptTemplateBase):
"""Create a Kernel prompt template."""
_blocks: list[Block] = PrivateAttr(default_factory=list)
@field_validator("prompt_template_config")
@classmethod
def validate_template_format(cls, v: "PromptTemplateConfig") -> "PromptTemplateConfig":
"""Validate the template format."""
if v.template_format != KERNEL_TEMPLATE_FORMAT_NAME:
raise ValueError(f"Invalid prompt template format: {v.template_format}. Expected: semantic-kernel")
return v
def model_post_init(self, _: Any) -> None:
"""Post init model."""
self._blocks = self.extract_blocks()
# Add all of the existing input variables to our known set. We'll avoid adding any
# dynamically discovered input variables with the same name.
seen = {iv.name.lower() for iv in self.prompt_template_config.input_variables}
# Enumerate every block in the template, adding any variables that are referenced.
for block in self._blocks:
if isinstance(block, VarBlock):
# Add all variables from variable blocks, e.g. "{{$a}}".
self._add_if_missing(block.name, seen)
continue
if isinstance(block, CodeBlock):
for sub_block in block.tokens:
if isinstance(sub_block, VarBlock):
# Add all variables from code blocks, e.g. "{{p.bar $b}}".
self._add_if_missing(sub_block.name, seen)
continue
if isinstance(sub_block, NamedArgBlock) and sub_block.variable:
# Add all variables from named arguments, e.g. "{{p.bar b = $b}}".
# represents a named argument for a function call.
# For example, in the template {{ MyPlugin.MyFunction var1=$boo }}, var1=$boo
# is a named arg block.
self._add_if_missing(sub_block.variable.name, seen)
def extract_blocks(self) -> list[Block]:
"""Given the prompt template, extract all the blocks (text, variables, function calls)."""
logger.debug(f"Extracting blocks from template: {self.prompt_template_config.template}")
if not self.prompt_template_config.template:
return []
return TemplateTokenizer.tokenize(self.prompt_template_config.template)
def _add_if_missing(self, variable_name: str, seen: set):
# Convert variable_name to lower case to handle case-insensitivity
if variable_name and variable_name.lower() not in seen:
seen.add(variable_name.lower())
self.prompt_template_config.input_variables.append(InputVariable(name=variable_name))
async def render(self, kernel: "Kernel", arguments: "KernelArguments | None" = None) -> str:
"""Render the prompt template.
Using the prompt template, replace the variables with their values
and execute the functions replacing their reference with the
function result.
Args:
kernel ("Kernel"): The kernel to use for functions.
arguments ("KernelArguments | None"): The arguments to use for rendering. (Default value = None)
Returns:
str: The prompt template ready to be used for an AI request
"""
return await self.render_blocks(self._blocks, kernel, arguments)
async def render_blocks(
self, blocks: list[Block], kernel: "Kernel", arguments: "KernelArguments | None" = None
) -> str:
"""Given a list of blocks render each block and compose the final result.
Args:
blocks (list[Block]): Template blocks generated by ExtractBlocks
kernel ("Kernel"): The kernel to use for functions
arguments ("KernelArguments | None"): The arguments to use for rendering (Default value = None)
Returns:
str: The prompt template ready to be used for an AI request
"""
from semantic_kernel.template_engine.protocols.code_renderer import CodeRenderer
from semantic_kernel.template_engine.protocols.text_renderer import TextRenderer
logger.debug(f"Rendering list of {len(blocks)} blocks")
rendered_blocks: list[str] = []
arguments = self._get_trusted_arguments(arguments or KernelArguments())
allow_unsafe_function_output = self._get_allow_dangerously_set_function_output()
for block in blocks:
if isinstance(block, TextRenderer):
rendered_blocks.append(block.render(kernel, arguments))
continue
if isinstance(block, CodeRenderer):
try:
rendered = await block.render_code(kernel, arguments)
except Exception as exc:
logger.error(f"Error rendering code block: {exc}")
raise TemplateRenderException(f"Error rendering code block: {exc}") from exc
rendered_blocks.append(rendered if allow_unsafe_function_output else escape(rendered))
prompt = "".join(rendered_blocks)
logger.debug(f"Rendered prompt: {prompt}")
return prompt
@staticmethod
def quick_render(template: str, arguments: dict[str, Any]) -> str:
"""Quick render a Kernel prompt template, only supports text and variable blocks.
Args:
template: The template to render
arguments: The arguments to use for rendering
Returns:
str: The prompt template ready to be used for an AI request
"""
from semantic_kernel import Kernel
from semantic_kernel.template_engine.protocols.code_renderer import CodeRenderer
blocks = TemplateTokenizer.tokenize(template)
if any(isinstance(block, CodeRenderer) for block in blocks):
raise ValueError("Quick render does not support code blocks.")
kernel = Kernel()
return "".join([block.render(kernel, arguments) for block in blocks]) # type: ignore