Skip to content

Commit

Permalink
Add finish flag, and rename flush of wandb callback
Browse files Browse the repository at this point in the history
Added a finish flag to flush to allow wandb runs to be more easily managed by the callback usage as opposed to extracting wandb specific behavior

Renamed the flush command to a more generic name for trackers
  • Loading branch information
ash0ts committed Mar 2, 2023
1 parent eb4f1c5 commit 64146a0
Showing 1 changed file with 32 additions and 7 deletions.
39 changes: 32 additions & 7 deletions langchain/callbacks/wandb_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,12 @@ def __init__(
name=self.name,
notes=self.notes,
)
wandb.termwarn(
"""The wandb callback is currently in beta and is subject to change based on updates to `langchain`.
Please report any issues to https://github.com/wandb/wandb/issues with the tag `langchain`.
""",
repeat=False,
)
self.callback_columns = []
self.action_records = []
self.complexity_metrics = complexity_metrics
Expand Down Expand Up @@ -529,7 +535,6 @@ def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
}
)
resp.update(self.get_custom_callback_meta())

self.on_agent_action_records.append(resp)
self.action_records.append(resp)
self.run.log(resp)
Expand Down Expand Up @@ -588,10 +593,11 @@ def _create_session_analysis_df(self):
session_analysis_df = pd.concat([llm_input_prompts_df, llm_outputs_df], axis=1)
return session_analysis_df

def flush_and_reset_session(
def flush_tracker(
self,
langchain_asset=None,
reset: bool = True,
finish: bool = False,
job_type: Optional[str] = None,
project: Optional[str] = None,
entity: Optional[str] = None,
Expand All @@ -602,6 +608,25 @@ def flush_and_reset_session(
visualize: Optional[bool] = None,
complexity_metrics: Optional[bool] = None,
):
"""Flush the tracker and reset the session.
Args:
langchain_asset: The langchain asset to save.
reset: Whether to reset the session.
finish: Whether to finish the run.
job_type: The job type.
project: The project.
entity: The entity.
tags: The tags.
group: The group.
name: The name.
notes: The notes.
visualize: Whether to visualize.
complexity_metrics: Whether to compute complexity metrics.
Returns:
None
"""
action_records_table = wandb.Table(dataframe=pd.DataFrame(self.action_records))
session_analysis_table = wandb.Table(
dataframe=self._create_session_analysis_df()
Expand Down Expand Up @@ -630,9 +655,10 @@ def flush_and_reset_session(
pass
self.run.log_artifact(model_artifact)

if reset:
if finish or reset:
self.run.finish()
self.temp_dir.cleanup()
if reset:
self.__init__(
job_type=job_type if job_type else self.job_type,
project=project if project else self.project,
Expand Down Expand Up @@ -670,7 +696,7 @@ def main():

# SCENARIO 1 - LLM
llm_result = llm.generate(["Tell me a joke", "Tell me a poem"] * 3)
wandb_callback.flush_and_reset_session(llm, name="simple_sequential")
wandb_callback.flush_tracker(llm, name="simple_sequential")

# SCENARIO 2 - Chain
template = """You are a playwright. Given the title of play, it is your job to write a synopsis for that title.
Expand Down Expand Up @@ -698,7 +724,7 @@ def main():
{"input": "the best in class mlops tooling"},
]
overall_chain.apply(test_prompts)
wandb_callback.flush_and_reset_session(overall_chain, name="agent")
wandb_callback.flush_tracker(overall_chain, name="agent")

# SCENARIO 3 - Agent with Tools
tools = load_tools(["serpapi", "llm-math"], llm=llm, callback_manager=manager)
Expand All @@ -712,8 +738,7 @@ def main():
agent.run(
"Who is Leo DiCaprio's girlfriend? What is her current age raised to the 0.43 power?"
)
wandb_callback.flush_and_reset_session(agent, reset=False)
wandb_callback.run.finish()
wandb_callback.flush_tracker(agent, reset=False, finish=True)


if __name__ == "__main__":
Expand Down

0 comments on commit 64146a0

Please sign in to comment.