Skip to content

Commit

Permalink
Merge pull request #8 from VowpalWabbit/update_w_score
Browse files Browse the repository at this point in the history
update score to take entire response object to make it easier for user
  • Loading branch information
olgavrou authored Aug 29, 2023
2 parents d46ad01 + 48aaa27 commit 256849e
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
3 changes: 2 additions & 1 deletion libs/langchain/langchain/chains/rl_chain/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def output_keys(self) -> List[str]:
return [self.output_key]

def update_with_delayed_score(
self, score: float, event: TEvent, force_score: bool = False
self, score: float, chain_response: Dict[str, Any], force_score: bool = False
) -> None:
"""
Updates the learned policy with the score provided.
Expand All @@ -415,6 +415,7 @@ def update_with_delayed_score(
)
if self.metrics:
self.metrics.on_feedback(score)
event: TEvent = chain_response["selection_metadata"]
self._call_after_scoring_before_learning(event=event, score=score)
self.active_policy.learn(event=event)
self.active_policy.log(event=event)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def test_update_with_delayed_score_with_auto_validator_throws() -> None:
selection_metadata = response["selection_metadata"]
assert selection_metadata.selected.score == 3.0
with pytest.raises(RuntimeError):
chain.update_with_delayed_score(event=selection_metadata, score=100)
chain.update_with_delayed_score(chain_response=response, score=100)


@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
Expand All @@ -109,7 +109,7 @@ def test_update_with_delayed_score_force() -> None:
selection_metadata = response["selection_metadata"]
assert selection_metadata.selected.score == 3.0
chain.update_with_delayed_score(
event=selection_metadata, score=100, force_score=True
chain_response=response, score=100, force_score=True
)
assert selection_metadata.selected.score == 100.0

Expand All @@ -131,7 +131,7 @@ def test_update_with_delayed_score() -> None:
assert response["response"] == "hey"
selection_metadata = response["selection_metadata"]
assert selection_metadata.selected.score is None
chain.update_with_delayed_score(event=selection_metadata, score=100)
chain.update_with_delayed_score(chain_response=response, score=100)
assert selection_metadata.selected.score == 100.0


Expand Down

0 comments on commit 256849e

Please sign in to comment.