Skip to content

Commit

Permalink
Rm RunTypeEnum (#8553)
Browse files Browse the repository at this point in the history
We already support raw strings in the SDK but would like to deprecate
client-side validation of run types. This removes its usage
  • Loading branch information
hinthornw authored Aug 1, 2023
1 parent 2a26cc6 commit e83250c
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 39 deletions.
28 changes: 14 additions & 14 deletions libs/langchain/langchain/callbacks/tracers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from tenacity import RetryCallState

from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum
from langchain.callbacks.tracers.schemas import Run
from langchain.load.dump import dumpd
from langchain.schema.document import Document
from langchain.schema.output import ChatGeneration, LLMResult
Expand Down Expand Up @@ -110,7 +110,7 @@ def on_llm_start(
start_time=start_time,
execution_order=execution_order,
child_execution_order=execution_order,
run_type=RunTypeEnum.llm,
run_type="llm",
tags=tags or [],
)
self._start_trace(llm_run)
Expand All @@ -130,7 +130,7 @@ def on_llm_new_token(

run_id_ = str(run_id)
llm_run = self.run_map.get(run_id_)
if llm_run is None or llm_run.run_type != RunTypeEnum.llm:
if llm_run is None or llm_run.run_type != "llm":
raise TracerException("No LLM Run found to be traced")
llm_run.events.append(
{
Expand Down Expand Up @@ -182,7 +182,7 @@ def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> Non

run_id_ = str(run_id)
llm_run = self.run_map.get(run_id_)
if llm_run is None or llm_run.run_type != RunTypeEnum.llm:
if llm_run is None or llm_run.run_type != "llm":
raise TracerException("No LLM Run found to be traced")
llm_run.outputs = response.dict()
for i, generations in enumerate(response.generations):
Expand Down Expand Up @@ -210,7 +210,7 @@ def on_llm_error(

run_id_ = str(run_id)
llm_run = self.run_map.get(run_id_)
if llm_run is None or llm_run.run_type != RunTypeEnum.llm:
if llm_run is None or llm_run.run_type != "llm":
raise TracerException("No LLM Run found to be traced")
llm_run.error = repr(error)
llm_run.end_time = datetime.utcnow()
Expand Down Expand Up @@ -246,7 +246,7 @@ def on_chain_start(
execution_order=execution_order,
child_execution_order=execution_order,
child_runs=[],
run_type=RunTypeEnum.chain,
run_type="chain",
tags=tags or [],
)
self._start_trace(chain_run)
Expand All @@ -259,7 +259,7 @@ def on_chain_end(
if not run_id:
raise TracerException("No run_id provided for on_chain_end callback.")
chain_run = self.run_map.get(str(run_id))
if chain_run is None or chain_run.run_type != RunTypeEnum.chain:
if chain_run is None or chain_run.run_type != "chain":
raise TracerException("No chain Run found to be traced")

chain_run.outputs = outputs
Expand All @@ -279,7 +279,7 @@ def on_chain_error(
if not run_id:
raise TracerException("No run_id provided for on_chain_error callback.")
chain_run = self.run_map.get(str(run_id))
if chain_run is None or chain_run.run_type != RunTypeEnum.chain:
if chain_run is None or chain_run.run_type != "chain":
raise TracerException("No chain Run found to be traced")

chain_run.error = repr(error)
Expand Down Expand Up @@ -316,7 +316,7 @@ def on_tool_start(
execution_order=execution_order,
child_execution_order=execution_order,
child_runs=[],
run_type=RunTypeEnum.tool,
run_type="tool",
tags=tags or [],
)
self._start_trace(tool_run)
Expand All @@ -327,7 +327,7 @@ def on_tool_end(self, output: str, *, run_id: UUID, **kwargs: Any) -> None:
if not run_id:
raise TracerException("No run_id provided for on_tool_end callback.")
tool_run = self.run_map.get(str(run_id))
if tool_run is None or tool_run.run_type != RunTypeEnum.tool:
if tool_run is None or tool_run.run_type != "tool":
raise TracerException("No tool Run found to be traced")

tool_run.outputs = {"output": output}
Expand All @@ -347,7 +347,7 @@ def on_tool_error(
if not run_id:
raise TracerException("No run_id provided for on_tool_error callback.")
tool_run = self.run_map.get(str(run_id))
if tool_run is None or tool_run.run_type != RunTypeEnum.tool:
if tool_run is None or tool_run.run_type != "tool":
raise TracerException("No tool Run found to be traced")

tool_run.error = repr(error)
Expand Down Expand Up @@ -386,7 +386,7 @@ def on_retriever_start(
child_execution_order=execution_order,
tags=tags,
child_runs=[],
run_type=RunTypeEnum.retriever,
run_type="retriever",
)
self._start_trace(retrieval_run)
self._on_retriever_start(retrieval_run)
Expand All @@ -402,7 +402,7 @@ def on_retriever_error(
if not run_id:
raise TracerException("No run_id provided for on_retriever_error callback.")
retrieval_run = self.run_map.get(str(run_id))
if retrieval_run is None or retrieval_run.run_type != RunTypeEnum.retriever:
if retrieval_run is None or retrieval_run.run_type != "retriever":
raise TracerException("No retriever Run found to be traced")

retrieval_run.error = repr(error)
Expand All @@ -418,7 +418,7 @@ def on_retriever_end(
if not run_id:
raise TracerException("No run_id provided for on_retriever_end callback.")
retrieval_run = self.run_map.get(str(run_id))
if retrieval_run is None or retrieval_run.run_type != RunTypeEnum.retriever:
if retrieval_run is None or retrieval_run.run_type != "retriever":
raise TracerException("No retriever Run found to be traced")
retrieval_run.outputs = {"documents": documents}
retrieval_run.end_time = datetime.utcnow()
Expand Down
4 changes: 2 additions & 2 deletions libs/langchain/langchain/callbacks/tracers/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from langsmith import Client

from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum, TracerSession
from langchain.callbacks.tracers.schemas import Run, TracerSession
from langchain.env import get_runtime_environment
from langchain.load.dump import dumpd
from langchain.schema.messages import BaseMessage
Expand Down Expand Up @@ -107,7 +107,7 @@ def on_chat_model_start(
start_time=start_time,
execution_order=execution_order,
child_execution_order=execution_order,
run_type=RunTypeEnum.llm,
run_type="llm",
tags=tags,
)
self._start_trace(chat_model_run)
Expand Down
13 changes: 12 additions & 1 deletion libs/langchain/langchain/callbacks/tracers/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,27 @@
from __future__ import annotations

import datetime
import warnings
from typing import Any, Dict, List, Optional
from uuid import UUID

from langsmith.schemas import RunBase as BaseRunV2
from langsmith.schemas import RunTypeEnum
from langsmith.schemas import RunTypeEnum as RunTypeEnumDep
from pydantic import BaseModel, Field, root_validator

from langchain.schema import LLMResult


def RunTypeEnum() -> RunTypeEnumDep:
"""RunTypeEnum."""
warnings.warn(
"RunTypeEnum is deprecated. Please directly use a string instead"
" (e.g. 'llm', 'chain', 'tool').",
DeprecationWarning,
)
return RunTypeEnumDep


class TracerSessionV1Base(BaseModel):
"""Base class for TracerSessionV1."""

Expand Down
8 changes: 4 additions & 4 deletions libs/langchain/langchain/callbacks/tracers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)

from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum
from langchain.callbacks.tracers.schemas import Run

if TYPE_CHECKING:
from wandb import Settings as WBSettings
Expand Down Expand Up @@ -154,11 +154,11 @@ def _convert_lc_run_to_wb_span(self, run: Run) -> "Span":
:param run: The LangChain Run to convert.
:return: The converted W&B Trace Span.
"""
if run.run_type == RunTypeEnum.llm:
if run.run_type == "llm":
return self._convert_llm_run_to_wb_span(run)
elif run.run_type == RunTypeEnum.chain:
elif run.run_type == "chain":
return self._convert_chain_run_to_wb_span(run)
elif run.run_type == RunTypeEnum.tool:
elif run.run_type == "tool":
return self._convert_tool_run_to_wb_span(run)
else:
return self._convert_run_to_wb_span(run)
Expand Down
18 changes: 9 additions & 9 deletions libs/langchain/langchain/smith/evaluation/runner_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from urllib.parse import urlparse, urlunparse

from langsmith import Client, RunEvaluator
from langsmith.schemas import Dataset, DataType, Example, RunTypeEnum
from langsmith.schemas import Dataset, DataType, Example

from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.manager import Callbacks
Expand Down Expand Up @@ -341,9 +341,9 @@ def _setup_evaluation(
first_example, examples = _first_example(examples)
if isinstance(llm_or_chain_factory, BaseLanguageModel):
run_inputs, run_outputs = None, None
run_type = RunTypeEnum.llm
run_type = "llm"
else:
run_type = RunTypeEnum.chain
run_type = "chain"
if data_type in (DataType.chat, DataType.llm):
raise ValueError(
"Cannot evaluate a chain on dataset with "
Expand All @@ -370,13 +370,13 @@ def _setup_evaluation(
def _determine_input_key(
config: RunEvalConfig,
run_inputs: Optional[List[str]],
run_type: RunTypeEnum,
run_type: str,
) -> Optional[str]:
if config.input_key:
input_key = config.input_key
if run_inputs and input_key not in run_inputs:
raise ValueError(f"Input key {input_key} not in run inputs {run_inputs}")
elif run_type == RunTypeEnum.llm:
elif run_type == "llm":
input_key = None
elif run_inputs and len(run_inputs) == 1:
input_key = run_inputs[0]
Expand All @@ -391,15 +391,15 @@ def _determine_input_key(
def _determine_prediction_key(
config: RunEvalConfig,
run_outputs: Optional[List[str]],
run_type: RunTypeEnum,
run_type: str,
) -> Optional[str]:
if config.prediction_key:
prediction_key = config.prediction_key
if run_outputs and prediction_key not in run_outputs:
raise ValueError(
f"Prediction key {prediction_key} not in run outputs {run_outputs}"
)
elif run_type == RunTypeEnum.llm:
elif run_type == "llm":
prediction_key = None
elif run_outputs and len(run_outputs) == 1:
prediction_key = run_outputs[0]
Expand Down Expand Up @@ -432,7 +432,7 @@ def _determine_reference_key(
def _construct_run_evaluator(
eval_config: Union[EvaluatorType, EvalConfig],
eval_llm: BaseLanguageModel,
run_type: RunTypeEnum,
run_type: str,
data_type: DataType,
example_outputs: Optional[List[str]],
reference_key: Optional[str],
Expand Down Expand Up @@ -472,7 +472,7 @@ def _construct_run_evaluator(

def _load_run_evaluators(
config: RunEvalConfig,
run_type: RunTypeEnum,
run_type: str,
data_type: DataType,
example_outputs: Optional[List[str]],
run_inputs: Optional[List[str]],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Any, Dict, List, Optional

from langsmith import EvaluationResult, RunEvaluator
from langsmith.schemas import DataType, Example, Run, RunTypeEnum
from langsmith.schemas import DataType, Example, Run

from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
Expand Down Expand Up @@ -327,7 +327,7 @@ async def aevaluate_run(
def from_run_and_data_type(
cls,
evaluator: StringEvaluator,
run_type: RunTypeEnum,
run_type: str,
data_type: DataType,
input_key: Optional[str] = None,
prediction_key: Optional[str] = None,
Expand All @@ -343,7 +343,7 @@ def from_run_and_data_type(
Args:
evaluator (StringEvaluator): The string evaluator to use.
run_type (RunTypeEnum): The type of run being evaluated.
run_type (str): The type of run being evaluated.
Supported types are LLM and Chain.
data_type (DataType): The type of dataset used in the run.
input_key (str, optional): The key used to map the input from the run.
Expand All @@ -361,9 +361,9 @@ def from_run_and_data_type(
""" # noqa: E501

# Configure how run inputs/predictions are passed to the evaluator
if run_type == RunTypeEnum.llm:
if run_type == "llm":
run_mapper: StringRunMapper = LLMStringRunMapper()
elif run_type == RunTypeEnum.chain:
elif run_type == "chain":
run_mapper = ChainStringRunMapper(
input_key=input_key, prediction_key=prediction_key
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
ToolRun,
TracerSessionV1,
)
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum, TracerSessionV1Base
from langchain.callbacks.tracers.schemas import Run, TracerSessionV1Base
from langchain.schema import LLMResult
from langchain.schema.messages import HumanMessage

Expand Down Expand Up @@ -589,7 +589,7 @@ def test_convert_run(
outputs=LLMResult(generations=[[]]).dict(),
serialized={},
extra={},
run_type=RunTypeEnum.llm,
run_type="llm",
)
chain_run = Run(
id="57a08cc4-73d2-4236-8371-549099d07fad",
Expand All @@ -603,7 +603,7 @@ def test_convert_run(
outputs={},
child_runs=[llm_run],
extra={},
run_type=RunTypeEnum.chain,
run_type="chain",
)

tool_run = Run(
Expand All @@ -618,7 +618,7 @@ def test_convert_run(
serialized={},
child_runs=[],
extra={},
run_type=RunTypeEnum.tool,
run_type="tool",
)

expected_llm_run = LLMRun(
Expand Down

0 comments on commit e83250c

Please sign in to comment.