diff --git a/changelog.md b/changelog.md index 3644bcf2..1aabc556 100644 --- a/changelog.md +++ b/changelog.md @@ -5,6 +5,7 @@ Features -------- * "Eager" completions for the `source` command, limited to `*.sql` files. * Suggest column names from all tables in the current database after SELECT (#212) +* Put fuzzy completions more often to the bottom of the suggestion list. Bug Fixes diff --git a/mycli/sqlcompleter.py b/mycli/sqlcompleter.py index 40a7d49d..6996949c 100644 --- a/mycli/sqlcompleter.py +++ b/mycli/sqlcompleter.py @@ -1,6 +1,7 @@ from __future__ import annotations from collections import Counter +from enum import IntEnum import logging import re from typing import Any, Collection, Generator, Iterable, Literal @@ -20,6 +21,14 @@ _logger = logging.getLogger(__name__) +class Fuzziness(IntEnum): + PERFECT = 0 + REGEX = 1 + UNDER_WORDS = 2 + CAMEL_CASE = 3 + RAPIDFUZZ = 4 + + class SQLCompleter(Completer): favorite_keywords = [ 'SELECT', @@ -956,7 +965,7 @@ def find_matches( start_only: bool = False, fuzzy: bool = True, casing: str | None = None, - ) -> Generator[Completion, None, None]: + ) -> Generator[tuple[str, int], None, None]: """Find completion matches for the given text. Given the user's input text and a collection of available @@ -975,10 +984,14 @@ def find_matches( # unicode support not possible without adding the regex dependency case_change_pat = re.compile("(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])") - completions: list[str] = [] + completions: list[tuple[str, int]] = [] + + def empty_generator(): + for item in []: + yield item if re.match(r'^[\d\.]', text): - return (Completion(x, -len(text)) for x in completions) + return empty_generator() if fuzzy: regex = ".{0,3}?".join(map(re.escape, text)) @@ -989,7 +1002,7 @@ def find_matches( for item in collection: r = pat.search(item.lower()) if r: - completions.append(item) + completions.append((item, Fuzziness.REGEX)) continue under_words_item = [x for x in item.lower().split('_') if x] @@ -1000,7 +1013,7 @@ def find_matches( occurrences += 1 break if occurrences >= len(under_words_text): - completions.append(item) + completions.append((item, Fuzziness.UNDER_WORDS)) continue case_words_item = re.split(case_change_pat, item) @@ -1011,7 +1024,7 @@ def find_matches( occurrences += 1 break if occurrences >= len(case_words_text): - completions.append(item) + completions.append((item, Fuzziness.CAMEL_CASE)) continue if len(text) >= 4: @@ -1031,31 +1044,25 @@ def find_matches( continue if item in completions: continue - completions.append(item) + completions.append((item, Fuzziness.RAPIDFUZZ)) else: match_end_limit = len(text) if start_only else None for item in collection: match_point = item.lower().find(text, 0, match_end_limit) if match_point >= 0: - completions.append(item) + completions.append((item, Fuzziness.PERFECT)) if casing == "auto": casing = "lower" if last and (last[0].islower() or last[-1].islower()) else "upper" - def apply_case(kw: str) -> str: + def apply_case(tup: tuple[str, int]) -> tuple[str, int]: + kw, fuzziness = tup if casing == "upper": - return kw.upper() - return kw.lower() - - def exact_leading_key(item: str, text: str): - if text and item.lower().startswith(text): - return -1000 + len(item) - return 0 + return (kw.upper(), fuzziness) + return (kw.lower(), fuzziness) - completions = sorted(completions, key=lambda item: exact_leading_key(item, text)) - - return (Completion(x if casing is None else apply_case(x), -len(text)) for x in completions) + return (x if casing is None else apply_case(x) for x in completions) def get_completions( self, @@ -1064,19 +1071,26 @@ def get_completions( smart_completion: bool | None = None, ) -> Iterable[Completion]: word_before_cursor = document.get_word_before_cursor(WORD=True) + last_for_len = last_word(word_before_cursor, include="most_punctuations") + text_for_len = last_for_len.lower() + if smart_completion is None: smart_completion = self.smart_completion # If smart_completion is off then match any word that starts with # 'word_before_cursor'. if not smart_completion: - return self.find_matches(word_before_cursor, self.all_completions, start_only=True, fuzzy=False) + matches = self.find_matches(word_before_cursor, self.all_completions, start_only=True, fuzzy=False) + return (Completion(x[0], -len(text_for_len)) for x in matches) - completions: list[Completion] = [] + completions: list[tuple[str, int, int]] = [] suggestions = suggest_type(document.text, document.text_before_cursor) + rigid_sort = False + rank = 0 for suggestion in suggestions: _logger.debug("Suggestion type: %r", suggestion["type"]) + rank += 1 if suggestion["type"] == "column": tables = suggestion["tables"] @@ -1093,13 +1107,13 @@ def get_completions( scoped_cols = sorted(set(scoped_cols), key=lambda s: s.strip('`')) cols = self.find_matches(word_before_cursor, scoped_cols) - completions.extend(cols) + completions.extend([(*x, rank) for x in cols]) elif suggestion["type"] == "function": # suggest user-defined functions using substring matching funcs = self.populate_schema_objects(suggestion["schema"], "functions") user_funcs = self.find_matches(word_before_cursor, funcs) - completions.extend(user_funcs) + completions.extend([(*x, rank) for x in user_funcs]) # suggest hardcoded functions using startswith matching only if # there is no schema qualifier. If a schema qualifier is @@ -1109,67 +1123,69 @@ def get_completions( predefined_funcs = self.find_matches( word_before_cursor, self.functions, start_only=True, fuzzy=False, casing=self.keyword_casing ) - completions.extend(predefined_funcs) + completions.extend([(*x, rank) for x in predefined_funcs]) elif suggestion["type"] == "procedure": procs = self.populate_schema_objects(suggestion["schema"], "procedures") procs_m = self.find_matches(word_before_cursor, procs) - completions.extend(procs_m) + completions.extend([(*x, rank) for x in procs_m]) elif suggestion["type"] == "table": tables = self.populate_schema_objects(suggestion["schema"], "tables") tables_m = self.find_matches(word_before_cursor, tables) - completions.extend(tables_m) + completions.extend([(*x, rank) for x in tables_m]) elif suggestion["type"] == "view": views = self.populate_schema_objects(suggestion["schema"], "views") views_m = self.find_matches(word_before_cursor, views) - completions.extend(views_m) + completions.extend([(*x, rank) for x in views_m]) elif suggestion["type"] == "alias": aliases = suggestion["aliases"] aliases_m = self.find_matches(word_before_cursor, aliases) - completions.extend(aliases_m) + completions.extend([(*x, rank) for x in aliases_m]) elif suggestion["type"] == "database": dbs_m = self.find_matches(word_before_cursor, self.databases) - completions.extend(dbs_m) + completions.extend([(*x, rank) for x in dbs_m]) elif suggestion["type"] == "keyword": keywords_m = self.find_matches(word_before_cursor, self.keywords, casing=self.keyword_casing) - completions.extend(keywords_m) + completions.extend([(*x, rank) for x in keywords_m]) elif suggestion["type"] == "show": show_items_m = self.find_matches( word_before_cursor, self.show_items, start_only=False, fuzzy=True, casing=self.keyword_casing ) - completions.extend(show_items_m) + completions.extend([(*x, rank) for x in show_items_m]) elif suggestion["type"] == "change": change_items_m = self.find_matches(word_before_cursor, self.change_items, start_only=False, fuzzy=True) - completions.extend(change_items_m) + completions.extend([(*x, rank) for x in change_items_m]) elif suggestion["type"] == "user": users_m = self.find_matches(word_before_cursor, self.users, start_only=False, fuzzy=True) - completions.extend(users_m) + completions.extend([(*x, rank) for x in users_m]) elif suggestion["type"] == "special": special_m = self.find_matches(word_before_cursor, self.special_commands, start_only=True, fuzzy=False) # specials are special, and go early in the candidates, first if possible - completions = list(special_m) + completions + completions.extend([(*x, 0) for x in special_m]) elif suggestion["type"] == "favoritequery": if hasattr(FavoriteQueries, 'instance') and hasattr(FavoriteQueries.instance, 'list'): queries_m = self.find_matches(word_before_cursor, FavoriteQueries.instance.list(), start_only=False, fuzzy=True) - completions.extend(queries_m) + completions.extend([(*x, rank) for x in queries_m]) elif suggestion["type"] == "table_format": formats_m = self.find_matches(word_before_cursor, self.table_formats) - completions.extend(formats_m) + completions.extend([(*x, rank) for x in formats_m]) elif suggestion["type"] == "file_name": file_names_m = self.find_files(word_before_cursor) - completions.extend(file_names_m) + completions.extend([(*x, rank) for x in file_names_m]) + # for filenames we _really_ want directories to go last + rigid_sort = True elif suggestion["type"] == "llm": if not word_before_cursor: tokens = document.text.split()[1:] @@ -1182,7 +1198,7 @@ def get_completions( start_only=False, fuzzy=True, ) - completions.extend(subcommands_m) + completions.extend([(*x, rank) for x in subcommands_m]) elif suggestion["type"] == "enum_value": enum_values = self.populate_enum_values( suggestion["tables"], @@ -1191,23 +1207,44 @@ def get_completions( ) if enum_values: quoted_values = [self._quote_sql_string(value) for value in enum_values] - return list(self.find_matches(word_before_cursor, quoted_values)) + completions = [(*x, rank) for x in self.find_matches(word_before_cursor, quoted_values)] + break + + def completion_sort_key(item: tuple[str, int, int], text_for_len: str): + candidate, fuzziness, rank = item + if not text_for_len: + # sort only by the rank (the order of the completion type) + return (0, rank, 0) + elif candidate.lower().startswith(text_for_len): + # sort only by the length of the candidate + return (0, 0, -1000 + len(candidate)) + # sort by fuzziness and rank + # todo add alpha here, or original order? + return (fuzziness, rank, 0) + + if rigid_sort: + uniq_completions_str = dict.fromkeys(x[0] for x in completions) + else: + sorted_completions = sorted(completions, key=lambda item: completion_sort_key(item, text_for_len.lower())) + uniq_completions_str = dict.fromkeys(x[0] for x in sorted_completions) - return completions + return (Completion(x, -len(text_for_len)) for x in uniq_completions_str) - def find_files(self, word: str) -> Generator[Completion, None, None]: + def find_files(self, word: str) -> Generator[tuple[str, int], None, None]: """Yield matching directory or file names. :param word: :return: iterable """ + # todo position is ignored, but may need to be used + # todo fuzzy matches for filenames base_path, last_path, position = parse_path(word) paths = suggest_path(word) for name in paths: suggestion = complete_path(name, last_path) if suggestion: - yield Completion(suggestion, position) + yield (suggestion, Fuzziness.PERFECT) def populate_scoped_cols(self, scoped_tbls: list[tuple[str | None, str, str | None]]) -> list[str]: """Find all columns in a set of scoped_tables diff --git a/test/test_smart_completion_public_schema_only.py b/test/test_smart_completion_public_schema_only.py index 0ee337cf..2afa8eab 100644 --- a/test/test_smart_completion_public_schema_only.py +++ b/test/test_smart_completion_public_schema_only.py @@ -58,6 +58,7 @@ def complete_event(): def test_use_database_completion(completer, complete_event): text = "USE " position = len(text) + special.register_special_command(..., 'use', '\\u', 'Change to a new database.', aliases=['\\u']) result = completer.get_completions(Document(text=text, cursor_position=position), complete_event) assert list(result) == [ Completion(text="test", start_position=0), @@ -69,7 +70,7 @@ def test_special_name_completion(completer, complete_event): text = "\\d" position = len("\\d") result = completer.get_completions(Document(text=text, cursor_position=position), complete_event) - assert result == [Completion(text="\\dt", start_position=-2)] + assert list(result) == [Completion(text="\\dt", start_position=-2)] def test_empty_string_completion(completer, complete_event): @@ -136,14 +137,12 @@ def test_function_name_completion(completer, complete_event): position = len("SELECT MA") result = completer.get_completions(Document(text=text, cursor_position=position), complete_event) assert list(result) == [ - Completion(text='email', start_position=-2), Completion(text='MAX', start_position=-2), + Completion(text='MATCH', start_position=-2), + Completion(text='MASTER', start_position=-2), Completion(text='MAKE_SET', start_position=-2), Completion(text='MAKEDATE', start_position=-2), Completion(text='MAKETIME', start_position=-2), - Completion(text='MASTER_POS_WAIT', start_position=-2), - Completion(text='MATCH', start_position=-2), - Completion(text='MASTER', start_position=-2), Completion(text='MAX_ROWS', start_position=-2), Completion(text='MAX_SIZE', start_position=-2), Completion(text='MAXVALUE', start_position=-2), @@ -157,6 +156,7 @@ def test_function_name_completion(completer, complete_event): Completion(text='MASTER_LOG_POS', start_position=-2), Completion(text='MASTER_SSL_CRL', start_position=-2), Completion(text='MASTER_SSL_KEY', start_position=-2), + Completion(text='MASTER_POS_WAIT', start_position=-2), Completion(text='MASTER_LOG_FILE', start_position=-2), Completion(text='MASTER_PASSWORD', start_position=-2), Completion(text='MASTER_SSL_CERT', start_position=-2), @@ -177,6 +177,7 @@ def test_function_name_completion(completer, complete_event): Completion(text='MASTER_COMPRESSION_ALGORITHMS', start_position=-2), Completion(text='MASTER_SSL_VERIFY_SERVER_CERT', start_position=-2), Completion(text='MASTER_ZSTD_COMPRESSION_LEVEL', start_position=-2), + Completion(text='email', start_position=-2), Completion(text='DECIMAL', start_position=-2), Completion(text='SMALLINT', start_position=-2), Completion(text='TIMESTAMP', start_position=-2), @@ -231,7 +232,7 @@ def test_suggested_column_names(completer, complete_event): ] + list(map(Completion, completer.functions)) + [Completion(text="users", start_position=0)] - + list(map(Completion, completer.keywords)) + + [x for x in map(Completion, completer.keywords) if x.text not in completer.functions] ) @@ -318,7 +319,7 @@ def test_suggested_multiple_column_names(completer, complete_event): ] + list(map(Completion, completer.functions)) + [Completion(text="u", start_position=0)] - + list(map(Completion, completer.keywords)) + + [x for x in map(Completion, completer.keywords) if x.text not in completer.functions] ) @@ -460,32 +461,31 @@ def test_table_names_fuzzy(completer, complete_event): def test_auto_escaped_col_names(completer, complete_event): text = "SELECT from `select`" position = len("SELECT ") - result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) - assert result == [ - Completion(text="*", start_position=0), - Completion(text="id", start_position=0), - Completion(text="`insert`", start_position=0), - Completion(text="ABC", start_position=0), - ] + list(map(Completion, completer.functions)) + [Completion(text="select", start_position=0)] + list( - map(Completion, completer.keywords) + result = [x.text for x in completer.get_completions(Document(text=text, cursor_position=position), complete_event)] + expected = ( + [ + "*", + "id", + "`insert`", + "ABC", + ] + + completer.functions + + ["select"] + + [x for x in completer.keywords if x not in completer.functions] ) + assert result == expected def test_un_escaped_table_names(completer, complete_event): text = "SELECT from réveillé" position = len("SELECT ") - result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) - assert result == list( - [ - Completion(text="*", start_position=0), - Completion(text="id", start_position=0), - Completion(text="`insert`", start_position=0), - Completion(text="ABC", start_position=0), - ] - + list(map(Completion, completer.functions)) - + [Completion(text="réveillé", start_position=0)] - + list(map(Completion, completer.keywords)) - ) + result = [x.text for x in completer.get_completions(Document(text=text, cursor_position=position), complete_event)] + assert result == [ + "*", + "id", + "`insert`", + "ABC", + ] + completer.functions + ["réveillé"] + [x for x in completer.keywords if x not in completer.functions] # todo: the fixtures are insufficient; the database name should also appear in the result @@ -551,18 +551,18 @@ def dummy_list_path(dir_name): @patch("mycli.packages.filepaths.list_path", new=dummy_list_path) @pytest.mark.parametrize( "text,expected", + # it may be that the cursor positions should be 0, but the position + # info is currently being dropped in find_files() [ - # ('source ', [('~', 0), - # ('/', 0), - # ('.', 0), - # ('..', 0)]), - ("source /", [("dir1", 0), ("file1.sql", 0), ("file2.sql", 0)]), - ("source /dir1/", [("subdir1", 0), ("subfile1.sql", 0), ("subfile2.sql", 0)]), - ("source /dir1/subdir1/", [("lastfile.sql", 0)]), + ('source ', [('/', 0), ('~', 0), ('.', 0), ('..', 0)]), + ("source /", [("dir1", -1), ("file1.sql", -1), ("file2.sql", -1)]), + ("source /dir1/", [("subdir1", -6), ("subfile1.sql", -6), ("subfile2.sql", -6)]), + ("source /dir1/subdir1/", [("lastfile.sql", -14)]), ], ) def test_file_name_completion(completer, complete_event, text, expected): position = len(text) + special.register_special_command(..., 'source', '\\. filename', 'Execute commands from file.', aliases=['\\.']) result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) expected = [Completion(txt, pos) for txt, pos in expected] assert result == expected @@ -599,6 +599,7 @@ def test_source_eager_completion(completer, complete_event): script_filename = 'script_for_test_suite.sql' f = open(script_filename, 'w') f.close() + special.register_special_command(..., 'source', '\\. filename', 'Execute commands from file.', aliases=['\\.']) result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) success = True error = 'unknown'