豆豆友情提示:这是一个非官方 GitHub 代理镜像,主要用于网络测试或访问加速。请勿在此进行登录、注册或处理任何敏感信息。进行这些操作请务必访问官方网站 github.com。 Raw 内容也通过此代理提供。
Skip to content

Commit c8a75d8

Browse files
ktranDevtools-frontend LUCI CQ
authored andcommitted
[ai-eval] Split data collection from stats computation
This is a clean up to split the data collection part from the stats computation. Bug: 475195894 Change-Id: I5533d158723ad659b0ab3fa844dbdc3b8aa66084 Reviewed-on: https://chromium-review.googlesource.com/c/devtools/devtools-frontend/+/7548045 Commit-Queue: Kim-Anh Tran <kimanh@chromium.org> Reviewed-by: Kateryna Prokopenko <kprokopenko@chromium.org>
1 parent ac222e6 commit c8a75d8

File tree

1 file changed

+175
-142
lines changed

1 file changed

+175
-142
lines changed

scripts/ai_assistance/suite/helpers/evaluators.ts

Lines changed: 175 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -314,112 +314,78 @@ function calculateStandardDeviation(values: number[]): number {
314314
return Math.sqrt(variance);
315315
}
316316

317+
/**
318+
* Calculate the overall score for a conversation based on rubric importance.
319+
*/
320+
function calculateWeightedScore(rubricScores: RubricScore[], weights: RubricWeights): number {
321+
let totalWeightedScore = 0;
322+
let totalWeight = 0;
323+
for (const {rubric, score} of rubricScores) {
324+
const weight = weights[rubric] ?? IMPORTANCE_WEIGHTS.minor;
325+
totalWeightedScore += score * weight;
326+
totalWeight += weight;
327+
}
328+
return totalWeight > 0 ? totalWeightedScore / totalWeight : 0;
329+
}
330+
317331
export async function itEval(config: ItEval): Promise<void> {
318332
const state = stateStorage.getStore();
319333
assert.ok(state);
320-
if ('succeed' in config) {
321-
for (const [date, outputs] of Object.entries(state.outputsByDate)) {
322-
if (!outputs) {
323-
continue;
324-
}
325334

326-
const allDevToolsConversations = outputs.flatMap(o => o.contents.conversations);
335+
let goldenText = '';
336+
if ('rouge' in config) {
337+
const golden = await getGolden(state.store.type, state.store.label);
338+
goldenText = golden?.queries.at(-1)?.response.text ?? '';
339+
}
327340

328-
let total = 0;
329-
let succeeded = 0;
330-
for (const conversation of allDevToolsConversations) {
331-
total++;
332-
if (config.succeed(conversation)) {
333-
succeeded++;
334-
}
335-
state.store.saveResult(config.test, date, {type: 'BINARY', success: succeeded, total});
336-
}
341+
for (const [date, outputs] of Object.entries(state.outputsByDate)) {
342+
if (!outputs) {
343+
continue;
337344
}
338-
} else if ('judge' in config) {
339-
for (const [date, outputs] of Object.entries(state.outputsByDate)) {
340-
if (!outputs) {
341-
continue;
342-
}
343-
const allDevToolsConversations = outputs.flatMap(o => o.contents.conversations);
345+
const conversations = outputs.flatMap(o => o.contents.conversations);
346+
347+
if ('succeed' in config) {
348+
const details = conversations.map(conversation => ({
349+
success: config.succeed(conversation),
350+
conversation,
351+
}));
352+
state.store.saveResult(config.test, date, {type: 'BINARY', details});
353+
} else if ('judge' in config) {
344354
const repeatCount = argv.repeat;
345-
346-
// Collect scores from all examples.
347-
const results = await Promise.all(allDevToolsConversations.flatMap(example => {
348-
return Array.from({length: repeatCount}, () => geminiLimiter.run(() => config.judge(example)));
349-
}));
350-
351-
// Calculate stats for each rubric (take the average for each rubric among multiple conversations)
352-
const inputCount = allDevToolsConversations.length;
353-
const statsByRubric: Record<RubricName, RubricStats> = {};
354-
const allRubrics = new Set(results.flatMap(r => r.rubricScores.map(i => i.rubric)));
355-
for (const rubric of allRubrics) {
356-
const scores = results.flatMap(r => r.rubricScores.filter(i => i.rubric === rubric).map(i => i.score));
357-
const total = scores.reduce((acc, score) => acc + score, 0);
358-
359-
const average = scores.length ? total / scores.length : 0;
360-
const standardDeviation = calculateStandardDeviation(scores);
361-
statsByRubric[rubric] = {average, standardDeviation, allScores: scores};
362-
}
363-
364-
// Weight all rubrics (take the first result's weights, as they should be the same for all results).
365-
const weights = results.length > 0 ? results[0].rubricWeights : {};
366-
const allWeightedScores: number[] = [];
367-
for (const rubricScores of results) {
368-
let totalWeightedScore = 0;
369-
let totalWeight = 0;
370-
for (const {rubric, score} of rubricScores.rubricScores) {
371-
const weight = weights[rubric] ?? IMPORTANCE_WEIGHTS.minor;
372-
totalWeightedScore += score * weight;
373-
totalWeight += weight;
374-
}
375-
if (totalWeight > 0) {
376-
allWeightedScores.push(totalWeightedScore / totalWeight);
377-
}
378-
}
379-
380-
// Calculate average and standard deviation for the overall score.
381-
const overallAverage = allWeightedScores.length ?
382-
allWeightedScores.reduce((acc, score) => acc + score, 0) / allWeightedScores.length :
383-
0;
384-
const overallStandardDeviation = calculateStandardDeviation(allWeightedScores);
385-
const overallStats: RubricStats = {
386-
average: overallAverage,
387-
standardDeviation: overallStandardDeviation,
388-
allScores: allWeightedScores,
389-
};
355+
const scoredEvals =
356+
conversations.flatMap(conversation => Array.from({length: repeatCount}, () => geminiLimiter.run(async () => {
357+
const res = await config.judge(conversation);
358+
return {
359+
conversation,
360+
rubricScores: res.rubricScores,
361+
rubricWeights: res.rubricWeights,
362+
};
363+
})));
364+
const results = await Promise.all(scoredEvals);
365+
const rubricWeights = results.length > 0 ? results[0].rubricWeights : {};
390366

391367
state.store.saveResult(config.test, date, {
392368
type: 'JUDGE',
393-
statsByRubric,
394-
overallStats,
395-
inputCount,
396369
repetitionCount: repeatCount,
370+
rubricWeights,
371+
details: results.map(r => ({
372+
conversation: r.conversation,
373+
rubricScores: r.rubricScores,
374+
})),
397375
});
398-
}
399-
} else if ('rouge' in config) {
400-
const golden = await getGolden(state.store.type, state.store.label);
401-
const goldenText = golden?.queries.at(-1)?.response.text ?? '';
402-
403-
for (const [date, outputs] of Object.entries(state.outputsByDate)) {
404-
if (!outputs) {
405-
continue;
406-
}
407-
const allDevToolsConversations = outputs.flatMap(o => o.contents.conversations);
408-
const scores: number[] = [];
409-
410-
for (const conversation of allDevToolsConversations) {
376+
} else if ('rouge' in config) {
377+
const details = conversations.map(conversation => {
411378
const candidateText = conversation.queries.at(-1)?.response.text ?? '';
412-
const score = ROUGE.score(candidateText, goldenText);
413-
scores.push(score);
414-
}
415-
416-
const total = scores.reduce((acc, score) => acc + score, 0);
417-
const average = scores.length ? total / scores.length : 0;
418-
const standardDeviation = calculateStandardDeviation(scores);
379+
return {
380+
conversation,
381+
score: ROUGE.score(candidateText, goldenText),
382+
goldenResponse: goldenText,
383+
};
384+
});
419385

420386
state.store.saveResult(config.test, date, {
421387
type: 'ROUGE',
422-
stats: {average, standardDeviation, allScores: scores},
388+
details,
423389
});
424390
}
425391
}
@@ -443,9 +409,9 @@ export async function evalGroup(config: GroupConfig, cb: (() => Promise<void>)):
443409

444410
await stateStorage.run(state, async () => {
445411
await cb();
446-
console.log(state.logs.join('\n'));
447-
printResults(state.store);
448412
});
413+
log(0, state.logs.join('\n'));
414+
printResults(state.store);
449415
}
450416

451417
function log(indentation: number, message: string): void {
@@ -462,69 +428,72 @@ function formatRubricStats(stats: RubricStats): string {
462428
return `${stats.average.toFixed(2)} (mean) ±${stats.standardDeviation.toFixed(2)}`;
463429
}
464430

465-
function formatJudgeResult(result: Extract<Result, {type: 'JUDGE'}>, row: string): string {
431+
function formatJudgeResult(stats: JudgeStats, row: string, repetitionCount: number): string {
466432
if (row === NUM_CONVERSATIONS) {
467-
return String(result.inputCount);
433+
return String(stats.inputCount);
468434
}
469435
if (row === NUM_EVALS_PER_CONVERSATION) {
470-
return String(result.repetitionCount);
436+
return String(repetitionCount);
471437
}
472438
if (row === OVERALL_STATS) {
473-
return formatRubricStats(result.overallStats);
439+
return formatRubricStats(stats.overallStats);
440+
}
441+
const rubricStats = stats.statsByRubric[row];
442+
return rubricStats ? formatRubricStats(rubricStats) : '-';
443+
}
444+
445+
function populateTableData(tableData: Record<string, Record<string, string>>, date: string, result: Result): void {
446+
const stats = calculateStats(result);
447+
if (!stats) {
448+
return;
449+
}
450+
451+
switch (result.type) {
452+
case 'BINARY': {
453+
if (!tableData[PASS_RATE]) {
454+
tableData[PASS_RATE] = {};
455+
}
456+
const binary = stats as BinaryStats;
457+
tableData[PASS_RATE][date] = `${binary.success} / ${binary.total} passed`;
458+
break;
459+
}
460+
case 'JUDGE': {
461+
const judge = stats as JudgeStats;
462+
const rubrics = Object.keys(judge.statsByRubric).sort();
463+
for (const row of [OVERALL_STATS, ...rubrics, NUM_CONVERSATIONS, NUM_EVALS_PER_CONVERSATION]) {
464+
if (!tableData[row]) {
465+
tableData[row] = {};
466+
}
467+
tableData[row][date] = formatJudgeResult(judge, row, result.repetitionCount);
468+
}
469+
break;
470+
}
471+
case 'ROUGE': {
472+
if (!tableData[ROUGE_L_SUM]) {
473+
tableData[ROUGE_L_SUM] = {};
474+
}
475+
const rouge = stats as RougeStats;
476+
tableData[ROUGE_L_SUM][date] = formatRubricStats(rouge);
477+
break;
478+
}
474479
}
475-
const stats = result.statsByRubric[row];
476-
return stats ? formatRubricStats(stats) : '-';
477480
}
478481

479482
function printResults(store: ResultStore): void {
480483
log(0, `Results for: ${store.type}/${store.label}`);
481484

482485
for (const [test, dateToResult] of store.results) {
483-
if (Object.keys(Object.fromEntries(dateToResult)).length === 0) {
486+
if (dateToResult.size === 0) {
484487
continue;
485488
}
486489
log(0, `\nTest: ${test}`);
487490

488491
const sortedDates = Array.from(dateToResult.keys()).sort();
489-
// Collect all rubric names, if this is a LLM-as-a-judge rating
490-
const allRubrics = new Set<string>();
491-
for (const result of dateToResult.values()) {
492-
if (result.type === 'JUDGE') {
493-
Object.keys(result.statsByRubric).forEach(r => allRubrics.add(r));
494-
}
495-
}
496-
497492
const tableData: Record<string, Record<string, string>> = {};
498493
for (const date of sortedDates) {
499494
const result = dateToResult.get(date);
500-
if (!result) {
501-
continue;
502-
}
503-
504-
switch (result.type) {
505-
case 'BINARY':
506-
if (!tableData[PASS_RATE]) {
507-
tableData[PASS_RATE] = {};
508-
}
509-
tableData[PASS_RATE][date] = `${result.success} / ${result.total} passed`;
510-
break;
511-
case 'JUDGE':
512-
// Create a row for each rubric, including overall and runs per query/input
513-
for (const row
514-
of [OVERALL_STATS, ...Array.from(allRubrics).sort(), NUM_CONVERSATIONS,
515-
NUM_EVALS_PER_CONVERSATION]) {
516-
if (!tableData[row]) {
517-
tableData[row] = {};
518-
}
519-
tableData[row][date] = formatJudgeResult(result, row);
520-
}
521-
break;
522-
case 'ROUGE':
523-
if (!tableData[ROUGE_L_SUM]) {
524-
tableData[ROUGE_L_SUM] = {};
525-
}
526-
tableData[ROUGE_L_SUM][date] = formatRubricStats(result.stats);
527-
break;
495+
if (result) {
496+
populateTableData(tableData, date, result);
528497
}
529498
}
530499
console.table(tableData);
@@ -537,22 +506,86 @@ interface RubricStats {
537506
allScores: number[];
538507
}
539508

509+
interface BinaryStats {
510+
success: number;
511+
total: number;
512+
}
513+
type RougeStats = RubricStats;
514+
interface JudgeStats {
515+
statsByRubric: Record<string, RubricStats>;
516+
overallStats: RubricStats;
517+
inputCount: number;
518+
}
519+
540520
type Result = {
541521
type: 'BINARY',
542-
total: number,
543-
success: number,
522+
details: Array<{success: boolean, conversation: Conversation}>,
544523
}|{
545524
type: 'JUDGE',
546-
547-
statsByRubric: Record<RubricName, RubricStats>,
548-
overallStats: RubricStats,
549-
inputCount: number,
550525
repetitionCount: number,
526+
rubricWeights: RubricWeights,
527+
details: Array<{conversation: Conversation, rubricScores: RubricScore[]}>,
551528
}|{
552529
type: 'ROUGE',
553-
stats: RubricStats,
530+
details: Array<{conversation: Conversation, score: number, goldenResponse: string}>,
554531
};
555532

533+
function calculateBinaryStats(result: Extract<Result, {type: 'BINARY'}>): BinaryStats {
534+
const success = result.details.filter(d => d.success).length;
535+
return {success, total: result.details.length};
536+
}
537+
538+
function calculateRougeStats(result: Extract<Result, {type: 'ROUGE'}>): RougeStats {
539+
const scores = result.details.map(d => d.score);
540+
const average = scores.length ? scores.reduce((a, b) => a + b, 0) / scores.length : 0;
541+
return {average, standardDeviation: calculateStandardDeviation(scores), allScores: scores};
542+
}
543+
544+
function calculateJudgeStats(result: Extract<Result, {type: 'JUDGE'}>): JudgeStats {
545+
const statsByRubric: Record<string, RubricStats> = {};
546+
assert.ok(result.details.length > 0, 'A judge result must have at least one conversation');
547+
548+
const allRubrics = result.details[0].rubricScores.map(s => s.rubric).sort();
549+
for (const detail of result.details) {
550+
const currentRubrics = detail.rubricScores.map(s => s.rubric).sort();
551+
assert.deepStrictEqual(
552+
currentRubrics, allRubrics, 'All conversations in a judge result must have the same rubrics');
553+
}
554+
555+
for (const rubric of allRubrics) {
556+
const scores = result.details.flatMap(d => d.rubricScores.filter(s => s.rubric === rubric).map(s => s.score));
557+
const average = scores.length ? scores.reduce((a, b) => a + b, 0) / scores.length : 0;
558+
statsByRubric[rubric] = {average, standardDeviation: calculateStandardDeviation(scores), allScores: scores};
559+
}
560+
561+
const overallScores = result.details.map(d => calculateWeightedScore(d.rubricScores, result.rubricWeights));
562+
const overallAverage = overallScores.length ? overallScores.reduce((a, b) => a + b, 0) / overallScores.length : 0;
563+
const overallStats = {
564+
average: overallAverage,
565+
standardDeviation: calculateStandardDeviation(overallScores),
566+
allScores: overallScores,
567+
};
568+
569+
return {
570+
statsByRubric,
571+
overallStats,
572+
inputCount: result.details.length / result.repetitionCount,
573+
};
574+
}
575+
576+
function calculateStats(result: Result): BinaryStats|RougeStats|JudgeStats|null {
577+
switch (result.type) {
578+
case 'BINARY':
579+
return calculateBinaryStats(result);
580+
case 'ROUGE':
581+
return calculateRougeStats(result);
582+
case 'JUDGE':
583+
return calculateJudgeStats(result);
584+
default:
585+
return null;
586+
}
587+
}
588+
556589
class ResultStore {
557590
// Map of testName => YYYY-MM-DD => Result
558591
#results = new Map<string, Map<string, Result>>();

0 commit comments

Comments
 (0)