diff --git a/ai_review/algo/language_handler.py b/ai_review/algo/language_handler.py index 37f0262..27ec96e 100644 --- a/ai_review/algo/language_handler.py +++ b/ai_review/algo/language_handler.py @@ -1,15 +1,53 @@ # Language Selection, source: https://github.com/bigcode-project/bigcode-dataset/blob/main/language_selection/programming-languages-to-file-extensions.json # noqa E501 -from typing import Dict +from typing import Dict, List, TypeVar, Optional, Set +from dataclasses import dataclass +from collections import defaultdict +from pathlib import Path from ai_review.config_loader import get_settings +FileType = TypeVar('FileType') # 为文件对象创建类型变量 -def filter_bad_extensions(files): - # Bad Extensions, source: https://github.com/EleutherAI/github-downloader/blob/345e7c4cbb9e0dc8a0615fd995a08bf9d73b3fe6/download_repo_text.py # noqa: E501 - bad_extensions = get_settings().bad_extensions.default - if get_settings().config.use_extra_bad_extensions: - bad_extensions += get_settings().bad_extensions.extra - return [f for f in files if f.filename is not None and is_valid_file(f.filename, bad_extensions)] +@dataclass +class LanguageGroup: + language: str + files: List[FileType] + +def get_file_extension(filename: str) -> Optional[str]: + """获取文件扩展名 + + Args: + filename: 文件名 + Returns: + 文件扩展名(包含点),如果没有扩展名则返回None + """ + if not filename: + return None + ext = Path(filename).suffix + return ext.lower() if ext else None + +def get_bad_extensions() -> Set[str]: + """获取需要过滤的文件扩展名集合 + + Returns: + 需要过滤的文件扩展名集合 + """ + settings = get_settings() + bad_extensions = set(settings.bad_extensions.default) + if settings.config.use_extra_bad_extensions: + bad_extensions.update(settings.bad_extensions.extra) + return bad_extensions + +def filter_bad_extensions(files: List[FileType]) -> List[FileType]: + """过滤掉不需要的文件扩展名 + + Args: + files: 需要过滤的文件列表 + Returns: + 过滤后的文件列表 + """ + bad_extensions = get_bad_extensions() + return [f for f in files if f.filename and get_file_extension(f.filename).strip('.') not in bad_extensions] def is_valid_file(filename:str, bad_extensions=None) -> bool: @@ -22,49 +60,56 @@ def is_valid_file(filename:str, bad_extensions=None) -> bool: return filename.split('.')[-1] not in bad_extensions -def sort_files_by_main_languages(languages: Dict, files: list): - """ - Sort files by their main language, put the files that are in the main language first and the rest files after +def sort_files_by_main_languages(languages: Dict[str, int], files: List[FileType]) -> List[LanguageGroup]: + """按主要语言对文件进行分类 + + Args: + languages: 语言使用统计字典 + files: 需要分类的文件列表 + Returns: + 按语言分组的文件列表 """ - # sort languages by their size - languages_sorted_list = [k for k, v in sorted(languages.items(), key=lambda item: item[1], reverse=True)] - # languages_sorted = sorted(languages, key=lambda x: x[1], reverse=True) - # get all extensions for the languages - main_extensions = [] - language_extension_map_org = get_settings().language_extension_map_org - language_extension_map = {k.lower(): v for k, v in language_extension_map_org.items()} - for language in languages_sorted_list: - if language.lower() in language_extension_map: - main_extensions.append(language_extension_map[language.lower()]) - else: - main_extensions.append([]) - - # filter out files bad extensions - files_filtered = filter_bad_extensions(files) - # sort files by their extension, put the files that are in the main extension first - # and the rest files after, map languages_sorted to their respective files - files_sorted = [] - rest_files = {} - - # if no languages detected, put all files in the "Other" category if not languages: - files_sorted = [({"language": "Other", "files": list(files_filtered)})] - return files_sorted - - main_extensions_flat = [] - for ext in main_extensions: - main_extensions_flat.extend(ext) - - for extensions, lang in zip(main_extensions, languages_sorted_list): # noqa: B905 - tmp = [] - for file in files_filtered: - extension_str = f".{file.filename.split('.')[-1]}" - if extension_str in extensions: - tmp.append(file) - else: - if (file.filename not in rest_files) and (extension_str not in main_extensions_flat): - rest_files[file.filename] = file - if len(tmp) > 0: - files_sorted.append({"language": lang, "files": tmp}) - files_sorted.append({"language": "Other", "files": list(rest_files.values())}) - return files_sorted + return [LanguageGroup(language="Other", files=filter_bad_extensions(files))] + + settings = get_settings() + language_map = {k.lower(): set(v) for k, v in settings.language_extension_map_org.items()} + + # 预处理语言映射 + ext_to_lang = {} + all_main_extensions = set() + for lang, count in sorted(languages.items(), key=lambda x: x[1], reverse=True): + if lang.lower() in language_map: + for ext in language_map[lang.lower()]: + ext = ext.lower() + ext_to_lang[ext] = lang + all_main_extensions.add(ext) + + # 文件分类 + lang_groups = defaultdict(list) + other_files = [] + + filtered_files = filter_bad_extensions(files) + for file in filtered_files: + if not file.filename: + continue + ext = get_file_extension(file.filename) + if ext in ext_to_lang: + lang_groups[ext_to_lang[ext]].append(file) + elif ext not in all_main_extensions: + other_files.append(file) + + # 构建结果 + result = [ + LanguageGroup(language=lang, files=files) + for lang, files in sorted( + lang_groups.items(), + key=lambda x: languages.get(x[0], 0), + reverse=True + ) + ] + + if other_files: + result.append(LanguageGroup(language="Other", files=other_files)) + + return result diff --git a/ai_review/algo/token_handler.py b/ai_review/algo/token_handler.py index eff238a..b570a61 100644 --- a/ai_review/algo/token_handler.py +++ b/ai_review/algo/token_handler.py @@ -37,56 +37,54 @@ class TokenHandler: method. """ - def __init__(self, pr=None, vars: dict = None, system="", user=""): - if vars is None: - vars = {} - ... + def __init__(self, pr=None, vars: dict = None, system: str = "", user: str = ""): """ - Initializes the TokenHandler object. - + 初始化 TokenHandler 对象 + Args: - - pr: The pull request object. - - vars: A dictionary of variables. - - system: The system string. - - user: The user string. + pr: Pull Request 对象 + vars: 变量字典,默认为空字典 + system: 系统提示字符串 + user: 用户提示字符串 """ + self.vars = vars or {} # 使用更简洁的空字典初始化 self.encoder = TokenEncoder.get_token_encoder() - if pr is not None: - self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user) + self.prompt_tokens = (self._get_system_user_tokens(pr, self.encoder, self.vars, system, user) + if pr is not None else 0) - def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user): + def _get_system_user_tokens(self, pr, encoder, vars: dict, system: str, user: str) -> int: """ - Calculates the number of tokens in the system and user strings. - + 计算系统和用户字符串中的令牌数 + Args: - - pr: The pull request object. - - encoder: An object of the encoding_for_model class from the tiktoken module. - - vars: A dictionary of variables. - - system: The system string. - - user: The user string. - + pr: Pull Request 对象 + encoder: tiktoken 编码器实例 + vars: 变量字典 + system: 系统提示字符串 + user: 用户提示字符串 + Returns: - The sum of the number of tokens in the system and user strings. + int: 系统和用户字符串中的总令牌数 """ try: environment = Environment(undefined=StrictUndefined) system_prompt = environment.from_string(system).render(vars) user_prompt = environment.from_string(user).render(vars) - system_prompt_tokens = len(encoder.encode(system_prompt)) - user_prompt_tokens = len(encoder.encode(user_prompt)) - return system_prompt_tokens + user_prompt_tokens + + return len(encoder.encode(system_prompt)) + len(encoder.encode(user_prompt)) + except Exception as e: - get_logger().error(f"Error in _get_system_user_tokens: {e}") + get_logger().error(f"Error in _get_system_user_tokens: {str(e)}") return 0 def count_tokens(self, patch: str) -> int: """ - Counts the number of tokens in a given patch string. - + 计算给定补丁字符串中的令牌数 + Args: - - patch: The patch string. - + patch: 补丁字符串 + Returns: - The number of tokens in the patch string. + int: 补丁字符串中的令牌数 """ return len(self.encoder.encode(patch, disallowed_special=())) diff --git a/ai_review/algo/utils.py b/ai_review/algo/utils.py index 9def125..4c0d7cd 100644 --- a/ai_review/algo/utils.py +++ b/ai_review/algo/utils.py @@ -1,1083 +1,3 @@ -from __future__ import annotations - -import copy -import difflib -import hashlib -import html -import json -import os -import re -import sys -import textwrap -import time -import traceback -from datetime import datetime -from enum import Enum -from importlib.metadata import PackageNotFoundError, version -from typing import Any, List, Tuple - -import html2text -import requests -import yaml -from pydantic import BaseModel -from starlette_context import context - -from ai_review.algo import MAX_TOKENS -from ai_review.algo.git_patch_processing import extract_hunk_lines_from_patch -from ai_review.algo.token_handler import TokenEncoder -from ai_review.algo.types import FilePatchInfo -from ai_review.config_loader import get_settings, global_settings -from ai_review.log import get_logger - - -def get_weak_model() -> str: - if get_settings().get("config.model_weak"): - return get_settings().config.model_weak - return get_settings().config.model - - -class Range(BaseModel): - line_start: int # should be 0-indexed - line_end: int - column_start: int = -1 - column_end: int = -1 - -class ModelType(str, Enum): - REGULAR = "regular" - WEAK = "weak" - -class PRReviewHeader(str, Enum): - REGULAR = "## PR Reviewer Guide" - INCREMENTAL = "## Incremental PR Reviewer Guide" - - -class PRDescriptionHeader(str, Enum): - CHANGES_WALKTHROUGH = "### **Changes walkthrough** 📝" - - -def get_setting(key: str) -> Any: - try: - key = key.upper() - return context.get("settings", global_settings).get(key, global_settings.get(key, None)) - except Exception: - return global_settings.get(key, None) - - -def emphasize_header(text: str, only_markdown=False, reference_link=None) -> str: - try: - # Finding the position of the first occurrence of ": " - colon_position = text.find(": ") - - # Splitting the string and wrapping the first part in tags - if colon_position != -1: - # Everything before the colon (inclusive) is wrapped in tags - if only_markdown: - if reference_link: - transformed_string = f"[**{text[:colon_position + 1]}**]({reference_link})\n" + text[colon_position + 1:] - else: - transformed_string = f"**{text[:colon_position + 1]}**\n" + text[colon_position + 1:] - else: - if reference_link: - transformed_string = f"{text[:colon_position + 1]}
" + text[colon_position + 1:] - else: - transformed_string = "" + text[:colon_position + 1] + "" +'
' + text[colon_position + 1:] - else: - # If there's no ": ", return the original string - transformed_string = text - - return transformed_string - except Exception as e: - get_logger().exception(f"Failed to emphasize header: {e}") - return text - - -def unique_strings(input_list: List[str]) -> List[str]: - if not input_list or not isinstance(input_list, list): - return input_list - seen = set() - unique_list = [] - for item in input_list: - if item not in seen: - unique_list.append(item) - seen.add(item) - return unique_list - -def convert_to_markdown_v2(output_data: dict, - gfm_supported: bool = True, - incremental_review=None, - git_provider=None, - files=None) -> str: - """ - Convert a dictionary of data into markdown format. - Args: - output_data (dict): A dictionary containing data to be converted to markdown format. - Returns: - str: The markdown formatted text generated from the input dictionary. - """ - - emojis = { - "Can be split": "🔀", - "Key issues to review": "⚡", - "Recommended focus areas for review": "⚡", - "Score": "🏅", - "Relevant tests": "🧪", - "Focused PR": "✨", - "Relevant ticket": "🎫", - "Security concerns": "🔒", - "Insights from user's answers": "📝", - "Code feedback": "🤖", - "Estimated effort to review [1-5]": "⏱️", - "Ticket compliance check": "🎫", - } - markdown_text = "" - if not incremental_review: - markdown_text += f"{PRReviewHeader.REGULAR.value} 🔍\n\n" - else: - markdown_text += f"{PRReviewHeader.INCREMENTAL.value} 🔍\n\n" - markdown_text += f"⏮️ Review for commits since previous PR-Agent review {incremental_review}.\n\n" - if not output_data or not output_data.get('review', {}): - return "" - - if get_settings().get("pr_reviewer.enable_intro_text", False): - markdown_text += f"Here are some key observations to aid the review process:\n\n" - - if gfm_supported: - markdown_text += "\n" - - for key, value in output_data['review'].items(): - if value is None or value == '' or value == {} or value == []: - if key.lower() not in ['can_be_split', 'key_issues_to_review']: - continue - key_nice = key.replace('_', ' ').capitalize() - emoji = emojis.get(key_nice, "") - if 'Estimated effort to review' in key_nice: - key_nice = 'Estimated effort to review' - value = str(value).strip() - if value.isnumeric(): - value_int = int(value) - else: - try: - value_int = int(value.split(',')[0]) - except ValueError: - continue - blue_bars = '🔵' * value_int - white_bars = '⚪' * (5 - value_int) - value = f"{value_int} {blue_bars}{white_bars}" - if gfm_supported: - markdown_text += f"\n" - else: - markdown_text += f"### {emoji} {key_nice}: {value}\n\n" - elif 'relevant tests' in key_nice.lower(): - value = str(value).strip().lower() - if gfm_supported: - markdown_text += f"\n" - else: - if is_value_no(value): - markdown_text += f'### {emoji} No relevant tests\n\n' - else: - markdown_text += f"### {emoji} PR contains tests\n\n" - elif 'ticket compliance check' in key_nice.lower(): - markdown_text = ticket_markdown_logic(emoji, markdown_text, value, gfm_supported) - elif 'security concerns' in key_nice.lower(): - if gfm_supported: - markdown_text += f"\n" - else: - if is_value_no(value): - markdown_text += f'### {emoji} No security concerns identified\n\n' - else: - markdown_text += f"### {emoji} Security concerns\n\n" - value = emphasize_header(value.strip(), only_markdown=True) - markdown_text += f"{value}\n\n" - elif 'can be split' in key_nice.lower(): - if gfm_supported: - markdown_text += f"\n" - elif 'key issues to review' in key_nice.lower(): - # value is a list of issues - if is_value_no(value): - if gfm_supported: - markdown_text += f"\n" - else: - markdown_text += f"### {emoji} No major issues detected\n\n" - else: - issues = value - if gfm_supported: - markdown_text += f"\n" - else: - if gfm_supported: - markdown_text += f"\n" - else: - markdown_text += f"### {emoji} {key_nice}: {value}\n\n" - - if gfm_supported: - markdown_text += "
" - markdown_text += f"{emoji} {key_nice}: {value}" - markdown_text += f"
" - if is_value_no(value): - markdown_text += f"{emoji} No relevant tests" - else: - markdown_text += f"{emoji} PR contains tests" - markdown_text += f"
" - if is_value_no(value): - markdown_text += f"{emoji} No security concerns identified" - else: - markdown_text += f"{emoji} Security concerns

