Skip to content

Commit

Permalink
Merge pull request #173 from princeton-nlp/print-pulls-gh-token
Browse files Browse the repository at this point in the history
Fix: Allow to set GH token from env var in collect/print_pulls
  • Loading branch information
john-b-yang authored Jul 5, 2024
2 parents 86f44ee + 2a03f7d commit a8d9a39
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 15 deletions.
2 changes: 1 addition & 1 deletion swebench/collect/build_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def main(pr_file: str, output: str, token: Optional[str] = None):
"""
if token is None:
# Get GitHub token from environment variable if not provided
token = os.environ["GITHUB_TOKEN"]
token = os.environ.get("GITHUB_TOKEN")

def load_repo(repo_name):
# Return repo object for a given repo name
Expand Down
19 changes: 12 additions & 7 deletions swebench/collect/print_pulls.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

"""Given the `<owner/name>` of a GitHub repo, this script writes the raw information for all the repo's PRs to a single `.jsonl` file."""

from __future__ import annotations

import argparse
import json
import logging
Expand All @@ -17,21 +19,23 @@
logger = logging.getLogger(__name__)


def log_all_pulls(repo: Repo, output: str):
def log_all_pulls(repo: Repo, output: str, max_pulls: int|None=None) -> None:
"""
Iterate over all pull requests in a repository and log them to a file
Args:
repo (Repo): repository object
output (str): output file name
"""
with open(output, "w") as output:
for pull in repo.get_all_pulls():
with open(output, "w") as file:
for i_pull, pull in enumerate(repo.get_all_pulls()):
setattr(pull, "resolved_issues", repo.extract_resolved_issues(pull))
print(json.dumps(obj2dict(pull)), end="\n", flush=True, file=output)
print(json.dumps(obj2dict(pull)), end="\n", flush=True, file=file)
if max_pulls is not None and i_pull >= max_pulls:
break


def main(repo_name: str, output: str, token: Optional[str] = None):
def main(repo_name: str, output: str, token: Optional[str] = None, max_pulls: int|None=None):
"""
Logic for logging all pull requests in a repository
Expand All @@ -41,16 +45,17 @@ def main(repo_name: str, output: str, token: Optional[str] = None):
token (str, optional): GitHub token
"""
if token is None:
token = os.environ["GITHUB_TOKEN"]
token = os.environ.get("GITHUB_TOKEN")
owner, repo = repo_name.split("/")
repo = Repo(owner, repo, token=token)
log_all_pulls(repo, output)
log_all_pulls(repo, output, max_pulls=max_pulls)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("repo_name", type=str, help="Name of the repository")
parser.add_argument("output", type=str, help="Output file name")
parser.add_argument("--token", type=str, help="GitHub token")
parser.add_argument("--max-pulls", type=int, help="Maximum number of pulls to log", default=None)
args = parser.parse_args()
main(**vars(args))
16 changes: 9 additions & 7 deletions swebench/collect/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from __future__ import annotations


import logging
import re
import requests
Expand All @@ -6,7 +9,7 @@
from bs4 import BeautifulSoup
from ghapi.core import GhApi
from fastcore.net import HTTP404NotFoundError, HTTP403ForbiddenError
from typing import Optional
from typing import Callable, Iterator, Optional

logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
Expand All @@ -30,7 +33,7 @@ def __init__(self, owner: str, name: str, token: Optional[str] = None):
self.api = GhApi(token=token)
self.repo = self.call_api(self.api.repos.get, owner=owner, repo=name)

def call_api(self, func: callable, **kwargs) -> dict:
def call_api(self, func: Callable, **kwargs) -> dict|None:
"""
API call wrapper with rate limit handling (checks every 5 minutes if rate limit is reset)
Expand Down Expand Up @@ -103,12 +106,12 @@ def extract_resolved_issues(self, pull: dict) -> list[str]:

def get_all_loop(
self,
func: callable,
func: Callable,
per_page: int = 100,
num_pages: Optional[int] = None,
quiet: bool = False,
**kwargs,
) -> list:
) -> Iterator:
"""
Return all values from a paginated API endpoint.
Expand Down Expand Up @@ -195,7 +198,7 @@ def get_all_pulls(
direction: str = "asc",
sort: str = "created",
state: str = "closed",
quiet: str = False,
quiet: bool = False,
) -> list:
"""
Wrapper for API call to get all PRs from repo
Expand Down Expand Up @@ -347,7 +350,7 @@ def extract_patches(pull: dict, repo: Repo) -> tuple[str, str]:
### MARK: Repo Specific Parsing Functions ###
def extract_problem_statement_and_hints_django(
pull: dict, repo: Repo
) -> tuple[str, str]:
) -> tuple[str, list[str]]:
"""
Get problem statement and hints from issues associated with a pull request
Expand Down Expand Up @@ -390,7 +393,6 @@ def extract_problem_statement_and_hints_django(
# Get all comments before first commit
comments_html = soup.find("div", {"id": "changelog"})
div_blocks = comments_html.find_all("div", class_="change")
comments = []
# Loop through each div block
for div_block in div_blocks:
# Find the comment text and timestamp
Expand Down
27 changes: 27 additions & 0 deletions tests/test_collect_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import subprocess


def test_collect_smoke_test():
cmd = ["python", "-m", "swebench.collect.print_pulls", "--help"]
result = subprocess.run(cmd, capture_output=True)
print(result.stdout)
print(result.stderr)
assert result.returncode == 0


def test_collect_one(tmp_path):
cmd = ["python", "-m", "swebench.collect.print_pulls", "pvlib/pvlib-python", str(tmp_path/ "out.txt"), "--max-pulls", "1"]
print(" ".join(cmd))
result = subprocess.run(cmd, capture_output=True)
print(result.stdout)
print(result.stderr)
assert result.returncode == 0


def test_collect_ds(tmp_path):
cmd = ["python", "-m", "swebench.collect.build_dataset", "tests/test_data/pvlib.jsonl", str(tmp_path/ "out.jsonl")]
print(" ".join(cmd))
result = subprocess.run(cmd, capture_output=True)
print(result.stdout)
print(result.stderr)
assert result.returncode == 0
Loading

0 comments on commit a8d9a39

Please sign in to comment.