195 lines
5.6 KiB
Python
195 lines
5.6 KiB
Python
"""Spell checking."""
|
|
import re
|
|
from typing import Dict
|
|
from typing import List
|
|
from typing import Optional
|
|
|
|
import pygls.uris as Uri
|
|
from docutils import nodes
|
|
from lsprotocol.types import CodeAction
|
|
from lsprotocol.types import CodeActionKind
|
|
from lsprotocol.types import CodeActionParams
|
|
from lsprotocol.types import Diagnostic
|
|
from lsprotocol.types import DiagnosticSeverity
|
|
from lsprotocol.types import DidSaveTextDocumentParams
|
|
from lsprotocol.types import Position
|
|
from lsprotocol.types import Range
|
|
from lsprotocol.types import TextEdit
|
|
from lsprotocol.types import WorkspaceEdit
|
|
from spellchecker import SpellChecker # type: ignore
|
|
|
|
from esbonio.lsp.rst import LanguageFeature
|
|
from esbonio.lsp.sphinx import SphinxLanguageServer
|
|
|
|
IGNORED_NODES = {nodes.raw, nodes.literal, nodes.literal_block}
|
|
"""Don't spell check Text contained in any of these nodes."""
|
|
|
|
|
|
class Spelling(LanguageFeature):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
|
|
self.lang = SpellChecker()
|
|
self.errors: Dict[str, List[MisSpelling]] = {}
|
|
|
|
def code_action(self, params: CodeActionParams) -> List[CodeAction]:
|
|
uri = params.text_document.uri
|
|
ranges = {d.range for d in params.context.diagnostics}
|
|
|
|
if uri in self.errors:
|
|
errors = self.errors[uri]
|
|
else:
|
|
errors = self.find_errors_for_uri(uri)
|
|
|
|
diagnostics = errors_to_diagnostics(errors)
|
|
actions = []
|
|
|
|
for error, diagnostic in zip(errors, diagnostics):
|
|
if len(ranges) > 0 and diagnostic.range not in ranges:
|
|
continue
|
|
|
|
fixes = self.lang.candidates(error.text)
|
|
if fixes is None:
|
|
continue
|
|
|
|
for fix in fixes:
|
|
actions.append(
|
|
CodeAction(
|
|
title=f"Correct '{error.text}' -> '{fix}'",
|
|
kind=CodeActionKind.QuickFix,
|
|
diagnostics=[diagnostic],
|
|
edit=WorkspaceEdit(
|
|
changes={
|
|
uri: [TextEdit(range=diagnostic.range, new_text=fix)]
|
|
}
|
|
),
|
|
)
|
|
)
|
|
|
|
return actions
|
|
|
|
def save(self, params: DidSaveTextDocumentParams):
|
|
self.find_errors_for_uri(params.text_document.uri)
|
|
|
|
def find_errors_for_uri(self, uri: str) -> List["MisSpelling"]:
|
|
"""Find any mis-spellings in the given document."""
|
|
|
|
doctree = self.rst.get_doctree(uri=uri)
|
|
if doctree is None:
|
|
return []
|
|
|
|
errors = []
|
|
|
|
for text in doctree.traverse(condition=nodes.Text):
|
|
parent = text.parent
|
|
# Don't spell check code block, raw blocks etc.
|
|
|
|
if type(parent) in IGNORED_NODES:
|
|
continue
|
|
|
|
# Don't spell check text we cannot tie back to the user's actual source file
|
|
self.logger.debug("%s: %s", type(parent), parent.source)
|
|
if parent.source != Uri.to_fs_path(uri):
|
|
continue
|
|
|
|
for word in find_words(
|
|
text.astext(), startline=parent.line, source=parent.source
|
|
):
|
|
# Ignore short words or any "word" that contains digits or other
|
|
# punctuation.
|
|
if len(word) <= 1 or re.search("[-/\\_\\d.=']", str(word)):
|
|
continue
|
|
|
|
if self.lang.unknown([str(word)]):
|
|
errors.append(word)
|
|
|
|
self.errors[uri] = errors
|
|
self.rst.set_diagnostics("spellcheck[en]", uri, errors_to_diagnostics(errors))
|
|
self.rst.sync_diagnostics()
|
|
|
|
return errors
|
|
|
|
|
|
class MisSpelling:
|
|
"""Represents an incorrectly spelled word."""
|
|
|
|
def __init__(self, line: int, character: int, text: str, source: Optional[str]):
|
|
self.line = line
|
|
self.character = character
|
|
self.text = text
|
|
self.source = source
|
|
|
|
def __str__(self):
|
|
return self.text
|
|
|
|
def __contains__(self, item):
|
|
return item in self.text
|
|
|
|
def __len__(self):
|
|
return len(self.text)
|
|
|
|
def __repr__(self):
|
|
return f"MisSpelling<{self.line}:{self.character}, {self.text}>"
|
|
|
|
|
|
def errors_to_diagnostics(errors: List[MisSpelling]) -> List[Diagnostic]:
|
|
diagnostics = []
|
|
|
|
for error in errors:
|
|
range_ = Range(
|
|
start=Position(line=error.line - 1, character=error.character),
|
|
end=Position(line=error.line - 1, character=error.character + len(error)),
|
|
)
|
|
|
|
diagnostics.append(
|
|
Diagnostic(
|
|
range=range_,
|
|
message=f"Incorrect spelling: '{error}'",
|
|
severity=DiagnosticSeverity.Warning,
|
|
source="spellcheck[en]",
|
|
)
|
|
)
|
|
|
|
return diagnostics
|
|
|
|
|
|
def find_words(
|
|
text: str, startline: int = 0, source: Optional[str] = None
|
|
) -> List[MisSpelling]:
|
|
words = []
|
|
delimiters = " \n"
|
|
skip_characters = ",'\"()[]"
|
|
current_word = None
|
|
|
|
line = startline
|
|
col = -1
|
|
|
|
for c in text:
|
|
col += 1
|
|
|
|
if c in skip_characters:
|
|
continue
|
|
|
|
if c not in delimiters:
|
|
if current_word is None:
|
|
current_word = MisSpelling(line, col, c, source)
|
|
else:
|
|
current_word.text += c
|
|
else:
|
|
if current_word is not None:
|
|
words.append(current_word)
|
|
current_word = None
|
|
|
|
if c == "\n":
|
|
line += 1
|
|
col = -1
|
|
|
|
if current_word is not None:
|
|
words.append(current_word)
|
|
|
|
return words
|
|
|
|
|
|
def esbonio_setup(rst: SphinxLanguageServer):
|
|
rst.add_feature(Spelling(rst))
|