\n\n" - value = emphasize_header(value.strip()) - markdown_text += f"{value}" - markdown_text += f"
" - markdown_text += process_can_be_split(emoji, value) - markdown_text += f"
" - markdown_text += f"{emoji} No major issues detected" - markdown_text += f"
" - # markdown_text += f"{emoji} {key_nice}

\n\n" - markdown_text += f"{emoji} Recommended focus areas for review

\n\n" - else: - markdown_text += f"### {emoji} Recommended focus areas for review\n\n#### \n" - for i, issue in enumerate(issues): - try: - if not issue or not isinstance(issue, dict): - continue - relevant_file = issue.get('relevant_file', '').strip() - issue_header = issue.get('issue_header', '').strip() - if issue_header.lower() == 'possible bug': - issue_header = 'Possible Issue' # Make the header less frightening - issue_content = issue.get('issue_content', '').strip() - start_line = int(str(issue.get('start_line', 0)).strip()) - end_line = int(str(issue.get('end_line', 0)).strip()) - - relevant_lines_str = extract_relevant_lines_str(end_line, files, relevant_file, start_line, dedent=True) - if git_provider: - reference_link = git_provider.get_line_link(relevant_file, start_line, end_line) - else: - reference_link = None - - if gfm_supported: - if reference_link is not None and len(reference_link) > 0: - if relevant_lines_str: - issue_str = f"
{issue_header}\n\n{issue_content}\n\n{relevant_lines_str}\n\n
" - else: - issue_str = f"{issue_header}
{issue_content}" - else: - issue_str = f"{issue_header}
{issue_content}" - else: - if reference_link is not None and len(reference_link) > 0: - issue_str = f"[**{issue_header}**]({reference_link})\n\n{issue_content}\n\n" - else: - issue_str = f"**{issue_header}**\n\n{issue_content}\n\n" - markdown_text += f"{issue_str}\n\n" - except Exception as e: - get_logger().exception(f"Failed to process 'Recommended focus areas for review': {e}") - if gfm_supported: - markdown_text += f"
" - markdown_text += f"{emoji} {key_nice}: {value}" - markdown_text += f"
\n" - - return markdown_text - - -def extract_relevant_lines_str(end_line, files, relevant_file, start_line, dedent=False) -> str: - """ - Finds 'relevant_file' in 'files', and extracts the lines from 'start_line' to 'end_line' string from the file content. - """ - try: - relevant_lines_str = "" - if files: - files = set_file_languages(files) - for file in files: - if file.filename.strip() == relevant_file: - if not file.head_file: - # as a fallback, extract relevant lines directly from patch - patch = file.patch - get_logger().info(f"No content found in file: '{file.filename}' for 'extract_relevant_lines_str'. Using patch instead") - _, selected_lines = extract_hunk_lines_from_patch(patch, file.filename, start_line, end_line,side='right') - if not selected_lines: - get_logger().error(f"Failed to extract relevant lines from patch: {file.filename}") - return "" - # filter out '-' lines - relevant_lines_str = "" - for line in selected_lines.splitlines(): - if line.startswith('-'): - continue - relevant_lines_str += line[1:] + '\n' - else: - relevant_file_lines = file.head_file.splitlines() - relevant_lines_str = "\n".join(relevant_file_lines[start_line - 1:end_line]) - - if dedent and relevant_lines_str: - # Remove the longest leading string of spaces and tabs common to all lines. - relevant_lines_str = textwrap.dedent(relevant_lines_str) - relevant_lines_str = f"```{file.language}\n{relevant_lines_str}\n```" - break - - return relevant_lines_str - except Exception as e: - get_logger().exception(f"Failed to extract relevant lines: {e}") - return "" - - -def ticket_markdown_logic(emoji, markdown_text, value, gfm_supported) -> str: - ticket_compliance_str = "" - compliance_emoji = '' - # Track compliance levels across all tickets - all_compliance_levels = [] - - if isinstance(value, list): - for ticket_analysis in value: - try: - ticket_url = ticket_analysis.get('ticket_url', '').strip() - explanation = '' - ticket_compliance_level = '' # Individual ticket compliance - fully_compliant_str = ticket_analysis.get('fully_compliant_requirements', '').strip() - not_compliant_str = ticket_analysis.get('not_compliant_requirements', '').strip() - requires_further_human_verification = ticket_analysis.get('requires_further_human_verification', - '').strip() - - if not fully_compliant_str and not not_compliant_str: - get_logger().debug(f"Ticket compliance has no requirements", - artifact={'ticket_url': ticket_url}) - continue - - # Calculate individual ticket compliance level - if fully_compliant_str: - if not_compliant_str: - ticket_compliance_level = 'Partially compliant' - else: - if not requires_further_human_verification: - ticket_compliance_level = 'Fully compliant' - else: - ticket_compliance_level = 'PR Code Verified' - elif not_compliant_str: - ticket_compliance_level = 'Not compliant' - - # Store the compliance level for aggregation - if ticket_compliance_level: - all_compliance_levels.append(ticket_compliance_level) - - # build compliance string - if fully_compliant_str: - explanation += f"Compliant requirements:\n\n{fully_compliant_str}\n\n" - if not_compliant_str: - explanation += f"Non-compliant requirements:\n\n{not_compliant_str}\n\n" - if requires_further_human_verification: - explanation += f"Requires further human verification:\n\n{requires_further_human_verification}\n\n" - ticket_compliance_str += f"\n\n**[{ticket_url.split('/')[-1]}]({ticket_url}) - {ticket_compliance_level}**\n\n{explanation}\n\n" - - # for debugging - if requires_further_human_verification: - get_logger().debug(f"Ticket compliance requires further human verification", - artifact={'ticket_url': ticket_url, - 'requires_further_human_verification': requires_further_human_verification, - 'compliance_level': ticket_compliance_level}) - - except Exception as e: - get_logger().exception(f"Failed to process ticket compliance: {e}") - continue - - # Calculate overall compliance level and emoji - if all_compliance_levels: - if all(level == 'Fully compliant' for level in all_compliance_levels): - compliance_level = 'Fully compliant' - compliance_emoji = '✅' - elif all(level == 'PR Code Verified' for level in all_compliance_levels): - compliance_level = 'PR Code Verified' - compliance_emoji = '✅' - elif any(level == 'Not compliant' for level in all_compliance_levels): - # If there's a mix of compliant and non-compliant tickets - if any(level in ['Fully compliant', 'PR Code Verified'] for level in all_compliance_levels): - compliance_level = 'Partially compliant' - compliance_emoji = '🔶' - else: - compliance_level = 'Not compliant' - compliance_emoji = '❌' - elif any(level == 'Partially compliant' for level in all_compliance_levels): - compliance_level = 'Partially compliant' - compliance_emoji = '🔶' - else: - compliance_level = 'PR Code Verified' - compliance_emoji = '✅' - - # Set extra statistics outside the ticket loop - get_settings().set('config.extra_statistics', {'compliance_level': compliance_level}) - - # editing table row for ticket compliance analysis - if gfm_supported: - markdown_text += f"\n\n" - markdown_text += f"**{emoji} Ticket compliance analysis {compliance_emoji}**\n\n" - markdown_text += ticket_compliance_str - markdown_text += f"\n" - else: - markdown_text += f"### {emoji} Ticket compliance analysis {compliance_emoji}\n\n" - markdown_text += ticket_compliance_str + "\n\n" - - return markdown_text - - -def process_can_be_split(emoji, value): - try: - # key_nice = "Can this PR be split?" - key_nice = "Multiple PR themes" - markdown_text = "" - if not value or isinstance(value, list) and len(value) == 1: - value = "No" - # markdown_text += f" {emoji} {key_nice}\n\n{value}\n\n\n" - # markdown_text += f"### {emoji} No multiple PR themes\n\n" - markdown_text += f"{emoji} No multiple PR themes\n\n" - else: - markdown_text += f"{emoji} {key_nice}

\n\n" - for i, split in enumerate(value): - title = split.get('title', '') - relevant_files = split.get('relevant_files', []) - markdown_text += f"
\nSub-PR theme: {title}\n\n" - markdown_text += f"___\n\nRelevant files:\n\n" - for file in relevant_files: - markdown_text += f"- {file}\n" - markdown_text += f"___\n\n" - markdown_text += f"
\n\n" - - # markdown_text += f"#### Sub-PR theme: {title}\n\n" - # markdown_text += f"Relevant files:\n\n" - # for file in relevant_files: - # markdown_text += f"- {file}\n" - # markdown_text += "\n" - # number_of_splits = len(value) - # markdown_text += f" {emoji} {key_nice}\n" - # for i, split in enumerate(value): - # title = split.get('title', '') - # relevant_files = split.get('relevant_files', []) - # if i == 0: - # markdown_text += f"
\nSub-PR theme:
{title}
\n\n" - # markdown_text += f"
\n" - # markdown_text += f"Relevant files:\n" - # markdown_text += f"
    \n" - # for file in relevant_files: - # markdown_text += f"
  • {file}
  • \n" - # markdown_text += f"
\n\n
\n" - # else: - # markdown_text += f"\n
\nSub-PR theme:
{title}
\n\n" - # markdown_text += f"
\n" - # markdown_text += f"Relevant files:\n" - # markdown_text += f"
    \n" - # for file in relevant_files: - # markdown_text += f"
  • {file}
  • \n" - # markdown_text += f"
\n\n
\n" - except Exception as e: - get_logger().exception(f"Failed to process can be split: {e}") - return "" - return markdown_text - - -def parse_code_suggestion(code_suggestion: dict, i: int = 0, gfm_supported: bool = True) -> str: - """ - Convert a dictionary of data into markdown format. - - Args: - code_suggestion (dict): A dictionary containing data to be converted to markdown format. - - Returns: - str: A string containing the markdown formatted text generated from the input dictionary. - """ - markdown_text = "" - if gfm_supported and 'relevant_line' in code_suggestion: - markdown_text += '' - for sub_key, sub_value in code_suggestion.items(): - try: - if sub_key.lower() == 'relevant_file': - relevant_file = sub_value.strip('`').strip('"').strip("'") - markdown_text += f"" - # continue - elif sub_key.lower() == 'suggestion': - markdown_text += (f"" - f"") - elif sub_key.lower() == 'relevant_line': - markdown_text += f"" - sub_value_list = sub_value.split('](') - relevant_line = sub_value_list[0].lstrip('`').lstrip('[') - if len(sub_value_list) > 1: - link = sub_value_list[1].rstrip(')').strip('`') - markdown_text += f"" - else: - markdown_text += f"" - markdown_text += "" - except Exception as e: - get_logger().exception(f"Failed to parse code suggestion: {e}") - pass - markdown_text += '
relevant file{relevant_file}
{sub_key}      \n\n\n\n{sub_value.strip()}\n\n\n
relevant line{relevant_line}{relevant_line}
' - markdown_text += "
" - else: - for sub_key, sub_value in code_suggestion.items(): - if isinstance(sub_key, str): - sub_key = sub_key.rstrip() - if isinstance(sub_value,str): - sub_value = sub_value.rstrip() - if isinstance(sub_value, dict): # "code example" - markdown_text += f" - **{sub_key}:**\n" - for code_key, code_value in sub_value.items(): # 'before' and 'after' code - code_str = f"```\n{code_value}\n```" - code_str_indented = textwrap.indent(code_str, ' ') - markdown_text += f" - **{code_key}:**\n{code_str_indented}\n" - else: - if "relevant_file" in sub_key.lower(): - markdown_text += f"\n - **{sub_key}:** {sub_value} \n" - else: - markdown_text += f" **{sub_key}:** {sub_value} \n" - if "relevant_line" not in sub_key.lower(): # nicer presentation - # markdown_text = markdown_text.rstrip('\n') + "\\\n" # works for gitlab - markdown_text = markdown_text.rstrip('\n') + " \n" # works for gitlab and bitbucker - - markdown_text += "\n" - return markdown_text - - -def try_fix_json(review, max_iter=10, code_suggestions=False): - """ - Fix broken or incomplete JSON messages and return the parsed JSON data. - - Args: - - review: A string containing the JSON message to be fixed. - - max_iter: An integer representing the maximum number of iterations to try and fix the JSON message. - - code_suggestions: A boolean indicating whether to try and fix JSON messages with code feedback. - - Returns: - - data: A dictionary containing the parsed JSON data. - - The function attempts to fix broken or incomplete JSON messages by parsing until the last valid code suggestion. - If the JSON message ends with a closing bracket, the function calls the fix_json_escape_char function to fix the - message. - If code_suggestions is True and the JSON message contains code feedback, the function tries to fix the JSON - message by parsing until the last valid code suggestion. - The function uses regular expressions to find the last occurrence of "}," with any number of whitespaces or - newlines. - It tries to parse the JSON message with the closing bracket and checks if it is valid. - If the JSON message is valid, the parsed JSON data is returned. - If the JSON message is not valid, the last code suggestion is removed and the process is repeated until a valid JSON - message is obtained or the maximum number of iterations is reached. - If a valid JSON message is not obtained, an error is logged and an empty dictionary is returned. - """ - - if review.endswith("}"): - return fix_json_escape_char(review) - - data = {} - if code_suggestions: - closing_bracket = "]}" - else: - closing_bracket = "]}}" - - if (review.rfind("'Code feedback': [") > 0 or review.rfind('"Code feedback": [') > 0) or \ - (review.rfind("'Code suggestions': [") > 0 or review.rfind('"Code suggestions": [') > 0) : - last_code_suggestion_ind = [m.end() for m in re.finditer(r"\}\s*,", review)][-1] - 1 - valid_json = False - iter_count = 0 - - while last_code_suggestion_ind > 0 and not valid_json and iter_count < max_iter: - try: - data = json.loads(review[:last_code_suggestion_ind] + closing_bracket) - valid_json = True - review = review[:last_code_suggestion_ind].strip() + closing_bracket - except json.decoder.JSONDecodeError: - review = review[:last_code_suggestion_ind] - last_code_suggestion_ind = [m.end() for m in re.finditer(r"\}\s*,", review)][-1] - 1 - iter_count += 1 - - if not valid_json: - get_logger().error("Unable to decode JSON response from AI") - data = {} - - return data - - -def fix_json_escape_char(json_message=None): - """ - Fix broken or incomplete JSON messages and return the parsed JSON data. - - Args: - json_message (str): A string containing the JSON message to be fixed. - - Returns: - dict: A dictionary containing the parsed JSON data. - - Raises: - None - - """ - try: - result = json.loads(json_message) - except Exception as e: - # Find the offending character index: - idx_to_replace = int(str(e).split(' ')[-1].replace(')', '')) - # Remove the offending character: - json_message = list(json_message) - json_message[idx_to_replace] = ' ' - new_message = ''.join(json_message) - return fix_json_escape_char(json_message=new_message) - return result - - -def convert_str_to_datetime(date_str): - """ - Convert a string representation of a date and time into a datetime object. - - Args: - date_str (str): A string representation of a date and time in the format '%a, %d %b %Y %H:%M:%S %Z' - - Returns: - datetime: A datetime object representing the input date and time. - - Example: - >>> convert_str_to_datetime('Mon, 01 Jan 2022 12:00:00 UTC') - datetime.datetime(2022, 1, 1, 12, 0, 0) - """ - datetime_format = '%a, %d %b %Y %H:%M:%S %Z' - return datetime.strptime(date_str, datetime_format) - - -def load_large_diff(filename, new_file_content_str: str, original_file_content_str: str, show_warning: bool = True) -> str: - """ - Generate a patch for a modified file by comparing the original content of the file with the new content provided as - input. - """ - if not original_file_content_str and not new_file_content_str: - return "" - - try: - original_file_content_str = (original_file_content_str or "").rstrip() + "\n" - new_file_content_str = (new_file_content_str or "").rstrip() + "\n" - diff = difflib.unified_diff(original_file_content_str.splitlines(keepends=True), - new_file_content_str.splitlines(keepends=True)) - if get_settings().config.verbosity_level >= 2 and show_warning: - get_logger().info(f"File was modified, but no patch was found. Manually creating patch: {filename}.") - patch = ''.join(diff) - return patch - except Exception as e: - get_logger().exception(f"Failed to generate patch for file: {filename}") - return "" - - -def update_settings_from_args(args: List[str]) -> List[str]: - """ - Update the settings of the Dynaconf object based on the arguments passed to the function. - - Args: - args: A list of arguments passed to the function. - Example args: ['--pr_code_suggestions.extra_instructions="be funny', - '--pr_code_suggestions.num_code_suggestions=3'] - - Returns: - None - - Raises: - ValueError: If the argument is not in the correct format. - - """ - other_args = [] - if args: - for arg in args: - arg = arg.strip() - if arg.startswith('--'): - arg = arg.strip('-').strip() - vals = arg.split('=', 1) - if len(vals) != 2: - if len(vals) > 2: # --extended is a valid argument - get_logger().error(f'Invalid argument format: {arg}') - other_args.append(arg) - continue - key, value = _fix_key_value(*vals) - get_settings().set(key, value) - get_logger().info(f'Updated setting {key} to: "{value}"') - else: - other_args.append(arg) - return other_args - - -def _fix_key_value(key: str, value: str): - key = key.strip().upper() - value = value.strip() - try: - value = yaml.safe_load(value) - except Exception as e: - get_logger().debug(f"Failed to parse YAML for config override {key}={value}", exc_info=e) - return key, value - - -def load_yaml(response_text: str, keys_fix_yaml: List[str] = [], first_key="", last_key="") -> dict: - response_text = response_text.strip('\n').removeprefix('```yaml').rstrip().removesuffix('```') - try: - data = yaml.safe_load(response_text) - except Exception as e: - get_logger().warning(f"Initial failure to parse AI prediction: {e}") - data = try_fix_yaml(response_text, keys_fix_yaml=keys_fix_yaml, first_key=first_key, last_key=last_key) - if not data: - get_logger().error(f"Failed to parse AI prediction after fallbacks", - artifact={'response_text': response_text}) - else: - get_logger().info(f"Successfully parsed AI prediction after fallbacks", - artifact={'response_text': response_text}) - return data - - - -def try_fix_yaml(response_text: str, - keys_fix_yaml: List[str] = [], - first_key="", - last_key="",) -> dict: - response_text_lines = response_text.split('\n') - - keys_yaml = ['relevant line:', 'suggestion content:', 'relevant file:', 'existing code:', 'improved code:'] - keys_yaml = keys_yaml + keys_fix_yaml - # first fallback - try to convert 'relevant line: ...' to relevant line: |-\n ...' - response_text_lines_copy = response_text_lines.copy() - for i in range(0, len(response_text_lines_copy)): - for key in keys_yaml: - if key in response_text_lines_copy[i] and not '|' in response_text_lines_copy[i]: - response_text_lines_copy[i] = response_text_lines_copy[i].replace(f'{key}', - f'{key} |\n ') - try: - data = yaml.safe_load('\n'.join(response_text_lines_copy)) - get_logger().info(f"Successfully parsed AI prediction after adding |-\n") - return data - except: - pass - - # second fallback - try to extract only range from first ```yaml to ```` - snippet_pattern = r'```(yaml)?[\s\S]*?```' - snippet = re.search(snippet_pattern, '\n'.join(response_text_lines_copy)) - if snippet: - snippet_text = snippet.group() - try: - data = yaml.safe_load(snippet_text.removeprefix('```yaml').rstrip('`')) - get_logger().info(f"Successfully parsed AI prediction after extracting yaml snippet") - return data - except: - pass - - - # third fallback - try to remove leading and trailing curly brackets - response_text_copy = response_text.strip().rstrip().removeprefix('{').removesuffix('}').rstrip(':\n') - try: - data = yaml.safe_load(response_text_copy) - get_logger().info(f"Successfully parsed AI prediction after removing curly brackets") - return data - except: - pass - - - # forth fallback - try to extract yaml snippet by 'first_key' and 'last_key' - # note that 'last_key' can be in practice a key that is not the last key in the yaml snippet. - # it just needs to be some inner key, so we can look for newlines after it - if first_key and last_key: - index_start = response_text.find(f"\n{first_key}:") - if index_start == -1: - index_start = response_text.find(f"{first_key}:") - index_last_code = response_text.rfind(f"{last_key}:") - index_end = response_text.find("\n\n", index_last_code) # look for newlines after last_key - if index_end == -1: - index_end = len(response_text) - response_text_copy = response_text[index_start:index_end].strip().strip('```yaml').strip('`').strip() - try: - data = yaml.safe_load(response_text_copy) - get_logger().info(f"Successfully parsed AI prediction after extracting yaml snippet") - return data - except: - pass - - # fifth fallback - try to remove leading '+' (sometimes added by AI for 'existing code' and 'improved code') - response_text_lines_copy = response_text_lines.copy() - for i in range(0, len(response_text_lines_copy)): - response_text_lines_copy[i] = ' ' + response_text_lines_copy[i][1:] - try: - data = yaml.safe_load('\n'.join(response_text_lines_copy)) - get_logger().info(f"Successfully parsed AI prediction after removing leading '+'") - return data - except: - pass - - # sixth fallback - try to remove last lines - for i in range(1, len(response_text_lines)): - response_text_lines_tmp = '\n'.join(response_text_lines[:-i]) - try: - data = yaml.safe_load(response_text_lines_tmp) - get_logger().info(f"Successfully parsed AI prediction after removing {i} lines") - return data - except: - pass - - -def set_custom_labels(variables, git_provider=None): - if not get_settings().config.enable_custom_labels: - return - - labels = get_settings().get('custom_labels', {}) - if not labels: - # set default labels - labels = ['Bug fix', 'Tests', 'Bug fix with tests', 'Enhancement', 'Documentation', 'Other'] - labels_list = "\n - ".join(labels) if labels else "" - labels_list = f" - {labels_list}" if labels_list else "" - variables["custom_labels"] = labels_list - return - - # Set custom labels - variables["custom_labels_class"] = "class Label(str, Enum):" - counter = 0 - labels_minimal_to_labels_dict = {} - for k, v in labels.items(): - description = "'" + v['description'].strip('\n').replace('\n', '\\n') + "'" - # variables["custom_labels_class"] += f"\n {k.lower().replace(' ', '_')} = '{k}' # {description}" - variables["custom_labels_class"] += f"\n {k.lower().replace(' ', '_')} = {description}" - labels_minimal_to_labels_dict[k.lower().replace(' ', '_')] = k - counter += 1 - variables["labels_minimal_to_labels_dict"] = labels_minimal_to_labels_dict - -def get_user_labels(current_labels: List[str] = None): - """ - Only keep labels that has been added by the user - """ - try: - enable_custom_labels = get_settings().config.get('enable_custom_labels', False) - custom_labels = get_settings().get('custom_labels', []) - if current_labels is None: - current_labels = [] - user_labels = [] - for label in current_labels: - if label.lower() in ['bug fix', 'tests', 'enhancement', 'documentation', 'other']: - continue - if enable_custom_labels: - if label in custom_labels: - continue - user_labels.append(label) - if user_labels: - get_logger().debug(f"Keeping user labels: {user_labels}") - except Exception as e: - get_logger().exception(f"Failed to get user labels: {e}") - return current_labels - return user_labels - - -def get_max_tokens(model): - """ - Get the maximum number of tokens allowed for a model. - logic: - (1) If the model is in './pr_agent/algo/__init__.py', use the value from there. - (2) else, the user needs to define explicitly 'config.custom_model_max_tokens' - - For both cases, we further limit the number of tokens to 'config.max_model_tokens' if it is set. - This aims to improve the algorithmic quality, as the AI model degrades in performance when the input is too long. - """ - settings = get_settings() - if model in MAX_TOKENS: - max_tokens_model = MAX_TOKENS[model] - elif settings.config.custom_model_max_tokens > 0: - max_tokens_model = settings.config.custom_model_max_tokens - else: - raise Exception(f"Ensure {model} is defined in MAX_TOKENS in ./pr_agent/algo/__init__.py or set a positive value for it in config.custom_model_max_tokens") - - if settings.config.max_model_tokens and settings.config.max_model_tokens > 0: - max_tokens_model = min(settings.config.max_model_tokens, max_tokens_model) - return max_tokens_model - - -def clip_tokens(text: str, max_tokens: int, add_three_dots=True, num_input_tokens=None, delete_last_line=False) -> str: - """ - Clip the number of tokens in a string to a maximum number of tokens. - - Args: - text (str): The string to clip. - max_tokens (int): The maximum number of tokens allowed in the string. - add_three_dots (bool, optional): A boolean indicating whether to add three dots at the end of the clipped - Returns: - str: The clipped string. - """ - if not text: - return text - - try: - if num_input_tokens is None: - encoder = TokenEncoder.get_token_encoder() - num_input_tokens = len(encoder.encode(text)) - if num_input_tokens <= max_tokens: - return text - if max_tokens < 0: - return "" - - # calculate the number of characters to keep - num_chars = len(text) - chars_per_token = num_chars / num_input_tokens - factor = 0.9 # reduce by 10% to be safe - num_output_chars = int(factor * chars_per_token * max_tokens) - - # clip the text - if num_output_chars > 0: - clipped_text = text[:num_output_chars] - if delete_last_line: - clipped_text = clipped_text.rsplit('\n', 1)[0] - if add_three_dots: - clipped_text += "\n...(truncated)" - else: # if the text is empty - clipped_text = "" - - return clipped_text - except Exception as e: - get_logger().warning(f"Failed to clip tokens: {e}") - return text - -def replace_code_tags(text): - """ - Replace odd instances of ` with and even instances of ` with - """ - text = html.escape(text) - parts = text.split('`') - for i in range(1, len(parts), 2): - parts[i] = '' + parts[i] + '' - return ''.join(parts) - - -def find_line_number_of_relevant_line_in_file(diff_files: List[FilePatchInfo], - relevant_file: str, - relevant_line_in_file: str, - absolute_position: int = None) -> Tuple[int, int]: - position = -1 - if absolute_position is None: - absolute_position = -1 - re_hunk_header = re.compile( - r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)") - - if not diff_files: - return position, absolute_position - - for file in diff_files: - if file.filename and (file.filename.strip() == relevant_file): - patch = file.patch - patch_lines = patch.splitlines() - delta = 0 - start1, size1, start2, size2 = 0, 0, 0, 0 - if absolute_position != -1: # matching absolute to relative - for i, line in enumerate(patch_lines): - # new hunk - if line.startswith('@@'): - delta = 0 - match = re_hunk_header.match(line) - start1, size1, start2, size2 = map(int, match.groups()[:4]) - elif not line.startswith('-'): - delta += 1 - - # - absolute_position_curr = start2 + delta - 1 - - if absolute_position_curr == absolute_position: - position = i - break - else: - # try to find the line in the patch using difflib, with some margin of error - matches_difflib: list[str | Any] = difflib.get_close_matches(relevant_line_in_file, - patch_lines, n=3, cutoff=0.93) - if len(matches_difflib) == 1 and matches_difflib[0].startswith('+'): - relevant_line_in_file = matches_difflib[0] - - - for i, line in enumerate(patch_lines): - if line.startswith('@@'): - delta = 0 - match = re_hunk_header.match(line) - start1, size1, start2, size2 = map(int, match.groups()[:4]) - elif not line.startswith('-'): - delta += 1 - - if relevant_line_in_file in line and line[0] != '-': - position = i - absolute_position = start2 + delta - 1 - break - - if position == -1 and relevant_line_in_file[0] == '+': - no_plus_line = relevant_line_in_file[1:].lstrip() - for i, line in enumerate(patch_lines): - if line.startswith('@@'): - delta = 0 - match = re_hunk_header.match(line) - start1, size1, start2, size2 = map(int, match.groups()[:4]) - elif not line.startswith('-'): - delta += 1 - - if no_plus_line in line and line[0] != '-': - # The model might add a '+' to the beginning of the relevant_line_in_file even if originally - # it's a context line - position = i - absolute_position = start2 + delta - 1 - break - return position, absolute_position - -def get_rate_limit_status(github_token) -> dict: - GITHUB_API_URL = get_settings(use_context=False).get("GITHUB.BASE_URL", "https://api.github.com").rstrip("/") # "https://api.github.com" - # GITHUB_API_URL = "https://api.github.com" - RATE_LIMIT_URL = f"{GITHUB_API_URL}/rate_limit" - HEADERS = { - "Accept": "application/vnd.github.v3+json", - "Authorization": f"token {github_token}" - } - - response = requests.get(RATE_LIMIT_URL, headers=HEADERS) - try: - rate_limit_info = response.json() - if rate_limit_info.get('message') == 'Rate limiting is not enabled.': # for github enterprise - return {'resources': {}} - response.raise_for_status() # Check for HTTP errors - except: # retry - time.sleep(0.1) - response = requests.get(RATE_LIMIT_URL, headers=HEADERS) - return response.json() - return rate_limit_info - - -def validate_rate_limit_github(github_token, installation_id=None, threshold=0.1) -> bool: - try: - rate_limit_status = get_rate_limit_status(github_token) - if installation_id: - get_logger().debug(f"installation_id: {installation_id}, Rate limit status: {rate_limit_status['rate']}") - # validate that the rate limit is not exceeded - # validate that the rate limit is not exceeded - for key, value in rate_limit_status['resources'].items(): - if value['remaining'] < value['limit'] * threshold: - get_logger().error(f"key: {key}, value: {value}") - return False - return True - except Exception as e: - get_logger().error(f"Error in rate limit {e}", - artifact={"traceback": traceback.format_exc()}) - return True - - -def validate_and_await_rate_limit(github_token): - try: - rate_limit_status = get_rate_limit_status(github_token) - # validate that the rate limit is not exceeded - for key, value in rate_limit_status['resources'].items(): - if value['remaining'] < value['limit'] // 80: - get_logger().error(f"key: {key}, value: {value}") - sleep_time_sec = value['reset'] - datetime.now().timestamp() - sleep_time_hour = sleep_time_sec / 3600.0 - get_logger().error(f"Rate limit exceeded. Sleeping for {sleep_time_hour} hours") - if sleep_time_sec > 0: - time.sleep(sleep_time_sec + 1) - rate_limit_status = get_rate_limit_status(github_token) - return rate_limit_status - except: - get_logger().error("Error in rate limit") - return None - - -def github_action_output(output_data: dict, key_name: str): - try: - if not get_settings().get('github_action_config.enable_output', False): - return - - key_data = output_data.get(key_name, {}) - with open(os.environ['GITHUB_OUTPUT'], 'a') as fh: - print(f"{key_name}={json.dumps(key_data, indent=None, ensure_ascii=False)}", file=fh) - except Exception as e: - get_logger().error(f"Failed to write to GitHub Action output: {e}") - return - - -def show_relevant_configurations(relevant_section: str) -> str: - skip_keys = ['ai_disclaimer', 'ai_disclaimer_title', 'ANALYTICS_FOLDER', 'secret_provider', "skip_keys", "app_id", "redirect", - 'trial_prefix_message', 'no_eligible_message', 'identity_provider', 'ALLOWED_REPOS','APP_NAME'] extra_skip_keys = get_settings().config.get('config.skip_keys', []) if extra_skip_keys: skip_keys.extend(extra_skip_keys) diff --git a/ai_review/analyzers/ast_parser.py b/ai_review/analyzers/ast_parser.py index 8771874..a7a1703 100644 --- a/ai_review/analyzers/ast_parser.py +++ b/ai_review/analyzers/ast_parser.py @@ -1,11 +1,21 @@ import ast +from typing import List, Dict, Any class ASTAnalyzer(ast.NodeVisitor): + """AST分析器,用于检测代码质量问题""" + def __init__(self): - self.findings = [] - self.function_complexity = {} - - def visit_FunctionDef(self, node): + self.findings: List[str] = [] + self.function_complexity: Dict[str, int] = {} + self.references: List[ast.Name] = [] + + def visit_FunctionDef(self, node: ast.FunctionDef) -> None: + """ + 分析函数定义,计算复杂度 + + Args: + node: 函数定义节点 + """ # 计算圈复杂度 complexity = len(node.body) # 计算函数复杂度 @@ -14,17 +24,56 @@ def visit_FunctionDef(self, node): self.function_complexity[node.name] = complexity self.generic_visit(node) - def visit_Name(self, node): - # 检测未使用变量 + def calculate_function_complexity(self, node: ast.FunctionDef) -> int: + """ + 计算函数的圈复杂度 + + Args: + node: 函数定义节点 + Returns: + int: 函数的圈复杂度 + """ + complexity = 1 # 基础复杂度 + + for child in ast.walk(node): + if isinstance(child, (ast.If, ast.While, ast.For, ast.ExceptHandler)): + complexity += 1 + elif isinstance(child, ast.BoolOp): + complexity += len(child.values) - 1 + + return complexity + + def visit_Name(self, node: ast.Name) -> None: + """ + 检查变量使用情况 + + Args: + node: 变量节点 + """ if isinstance(node.ctx, ast.Store): - if not any(ref.id == node.id for ref in self.references): - self.findings.append(f"未使用变量: {node.id}") + self._check_unused_variable(node) + + def visit_Assign(self, node: ast.Assign) -> None: + """ + 检查赋值语句 + + Args: + node: 赋值节点 + """ + for target in node.targets: + if isinstance(target, ast.Name) and isinstance(target.ctx, ast.Store): + self._check_unused_variable(target) + self.generic_visit(node) + + def _check_unused_variable(self, node: ast.Name) -> None: + """ + 检查变量是否未使用 + + Args: + node: 变量节点 + """ + if not any(ref.id == node.id for ref in self.references): + self.findings.append(f"未使用变量: {node.id}") - def visit_Assign(self, node): - # 检测未使用变量 - if isinstance(node.targets[0].ctx, ast.Store): - if not any(ref.id == node.targets[0].id for ref in self.references): - self.findings.append(f"未使用变量: {node.targets[0].id}") - def visit_Assert(self, node: ast.Assert) -> ast.Any: return super().visit_Assert(node) \ No newline at end of file diff --git a/ai_review/core/analyzer.py b/ai_review/core/analyzer.py index e7f2bdb..98e1e5b 100644 --- a/ai_review/core/analyzer.py +++ b/ai_review/core/analyzer.py @@ -1,64 +1,117 @@ -from typing import Dict, List +from typing import Dict, List, Callable, Optional, Any +from dataclasses import dataclass from .parser import ASTParser - import importlib -from typing import Dict, List, Callable +import logging +from pathlib import Path + +@dataclass +class Finding: + """Represents a code analysis finding.""" + rule: str + message: str + severity: str + line: int + code_snippet: Optional[str] = None class CodeAnalyzer: + """A code analyzer that applies multiple rules to analyze code quality.""" + def __init__(self, rule_modules: List[str]): - """Initialize the analyzer with rule modules.""" + """Initialize the analyzer with rule modules. + + Args: + rule_modules: List of module names containing rule functions. + + Raises: + ValueError: If no rule modules are specified. + ImportError: If a rule module cannot be imported. + RuntimeError: If there's an error loading rules from a module. + """ if not rule_modules: raise ValueError("No rule modules specified") + + self.logger = logging.getLogger(__name__) self.rules = self._load_rules(rule_modules) def _load_rules(self, modules: List[str]) -> Dict[str, Callable]: - """动态加载规则检测器""" - rules = {} + """Dynamically load rule functions from specified modules. + + Args: + modules: List of module names to load rules from. + + Returns: + Dictionary mapping rule names to their corresponding functions. + + Raises: + ImportError: If a module cannot be imported. + RuntimeError: If there's an error loading rules from a module. + """ + rules: Dict[str, Callable] = {} + for module_name in modules: try: module = importlib.import_module(module_name) for rule_name in dir(module): if rule_name.startswith('_'): continue + rule = getattr(module, rule_name) if callable(rule): rules[rule_name] = rule + except ImportError as e: + self.logger.error(f"Failed to import rule module {module_name}: {e}") raise ImportError(f"Failed to import rule module {module_name}: {e}") except Exception as e: + self.logger.error(f"Error loading rules from {module_name}: {e}") raise RuntimeError(f"Error loading rules from {module_name}: {e}") + return rules - def analyze(self, code: str) -> List[dict]: - """执行多维度代码审查""" + + def analyze(self, code: str) -> List[Finding]: + """Perform multi-dimensional code analysis. + + Args: + code: The source code to analyze. + + Returns: + List of Finding objects containing analysis results. + """ try: ast_parser = ASTParser(code) except ValueError as e: - return [{ - 'rule': 'syntax_check', - 'message': str(e), - 'severity': 'critical', - 'line': 0 - }] + return [Finding( + rule='syntax_check', + message=str(e), + severity='critical', + line=0 + )] - findings = [] + findings: List[Finding] = [] + code_lines = ast_parser.raw_code.splitlines() - # 执行静态规则检查 + # Execute static rule checks for rule_name, check_func in self.rules.items(): try: if issues := check_func(ast_parser.tree): - findings.extend({ - 'rule': rule_name, - 'message': issue['msg'], - 'severity': issue['level'], - 'line': issue['lineno'], - 'code_snippet': ast_parser.raw_code.splitlines()[issue['lineno']-1] - } for issue in issues) + findings.extend( + Finding( + rule=rule_name, + message=issue['msg'], + severity=issue['level'], + line=issue['lineno'], + code_snippet=code_lines[issue['lineno']-1] if 0 <= issue['lineno']-1 < len(code_lines) else None + ) + for issue in issues + ) except Exception as e: - findings.append({ - 'rule': rule_name, - 'message': f"Rule execution failed: {e}", - 'severity': 'error', - 'line': 0 - }) + self.logger.error(f"Rule {rule_name} execution failed: {e}") + findings.append(Finding( + rule=rule_name, + message=f"Rule execution failed: {e}", + severity='error', + line=0 + )) return findings \ No newline at end of file