Skip to content

Commit

Permalink
Merge pull request #7 from VowpalWabbit/scorer_activate_deactivate
Browse files Browse the repository at this point in the history
activate and deactivate scorer
  • Loading branch information
olgavrou authored Aug 29, 2023
2 parents 5fb781d + c4ccaeb commit d46ad01
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 32 deletions.
77 changes: 45 additions & 32 deletions libs/langchain/langchain/chains/rl_chain/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ def log(self, event: TEvent) -> None:
selection_scorer: Union[SelectionScorer, None]
active_policy: Policy = _NoOpPolicy()
auto_embed: bool = False
selection_scorer_activated: bool = True
selected_input_key = "rl_chain_selected"
selected_based_on_input_key = "rl_chain_selected_based_on"
metrics: Optional[MetricsTracker] = None
Expand Down Expand Up @@ -400,6 +401,42 @@ def output_keys(self) -> List[str]:
"""
return [self.output_key]

def update_with_delayed_score(
self, score: float, event: TEvent, force_score: bool = False
) -> None:
"""
Updates the learned policy with the score provided.
Will raise an error if selection_scorer is set, and force_score=True was not provided during the method call
""" # noqa: E501
if self._can_use_selection_scorer() and not force_score:
raise RuntimeError(
"The selection scorer is set, and force_score was not set to True. \
Please set force_score=True to use this function."
)
if self.metrics:
self.metrics.on_feedback(score)
self._call_after_scoring_before_learning(event=event, score=score)
self.active_policy.learn(event=event)
self.active_policy.log(event=event)

def deactivate_selection_scorer(self) -> None:
"""
Deactivates the selection scorer, meaning that the chain will no longer attempt to use the selection scorer to score responses.
""" # noqa: E501
self.selection_scorer_activated = False

def activate_selection_scorer(self) -> None:
"""
Activates the selection scorer, meaning that the chain will attempt to use the selection scorer to score responses.
""" # noqa: E501
self.selection_scorer_activated = True

def save_progress(self) -> None:
"""
This function should be called to save the state of the learned policy model.
""" # noqa: E501
self.active_policy.save()

def _validate_inputs(self, inputs: Dict[str, Any]) -> None:
super()._validate_inputs(inputs)
if (
Expand All @@ -412,6 +449,12 @@ def _validate_inputs(self, inputs: Dict[str, Any]) -> None:
they are reserved for internal use during auto reward."
)

def _can_use_selection_scorer(self) -> bool:
"""
Returns whether the chain can use the selection scorer to score responses or not.
""" # noqa: E501
return self.selection_scorer is not None and self.selection_scorer_activated

@abstractmethod
def _call_before_predict(self, inputs: Dict[str, Any]) -> TEvent:
...
Expand All @@ -434,30 +477,6 @@ def _call_after_scoring_before_learning(
) -> TEvent:
...

def update_with_delayed_score(
self, score: float, event: TEvent, force_score: bool = False
) -> None:
"""
Updates the learned policy with the score provided.
Will raise an error if selection_scorer is set, and force_score=True was not provided during the method call
""" # noqa: E501
if self.selection_scorer and not force_score:
raise RuntimeError(
"The selection scorer is set, and force_score was not set to True. \
Please set force_score=True to use this function."
)
if self.metrics:
self.metrics.on_feedback(score)
self._call_after_scoring_before_learning(event=event, score=score)
self.active_policy.learn(event=event)
self.active_policy.log(event=event)

def set_auto_embed(self, auto_embed: bool) -> None:
"""
Sets whether the chain should auto embed the inputs or not.
"""
self.auto_embed = auto_embed

def _call(
self,
inputs: Dict[str, Any],
Expand Down Expand Up @@ -494,8 +513,8 @@ def _call(

score = None
try:
if self.selection_scorer:
score = self.selection_scorer.score_response(
if self._can_use_selection_scorer():
score = self.selection_scorer.score_response( # type: ignore
inputs=next_chain_inputs, llm_response=output, event=event
)
except Exception as e:
Expand All @@ -511,12 +530,6 @@ def _call(

return {self.output_key: {"response": output, "selection_metadata": event}}

def save_progress(self) -> None:
"""
This function should be called to save the state of the learned policy model.
"""
self.active_policy.save()

@property
def _chain_type(self) -> str:
return "llm_personalizer_chain"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -363,3 +363,41 @@ def test_calling_chain_w_reserved_inputs_throws() -> None:
User=rl_chain.BasedOn("Context"),
rl_chain_selected=rl_chain.ToSelectFrom(["0", "1", "2"]),
)


@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_activate_and_deactivate_scorer() -> None:
llm, PROMPT = setup()
scorer_llm = FakeListChatModel(responses=[300])
chain = pick_best_chain.PickBest.from_llm(
llm=llm,
prompt=PROMPT,
selection_scorer=pick_best_chain.base.AutoSelectionScorer(llm=scorer_llm),
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()),
)
response = chain.run(
User=pick_best_chain.base.BasedOn("Context"),
action=pick_best_chain.base.ToSelectFrom(["0", "1", "2"]),
)
# chain llm used for both basic prompt and for scoring
assert response["response"] == "hey"
selection_metadata = response["selection_metadata"]
assert selection_metadata.selected.score == 300.0

chain.deactivate_selection_scorer()
response = chain.run(
User=pick_best_chain.base.BasedOn("Context"),
action=pick_best_chain.base.ToSelectFrom(["0", "1", "2"]),
)
assert response["response"] == "hey"
selection_metadata = response["selection_metadata"]
assert selection_metadata.selected.score is None

chain.activate_selection_scorer()
response = chain.run(
User=pick_best_chain.base.BasedOn("Context"),
action=pick_best_chain.base.ToSelectFrom(["0", "1", "2"]),
)
assert response["response"] == "hey"
selection_metadata = response["selection_metadata"]
assert selection_metadata.selected.score == 300.0

0 comments on commit d46ad01

Please sign in to comment.