forked from ROCm/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
nightly_hotpatch.py
218 lines (182 loc) · 7.53 KB
/
nightly_hotpatch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
#!/usr/bin/env python3
import argparse
import os
import shutil
import subprocess
import sys
import tempfile
import urllib.request
from typing import cast, List, NoReturn, Optional
def parse_arguments() -> argparse.Namespace:
"""
Parses command-line arguments using argparse.
Returns:
argparse.Namespace: The parsed arguments containing the PR number, optional target directory, and strip count.
"""
parser = argparse.ArgumentParser(
description=(
"Download and apply a Pull Request (PR) patch from the PyTorch GitHub repository "
"to your local PyTorch installation.\n\n"
"Best Practice: Since this script involves hot-patching PyTorch, it's recommended to use "
"a disposable environment like a Docker container or a dedicated Python virtual environment (venv). "
"This ensures that if the patching fails, you can easily recover by resetting the environment."
),
epilog=(
"Example:\n"
" python nightly_hotpatch.py 12345\n"
" python nightly_hotpatch.py 12345 --directory /path/to/pytorch --strip 1\n\n"
"These commands will download the patch for PR #12345 and apply it to your local "
"PyTorch installation."
),
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"PR_NUMBER",
type=int,
help="The number of the Pull Request (PR) from the PyTorch GitHub repository to download and apply as a patch.",
)
parser.add_argument(
"--directory",
"-d",
type=str,
default=None,
help="Optional. Specify the target directory to apply the patch. "
"If not provided, the script will use the PyTorch installation path.",
)
parser.add_argument(
"--strip",
"-p",
type=int,
default=1,
help="Optional. Specify the strip count to remove leading directories from file paths in the patch. Default is 1.",
)
return parser.parse_args()
def get_pytorch_path() -> str:
"""
Retrieves the installation path of PyTorch in the current environment.
Returns:
str: The directory of the PyTorch installation.
Exits:
If PyTorch is not installed in the current Python environment, the script will exit.
"""
try:
import torch
torch_paths: List[str] = cast(List[str], torch.__path__)
torch_path: str = torch_paths[0]
parent_path: str = os.path.dirname(torch_path)
print(f"PyTorch is installed at: {torch_path}")
print(f"Parent directory for patching: {parent_path}")
return parent_path
except ImportError:
handle_import_error()
def handle_import_error() -> NoReturn:
"""
Handle the case where PyTorch is not installed and exit the program.
Exits:
NoReturn: This function will terminate the program.
"""
print("Error: PyTorch is not installed in the current Python environment.")
sys.exit(1)
def download_patch(pr_number: int, repo_url: str, download_dir: str) -> str:
"""
Downloads the patch file for a given PR from the specified GitHub repository.
Args:
pr_number (int): The pull request number.
repo_url (str): The URL of the repository where the PR is hosted.
download_dir (str): The directory to store the downloaded patch.
Returns:
str: The path to the downloaded patch file.
Exits:
If the download fails, the script will exit.
"""
patch_url = f"{repo_url}/pull/{pr_number}.diff"
patch_file = os.path.join(download_dir, f"pr-{pr_number}.patch")
print(f"Downloading PR #{pr_number} patch from {patch_url}...")
try:
with urllib.request.urlopen(patch_url) as response, open(
patch_file, "wb"
) as out_file:
shutil.copyfileobj(response, out_file)
if not os.path.isfile(patch_file):
print(f"Failed to download patch for PR #{pr_number}")
sys.exit(1)
print(f"Patch downloaded to {patch_file}")
return patch_file
except urllib.error.HTTPError as e:
print(f"HTTP Error: {e.code} when downloading patch for PR #{pr_number}")
sys.exit(1)
except Exception as e:
print(f"An error occurred while downloading the patch: {e}")
sys.exit(1)
def apply_patch(patch_file: str, target_dir: Optional[str], strip_count: int) -> None:
"""
Applies the downloaded patch to the specified directory using the given strip count.
Args:
patch_file (str): The path to the patch file.
target_dir (Optional[str]): The directory to apply the patch to. If None, uses PyTorch installation path.
strip_count (int): The number of leading directories to strip from file paths in the patch.
Exits:
If the patch command fails or the 'patch' utility is not available, the script will exit.
"""
if target_dir:
print(f"Applying patch in directory: {target_dir}")
else:
print("No target directory specified. Using PyTorch installation path.")
print(f"Applying patch with strip count: {strip_count}")
try:
# Construct the patch command with -d and -p options
patch_command = ["patch", f"-p{strip_count}", "-i", patch_file]
if target_dir:
patch_command.insert(
1, f"-d{target_dir}"
) # Insert -d option right after 'patch'
print(f"Running command: {' '.join(patch_command)}")
result = subprocess.run(patch_command, capture_output=True, text=True)
else:
patch_command.insert(1, f"-d{target_dir}")
print(f"Running command: {' '.join(patch_command)}")
result = subprocess.run(patch_command, capture_output=True, text=True)
# Check if the patch was applied successfully
if result.returncode != 0:
print("Failed to apply patch.")
print("Patch output:")
print(result.stdout)
print(result.stderr)
sys.exit(1)
else:
print("Patch applied successfully.")
except FileNotFoundError:
print("Error: The 'patch' utility is not installed or not found in PATH.")
sys.exit(1)
except Exception as e:
print(f"An error occurred while applying the patch: {e}")
sys.exit(1)
def main() -> None:
"""
Main function to orchestrate the patch download and application process.
Steps:
1. Parse command-line arguments to get the PR number, optional target directory, and strip count.
2. Retrieve the local PyTorch installation path or use the provided target directory.
3. Download the patch for the provided PR number.
4. Apply the patch to the specified directory with the given strip count.
"""
args = parse_arguments()
pr_number = args.PR_NUMBER
custom_target_dir = args.directory
strip_count = args.strip
if custom_target_dir:
if not os.path.isdir(custom_target_dir):
print(
f"Error: The specified target directory '{custom_target_dir}' does not exist."
)
sys.exit(1)
target_dir = custom_target_dir
print(f"Using custom target directory: {target_dir}")
else:
target_dir = get_pytorch_path()
repo_url = "https://github.com/pytorch/pytorch"
with tempfile.TemporaryDirectory() as tmpdirname:
patch_file = download_patch(pr_number, repo_url, tmpdirname)
apply_patch(patch_file, target_dir, strip_count)
if __name__ == "__main__":
main()