Skip to content

Commit

Permalink
centralised logging
Browse files Browse the repository at this point in the history
  • Loading branch information
olgavrou committed Aug 17, 2023
1 parent f470b68 commit 43ac147
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 33 deletions.
9 changes: 1 addition & 8 deletions prompt_selection.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -347,13 +347,6 @@
"rnd_chain.metrics.to_pandas()['score'].plot(label=\"slates\")\n",
"plt.legend()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand All @@ -372,7 +365,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
"version": "3.8.10"
},
"orig_nbformat": 4
},
Expand Down
6 changes: 5 additions & 1 deletion rl_chain.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@
"import rl_chain\n",
"from langchain.prompts.prompt import PromptTemplate\n",
"\n",
"import logging\n",
"logger = logging.getLogger(\"rl_chain\")\n",
"logger.setLevel(logging.INFO)\n",
"\n",
"_PROMPT_TEMPLATE = \"\"\"Here is the description of a meal: {meal}.\n",
"\n",
"You have to embed this into the given text where it makes sense. Here is the given text: {text_to_personalize}.\n",
Expand Down Expand Up @@ -341,7 +345,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
"version": "3.8.10"
},
"orig_nbformat": 4
},
Expand Down
17 changes: 17 additions & 0 deletions rl_chain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,20 @@
Policy,
VwPolicy,
)

import logging


def configure_logger():
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
ch = logging.StreamHandler()
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
ch.setFormatter(formatter)
ch.setLevel(logging.INFO)
logger.addHandler(ch)


configure_logger()
5 changes: 3 additions & 2 deletions rl_chain/metrics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pandas as pd
from typing import Optional


class MetricsTracker:
def __init__(self, step: int):
self._history = []
Expand All @@ -20,7 +21,7 @@ def on_feedback(self, score: Optional[float]) -> None:
self._num += score or 0
self._i += 1
if self._step > 0 and self._i % self._step == 0:
self._history.append({'step': self._i, 'score': self.score})
self._history.append({"step": self._i, "score": self.score})

def to_pandas(self) -> pd.DataFrame:
return pd.DataFrame(self._history)
return pd.DataFrame(self._history)
8 changes: 4 additions & 4 deletions rl_chain/model_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
import glob
import logging

logger = logging.getLogger(__name__)


class ModelRepository:
def __init__(
self,
folder: Union[str, os.PathLike],
logger: logging.Logger,
with_history: bool = True,
reset: bool = False,
):
Expand All @@ -27,7 +28,6 @@ def __init__(
os.remove(self.model_path)

self.folder.mkdir(parents=True, exist_ok=True)
self.logger = logger

def get_tag(self) -> str:
return datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
Expand All @@ -37,7 +37,7 @@ def has_history(self) -> bool:

def save(self, workspace: vw.Workspace) -> None:
with open(self.model_path, "wb") as f:
self.logger.info(f"storing rl_chain model in: {self.model_path}")
logger.info(f"storing rl_chain model in: {self.model_path}")
f.write(workspace.serialize())
if self.with_history: # write history
shutil.copyfile(self.model_path, self.folder / f"model-{self.get_tag()}.vw")
Expand All @@ -48,6 +48,6 @@ def load(self, commandline: Sequence[str]) -> vw.Workspace:
with open(self.model_path, "rb") as f:
model_data = f.read()
if model_data:
self.logger.info(f"rl_chain model is loaded from: {self.model_path}")
logger.info(f"rl_chain model is loaded from: {self.model_path}")
return vw.Workspace(commandline, model_data=model_data)
return vw.Workspace(commandline)
4 changes: 4 additions & 0 deletions rl_chain/pick_best_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
from langchain.chains.llm import LLMChain
from sentence_transformers import SentenceTransformer

import logging

logger = logging.getLogger(__name__)

# sentinel object used to distinguish between user didn't supply anything or user explicitly supplied None
SENTINEL = object()

Expand Down
28 changes: 10 additions & 18 deletions rl_chain/rl_chain_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,6 @@
)

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
ch = logging.StreamHandler()
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
ch.setFormatter(formatter)
ch.setLevel(logging.INFO)
logger.addHandler(ch)


class _BasedOn:
Expand Down Expand Up @@ -168,14 +162,14 @@ def __init__(
model_repo: ModelRepository,
vw_cmd: Sequence[str],
feature_embedder: Embedder,
logger: VwLogger,
*_,
**__,
vw_logger: VwLogger,
*args,
**kwargs,
):
self.model_repo = model_repo
self.workspace = self.model_repo.load(vw_cmd)
self.feature_embedder = feature_embedder
self.logger = logger
self.vw_logger = vw_logger

def predict(self, event: Event) -> Any:
text_parser = vw.TextFormatParser(self.workspace)
Expand All @@ -191,9 +185,9 @@ def learn(self, event: Event):
self.workspace.learn_one(multi_ex)

def log(self, event: Event):
if self.logger.logging_enabled():
if self.vw_logger.logging_enabled():
vw_ex = self.feature_embedder.format(event)
self.logger.log(vw_ex)
self.vw_logger.log(vw_ex)

def save(self):
self.model_repo.save()
Expand Down Expand Up @@ -284,7 +278,7 @@ def __init__(
vw_cmd=None,
policy=VwPolicy,
vw_logs: Optional[Union[str, os.PathLike]] = None,
metrics_step = -1,
metrics_step=-1,
*args,
**kwargs,
):
Expand All @@ -295,11 +289,11 @@ def __init__(
)
self.policy = policy(
model_repo=ModelRepository(
model_save_dir, logger, with_history=True, reset=reset_model
model_save_dir, with_history=True, reset=reset_model
),
vw_cmd=vw_cmd or [],
feature_embedder=feature_embedder,
logger=VwLogger(vw_logs),
vw_logger=VwLogger(vw_logs),
)
self.metrics = MetricsTracker(step=metrics_step)

Expand Down Expand Up @@ -426,9 +420,7 @@ def _call(
f"The LLM was not able to rank and the chain was not able to adjust to this response, error: {e}"
)
self.metrics.on_feedback(score)
event = self._call_after_scoring_before_learning(
score=score, event=event
)
event = self._call_after_scoring_before_learning(score=score, event=event)
self.policy.learn(event=event)
self.policy.log(event=event)

Expand Down

0 comments on commit 43ac147

Please sign in to comment.