Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
HowieHwong committed Apr 21, 2024
1 parent 62f9d6d commit 4ae6f6b
Showing 1 changed file with 5 additions and 23 deletions.
28 changes: 5 additions & 23 deletions trustllm_pkg/trustllm/utils/longformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# Configure basic logging
logging.basicConfig(filename='huggingface_evaluator.log', level=logging.INFO,
format='%(asctime)s:%(levelname)s:%(message)s')

PROGRESS_FILENAME = 'longformer_eval.json'

class HuggingFaceEvaluator:
"""
Expand Down Expand Up @@ -35,7 +35,7 @@ def __init__(self, model_name='LibrAI/longformer-harmful-ro', device=None, save_
if not os.path.exists(self.save_dir):
os.makedirs(self.save_dir)

def evaluate(self, data, resume=False, progress_filename='longformer_eval.json'):
def evaluate(self, data, resume=False, progress_filename=PROGRESS_FILENAME):
"""
Evaluate a dataset using the loaded Hugging Face model.
Expand All @@ -62,41 +62,22 @@ def evaluate(self, data, resume=False, progress_filename='longformer_eval.json')
print('Valid data number: {}'.format(len(data)))
print('Evaluating...')






evaluated_data = self._evaluate_batch(data)


# for item in tqdm(data):
# try:
# if 'eval_res' not in item:
# text = item.get('res', '')
# eval_res = self.classifier(text)[0]["label"]
# item['eval_res'] = eval_res
# logging.info("Evaluated item: %s", item)
# except Exception as e:
# logging.error("Error evaluating item %s: %s", item, str(e))
# file_process.save_json(data, os.path.join(self.save_dir, progress_filename))
# raise

file_process.save_json(data, os.path.join(self.save_dir, progress_filename))
return evaluated_data



def _evaluate_batch(self,data, batch_size=32):
def _evaluate_batch(self,data, batch_size=32, progress_filename=PROGRESS_FILENAME):
results = []
for i in tqdm(range(0, len(data), batch_size)):
batch = data[i:i+batch_size]
texts = [item.get('res', '') for item in batch if 'eval_res' not in item]
if texts:
try:
# 使用批处理进行分类
batch_outputs = self.classifier(texts)
assert len(batch_outputs) == len(texts) # 确保输出与输入对应
assert len(batch_outputs) == len(texts)
idx = 0
for item in batch:
if 'eval_res' not in item:
Expand All @@ -106,6 +87,7 @@ def _evaluate_batch(self,data, batch_size=32):
logging.info("Processed batch from %s to %s", i, i+batch_size)
except Exception as e:
logging.error("Error processing batch %s to %s: %s", i, i+batch_size, str(e))
file_process.save_json(data, os.path.join(self.save_dir, progress_filename))
raise
else:
results.extend(batch)
Expand Down

0 comments on commit 4ae6f6b

Please sign in to comment.