Skip to content

Commit

Permalink
Merge pull request #460 from julep-ai/f/workflow-tests
Browse files Browse the repository at this point in the history
feat: Add test for evaluate step
  • Loading branch information
whiterabbit1983 authored Aug 19, 2024
2 parents 729a0aa + aead2ab commit 72da581
Show file tree
Hide file tree
Showing 197 changed files with 6,532 additions and 1,412 deletions.
4 changes: 3 additions & 1 deletion agents-api/agents_api/activities/demo.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Callable

from temporalio import activity

from ..env import testing
Expand All @@ -12,6 +14,6 @@ async def mock_demo_activity(a: int, b: int) -> int:
return a + b


demo_activity = activity.defn(name="demo_activity")(
demo_activity: Callable[[int, int], int] = activity.defn(name="demo_activity")(
demo_activity if not testing else mock_demo_activity
)
4 changes: 2 additions & 2 deletions agents-api/agents_api/activities/embed_docs.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from beartype import beartype
from temporalio import activity

from ..clients import cozo
from ..clients import embed as embedder
from ..clients.cozo import get_cozo_client
from ..env import testing
from ..models.docs.embed_snippets import embed_snippets as embed_snippets_query
from .types import EmbedDocsPayload
Expand All @@ -28,7 +28,7 @@ async def embed_docs(payload: EmbedDocsPayload, cozo_client=None) -> None:
doc_id=payload.doc_id,
snippet_indices=indices,
embeddings=embeddings,
client=cozo_client or get_cozo_client(),
client=cozo_client or cozo.get_cozo_client(),
)


Expand Down
7 changes: 4 additions & 3 deletions agents-api/agents_api/activities/logger.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import logging
from typing import TextIO

logger = logging.getLogger(__name__)
h = logging.StreamHandler()
fmt = logging.Formatter("[%(asctime)s/%(levelname)s] - %(message)s")
logger: logging.Logger = logging.getLogger(__name__)
h: logging.StreamHandler[TextIO] = logging.StreamHandler()
fmt: logging.Formatter = logging.Formatter("[%(asctime)s/%(levelname)s] - %(message)s")
h.setFormatter(fmt)
logger.addHandler(h)
4 changes: 2 additions & 2 deletions agents-api/agents_api/activities/summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@


# TODO: remove stubs
def entries_summarization_query(*args, **kwargs):
def entries_summarization_query(*args, **kwargs) -> pd.DataFrame:
return pd.DataFrame()


def get_toplevel_entries_query(*args, **kwargs):
def get_toplevel_entries_query(*args, **kwargs) -> pd.DataFrame:
return pd.DataFrame()


Expand Down
4 changes: 4 additions & 0 deletions agents-api/agents_api/activities/task_steps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@

from .evaluate_step import evaluate_step
from .if_else_step import if_else_step
from .log_step import log_step
from .prompt_step import prompt_step
from .raise_complete_async import raise_complete_async
from .return_step import return_step
from .switch_step import switch_step
from .tool_call_step import tool_call_step
from .transition_step import transition_step
from .wait_for_input_step import wait_for_input_step
from .yield_step import yield_step
27 changes: 16 additions & 11 deletions agents-api/agents_api/activities/task_steps/evaluate_step.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,30 @@
from typing import Any
import logging

from beartype import beartype
from temporalio import activity

from ...activities.task_steps.utils import simple_eval_dict
from ...autogen.openapi_model import EvaluateStep
from ...common.protocol.tasks import (
StepContext,
StepOutcome,
)
from ...common.protocol.tasks import StepContext, StepOutcome
from ...env import testing


@beartype
async def evaluate_step(
context: StepContext[EvaluateStep],
) -> StepOutcome[dict[str, Any]]:
exprs = context.definition.arguments
output = simple_eval_dict(exprs, values=context.model_dump())
async def evaluate_step(context: StepContext) -> StepOutcome:
# NOTE: This activity is only for returning immediately, so we just evaluate the expression
# Hence, it's a local activity and SHOULD NOT fail
try:
assert isinstance(context.current_step, EvaluateStep)

exprs = context.current_step.evaluate
output = simple_eval_dict(exprs, values=context.model_dump())

