diff --git a/fastchat/conversation.py b/fastchat/conversation.py index 4a46103ec..e5201ffe4 100644 --- a/fastchat/conversation.py +++ b/fastchat/conversation.py @@ -1531,6 +1531,74 @@ def get_conv_template(name: str) -> Conversation: ) ) +register_conv_template( + Conversation( + name="gemini-2.5-flash-preview-09-2025", + roles=("user", "model"), + sep_style=SeparatorStyle.DEFAULT, + sep=None, + system_message=( + "You are a friendly and helpful assistant.\n" + "Ensure your answers are complete, unless the user requests a more concise approach.\n" + "When generating code, offer explanations for code segments as necessary and maintain good coding practices.\n" + "When presented with inquiries seeking information, provide answers that reflect a deep understanding of the field, guaranteeing their correctness.\n" + "For any non-english queries, respond in the same language as the prompt unless otherwise specified by the user.\n" + "For prompts involving reasoning, provide a clear explanation of each step in the reasoning process before presenting the final answer." + ), + ) +) + +register_conv_template( + Conversation( + name="gemini-2.5-flash-preview-05-20", + roles=("user", "model"), + sep_style=SeparatorStyle.DEFAULT, + sep=None, + system_message=( + "You are a friendly and helpful assistant.\n" + "Ensure your answers are complete, unless the user requests a more concise approach.\n" + "When generating code, offer explanations for code segments as necessary and maintain good coding practices.\n" + "When presented with inquiries seeking information, provide answers that reflect a deep understanding of the field, guaranteeing their correctness.\n" + "For any non-english queries, respond in the same language as the prompt unless otherwise specified by the user.\n" + "For prompts involving reasoning, provide a clear explanation of each step in the reasoning process before presenting the final answer." + ), + ) +) + +register_conv_template( + Conversation( + name="gemini-2.5-flash", + roles=("user", "model"), + sep_style=SeparatorStyle.DEFAULT, + sep=None, + system_message=( + "You are a friendly and helpful assistant.\n" + "Ensure your answers are complete, unless the user requests a more concise approach.\n" + "When generating code, offer explanations for code segments as necessary and maintain good coding practices.\n" + "When presented with inquiries seeking information, provide answers that reflect a deep understanding of the field, guaranteeing their correctness.\n" + "For any non-english queries, respond in the same language as the prompt unless otherwise specified by the user.\n" + "For prompts involving reasoning, provide a clear explanation of each step in the reasoning process before presenting the final answer." + ), + ) +) + +register_conv_template( + Conversation( + name="gemini-2.5-pro", + roles=("user", "model"), + sep_style=SeparatorStyle.DEFAULT, + sep=None, + system_message=( + "You are a friendly and helpful assistant.\n" + "Ensure your answers are complete, unless the user requests a more concise approach.\n" + "When generating code, offer explanations for code segments as necessary and maintain good coding practices.\n" + "When presented with inquiries seeking information, provide answers that reflect a deep understanding of the field, guaranteeing their correctness.\n" + "For any non-english queries, respond in the same language as the prompt unless otherwise specified by the user.\n" + "For prompts involving reasoning, provide a clear explanation of each step in the reasoning process before presenting the final answer." + ), + ) +) + # BiLLa default template register_conv_template( Conversation( diff --git a/fastchat/llm_judge/common.py b/fastchat/llm_judge/common.py index d2640d601..63891fab2 100644 --- a/fastchat/llm_judge/common.py +++ b/fastchat/llm_judge/common.py @@ -13,11 +13,13 @@ import openai import anthropic +from google import genai from fastchat.model.model_adapter import ( get_conversation_template, ANTHROPIC_MODEL_LIST, OPENAI_MODEL_LIST, + GOOGLE_MODEL_LIST, ) # API setting constants @@ -169,6 +171,10 @@ def run_judge_single(question, answer, judge, ref_answer, multi_turn=False): judgment = chat_completion_anthropic( model, conv, temperature=0, max_tokens=1024 ) + elif model in GOOGLE_MODEL_LIST: + judgment = chat_completion_google( + model, conv, temperature=0, max_tokens=2048 + ) else: raise ValueError(f"Invalid judge model name: {model}") @@ -493,6 +499,35 @@ def chat_completion_anthropic(model, conv, temperature, max_tokens, api_dict=Non return output.strip() +def chat_completion_google(model, conv, temperature, max_tokens, api_dict=None): + if api_dict is not None and "api_key" in api_dict: + api_key = api_dict["api_key"] + else: + api_key = os.environ["GOOGLE_API_KEY"] + + output = API_ERROR_OUTPUT + for _ in range(API_MAX_RETRY): + try: + client = genai.Client(api_key = api_key) + prompt = conv.get_prompt() + response = client.models.generate_content( + model = model, + contents = prompt, + config = { + "max_output_tokens": max_tokens, + "temperature": temperature, + } + ) + output = response.text + if output is None: + output = "" + break + except genai.errors.APIError as e: + print(type(e), e) + time.sleep(API_RETRY_SLEEP) + return output.strip() + + def chat_completion_palm(chat_state, model, conv, temperature, max_tokens): from fastchat.serve.api_provider import init_palm_chat diff --git a/fastchat/llm_judge/gen_api_answer.py b/fastchat/llm_judge/gen_api_answer.py index 8f9c62624..03c20ccbb 100644 --- a/fastchat/llm_judge/gen_api_answer.py +++ b/fastchat/llm_judge/gen_api_answer.py @@ -18,10 +18,11 @@ temperature_config, chat_completion_openai, chat_completion_anthropic, + chat_completion_google, chat_completion_palm, ) from fastchat.llm_judge.gen_model_answer import reorg_answer_file -from fastchat.model.model_adapter import get_conversation_template, ANTHROPIC_MODEL_LIST +from fastchat.model.model_adapter import get_conversation_template, ANTHROPIC_MODEL_LIST, GOOGLE_MODEL_LIST def get_answer( @@ -55,6 +56,8 @@ def get_answer( chat_state, output = chat_completion_palm( chat_state, model, conv, temperature, max_tokens ) + elif model in GOOGLE_MODEL_LIST: + output = chat_completion_google(model, conv, temperature, max_tokens) else: output = chat_completion_openai(model, conv, temperature, max_tokens) diff --git a/fastchat/model/model_adapter.py b/fastchat/model/model_adapter.py index 16cf5d2b6..ce8f3dc8d 100644 --- a/fastchat/model/model_adapter.py +++ b/fastchat/model/model_adapter.py @@ -93,6 +93,12 @@ "o1-mini", ) +GOOGLE_MODEL_LIST = ( + "gemini-2.5-flash-preview-09-2025", + "gemini-2.5-flash", + "gemini-2.5-pro", +) + class BaseModelAdapter: """The base and the default model adapter.""" @@ -2244,6 +2250,10 @@ def load_model(self, model_path: str, from_pretrained_kwargs: dict): def get_default_conv_template(self, model_path: str) -> Conversation: if "gemini-1.5-pro" in model_path: return get_conv_template("gemini-1.5-pro") + elif "gemini-2.5-flash" in model_path: + return get_conv_template("gemini-2.5-flash") + elif "gemini-2.5-pro" in model_path: + return get_conv_template("gemini-2.5-pro") return get_conv_template("gemini") diff --git a/fastchat/model/model_registry.py b/fastchat/model/model_registry.py index 2eed9649e..8634574ef 100644 --- a/fastchat/model/model_registry.py +++ b/fastchat/model/model_registry.py @@ -91,6 +91,10 @@ def get_model_info(name: str) -> ModelInfo: register_model_info( [ + "gemini-2.5-flash-preview-09-2025", + "gemini-2.5-flash-preview-05-20", + "gemini-2.5-flash", + "gemini-2.5-pro", "gemini-1.5-pro-exp-0827", "gemini-1.5-pro-exp-0801", "gemini-1.5-flash-exp-0827", diff --git a/pyproject.toml b/pyproject.toml index 916aaeae0..9bfe1ed19 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ model_worker = ["accelerate>=0.21", "peft", "sentencepiece", "torch", "transformers>=4.31.0", "protobuf", "openai", "anthropic"] webui = ["gradio>=4.10", "plotly", "scipy"] train = ["einops", "flash-attn>=2.0", "wandb"] -llm_judge = ["openai<1", "anthropic>=0.3", "ray"] +llm_judge = ["openai<1", "anthropic>=0.3", "ray", "google-genai"] dev = ["black==23.3.0", "pylint==2.8.2"] [project.urls]