diff --git a/fastchat/serve/moderation/moderator.py b/fastchat/serve/moderation/moderator.py index c5fbd042c..90b0299ec 100644 --- a/fastchat/serve/moderation/moderator.py +++ b/fastchat/serve/moderation/moderator.py @@ -17,6 +17,9 @@ def __init__(self): self.conv_moderation_responses: List[ Dict[str, Dict[str, Union[str, Dict[str, float]]]] ] = [] + self.text_flagged = False + self.csam_flagged = False + self.nsfw_flagged = False def _image_moderation_filter(self, image: Image) -> Tuple[bool, bool]: """Function that detects whether image violates moderation policies. @@ -34,6 +37,11 @@ def _text_moderation_filter(self, text: str) -> bool: """ raise NotImplementedError + def reset_moderation_flags(self): + self.text_flagged = False + self.csam_flagged = False + self.nsfw_flagged = False + def image_and_text_moderation_filter( self, image: Image, text: str ) -> Dict[str, Dict[str, Union[str, Dict[str, float]]]]: @@ -77,9 +85,6 @@ def __init__(self, use_remote_storage: bool = False): to the moderation API. """ super().__init__() - self.text_flagged = False - self.csam_flagged = False - self.nsfw_flagged = False def _image_moderation_request( self, image_bytes: bytes, endpoint: str, api_key: str @@ -253,6 +258,7 @@ def image_and_text_moderation_filter( } """ print("moderating text: ", text) + self.reset_moderation_flags() text_flagged_map = self.text_moderation_filter(text, model_list, do_moderation) if image is not None: