-
Notifications
You must be signed in to change notification settings - Fork 0
/
blackd.py
194 lines (167 loc) · 6.52 KB
/
blackd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import asyncio
from concurrent.futures import Executor, ProcessPoolExecutor
from datetime import datetime
from functools import partial
import logging
from multiprocessing import freeze_support
from typing import Set, Tuple
from aiohttp import web
import aiohttp_cors
import black
import click
from _black_version import version as __version__
# This is used internally by tests to shut down the server prematurely
_stop_signal = asyncio.Event()
# Request headers
PROTOCOL_VERSION_HEADER = "X-Protocol-Version"
LINE_LENGTH_HEADER = "X-Line-Length"
PYTHON_VARIANT_HEADER = "X-Python-Variant"
SKIP_STRING_NORMALIZATION_HEADER = "X-Skip-String-Normalization"
FAST_OR_SAFE_HEADER = "X-Fast-Or-Safe"
DIFF_HEADER = "X-Diff"
BLACK_HEADERS = [
PROTOCOL_VERSION_HEADER,
LINE_LENGTH_HEADER,
PYTHON_VARIANT_HEADER,
SKIP_STRING_NORMALIZATION_HEADER,
FAST_OR_SAFE_HEADER,
DIFF_HEADER,
]
# Response headers
BLACK_VERSION_HEADER = "X-Black-Version"
class InvalidVariantHeader(Exception):
pass
@click.command(context_settings={"help_option_names": ["-h", "--help"]})
@click.option(
"--bind-host", type=str, help="Address to bind the server to.", default="localhost"
)
@click.option("--bind-port", type=int, help="Port to listen on", default=45484)
@click.version_option(version=black.__version__)
def main(bind_host: str, bind_port: int) -> None:
logging.basicConfig(level=logging.INFO)
app = make_app()
ver = black.__version__
black.out(f"blackd version {ver} listening on {bind_host} port {bind_port}")
web.run_app(app, host=bind_host, port=bind_port, handle_signals=True, print=None)
def make_app() -> web.Application:
app = web.Application()
executor = ProcessPoolExecutor()
cors = aiohttp_cors.setup(app)
resource = cors.add(app.router.add_resource("/"))
cors.add(
resource.add_route("POST", partial(handle, executor=executor)),
{
"*": aiohttp_cors.ResourceOptions(
allow_headers=(*BLACK_HEADERS, "Content-Type"), expose_headers="*"
)
},
)
return app
async def handle(request: web.Request, executor: Executor) -> web.Response:
headers = {BLACK_VERSION_HEADER: __version__}
try:
if request.headers.get(PROTOCOL_VERSION_HEADER, "1") != "1":
return web.Response(
status=501, text="This server only supports protocol version 1"
)
try:
line_length = int(
request.headers.get(LINE_LENGTH_HEADER, black.DEFAULT_LINE_LENGTH)
)
except ValueError:
return web.Response(status=400, text="Invalid line length header value")
if PYTHON_VARIANT_HEADER in request.headers:
value = request.headers[PYTHON_VARIANT_HEADER]
try:
pyi, versions = parse_python_variant_header(value)
except InvalidVariantHeader as e:
return web.Response(
status=400,
text=f"Invalid value for {PYTHON_VARIANT_HEADER}: {e.args[0]}",
)
else:
pyi = False
versions = set()
skip_string_normalization = bool(
request.headers.get(SKIP_STRING_NORMALIZATION_HEADER, False)
)
fast = False
if request.headers.get(FAST_OR_SAFE_HEADER, "safe") == "fast":
fast = True
mode = black.FileMode(
target_versions=versions,
is_pyi=pyi,
line_length=line_length,
string_normalization=not skip_string_normalization,
)
req_bytes = await request.content.read()
charset = request.charset if request.charset is not None else "utf8"
req_str = req_bytes.decode(charset)
then = datetime.utcnow()
loop = asyncio.get_event_loop()
formatted_str = await loop.run_in_executor(
executor, partial(black.format_file_contents, req_str, fast=fast, mode=mode)
)
# Only output the diff in the HTTP response
only_diff = bool(request.headers.get(DIFF_HEADER, False))
if only_diff:
now = datetime.utcnow()
src_name = f"In\t{then} +0000"
dst_name = f"Out\t{now} +0000"
loop = asyncio.get_event_loop()
formatted_str = await loop.run_in_executor(
executor,
partial(black.diff, req_str, formatted_str, src_name, dst_name),
)
return web.Response(
content_type=request.content_type,
charset=charset,
headers=headers,
text=formatted_str,
)
except black.NothingChanged:
return web.Response(status=204, headers=headers)
except black.InvalidInput as e:
return web.Response(status=400, headers=headers, text=str(e))
except Exception as e:
logging.exception("Exception during handling a request")
return web.Response(status=500, headers=headers, text=str(e))
def parse_python_variant_header(value: str) -> Tuple[bool, Set[black.TargetVersion]]:
if value == "pyi":
return True, set()
else:
versions = set()
for version in value.split(","):
if version.startswith("py"):
version = version[len("py") :]
if "." in version:
major_str, *rest = version.split(".")
else:
major_str = version[0]
rest = [version[1:]] if len(version) > 1 else []
try:
major = int(major_str)
if major not in (2, 3):
raise InvalidVariantHeader("major version must be 2 or 3")
if len(rest) > 0:
minor = int(rest[0])
if major == 2 and minor != 7:
raise InvalidVariantHeader(
"minor version must be 7 for Python 2"
)
else:
# Default to lowest supported minor version.
minor = 7 if major == 2 else 3
version_str = f"PY{major}{minor}"
if major == 3 and not hasattr(black.TargetVersion, version_str):
raise InvalidVariantHeader(f"3.{minor} is not supported")
versions.add(black.TargetVersion[version_str])
except (KeyError, ValueError):
raise InvalidVariantHeader("expected e.g. '3.7', 'py3.5'")
return False, versions
def patched_main() -> None:
freeze_support()
black.patch_click()
main()
if __name__ == "__main__":
patched_main()