diff --git a/.github/workflows/pytest-gpu.yaml b/.github/workflows/pytest-gpu.yaml index 3722bf91ff..bbc8a3afe2 100644 --- a/.github/workflows/pytest-gpu.yaml +++ b/.github/workflows/pytest-gpu.yaml @@ -58,5 +58,19 @@ jobs: id: tests run: | set -ex - export PR_NUMBER=$(jq --raw-output .pull_request.number "$GITHUB_EVENT_PATH") - python .github/mcp/mcp_pytest.py --image '${{ inputs.container }}' --pr_number $PR_NUMBER --pytest_markers '${{ inputs.pytest-markers }}' --pytest_command '${{ inputs.pytest-command }}' --timeout ${{ inputs.mcloud-timeout }} + + PR_NUMBER="$(jq --raw-output .pull_request.number "$GITHUB_EVENT_PATH")" + REF_ARGS="" + + # Use the PR number if it exists, commit SHA for protected branches and the branch name otherwise + if [ -z "$PR_NUMBER" ]; then + if [[ "$GITHUB_REF" =~ "refs/heads/dev" || "$GITHUB_REF" =~ "refs/heads/main" || "$GITHUB_REF" =~ "refs/heads/release" ]]; then + REF_ARGS="--git_commit $GITHUB_SHA" + else + REF_ARGS="--git_branch $GITHUB_REF_NAME" + fi + else + REF_ARGS="--pr_number $PR_NUMBER" + fi + + python .github/mcp/mcp_pytest.py --image '${{ inputs.container }}' --pytest_markers '${{ inputs.pytest-markers }}' --pytest_command '${{ inputs.pytest-command }}' --timeout ${{ inputs.mcloud-timeout }} ${REF_ARGS}