Skip to content

Commit

Permalink
Merge pull request #1 from argilla-io/cron-mmlu-translations
Browse files Browse the repository at this point in the history
New cron for c4ai
  • Loading branch information
dvsrepo authored Jul 9, 2024
2 parents 9395bfd + fa5deea commit 8243889
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 87 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/actions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ on:
- "main"
workflow_dispatch: # Enables manual trigger without additional inputs
schedule:
- cron: "0 */2 * * *" # Every hour at 0 and 30 minutes
- cron: "0 * * * *" # Every hour

jobs:
build:
Expand Down
124 changes: 39 additions & 85 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from datasets import Dataset
from datasets import Dataset, load_dataset, concatenate_datasets
import argilla as rg

# Required environment variables
Expand All @@ -8,90 +8,44 @@
ARGILLA_API_KEY = os.getenv("ARGILLA_API_KEY")
SOURCE_DATASET = os.getenv("SOURCE_DATASET")
PARSED_RESULTS_DATASET = os.getenv("HF_DATASET_RESULTS")

# Optional environment variables
REQUIRED_RESPONSES = int(os.getenv("REQUIRED_RESPONSES", "2"))
RESULTS_DATASET = os.getenv("RESULTS_DATASET", f"{SOURCE_DATASET}-results")
SOURCE_WORKSPACE = os.getenv("SOURCE_WORKSPACE", "admin")
RESULTS_WORKSPACE = os.getenv("RESULTS_WORKSPACE", "results")
DELETE_SOURCE_RECORDS = os.getenv("DELETE_SOURCE_RECORDS", "False").lower() == 'true'


def completed_with_overlap(records, required_responses):
"""
Filters records to find those with responses equal to or greater than the required amount.
"""
completed = [r for r in records if len(r.responses) >= required_responses]
return completed

def build_parsed_results(dataset):
"""
Constructs a new dataset from the original, extracting relevant fields and adding additional
fields for parsed results.
"""
questions = [(question.name, question.type) for question in dataset.questions]
results = []
for record in dataset:
result = {
"fields": dict(record.fields),
"metadata": dict(record.metadata),
"num_responses": len(record.responses),
"user_ids": [str(response.user_id) for response in record.responses]
}
for question, _ in questions:
result[question] = []
for response in record.responses:
for question, kind in questions:
if question in response.values:
value = response.values[question].value
if value is not None:
if kind == 'span':
result[question].append([dict(v) for v in value])
else:
result[question].append(value)
results.append(result)
return Dataset.from_list(results)

rg.init(api_url=ARGILLA_API_URL, api_key=ARGILLA_API_KEY)

# Ensure workspace exists
if RESULTS_WORKSPACE not in [workspace.name for workspace in rg.Workspace.list()]:
rg.Workspace.create(RESULTS_WORKSPACE)

dataset = rg.FeedbackDataset.from_argilla(SOURCE_DATASET, workspace=SOURCE_WORKSPACE)
print(f"Current dataset size: {len(dataset)}")

submitted_so_far = dataset.filter_by(response_status="submitted")
print(f"Submitted: {len(submitted_so_far)}")

completed_remote_records = completed_with_overlap(submitted_so_far.records, REQUIRED_RESPONSES)
print(f"Completed so far: {len(completed_remote_records)}")

if completed_remote_records:
local_submitted = submitted_so_far.pull()
completed_local_records = completed_with_overlap(local_submitted.records, REQUIRED_RESPONSES)
try:
results = rg.FeedbackDataset.from_argilla(RESULTS_DATASET, workspace=RESULTS_WORKSPACE)
results.add_records(completed_local_records)
if DELETE_SOURCE_RECORDS:
print("Deleting source records")
dataset.delete_records([r for r in completed_remote_records])
except Exception as e:
rg_dataset = rg.FeedbackDataset(
fields=local_submitted.fields,
questions=local_submitted.questions,
metadata_properties=local_submitted.metadata_properties,
guidelines=local_submitted.guidelines
)
rg_dataset.add_records(completed_local_records)
results = rg_dataset.push_to_argilla(RESULTS_DATASET, workspace=RESULTS_WORKSPACE)
if DELETE_SOURCE_RECORDS:
dataset.delete_records([r for r in completed_remote_records])

parsed_results_dataset = build_parsed_results(results)
print(f"Pushing dataset to {PARSED_RESULTS_DATASET}...")
parsed_results_dataset.push_to_hub(PARSED_RESULTS_DATASET, token=HF_TOKEN)
print(f"Pushed dataset to {PARSED_RESULTS_DATASET}")
#print(f"Pushing dataset to {RAW_DATASET}....")
#results.push_to_huggingface(RAW_DATASET, token=HF_TOKEN)
#print(f"Pushed dataset to {RAW_DATASET}")
client = rg.Argilla(
api_url=ARGILLA_API_URL,
api_key=ARGILLA_API_KEY
)

ds = client.datasets(SOURCE_DATASET, workspace=SOURCE_WORKSPACE)

# Get submitted records (at least 1 user response)
filter = rg.Filter(("response.status", "==", "submitted"))
submitted = ds.records(query=rg.Query(filter=filter))
to_delete = list(submitted)
print(f"Number of records to delete: {len(to_delete)}")
submitted = ds.records(query=rg.Query(filter=filter))
record_list = submitted.to_list(flatten=False)
print(f"Number of records to persist: {len(record_list)}")

if len(record_list)>0:
hf_ds = Dataset.from_list(record_list)
# we need to remove this, otherwise it fails
hf_ds = hf_ds.remove_columns(["vectors"])

# Load existing hf dataset
previous_hf_ds = load_dataset(HF_DATASET_RESULTS, split="train")
print(f"Current HF dataset size: {len(previous_hf_ds)}")

# Add new submitted records
concatenated = concatenate_datasets([previous_hf_ds,hf_ds])
print(f"New HF dataset size: {len(concatenated)}")
concatenated.push_to_hub(HF_DATASET_RESULTS, private=True)
print(f"New HF dataset size: {len(concatenated)}")

print(f"Deleting records")
# this won't be needed with rc3 just ds.delete(to_delete)
count = 0
for r in to_delete:
ds.records.delete([r])
count +=1
print(f"Deleted: {count}")
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
argilla==1.27.0
argilla==2.0.0rc2
datasets
huggingface_hub
jinja2

0 comments on commit 8243889

Please sign in to comment.