Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 96 additions & 51 deletions ai_review/algo/language_handler.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,53 @@
# Language Selection, source: https://github.com/bigcode-project/bigcode-dataset/blob/main/language_selection/programming-languages-to-file-extensions.json # noqa E501
from typing import Dict
from typing import Dict, List, TypeVar, Optional, Set
from dataclasses import dataclass
from collections import defaultdict
from pathlib import Path

from ai_review.config_loader import get_settings

FileType = TypeVar('FileType') # 为文件对象创建类型变量

def filter_bad_extensions(files):
# Bad Extensions, source: https://github.com/EleutherAI/github-downloader/blob/345e7c4cbb9e0dc8a0615fd995a08bf9d73b3fe6/download_repo_text.py # noqa: E501
bad_extensions = get_settings().bad_extensions.default
if get_settings().config.use_extra_bad_extensions:
bad_extensions += get_settings().bad_extensions.extra
return [f for f in files if f.filename is not None and is_valid_file(f.filename, bad_extensions)]
@dataclass
class LanguageGroup:
language: str
files: List[FileType]

def get_file_extension(filename: str) -> Optional[str]:
"""获取文件扩展名

Args:
filename: 文件名
Returns:
文件扩展名(包含点),如果没有扩展名则返回None
"""
if not filename:
return None
ext = Path(filename).suffix
return ext.lower() if ext else None

def get_bad_extensions() -> Set[str]:
"""获取需要过滤的文件扩展名集合

Returns:
需要过滤的文件扩展名集合
"""
settings = get_settings()
bad_extensions = set(settings.bad_extensions.default)
if settings.config.use_extra_bad_extensions:
bad_extensions.update(settings.bad_extensions.extra)
return bad_extensions

def filter_bad_extensions(files: List[FileType]) -> List[FileType]:
"""过滤掉不需要的文件扩展名

Args:
files: 需要过滤的文件列表
Returns:
过滤后的文件列表
"""
bad_extensions = get_bad_extensions()
return [f for f in files if f.filename and get_file_extension(f.filename).strip('.') not in bad_extensions]


def is_valid_file(filename:str, bad_extensions=None) -> bool:
Expand All @@ -22,49 +60,56 @@ def is_valid_file(filename:str, bad_extensions=None) -> bool:
return filename.split('.')[-1] not in bad_extensions


def sort_files_by_main_languages(languages: Dict, files: list):
"""
Sort files by their main language, put the files that are in the main language first and the rest files after
def sort_files_by_main_languages(languages: Dict[str, int], files: List[FileType]) -> List[LanguageGroup]:
"""按主要语言对文件进行分类

