Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 11 additions & 21 deletions plugin/code_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .core.views import entire_content_region
from .core.views import first_selection_region
from .core.views import format_code_actions_for_quick_panel
from .core.views import kind_contains_other_kind
from .core.views import text_document_code_action_params
from .save_command import LspSaveCommand
from .save_command import SaveTask
Expand Down Expand Up @@ -92,18 +93,23 @@ def response_filter(sb: SessionBufferProtocol, actions: list[CodeActionOrCommand
if manual and only_kinds == MENU_ACTIONS_KINDS:
for action in code_actions:
kind = action.get('kind')
if kinds_include_kind([CodeActionKind.Refactor], kind):
if kind and kind_contains_other_kind(CodeActionKind.Refactor, kind):
self.refactor_actions_cache.append((sb.session.config.name, action))
elif kinds_include_kind([CodeActionKind.Source], kind):
elif kind and kind_contains_other_kind(CodeActionKind.Source, kind):
self.source_actions_cache.append((sb.session.config.name, action))
return [action for action in code_actions if kinds_include_kind(only_kinds, action.get('kind'))]
return [
action for action in code_actions
if any(kind_contains_other_kind(only_kind, action.get('kind', '')) for only_kind in only_kinds)
]
if manual:
return [a for a in actions if not a.get('disabled')]
# On implicit (selection change) request, only return commands and quick fix kinds.
return [
a for a in actions
if is_command(a) or not a.get('disabled') and
kinds_include_kind([CodeActionKind.QuickFix], a.get('kind', CodeActionKind.QuickFix))
if is_command(a) or (
not a.get('disabled') and
kind_contains_other_kind(CodeActionKind.QuickFix, a.get('kind', CodeActionKind.QuickFix))
)
]

task = self._collect_code_actions_async(listener, request_factory, response_filter)
Expand Down Expand Up @@ -208,22 +214,6 @@ def get_matching_on_save_kinds(
return matching_kinds


def kinds_include_kind(kinds: list[CodeActionKind], kind: CodeActionKind | None) -> bool:
"""
The "kinds" include "kind" if "kind" matches one of the "kinds" exactly or one of the "kinds" is a prefix
of the whole "kind" (where prefix must be followed by a dot).
"""
if not kind:
return False
for kinds_item in kinds:
if kinds_item == kind:
return True
kinds_item_len = len(kinds_item)
if len(kind) > kinds_item_len and kind.startswith(kinds_item) and kind[kinds_item_len] == '.':
return True
return False


class CodeActionOnSaveTask(SaveTask):
"""
The main task that requests code actions from sessions and runs them.
Expand Down
27 changes: 20 additions & 7 deletions plugin/core/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@
from .protocol import ProgressParams
from .protocol import ProgressToken
from .protocol import PublishDiagnosticsParams
from .protocol import RegistrationParams
from .protocol import Range
from .protocol import RegistrationParams
from .protocol import Request
from .protocol import Response
from .protocol import ResponseError
Expand All @@ -82,8 +82,8 @@
from .protocol import WorkspaceDiagnosticParams
from .protocol import WorkspaceDiagnosticReport
from .protocol import WorkspaceDocumentDiagnosticReport
from .protocol import WorkspaceFullDocumentDiagnosticReport
from .protocol import WorkspaceEdit
from .protocol import WorkspaceFullDocumentDiagnosticReport
from .settings import client_configs
from .settings import globalprefs
from .settings import userprefs
Expand All @@ -106,6 +106,7 @@
from .version import __version__
from .views import extract_variables
from .views import get_uri_and_range_from_location
from .views import kind_contains_other_kind
from .views import MarkdownLangMap
from .workspace import is_subpath_of
from .workspace import WorkspaceFolder
Expand Down Expand Up @@ -1317,6 +1318,7 @@ def __init__(self, manager: Manager, logger: Logger, workspace_folders: list[Wor
self._plugin: AbstractPlugin | None = None
self._status_messages: dict[str, str] = {}
self._semantic_tokens_map = get_semantic_tokens_map(config.semantic_tokens)
self._is_executing_refactoring_command = False

def __getattr__(self, name: str) -> Any:
"""
Expand Down Expand Up @@ -1650,7 +1652,8 @@ def _template_variables(self) -> dict[str, str]:
return variables

def execute_command(
self, command: ExecuteCommandParams, progress: bool, view: sublime.View | None = None
self, command: ExecuteCommandParams, *, progress: bool = False, view: sublime.View | None = None,
is_refactoring: bool = False,
) -> Promise:
"""Run a command from any thread. Your .then() continuations will run in Sublime's worker thread."""
if self._plugin:
Expand Down Expand Up @@ -1678,13 +1681,20 @@ def run_async() -> None:
sublime.set_timeout_async(run_async)
return Promise.resolve(None)
# TODO: Our Promise class should be able to handle errors/exceptions
return Promise(
execute_command = Promise(
lambda resolve: self.send_request(
Request("workspace/executeCommand", command, None, progress),
resolve,
lambda err: resolve(Error(err["code"], err["message"], err.get("data")))
)
)
if is_refactoring:
self._is_executing_refactoring_command = True
execute_command.then(lambda _: self._reset_is_executing_refactoring_command())
return execute_command

def _reset_is_executing_refactoring_command(self) -> None:
self._is_executing_refactoring_command = False

def run_code_action_async(
self, code_action: Command | CodeAction, progress: bool, view: sublime.View | None = None
Expand All @@ -1697,7 +1707,8 @@ def run_code_action_async(
arguments = code_action.get('arguments', None)
if isinstance(arguments, list):
command_params['arguments'] = arguments
return self.execute_command(command_params, progress, view)
is_refactoring = kind_contains_other_kind(CodeActionKind.Refactor, code_action.get('kind', ''))
return self.execute_command(command_params, progress=progress, view=view, is_refactoring=is_refactoring)
# At this point it cannot be a command anymore, it has to be a proper code action.
# A code action can have an edit and/or command. Note that it can have *both*. In case both are present, we
# must apply the edits before running the command.
Expand Down Expand Up @@ -1826,7 +1837,7 @@ def _apply_code_action_async(
self.window.status_message(f"Failed to apply code action: {code_action}")
return Promise.resolve(None)
edit = code_action.get("edit")
is_refactoring = code_action.get('kind') == CodeActionKind.Refactor
is_refactoring = kind_contains_other_kind(CodeActionKind.Refactor, code_action.get('kind', ''))
promise = self.apply_workspace_edit_async(edit, is_refactoring) if edit else Promise.resolve(None)
command = code_action.get("command")
if command is not None:
Expand All @@ -1836,14 +1847,16 @@ def _apply_code_action_async(
arguments = command.get("arguments")
if arguments is not None:
execute_command['arguments'] = arguments
return promise.then(lambda _: self.execute_command(execute_command, progress=False, view=view))
return promise.then(lambda _: self.execute_command(execute_command, progress=False, view=view,
is_refactoring=is_refactoring))
return promise

def apply_workspace_edit_async(self, edit: WorkspaceEdit, is_refactoring: bool = False) -> Promise[None]:
"""
Apply workspace edits, and return a promise that resolves on the async thread again after the edits have been
applied.
"""
is_refactoring = self._is_executing_refactoring_command or is_refactoring
return self.apply_parsed_workspace_edits(parse_workspace_edit(edit), is_refactoring)

def apply_parsed_workspace_edits(self, changes: WorkspaceChanges, is_refactoring: bool = False) -> Promise[None]:
Expand Down
13 changes: 13 additions & 0 deletions plugin/core/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,3 +853,16 @@ def format_code_actions_for_quick_panel(
if code_action.get('isPreferred', False):
selected_index = idx
return items, selected_index


def kind_contains_other_kind(kind: str, other_kind: str) -> bool:
"""
Check if `other_kind` is a sub-kind of `kind`.

The kind `"refactor.extract"` for example contains `"refactor.extract"` and ``"refactor.extract.function"`,
but not `"unicorn.refactor.extract"`, or `"refactor.extractAll"` or `refactor`.
"""
if kind == other_kind:
return True
kind_len = len(kind)
return len(other_kind) > kind_len and other_kind.startswith(kind + '.')
25 changes: 13 additions & 12 deletions tests/test_code_actions.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from __future__ import annotations
from copy import deepcopy
from LSP.plugin.code_actions import get_matching_on_save_kinds, kinds_include_kind
from LSP.plugin.code_actions import get_matching_on_save_kinds
from LSP.plugin.core.constants import RegionKey
from LSP.plugin.core.protocol import Point, Range
from LSP.plugin.core.url import filename_to_uri
from LSP.plugin.core.views import entire_content
from LSP.plugin.documents import DocumentSyncListener
from LSP.plugin.core.views import kind_contains_other_kind
from LSP.plugin.core.views import Point
from LSP.plugin.core.views import Range
from LSP.plugin.core.views import versioned_text_document_identifier
from LSP.plugin.documents import DocumentSyncListener
from setup import TextDocumentTestCase
from test_single_document import TEST_FILE_PATH
from typing import Any, Generator
Expand Down Expand Up @@ -254,16 +256,15 @@ def test_does_not_match_disabled_action(self) -> None:

def test_kind_matching(self) -> None:
# Positive
self.assertTrue(kinds_include_kind(['a'], 'a.b'))
self.assertTrue(kinds_include_kind(['a.b'], 'a.b'))
self.assertTrue(kinds_include_kind(['a.b', 'b'], 'b.c'))
self.assertTrue(kind_contains_other_kind('a', 'a.b'))
self.assertTrue(kind_contains_other_kind('a.b', 'a.b'))
# Negative
self.assertFalse(kinds_include_kind(['a'], 'b.a'))
self.assertFalse(kinds_include_kind(['a.b'], 'b'))
self.assertFalse(kinds_include_kind(['a.b'], 'a'))
self.assertFalse(kinds_include_kind(['aa'], 'a'))
self.assertFalse(kinds_include_kind(['aa.b'], 'a'))
self.assertFalse(kinds_include_kind(['aa.b'], 'b'))
self.assertFalse(kind_contains_other_kind('a', 'b.a'))
self.assertFalse(kind_contains_other_kind('a.b', 'b'))
self.assertFalse(kind_contains_other_kind('a.b', 'a'))
self.assertFalse(kind_contains_other_kind('aa', 'a'))
self.assertFalse(kind_contains_other_kind('aa.b', 'a'))
self.assertFalse(kind_contains_other_kind('aa.b', 'b'))


class CodeActionsListenerTestCase(TextDocumentTestCase):
Expand Down
Loading