result = StepOutcome(output=output)
return result

return StepOutcome(output=output)
except BaseException as e:
logging.error(f"Error in evaluate_step: {e}")
return StepOutcome(error=str(e))


# Note: This is here just for clarity. We could have just imported evaluate_step directly
Expand Down
42 changes: 29 additions & 13 deletions agents-api/agents_api/activities/task_steps/if_else_step.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,40 @@
import logging

from beartype import beartype
from simpleeval import simple_eval
from temporalio import activity

from ...autogen.openapi_model import (
IfElseWorkflowStep,
)
from ...autogen.openapi_model import IfElseWorkflowStep
from ...common.protocol.tasks import (
StepContext,
StepOutcome,
)
from ...env import testing


@activity.defn
@beartype
async def if_else_step(context: StepContext[IfElseWorkflowStep]) -> dict:
raise NotImplementedError()
# context_data: dict = context.model_dump()
async def if_else_step(context: StepContext) -> StepOutcome:
# NOTE: This activity is only for logging, so we just evaluate the expression
# Hence, it's a local activity and SHOULD NOT fail
try:
assert isinstance(context.current_step, IfElseWorkflowStep)

expr: str = context.current_step.if_
output = simple_eval(expr, names=context.model_dump())
output: bool = bool(output)

result = StepOutcome(output=output)
return result

# next_workflow = (
# context.definition.then
# if simple_eval(context.definition.if_, names=context_data)
# else context.definition.else_
# )
except BaseException as e:
logging.error(f"Error in if_else_step: {e}")
return StepOutcome(error=str(e))

# return {"goto_workflow": next_workflow}

# Note: This is here just for clarity. We could have just imported if_else_step directly
# They do the same thing, so we dont need to mock the if_else_step function
mock_if_else_step = if_else_step

if_else_step = activity.defn(name="if_else_step")(
if_else_step if not testing else mock_if_else_step
)
37 changes: 37 additions & 0 deletions agents-api/agents_api/activities/task_steps/log_step.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import logging

from beartype import beartype
from simpleeval import simple_eval
from temporalio import activity

from ...autogen.openapi_model import LogStep
from ...common.protocol.tasks import (
StepContext,
StepOutcome,
)
from ...env import testing


@beartype
async def log_step(context: StepContext) -> StepOutcome:
# NOTE: This activity is only for logging, so we just evaluate the expression
# Hence, it's a local activity and SHOULD NOT fail
try:
assert isinstance(context.current_step, LogStep)

expr: str = context.current_step.log
output = simple_eval(expr, names=context.model_dump())

result = StepOutcome(output=output)
return result

except BaseException as e:
logging.error(f"Error in log_step: {e}")
return StepOutcome(error=str(e))


# Note: This is here just for clarity. We could have just imported log_step directly
# They do the same thing, so we dont need to mock the log_step function
mock_log_step = log_step

log_step = activity.defn(name="log_step")(log_step if not testing else mock_log_step)
20 changes: 7 additions & 13 deletions agents-api/agents_api/activities/task_steps/prompt_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,25 @@
from beartype import beartype
from temporalio import activity

from ...autogen.openapi_model import (
InputChatMLMessage,
PromptStep,
)
from ...autogen.openapi_model import InputChatMLMessage
from ...clients import (
litellm, # We dont directly import `acompletion` so we can mock it
)
from ...common.protocol.tasks import (
StepContext,
StepOutcome,
)
from ...common.protocol.tasks import StepContext, StepOutcome
from ...common.utils.template import render_template


@activity.defn
@beartype
async def prompt_step(context: StepContext[PromptStep]) -> StepOutcome:
async def prompt_step(context: StepContext) -> StepOutcome:
# Get context data
context_data: dict = context.model_dump()

# Render template messages
prompt = (
[InputChatMLMessage(content=context.definition.prompt)]
if isinstance(context.definition.prompt, str)
else context.definition.prompt
[InputChatMLMessage(content=context.current_step.prompt)]
if isinstance(context.current_step.prompt, str)
else context.current_step.prompt
)

