usse/scrape/venv/lib/python3.10/site-packages/pygls/protocol/json_rpc.py
2023-12-22 15:26:01 +01:00

558 lines
20 KiB
Python

############################################################################
# Copyright(c) Open Law Library. All rights reserved. #
# See ThirdPartyNotices.txt in the project root for additional notices. #
# #
# Licensed under the Apache License, Version 2.0 (the "License") #
# you may not use this file except in compliance with the License. #
# You may obtain a copy of the License at #
# #
# http: // www.apache.org/licenses/LICENSE-2.0 #
# #
# Unless required by applicable law or agreed to in writing, software #
# distributed under the License is distributed on an "AS IS" BASIS, #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
# See the License for the specific language governing permissions and #
# limitations under the License. #
############################################################################
from __future__ import annotations
import asyncio
import enum
import json
import logging
import re
import sys
import uuid
import traceback
from concurrent.futures import Future
from functools import partial
from typing import (
Any,
Dict,
List,
Optional,
Type,
Union,
TYPE_CHECKING,
)
if TYPE_CHECKING:
from pygls.server import LanguageServer, WebSocketTransportAdapter
import attrs
from cattrs.errors import ClassValidationError
from lsprotocol.types import (
CANCEL_REQUEST,
EXIT,
WORKSPACE_EXECUTE_COMMAND,
ResponseErrorMessage,
)
from pygls.exceptions import (
JsonRpcException,
JsonRpcInternalError,
JsonRpcInvalidParams,
JsonRpcMethodNotFound,
JsonRpcRequestCancelled,
FeatureNotificationError,
FeatureRequestError,
)
from pygls.feature_manager import FeatureManager, is_thread_function
logger = logging.getLogger(__name__)
@attrs.define
class JsonRPCNotification:
"""A class that represents a generic json rpc notification message.
Used as a fallback for unknown types.
"""
method: str
jsonrpc: str
params: Any
@attrs.define
class JsonRPCRequestMessage:
"""A class that represents a generic json rpc request message.
Used as a fallback for unknown types.
"""
id: Union[int, str]
method: str
jsonrpc: str
params: Any
@attrs.define
class JsonRPCResponseMessage:
"""A class that represents a generic json rpc response message.
Used as a fallback for unknown types.
"""
id: Union[int, str]
jsonrpc: str
result: Any
class JsonRPCProtocol(asyncio.Protocol):
"""Json RPC protocol implementation using on top of `asyncio.Protocol`.
Specification of the protocol can be found here:
https://www.jsonrpc.org/specification
This class provides bidirectional communication which is needed for LSP.
"""
CHARSET = "utf-8"
CONTENT_TYPE = "application/vscode-jsonrpc"
MESSAGE_PATTERN = re.compile(
rb"^(?:[^\r\n]+\r\n)*"
+ rb"Content-Length: (?P<length>\d+)\r\n"
+ rb"(?:[^\r\n]+\r\n)*\r\n"
+ rb"(?P<body>{.*)",
re.DOTALL,
)
VERSION = "2.0"
def __init__(self, server: LanguageServer, converter):
self._server = server
self._converter = converter
self._shutdown = False
# Book keeping for in-flight requests
self._request_futures: Dict[str, Future[Any]] = {}
self._result_types: Dict[str, Any] = {}
self.fm = FeatureManager(server, converter)
self.transport: Optional[
Union[asyncio.WriteTransport, WebSocketTransportAdapter]
] = None
self._message_buf: List[bytes] = []
self._send_only_body = False
def __call__(self):
return self
def _execute_notification(self, handler, *params):
"""Executes notification message handler."""
if asyncio.iscoroutinefunction(handler):
future = asyncio.ensure_future(handler(*params))
future.add_done_callback(self._execute_notification_callback)
else:
if is_thread_function(handler):
self._server.thread_pool.apply_async(handler, (*params,))
else:
handler(*params)
def _execute_notification_callback(self, future):
"""Success callback used for coroutine notification message."""
if future.exception():
try:
raise future.exception()
except Exception:
error = JsonRpcInternalError.of(sys.exc_info()).to_dict()
logger.exception('Exception occurred in notification: "%s"', error)
# Revisit. Client does not support response with msg_id = None
# https://stackoverflow.com/questions/31091376/json-rpc-2-0-allow-notifications-to-have-an-error-response
# self._send_response(None, error=error)
def _execute_request(self, msg_id, handler, params):
"""Executes request message handler."""
if asyncio.iscoroutinefunction(handler):
future = asyncio.ensure_future(handler(params))
self._request_futures[msg_id] = future
future.add_done_callback(partial(self._execute_request_callback, msg_id))
else:
# Can't be canceled
if is_thread_function(handler):
self._server.thread_pool.apply_async(
handler,
(params,),
callback=partial(
self._send_response,
msg_id,
),
error_callback=partial(self._execute_request_err_callback, msg_id),
)
else:
self._send_response(msg_id, handler(params))
def _execute_request_callback(self, msg_id, future):
"""Success callback used for coroutine request message."""
try:
if not future.cancelled():
self._send_response(msg_id, result=future.result())
else:
self._send_response(
msg_id,
error=JsonRpcRequestCancelled(
f'Request with id "{msg_id}" is canceled'
),
)
self._request_futures.pop(msg_id, None)
except Exception:
error = JsonRpcInternalError.of(sys.exc_info()).to_dict()
logger.exception('Exception occurred for message "%s": %s', msg_id, error)
self._send_response(msg_id, error=error)
def _execute_request_err_callback(self, msg_id, exc):
"""Error callback used for coroutine request message."""
exc_info = (type(exc), exc, None)
error = JsonRpcInternalError.of(exc_info).to_dict()
logger.exception('Exception occurred for message "%s": %s', msg_id, error)
self._send_response(msg_id, error=error)
def _get_handler(self, feature_name):
"""Returns builtin or used defined feature by name if exists."""
try:
return self.fm.builtin_features[feature_name]
except KeyError:
try:
return self.fm.features[feature_name]
except KeyError:
raise JsonRpcMethodNotFound.of(feature_name)
def _handle_cancel_notification(self, msg_id):
"""Handles a cancel notification from the client."""
future = self._request_futures.pop(msg_id, None)
if not future:
logger.warning('Cancel notification for unknown message id "%s"', msg_id)
return
# Will only work if the request hasn't started executing
if future.cancel():
logger.info('Cancelled request with id "%s"', msg_id)
def _handle_notification(self, method_name, params):
"""Handles a notification from the client."""
if method_name == CANCEL_REQUEST:
self._handle_cancel_notification(params.id)
return
try:
handler = self._get_handler(method_name)
self._execute_notification(handler, params)
except (KeyError, JsonRpcMethodNotFound):
logger.warning('Ignoring notification for unknown method "%s"', method_name)
except Exception as error:
logger.exception(
'Failed to handle notification "%s": %s',
method_name,
params,
exc_info=True,
)
self._server._report_server_error(error, FeatureNotificationError)
def _handle_request(self, msg_id, method_name, params):
"""Handles a request from the client."""
try:
handler = self._get_handler(method_name)
# workspace/executeCommand is a special case
if method_name == WORKSPACE_EXECUTE_COMMAND:
handler(params, msg_id)
else:
self._execute_request(msg_id, handler, params)
except JsonRpcException as error:
logger.exception(
"Failed to handle request %s %s %s",
msg_id,
method_name,
params,
exc_info=True,
)
self._send_response(msg_id, None, error.to_dict())
self._server._report_server_error(error, FeatureRequestError)
except Exception as error:
logger.exception(
"Failed to handle request %s %s %s",
msg_id,
method_name,
params,
exc_info=True,
)
err = JsonRpcInternalError.of(sys.exc_info()).to_dict()
self._send_response(msg_id, None, err)
self._server._report_server_error(error, FeatureRequestError)
def _handle_response(self, msg_id, result=None, error=None):
"""Handles a response from the client."""
future = self._request_futures.pop(msg_id, None)
if not future:
logger.warning('Received response to unknown message id "%s"', msg_id)
return
if error is not None:
logger.debug('Received error response to message "%s": %s', msg_id, error)
future.set_exception(JsonRpcException.from_error(error))
else:
logger.debug('Received result for message "%s": %s', msg_id, result)
future.set_result(result)
def _serialize_message(self, data):
"""Function used to serialize data sent to the client."""
if hasattr(data, "__attrs_attrs__"):
return self._converter.unstructure(data)
if isinstance(data, enum.Enum):
return data.value
return data.__dict__
def _deserialize_message(self, data):
"""Function used to deserialize data recevied from the client."""
if "jsonrpc" not in data:
return data
try:
if "id" in data:
if "error" in data:
return self._converter.structure(data, ResponseErrorMessage)
elif "method" in data:
request_type = (
self.get_message_type(data["method"]) or JsonRPCRequestMessage
)
return self._converter.structure(data, request_type)
else:
response_type = (
self._result_types.pop(data["id"]) or JsonRPCResponseMessage
)
return self._converter.structure(data, response_type)
else:
method = data.get("method", "")
notification_type = self.get_message_type(method) or JsonRPCNotification
return self._converter.structure(data, notification_type)
except ClassValidationError as exc:
logger.error("Unable to deserialize message\n%s", traceback.format_exc())
raise JsonRpcInvalidParams() from exc
except Exception as exc:
logger.error("Unable to deserialize message\n%s", traceback.format_exc())
raise JsonRpcInternalError() from exc
def _procedure_handler(self, message):
"""Delegates message to handlers depending on message type."""
if message.jsonrpc != JsonRPCProtocol.VERSION:
logger.warning('Unknown message "%s"', message)
return
if self._shutdown and getattr(message, "method", "") != EXIT:
logger.warning("Server shutting down. No more requests!")
return
if hasattr(message, "method"):
if hasattr(message, "id"):
logger.debug("Request message received.")
self._handle_request(message.id, message.method, message.params)
else:
logger.debug("Notification message received.")
self._handle_notification(message.method, message.params)
else:
if hasattr(message, "error"):
logger.debug("Error message received.")
self._handle_response(message.id, None, message.error)
else:
logger.debug("Response message received.")
self._handle_response(message.id, message.result)
def _send_data(self, data):
"""Sends data to the client."""
if not data:
return
if self.transport is None:
logger.error("Unable to send data, no available transport!")
return
try:
body = json.dumps(data, default=self._serialize_message)
logger.info("Sending data: %s", body)
if self._send_only_body:
# Mypy/Pyright seem to think `write()` wants `"bytes | bytearray | memoryview"`
# But runtime errors with anything but `str`.
self.transport.write(body) # type: ignore
return
header = (
f"Content-Length: {len(body)}\r\n"
f"Content-Type: {self.CONTENT_TYPE}; charset={self.CHARSET}\r\n\r\n"
).encode(self.CHARSET)
self.transport.write(header + body.encode(self.CHARSET))
except Exception as error:
logger.exception("Error sending data", exc_info=True)
self._server._report_server_error(error, JsonRpcInternalError)
def _send_response(self, msg_id, result=None, error=None):
"""Sends a JSON RPC response to the client.
Args:
msg_id(str): Id from request
result(any): Result returned by handler
error(any): Error returned by handler
"""
if error is not None:
response = ResponseErrorMessage(id=msg_id, error=error)
else:
response_type = self._result_types.pop(msg_id, JsonRPCResponseMessage)
response = response_type(
id=msg_id, result=result, jsonrpc=JsonRPCProtocol.VERSION
)
self._send_data(response)
def connection_lost(self, exc):
"""Method from base class, called when connection is lost, in which case we
want to shutdown the server's process as well.
"""
logger.error("Connection to the client is lost! Shutting down the server.")
sys.exit(1)
def connection_made( # type: ignore # see: https://github.com/python/typeshed/issues/3021
self,
transport: asyncio.Transport,
):
"""Method from base class, called when connection is established"""
self.transport = transport
def data_received(self, data: bytes):
try:
self._data_received(data)
except Exception as error:
logger.exception("Error receiving data", exc_info=True)
self._server._report_server_error(error, JsonRpcInternalError)
def _data_received(self, data: bytes):
"""Method from base class, called when server receives the data"""
logger.debug("Received %r", data)
while len(data):
# Append the incoming chunk to the message buffer
self._message_buf.append(data)
# Look for the body of the message
message = b"".join(self._message_buf)
found = JsonRPCProtocol.MESSAGE_PATTERN.fullmatch(message)
body = found.group("body") if found else b""
length = int(found.group("length")) if found else 1
if len(body) < length:
# Message is incomplete; bail until more data arrives
return
# Message is complete;
# extract the body and any remaining data,
# and reset the buffer for the next message
body, data = body[:length], body[length:]
self._message_buf = []
# Parse the body
self._procedure_handler(
json.loads(
body.decode(self.CHARSET), object_hook=self._deserialize_message
)
)
def get_message_type(self, method: str) -> Optional[Type]:
"""Return the type definition of the message associated with the given method."""
return None
def get_result_type(self, method: str) -> Optional[Type]:
"""Return the type definition of the result associated with the given method."""
return None
def notify(self, method: str, params=None):
"""Sends a JSON RPC notification to the client."""
logger.debug("Sending notification: '%s' %s", method, params)
notification_type = self.get_message_type(method) or JsonRPCNotification
notification = notification_type(
method=method, params=params, jsonrpc=JsonRPCProtocol.VERSION
)
self._send_data(notification)
def send_request(self, method, params=None, callback=None, msg_id=None):
"""Sends a JSON RPC request to the client.
Args:
method(str): The method name of the message to send
params(any): The payload of the message
Returns:
Future that will be resolved once a response has been received
"""
if msg_id is None:
msg_id = str(uuid.uuid4())
request_type = self.get_message_type(method) or JsonRPCRequestMessage
logger.debug('Sending request with id "%s": %s %s', msg_id, method, params)
request = request_type(
id=msg_id,
method=method,
params=params,
jsonrpc=JsonRPCProtocol.VERSION,
)
future = Future() # type: ignore[var-annotated]
# If callback function is given, call it when result is received
if callback:
def wrapper(future: Future):
result = future.result()
logger.info("Client response for %s received: %s", params, result)
callback(result)
future.add_done_callback(wrapper)
self._request_futures[msg_id] = future
self._result_types[msg_id] = self.get_result_type(method)
self._send_data(request)
return future
def send_request_async(self, method, params=None, msg_id=None):
"""Calls `send_request` and wraps `concurrent.futures.Future` with
`asyncio.Future` so it can be used with `await` keyword.
Args:
method(str): The method name of the message to send
params(any): The payload of the message
msg_id(str|int): Optional, message id
Returns:
`asyncio.Future` that can be awaited
"""
return asyncio.wrap_future(
self.send_request(method, params=params, msg_id=msg_id)
)
def thread(self):
"""Decorator that mark function to execute it in a thread."""
return self.fm.thread()