Skip to content

Commit 12cd414

Browse files
jsondaicopybara-github
authored andcommitted
feat: GenAI Client(evals) - support setting autorater generation config for predefined rubric metrics
PiperOrigin-RevId: 833487047
1 parent dd4775b commit 12cd414

File tree

4 files changed

+123
-2
lines changed

4 files changed

+123
-2
lines changed

tests/unit/vertexai/genai/replays/test_evaluate_predefined_metrics.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from tests.unit.vertexai.genai.replays import pytest_helper
1818
from vertexai import types
19+
from google.genai import types as genai_types
1920
import pandas as pd
2021

2122

@@ -60,6 +61,50 @@ def test_evaluation_result(client):
6061
assert case_result.response_candidate_results is not None
6162

6263

64+
def test_evaluation_result_with_autorater_config(client):
65+
"""Tests that evaluate() produces a correctly structured EvaluationResult."""
66+
prompts_df = pd.DataFrame(
67+
{
68+
"prompt": ["Explain the concept of machine learning in simple terms."],
69+
"response": [
70+
"Machine learning is a type of artificial intelligence that allows"
71+
" computers to learn from data without being explicitly programmed."
72+
],
73+
}
74+
)
75+
76+
eval_dataset = types.EvaluationDataset(
77+
eval_dataset_df=prompts_df,
78+
candidate_name="gemini-2.5-flash",
79+
)
80+
81+
predefined_metric_with_autorater_config = types.RubricMetric.GENERAL_QUALITY(
82+
judge_model_generation_config=genai_types.GenerationConfig(
83+
temperature=0.1,
84+
max_output_tokens=1024,
85+
)
86+
)
87+
88+
evaluation_result = client.evals.evaluate(
89+
dataset=eval_dataset,
90+
metrics=[predefined_metric_with_autorater_config],
91+
)
92+
93+
assert isinstance(evaluation_result, types.EvaluationResult)
94+
95+
assert evaluation_result.summary_metrics is not None
96+
for summary in evaluation_result.summary_metrics:
97+
assert isinstance(summary, types.AggregatedMetricResult)
98+
assert summary.metric_name == "general_quality_v1"
99+
assert summary.mean_score is not None
100+
101+
assert evaluation_result.eval_case_results is not None
102+
for case_result in evaluation_result.eval_case_results:
103+
assert isinstance(case_result, types.EvalCaseResult)
104+
assert case_result.eval_case_index is not None
105+
assert case_result.response_candidate_results is not None
106+
107+
63108
def test_multi_turn_predefined_metric(client):
64109
"""Tests that evaluate works with multi-turn predefined metrics."""
65110
prompts_data = {

tests/unit/vertexai/genai/test_evals.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,54 @@ def test_eval_evaluate_with_agent_info(self, mock_execute_evaluation):
189189
assert "agent_info" in kwargs
190190
assert kwargs["agent_info"] == agent_info
191191

192+
def test_evaluate_predefined_metric_with_autorater_config(self):
193+
dataset = vertexai_genai_types.EvaluationDataset(
194+
eval_dataset_df=pd.DataFrame([{"prompt": "p1", "response": "r1"}])
195+
)
196+
generation_config = genai_types.GenerationConfig(
197+
temperature=0.1,
198+
max_output_tokens=1024,
199+
)
200+
metrics = [
201+
vertexai_genai_types.RubricMetric.GENERAL_QUALITY(
202+
judge_model_generation_config=generation_config
203+
)
204+
]
205+
206+
mock_prebuilt_metric = vertexai_genai_types.LLMMetric(
207+
name="general_quality_v1",
208+
prompt_template="Is this quality? {response}",
209+
)
210+
mock_prebuilt_metric._is_predefined = True
211+
mock_prebuilt_metric._config_source = (
212+
"gs://mock-metrics/general_quality/v1.yaml"
213+
)
214+
mock_prebuilt_metric._version = "v1"
215+
216+
with mock.patch(
217+
"vertexai._genai._evals_metric_loaders.LazyLoadedPrebuiltMetric._fetch_and_parse",
218+
return_value=mock_prebuilt_metric,
219+
), mock.patch(
220+
"vertexai._genai.evals.Evals._evaluate_instances"
221+
) as mock_evaluate_instances, mock.patch(
222+
"vertexai._genai._evals_metric_handlers._evals_constant.SUPPORTED_PREDEFINED_METRICS",
223+
frozenset(["general_quality_v1"]),
224+
):
225+
mock_evaluate_instances.return_value = (
226+
vertexai_genai_types.EvaluateInstancesResponse(
227+
metric_results=[vertexai_genai_types.MetricResult(score=0.9)]
228+
)
229+
)
230+
self.client.evals.evaluate(
231+
dataset=dataset,
232+
metrics=metrics,
233+
)
234+
235+
mock_evaluate_instances.assert_called_once()
236+
_, kwargs = mock_evaluate_instances.call_args
237+
assert "autorater_config" in kwargs
238+
assert kwargs["autorater_config"].generation_config == generation_config
239+
192240

193241
class TestEvalsVisualization:
194242
@mock.patch(

vertexai/_genai/_evals_metric_handlers.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -620,6 +620,10 @@ def _add_autorater_config(self, payload: dict[str, Any]) -> None:
620620
autorater_config = {}
621621
if self.metric.judge_model:
622622
autorater_config["autorater_model"] = self.metric.judge_model
623+
if self.metric.judge_model_generation_config:
624+
autorater_config["generation_config"] = (
625+
self.metric.judge_model_generation_config
626+
)
623627
if self.metric.judge_model_sampling_count:
624628
autorater_config["sampling_count"] = self.metric.judge_model_sampling_count # type: ignore[assignment]
625629

@@ -986,10 +990,25 @@ def _build_request_payload(
986990
agent_data=PredefinedMetricHandler._eval_case_to_agent_data(eval_case),
987991
)
988992

989-
return {
993+
request_payload = {
990994
"instance": instance_payload,
991995
}
992996

997+
autorater_config = {}
998+
if self.metric.judge_model:
999+
autorater_config["autorater_model"] = self.metric.judge_model
1000+
if self.metric.judge_model_generation_config:
1001+
autorater_config["generation_config"] = (
1002+
self.metric.judge_model_generation_config
1003+
)
1004+
if self.metric.judge_model_sampling_count:
1005+
autorater_config["sampling_count"] = self.metric.judge_model_sampling_count
1006+
if autorater_config:
1007+
request_payload["autorater_config"] = genai_types.AutoraterConfig(
1008+
**autorater_config
1009+
)
1010+
return request_payload
1011+
9931012
@override
9941013
def get_metric_result(
9951014
self, eval_case: types.EvalCase, response_index: int
@@ -1001,7 +1020,9 @@ def get_metric_result(
10011020
for attempt in range(_MAX_RETRIES):
10021021
try:
10031022
api_response = self.module._evaluate_instances(
1004-
metrics=[self.metric], instance=payload.get("instance")
1023+
metrics=[self.metric],
1024+
instance=payload.get("instance"),
1025+
autorater_config=payload.get("autorater_config"),
10051026
)
10061027
break
10071028
except genai_errors.ClientError as e:

vertexai/_genai/types/common.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2622,6 +2622,10 @@ class Metric(_common.BaseModel):
26222622
judge_model: Optional[str] = Field(
26232623
default=None, description="""The judge model for the metric."""
26242624
)
2625+
judge_model_generation_config: Optional[genai_types.GenerationConfig] = Field(
2626+
default=None,
2627+
description="""The generation config for the judge LLM (temperature, top_k, top_p, etc).""",
2628+
)
26252629
judge_model_sampling_count: Optional[int] = Field(
26262630
default=None, description="""The sampling count for the judge model."""
26272631
)
@@ -2825,6 +2829,9 @@ class MetricDict(TypedDict, total=False):
28252829
judge_model: Optional[str]
28262830
"""The judge model for the metric."""
28272831

2832+
judge_model_generation_config: Optional[genai_types.GenerationConfigDict]
2833+
"""The generation config for the judge LLM (temperature, top_k, top_p, etc)."""
2834+
28282835
judge_model_sampling_count: Optional[int]
28292836
"""The sampling count for the judge model."""
28302837

0 commit comments

Comments
 (0)