template_messages: list[InputChatMLMessage] = prompt
Expand All @@ -47,7 +41,7 @@ async def prompt_step(context: StepContext[PromptStep]) -> StepOutcome:
for m in messages
]

settings: dict = context.definition.settings.model_dump()
settings: dict = context.current_step.settings.model_dump()
# Get settings and run llm
response = await litellm.acompletion(
messages=messages,
Expand Down
37 changes: 37 additions & 0 deletions agents-api/agents_api/activities/task_steps/return_step.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import logging

from temporalio import activity

from ...activities.task_steps.utils import simple_eval_dict
from ...autogen.openapi_model import ReturnStep
from ...common.protocol.tasks import (
StepContext,
StepOutcome,
)
from ...env import testing


async def return_step(context: StepContext) -> StepOutcome:
# NOTE: This activity is only for returning immediately, so we just evaluate the expression
# Hence, it's a local activity and SHOULD NOT fail
try:
assert isinstance(context.current_step, ReturnStep)

exprs: dict[str, str] = context.current_step.return_
output = simple_eval_dict(exprs, values=context.model_dump())

result = StepOutcome(output=output)
return result

except BaseException as e:
logging.error(f"Error in log_step: {e}")
return StepOutcome(error=str(e))


# Note: This is here just for clarity. We could have just imported return_step directly
# They do the same thing, so we dont need to mock the return_step function
mock_return_step = return_step

return_step = activity.defn(name="return_step")(
return_step if not testing else mock_return_step
)
48 changes: 48 additions & 0 deletions agents-api/agents_api/activities/task_steps/switch_step.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import logging

from beartype import beartype
from simpleeval import simple_eval
from temporalio import activity

from ...autogen.openapi_model import SwitchStep
from ...common.protocol.tasks import (
StepContext,
StepOutcome,
)
from ...env import testing


@beartype
async def switch_step(context: StepContext) -> StepOutcome:
# NOTE: This activity is only for logging, so we just evaluate the expression
# Hence, it's a local activity and SHOULD NOT fail
try:
assert isinstance(context.current_step, SwitchStep)

# Assume that none of the cases evaluate to truthy
output: int = -1

cases: list[str] = [c.case for c in context.current_step.switch]

for i, case in enumerate(cases):
result = simple_eval(case, names=context.model_dump())

if result:
output = i
break

result = StepOutcome(output=output)
return result

except BaseException as e:
logging.error(f"Error in switch_step: {e}")
return StepOutcome(error=str(e))


# Note: This is here just for clarity. We could have just imported switch_step directly
# They do the same thing, so we dont need to mock the switch_step function
mock_switch_step = switch_step

switch_step = activity.defn(name="switch_step")(
switch_step if not testing else mock_switch_step
)
6 changes: 3 additions & 3 deletions agents-api/agents_api/activities/task_steps/tool_call_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
@beartype
async def tool_call_step(context: StepContext) -> dict:
raise NotImplementedError()
# assert isinstance(context.definition, ToolCallStep)
# assert isinstance(context.current_step, ToolCallStep)

# context.definition.tool_id
# context.definition.arguments
# context.current_step.tool_id
# context.current_step.arguments
# # get tool by id
# # call tool

Expand Down
25 changes: 25 additions & 0 deletions agents-api/agents_api/activities/task_steps/wait_for_input_step.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from temporalio import activity

from ...activities.task_steps.utils import simple_eval_dict
from ...autogen.openapi_model import WaitForInputStep
from ...common.protocol.tasks import StepContext, StepOutcome
from ...env import testing


async def wait_for_input_step(context: StepContext) -> StepOutcome:
assert isinstance(context.current_step, WaitForInputStep)

exprs = context.current_step.wait_for_input
output = simple_eval_dict(exprs, values=context.model_dump())

result = StepOutcome(output=output)
return result


# Note: This is here just for clarity. We could have just imported wait_for_input_step directly
# They do the same thing, so we dont need to mock the wait_for_input_step function
mock_wait_for_input_step = wait_for_input_step

wait_for_input_step = activity.defn(name="wait_for_input_step")(
wait_for_input_step if not testing else mock_wait_for_input_step
)
Loading

0 comments on commit 72da581

Please sign in to comment.