Args:
languages: 语言使用统计字典
files: 需要分类的文件列表
Returns:
按语言分组的文件列表
"""
# sort languages by their size
languages_sorted_list = [k for k, v in sorted(languages.items(), key=lambda item: item[1], reverse=True)]
# languages_sorted = sorted(languages, key=lambda x: x[1], reverse=True)
# get all extensions for the languages
main_extensions = []
language_extension_map_org = get_settings().language_extension_map_org
language_extension_map = {k.lower(): v for k, v in language_extension_map_org.items()}
for language in languages_sorted_list:
if language.lower() in language_extension_map:
main_extensions.append(language_extension_map[language.lower()])
else:
main_extensions.append([])

# filter out files bad extensions
files_filtered = filter_bad_extensions(files)
# sort files by their extension, put the files that are in the main extension first
# and the rest files after, map languages_sorted to their respective files
files_sorted = []
rest_files = {}

# if no languages detected, put all files in the "Other" category
if not languages:
files_sorted = [({"language": "Other", "files": list(files_filtered)})]
return files_sorted

main_extensions_flat = []
for ext in main_extensions:
main_extensions_flat.extend(ext)

for extensions, lang in zip(main_extensions, languages_sorted_list): # noqa: B905
tmp = []
for file in files_filtered:
extension_str = f".{file.filename.split('.')[-1]}"
if extension_str in extensions:
tmp.append(file)
else:
if (file.filename not in rest_files) and (extension_str not in main_extensions_flat):
rest_files[file.filename] = file
if len(tmp) > 0:
files_sorted.append({"language": lang, "files": tmp})
files_sorted.append({"language": "Other", "files": list(rest_files.values())})
return files_sorted
return [LanguageGroup(language="Other", files=filter_bad_extensions(files))]

settings = get_settings()
language_map = {k.lower(): set(v) for k, v in settings.language_extension_map_org.items()}

# 预处理语言映射
ext_to_lang = {}
all_main_extensions = set()
for lang, count in sorted(languages.items(), key=lambda x: x[1], reverse=True):
if lang.lower() in language_map:
for ext in language_map[lang.lower()]:
ext = ext.lower()
ext_to_lang[ext] = lang
all_main_extensions.add(ext)

# 文件分类
lang_groups = defaultdict(list)
other_files = []

filtered_files = filter_bad_extensions(files)
for file in filtered_files:
if not file.filename:
continue
ext = get_file_extension(file.filename)
if ext in ext_to_lang:
lang_groups[ext_to_lang[ext]].append(file)
elif ext not in all_main_extensions:
other_files.append(file)

# 构建结果
result = [
LanguageGroup(language=lang, files=files)
for lang, files in sorted(
lang_groups.items(),
key=lambda x: languages.get(x[0], 0),
reverse=True
)
]

if other_files:
result.append(LanguageGroup(language="Other", files=other_files))

return result
60 changes: 29 additions & 31 deletions ai_review/algo/token_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,56 +37,54 @@ class TokenHandler:
method.
"""

def __init__(self, pr=None, vars: dict = None, system="", user=""):
if vars is None:
vars = {}
...
def __init__(self, pr=None, vars: dict = None, system: str = "", user: str = ""):
"""
Initializes the TokenHandler object.

初始化 TokenHandler 对象
Args:
- pr: The pull request object.
- vars: A dictionary of variables.
- system: The system string.
- user: The user string.
pr: Pull Request 对象
vars: 变量字典,默认为空字典
system: 系统提示字符串
user: 用户提示字符串
"""
self.vars = vars or {} # 使用更简洁的空字典初始化
self.encoder = TokenEncoder.get_token_encoder()
if pr is not None:
self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user)
self.prompt_tokens = (self._get_system_user_tokens(pr, self.encoder, self.vars, system, user)
if pr is not None else 0)

def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user):
def _get_system_user_tokens(self, pr, encoder, vars: dict, system: str, user: str) -> int:
"""
Calculates the number of tokens in the system and user strings.

计算系统和用户字符串中的令牌数
Args:
- pr: The pull request object.
- encoder: An object of the encoding_for_model class from the tiktoken module.
- vars: A dictionary of variables.
- system: The system string.
- user: The user string.

pr: Pull Request 对象
encoder: tiktoken 编码器实例
vars: 变量字典
system: 系统提示字符串
user: 用户提示字符串
Returns:
The sum of the number of tokens in the system and user strings.
int: 系统和用户字符串中的总令牌数
"""
try:
environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(system).render(vars)
user_prompt = environment.from_string(user).render(vars)
system_prompt_tokens = len(encoder.encode(system_prompt))
user_prompt_tokens = len(encoder.encode(user_prompt))
return system_prompt_tokens + user_prompt_tokens

return len(encoder.encode(system_prompt)) + len(encoder.encode(user_prompt))

except Exception as e:
get_logger().error(f"Error in _get_system_user_tokens: {e}")
get_logger().error(f"Error in _get_system_user_tokens: {str(e)}")
return 0

def count_tokens(self, patch: str) -> int:
"""
Counts the number of tokens in a given patch string.

计算给定补丁字符串中的令牌数
Args:
- patch: The patch string.

patch: 补丁字符串
Returns:
The number of tokens in the patch string.
int: 补丁字符串中的令牌数
"""
return len(self.encoder.encode(patch, disallowed_special=()))
Loading