-
Notifications
You must be signed in to change notification settings - Fork 10
/
client.py
449 lines (337 loc) · 13.8 KB
/
client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
from __future__ import annotations
import asyncio
import json
import logging
import os
import pathlib
import sys
import traceback
import typing
import warnings
from lsprotocol import types
from lsprotocol.converters import get_converter
from packaging.version import parse as parse_version
from pygls.exceptions import JsonRpcException
from pygls.exceptions import PyglsError
from pygls.lsp.client import BaseLanguageClient
from pygls.protocol import default_converter
from .checks import LspSpecificationWarning
from .protocol import LanguageClientProtocol
if sys.version_info < (3, 9):
import importlib_resources as resources
else:
from importlib import resources # type: ignore[no-redef]
if typing.TYPE_CHECKING:
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Type
from typing import Union
__version__ = "0.4.2"
logger = logging.getLogger(__name__)
class LanguageClient(BaseLanguageClient):
"""Used to drive language servers under test."""
protocol: LanguageClientProtocol
def __init__(self, *args, configuration: Optional[Dict[str, Any]] = None, **kwargs):
if "protocol_cls" not in kwargs:
kwargs["protocol_cls"] = LanguageClientProtocol
super().__init__("pytest-lsp-client", __version__, *args, **kwargs)
self.capabilities: Optional[types.ClientCapabilities] = None
"""The client's capabilities."""
self.shown_documents: List[types.ShowDocumentParams] = []
"""Holds any received show document requests."""
self.messages: List[types.ShowMessageParams] = []
"""Holds any received ``window/showMessage`` requests."""
self.log_messages: List[types.LogMessageParams] = []
"""Holds any received ``window/logMessage`` requests."""
self.diagnostics: Dict[str, List[types.Diagnostic]] = {}
"""Holds any recieved diagnostics."""
self.progress_reports: Dict[
types.ProgressToken, List[types.ProgressParams]
] = {}
"""Holds any received progress updates."""
self.error: Optional[Exception] = None
"""Indicates if the client encountered an error."""
config = (configuration or {"": {}}).copy()
if "" not in config:
config[""] = {}
self._configuration: Dict[str, Dict[str, Any]] = config
"""Holds ``workspace/configuration`` values."""
self._setup_log_index = 0
"""Used to keep track of which log messages occurred during startup."""
self._last_log_index = 0
"""Used to keep track of which log messages correspond with which test case."""
self._stderr_forwarder: Optional[asyncio.Task] = None
"""A task that forwards the server's stderr to the test process."""
async def start_io(self, cmd: str, *args, **kwargs):
await super().start_io(cmd, *args, **kwargs)
# Forward the server's stderr to this process' stderr
if self._server and self._server.stderr:
self._stderr_forwarder = asyncio.create_task(forward_stderr(self._server))
async def stop(self):
if self._stderr_forwarder:
self._stderr_forwarder.cancel()
return await super().stop()
async def server_exit(self, server: asyncio.subprocess.Process):
"""Called when the server process exits."""
logger.debug("Server process exited with code: %s", server.returncode)
if self._stop_event.is_set():
return
loop = asyncio.get_running_loop()
loop.call_soon(
cancel_all_tasks,
f"Server process exited with return code: {server.returncode}",
)
def report_server_error(
self, error: Exception, source: Union[PyglsError, JsonRpcException]
):
"""Called when the server does something unexpected, e.g. sending malformed
JSON."""
self.error = error
tb = "".join(traceback.format_exc())
message = f"{source.__name__}: {error}\n{tb}" # type: ignore
loop = asyncio.get_running_loop()
loop.call_soon(cancel_all_tasks, message)
if self._stop_event:
self._stop_event.set()
def get_configuration(
self, *, section: Optional[str] = None, scope_uri: Optional[str] = None
) -> Optional[Any]:
"""Get a configuration value.
Parameters
----------
section
The optional section name to retrieve.
If ``None`` the top level configuration object for the requested scope will
be returned
scope_uri
The scope at which to set the configuration.
If ``None``, this will default to the global scope.
Returns
-------
Optional[Any]
The requested configuration value or ``None`` if not found.
"""
section = section or ""
scope = scope_uri or ""
# Find the longest prefix of ``scope``. The empty string is a prefix of all
# strings so there will always be at least one match
candidates = [c for c in self._configuration.keys() if scope.startswith(c)]
selected = sorted(candidates, key=len, reverse=True)[0]
if (item := self._configuration.get(selected, None)) is None:
return None
if section == "":
return item
for segment in section.split("."):
if not hasattr(item, "get"):
return None
if (item := item.get(segment, None)) is None:
return None
return item
def set_configuration(
self,
item: Any,
*,
section: Optional[str] = None,
scope_uri: Optional[str] = None,
):
"""Set a configuration value.
Parameters
----------
item
The value to set
section
The optional section name to set.
If ``None`` the top level configuration object will be overriden with
``item``.
scope_uri
The scope at which to set the configuration.
If ``None``, this will default to the global scope.
"""
section = section or ""
scope = scope_uri or ""
if section == "":
self._configuration[scope] = item
return
config = self._configuration.setdefault(scope, {})
*parents, name = section.split(".")
for segment in parents:
config = config.setdefault(segment, {})
config[name] = item
async def initialize_session(
self, params: types.InitializeParams
) -> types.InitializeResult:
"""Make an ``initialize`` request to a lanaguage server.
It will also automatically send an ``initialized`` notification once
the server responds.
Parameters
----------
params
The parameters to send to the client.
The following fields will be automatically set if left blank.
- ``process_id``: Set to the PID of the current process.
Returns
-------
InitializeResult
The result received from the client.
"""
self.capabilities = params.capabilities
if params.process_id is None:
params.process_id = os.getpid()
response = await self.initialize_async(params)
self.initialized(types.InitializedParams())
return response
async def shutdown_session(self) -> None:
"""Shutdown the server under test.
Helper method that handles sending ``shutdown`` and ``exit`` messages in the
correct order.
.. note::
This method will not attempt to send these messages if a fatal error has
occurred.
"""
if self.error is not None or self.capabilities is None:
return
await self.shutdown_async(None)
self.exit(None)
async def wait_for_notification(self, method: str):
"""Block until a notification with the given method is received.
Parameters
----------
method
The notification method to wait for, e.g. ``textDocument/publishDiagnostics``
"""
return await self.protocol.wait_for_notification_async(method)
async def forward_stderr(server: asyncio.subprocess.Process):
if server.stderr is None:
return
# EOF is signalled with an empty bytestring
while (line := await server.stderr.readline()) != b"":
sys.stderr.buffer.write(line)
def cancel_all_tasks(message: str):
"""Called to cancel all awaited tasks."""
for task in asyncio.all_tasks():
if sys.version_info < (3, 9):
task.cancel()
else:
task.cancel(message)
def make_test_lsp_client() -> LanguageClient:
"""Construct a new test client instance with the handlers needed to capture
additional responses from the server."""
client = LanguageClient(
converter_factory=default_converter,
)
@client.feature(types.WORKSPACE_CONFIGURATION)
def configuration(client: LanguageClient, params: types.ConfigurationParams):
return [
client.get_configuration(section=item.section, scope_uri=item.scope_uri)
for item in params.items
]
@client.feature(types.TEXT_DOCUMENT_PUBLISH_DIAGNOSTICS)
def publish_diagnostics(
client: LanguageClient, params: types.PublishDiagnosticsParams
):
client.diagnostics[params.uri] = params.diagnostics
@client.feature(types.WINDOW_WORK_DONE_PROGRESS_CREATE)
def create_work_done_progress(
client: LanguageClient, params: types.WorkDoneProgressCreateParams
):
if params.token in client.progress_reports:
# TODO: Send an error reponse to the client - might require changes
# to pygls...
warnings.warn(
f"Duplicate progress token: {params.token!r}",
LspSpecificationWarning,
stacklevel=2,
)
client.progress_reports.setdefault(params.token, [])
@client.feature(types.PROGRESS)
def progress(client: LanguageClient, params: types.ProgressParams):
if params.token not in client.progress_reports:
warnings.warn(
f"Unknown progress token: {params.token!r}",
LspSpecificationWarning,
stacklevel=2,
)
if not params.value:
return
if (kind := params.value.get("kind", None)) == "begin":
type_: Type[Any] = types.WorkDoneProgressBegin
elif kind == "report":
type_ = types.WorkDoneProgressReport
elif kind == "end":
type_ = types.WorkDoneProgressEnd
else:
raise TypeError(f"Unknown progress kind: {kind!r}")
value = client.protocol._converter.structure(params.value, type_)
client.progress_reports.setdefault(params.token, []).append(value)
@client.feature(types.WINDOW_LOG_MESSAGE)
def log_message(client: LanguageClient, params: types.LogMessageParams):
client.log_messages.append(params)
levels = [logger.error, logger.warning, logger.info, logger.debug]
levels[params.type.value - 1](params.message)
@client.feature(types.WINDOW_SHOW_MESSAGE)
def show_message(client: LanguageClient, params):
client.messages.append(params)
@client.feature(types.WINDOW_SHOW_DOCUMENT)
def show_document(
client: LanguageClient, params: types.ShowDocumentParams
) -> types.ShowDocumentResult:
client.shown_documents.append(params)
return types.ShowDocumentResult(success=True)
return client
def client_capabilities(client_spec: str) -> types.ClientCapabilities:
"""Find the capabilities that correspond to the given client spec.
This function supports the following syntax
``client-name`` or ``client-name@latest``
Return the capabilities of the latest version of ``client-name``
``client-name@v2``
Return the latest release of the ``v2`` of ``client-name``
``client-name@v2.3.1``
Return exactly ``v2.3.1`` of ``client-name``
Parameters
----------
client_spec
The string describing the client to load the corresponding
capabilities for.
Raises
------
ValueError
If the requested client's capabilities could not be found
Returns
-------
ClientCapabilities
The requested client capabilities
"""
candidates: Dict[str, pathlib.Path] = {}
client_spec = client_spec.replace("-", "_")
target_version = "latest"
if "@" in client_spec:
client_spec, target_version = client_spec.split("@")
if target_version.startswith("v"):
target_version = target_version[1:]
for resource in resources.files("pytest_lsp.clients").iterdir():
filename = typing.cast(pathlib.Path, resource)
# Skip the README or any other files that we don't care about.
if filename.suffix != ".json":
continue
name, version = filename.stem.split("_v")
if name == client_spec:
if version.startswith(target_version) or target_version == "latest":
candidates[version] = filename
if len(candidates) == 0:
raise ValueError(
f"Could not find capabilities for '{client_spec}@{target_version}'"
)
# Out of the available candidates, choose the latest version
selected_version = sorted(candidates.keys(), key=parse_version, reverse=True)[0]
filename = candidates[selected_version]
converter = get_converter()
capabilities = json.loads(filename.read_text())
params = converter.structure(capabilities, types.InitializeParams)
logger.info(
"Selected %s v%s",
params.client_info.name, # type: ignore[union-attr]
params.client_info.version, # type: ignore[union-attr]
)
return params.capabilities