@@ -314,112 +314,78 @@ function calculateStandardDeviation(values: number[]): number {
314314 return Math . sqrt ( variance ) ;
315315}
316316
317+ /**
318+ * Calculate the overall score for a conversation based on rubric importance.
319+ */
320+ function calculateWeightedScore ( rubricScores : RubricScore [ ] , weights : RubricWeights ) : number {
321+ let totalWeightedScore = 0 ;
322+ let totalWeight = 0 ;
323+ for ( const { rubric, score} of rubricScores ) {
324+ const weight = weights [ rubric ] ?? IMPORTANCE_WEIGHTS . minor ;
325+ totalWeightedScore += score * weight ;
326+ totalWeight += weight ;
327+ }
328+ return totalWeight > 0 ? totalWeightedScore / totalWeight : 0 ;
329+ }
330+
317331export async function itEval ( config : ItEval ) : Promise < void > {
318332 const state = stateStorage . getStore ( ) ;
319333 assert . ok ( state ) ;
320- if ( 'succeed' in config ) {
321- for ( const [ date , outputs ] of Object . entries ( state . outputsByDate ) ) {
322- if ( ! outputs ) {
323- continue ;
324- }
325334
326- const allDevToolsConversations = outputs . flatMap ( o => o . contents . conversations ) ;
335+ let goldenText = '' ;
336+ if ( 'rouge' in config ) {
337+ const golden = await getGolden ( state . store . type , state . store . label ) ;
338+ goldenText = golden ?. queries . at ( - 1 ) ?. response . text ?? '' ;
339+ }
327340
328- let total = 0 ;
329- let succeeded = 0 ;
330- for ( const conversation of allDevToolsConversations ) {
331- total ++ ;
332- if ( config . succeed ( conversation ) ) {
333- succeeded ++ ;
334- }
335- state . store . saveResult ( config . test , date , { type : 'BINARY' , success : succeeded , total} ) ;
336- }
341+ for ( const [ date , outputs ] of Object . entries ( state . outputsByDate ) ) {
342+ if ( ! outputs ) {
343+ continue ;
337344 }
338- } else if ( 'judge' in config ) {
339- for ( const [ date , outputs ] of Object . entries ( state . outputsByDate ) ) {
340- if ( ! outputs ) {
341- continue ;
342- }
343- const allDevToolsConversations = outputs . flatMap ( o => o . contents . conversations ) ;
345+ const conversations = outputs . flatMap ( o => o . contents . conversations ) ;
346+
347+ if ( 'succeed' in config ) {
348+ const details = conversations . map ( conversation => ( {
349+ success : config . succeed ( conversation ) ,
350+ conversation,
351+ } ) ) ;
352+ state . store . saveResult ( config . test , date , { type : 'BINARY' , details} ) ;
353+ } else if ( 'judge' in config ) {
344354 const repeatCount = argv . repeat ;
345-
346- // Collect scores from all examples.
347- const results = await Promise . all ( allDevToolsConversations . flatMap ( example => {
348- return Array . from ( { length : repeatCount } , ( ) => geminiLimiter . run ( ( ) => config . judge ( example ) ) ) ;
349- } ) ) ;
350-
351- // Calculate stats for each rubric (take the average for each rubric among multiple conversations)
352- const inputCount = allDevToolsConversations . length ;
353- const statsByRubric : Record < RubricName , RubricStats > = { } ;
354- const allRubrics = new Set ( results . flatMap ( r => r . rubricScores . map ( i => i . rubric ) ) ) ;
355- for ( const rubric of allRubrics ) {
356- const scores = results . flatMap ( r => r . rubricScores . filter ( i => i . rubric === rubric ) . map ( i => i . score ) ) ;
357- const total = scores . reduce ( ( acc , score ) => acc + score , 0 ) ;
358-
359- const average = scores . length ? total / scores . length : 0 ;
360- const standardDeviation = calculateStandardDeviation ( scores ) ;
361- statsByRubric [ rubric ] = { average, standardDeviation, allScores : scores } ;
362- }
363-
364- // Weight all rubrics (take the first result's weights, as they should be the same for all results).
365- const weights = results . length > 0 ? results [ 0 ] . rubricWeights : { } ;
366- const allWeightedScores : number [ ] = [ ] ;
367- for ( const rubricScores of results ) {
368- let totalWeightedScore = 0 ;
369- let totalWeight = 0 ;
370- for ( const { rubric, score} of rubricScores . rubricScores ) {
371- const weight = weights [ rubric ] ?? IMPORTANCE_WEIGHTS . minor ;
372- totalWeightedScore += score * weight ;
373- totalWeight += weight ;
374- }
375- if ( totalWeight > 0 ) {
376- allWeightedScores . push ( totalWeightedScore / totalWeight ) ;
377- }
378- }
379-
380- // Calculate average and standard deviation for the overall score.
381- const overallAverage = allWeightedScores . length ?
382- allWeightedScores . reduce ( ( acc , score ) => acc + score , 0 ) / allWeightedScores . length :
383- 0 ;
384- const overallStandardDeviation = calculateStandardDeviation ( allWeightedScores ) ;
385- const overallStats : RubricStats = {
386- average : overallAverage ,
387- standardDeviation : overallStandardDeviation ,
388- allScores : allWeightedScores ,
389- } ;
355+ const scoredEvals =
356+ conversations . flatMap ( conversation => Array . from ( { length : repeatCount } , ( ) => geminiLimiter . run ( async ( ) => {
357+ const res = await config . judge ( conversation ) ;
358+ return {
359+ conversation,
360+ rubricScores : res . rubricScores ,
361+ rubricWeights : res . rubricWeights ,
362+ } ;
363+ } ) ) ) ;
364+ const results = await Promise . all ( scoredEvals ) ;
365+ const rubricWeights = results . length > 0 ? results [ 0 ] . rubricWeights : { } ;
390366
391367 state . store . saveResult ( config . test , date , {
392368 type : 'JUDGE' ,
393- statsByRubric,
394- overallStats,
395- inputCount,
396369 repetitionCount : repeatCount ,
370+ rubricWeights,
371+ details : results . map ( r => ( {
372+ conversation : r . conversation ,
373+ rubricScores : r . rubricScores ,
374+ } ) ) ,
397375 } ) ;
398- }
399- } else if ( 'rouge' in config ) {
400- const golden = await getGolden ( state . store . type , state . store . label ) ;
401- const goldenText = golden ?. queries . at ( - 1 ) ?. response . text ?? '' ;
402-
403- for ( const [ date , outputs ] of Object . entries ( state . outputsByDate ) ) {
404- if ( ! outputs ) {
405- continue ;
406- }
407- const allDevToolsConversations = outputs . flatMap ( o => o . contents . conversations ) ;
408- const scores : number [ ] = [ ] ;
409-
410- for ( const conversation of allDevToolsConversations ) {
376+ } else if ( 'rouge' in config ) {
377+ const details = conversations . map ( conversation => {
411378 const candidateText = conversation . queries . at ( - 1 ) ?. response . text ?? '' ;
412- const score = ROUGE . score ( candidateText , goldenText ) ;
413- scores . push ( score ) ;
414- }
415-
416- const total = scores . reduce ( ( acc , score ) => acc + score , 0 ) ;
417- const average = scores . length ? total / scores . length : 0 ;
418- const standardDeviation = calculateStandardDeviation ( scores ) ;
379+ return {
380+ conversation,
381+ score : ROUGE . score ( candidateText , goldenText ) ,
382+ goldenResponse : goldenText ,
383+ } ;
384+ } ) ;
419385
420386 state . store . saveResult ( config . test , date , {
421387 type : 'ROUGE' ,
422- stats : { average , standardDeviation , allScores : scores } ,
388+ details ,
423389 } ) ;
424390 }
425391 }
@@ -443,9 +409,9 @@ export async function evalGroup(config: GroupConfig, cb: (() => Promise<void>)):
443409
444410 await stateStorage . run ( state , async ( ) => {
445411 await cb ( ) ;
446- console . log ( state . logs . join ( '\n' ) ) ;
447- printResults ( state . store ) ;
448412 } ) ;
413+ log ( 0 , state . logs . join ( '\n' ) ) ;
414+ printResults ( state . store ) ;
449415}
450416
451417function log ( indentation : number , message : string ) : void {
@@ -462,69 +428,72 @@ function formatRubricStats(stats: RubricStats): string {
462428 return `${ stats . average . toFixed ( 2 ) } (mean) ±${ stats . standardDeviation . toFixed ( 2 ) } ` ;
463429}
464430
465- function formatJudgeResult ( result : Extract < Result , { type : 'JUDGE' } > , row : string ) : string {
431+ function formatJudgeResult ( stats : JudgeStats , row : string , repetitionCount : number ) : string {
466432 if ( row === NUM_CONVERSATIONS ) {
467- return String ( result . inputCount ) ;
433+ return String ( stats . inputCount ) ;
468434 }
469435 if ( row === NUM_EVALS_PER_CONVERSATION ) {
470- return String ( result . repetitionCount ) ;
436+ return String ( repetitionCount ) ;
471437 }
472438 if ( row === OVERALL_STATS ) {
473- return formatRubricStats ( result . overallStats ) ;
439+ return formatRubricStats ( stats . overallStats ) ;
440+ }
441+ const rubricStats = stats . statsByRubric [ row ] ;
442+ return rubricStats ? formatRubricStats ( rubricStats ) : '-' ;
443+ }
444+
445+ function populateTableData ( tableData : Record < string , Record < string , string > > , date : string , result : Result ) : void {
446+ const stats = calculateStats ( result ) ;
447+ if ( ! stats ) {
448+ return ;
449+ }
450+
451+ switch ( result . type ) {
452+ case 'BINARY' : {
453+ if ( ! tableData [ PASS_RATE ] ) {
454+ tableData [ PASS_RATE ] = { } ;
455+ }
456+ const binary = stats as BinaryStats ;
457+ tableData [ PASS_RATE ] [ date ] = `${ binary . success } / ${ binary . total } passed` ;
458+ break ;
459+ }
460+ case 'JUDGE' : {
461+ const judge = stats as JudgeStats ;
462+ const rubrics = Object . keys ( judge . statsByRubric ) . sort ( ) ;
463+ for ( const row of [ OVERALL_STATS , ...rubrics , NUM_CONVERSATIONS , NUM_EVALS_PER_CONVERSATION ] ) {
464+ if ( ! tableData [ row ] ) {
465+ tableData [ row ] = { } ;
466+ }
467+ tableData [ row ] [ date ] = formatJudgeResult ( judge , row , result . repetitionCount ) ;
468+ }
469+ break ;
470+ }
471+ case 'ROUGE' : {
472+ if ( ! tableData [ ROUGE_L_SUM ] ) {
473+ tableData [ ROUGE_L_SUM ] = { } ;
474+ }
475+ const rouge = stats as RougeStats ;
476+ tableData [ ROUGE_L_SUM ] [ date ] = formatRubricStats ( rouge ) ;
477+ break ;
478+ }
474479 }
475- const stats = result . statsByRubric [ row ] ;
476- return stats ? formatRubricStats ( stats ) : '-' ;
477480}
478481
479482function printResults ( store : ResultStore ) : void {
480483 log ( 0 , `Results for: ${ store . type } /${ store . label } ` ) ;
481484
482485 for ( const [ test , dateToResult ] of store . results ) {
483- if ( Object . keys ( Object . fromEntries ( dateToResult ) ) . length === 0 ) {
486+ if ( dateToResult . size === 0 ) {
484487 continue ;
485488 }
486489 log ( 0 , `\nTest: ${ test } ` ) ;
487490
488491 const sortedDates = Array . from ( dateToResult . keys ( ) ) . sort ( ) ;
489- // Collect all rubric names, if this is a LLM-as-a-judge rating
490- const allRubrics = new Set < string > ( ) ;
491- for ( const result of dateToResult . values ( ) ) {
492- if ( result . type === 'JUDGE' ) {
493- Object . keys ( result . statsByRubric ) . forEach ( r => allRubrics . add ( r ) ) ;
494- }
495- }
496-
497492 const tableData : Record < string , Record < string , string > > = { } ;
498493 for ( const date of sortedDates ) {
499494 const result = dateToResult . get ( date ) ;
500- if ( ! result ) {
501- continue ;
502- }
503-
504- switch ( result . type ) {
505- case 'BINARY' :
506- if ( ! tableData [ PASS_RATE ] ) {
507- tableData [ PASS_RATE ] = { } ;
508- }
509- tableData [ PASS_RATE ] [ date ] = `${ result . success } / ${ result . total } passed` ;
510- break ;
511- case 'JUDGE' :
512- // Create a row for each rubric, including overall and runs per query/input
513- for ( const row
514- of [ OVERALL_STATS , ...Array . from ( allRubrics ) . sort ( ) , NUM_CONVERSATIONS ,
515- NUM_EVALS_PER_CONVERSATION ] ) {
516- if ( ! tableData [ row ] ) {
517- tableData [ row ] = { } ;
518- }
519- tableData [ row ] [ date ] = formatJudgeResult ( result , row ) ;
520- }
521- break ;
522- case 'ROUGE' :
523- if ( ! tableData [ ROUGE_L_SUM ] ) {
524- tableData [ ROUGE_L_SUM ] = { } ;
525- }
526- tableData [ ROUGE_L_SUM ] [ date ] = formatRubricStats ( result . stats ) ;
527- break ;
495+ if ( result ) {
496+ populateTableData ( tableData , date , result ) ;
528497 }
529498 }
530499 console . table ( tableData ) ;
@@ -537,22 +506,86 @@ interface RubricStats {
537506 allScores : number [ ] ;
538507}
539508
509+ interface BinaryStats {
510+ success : number ;
511+ total : number ;
512+ }
513+ type RougeStats = RubricStats ;
514+ interface JudgeStats {
515+ statsByRubric : Record < string , RubricStats > ;
516+ overallStats : RubricStats ;
517+ inputCount : number ;
518+ }
519+
540520type Result = {
541521 type : 'BINARY' ,
542- total : number ,
543- success : number ,
522+ details : Array < { success : boolean , conversation : Conversation } > ,
544523} | {
545524 type : 'JUDGE' ,
546-
547- statsByRubric : Record < RubricName , RubricStats > ,
548- overallStats : RubricStats ,
549- inputCount : number ,
550525 repetitionCount : number ,
526+ rubricWeights : RubricWeights ,
527+ details : Array < { conversation : Conversation , rubricScores : RubricScore [ ] } > ,
551528} | {
552529 type : 'ROUGE' ,
553- stats : RubricStats ,
530+ details : Array < { conversation : Conversation , score : number , goldenResponse : string } > ,
554531} ;
555532
533+ function calculateBinaryStats ( result : Extract < Result , { type : 'BINARY' } > ) : BinaryStats {
534+ const success = result . details . filter ( d => d . success ) . length ;
535+ return { success, total : result . details . length } ;
536+ }
537+
538+ function calculateRougeStats ( result : Extract < Result , { type : 'ROUGE' } > ) : RougeStats {
539+ const scores = result . details . map ( d => d . score ) ;
540+ const average = scores . length ? scores . reduce ( ( a , b ) => a + b , 0 ) / scores . length : 0 ;
541+ return { average, standardDeviation : calculateStandardDeviation ( scores ) , allScores : scores } ;
542+ }
543+
544+ function calculateJudgeStats ( result : Extract < Result , { type : 'JUDGE' } > ) : JudgeStats {
545+ const statsByRubric : Record < string , RubricStats > = { } ;
546+ assert . ok ( result . details . length > 0 , 'A judge result must have at least one conversation' ) ;
547+
548+ const allRubrics = result . details [ 0 ] . rubricScores . map ( s => s . rubric ) . sort ( ) ;
549+ for ( const detail of result . details ) {
550+ const currentRubrics = detail . rubricScores . map ( s => s . rubric ) . sort ( ) ;
551+ assert . deepStrictEqual (
552+ currentRubrics , allRubrics , 'All conversations in a judge result must have the same rubrics' ) ;
553+ }
554+
555+ for ( const rubric of allRubrics ) {
556+ const scores = result . details . flatMap ( d => d . rubricScores . filter ( s => s . rubric === rubric ) . map ( s => s . score ) ) ;
557+ const average = scores . length ? scores . reduce ( ( a , b ) => a + b , 0 ) / scores . length : 0 ;
558+ statsByRubric [ rubric ] = { average, standardDeviation : calculateStandardDeviation ( scores ) , allScores : scores } ;
559+ }
560+
561+ const overallScores = result . details . map ( d => calculateWeightedScore ( d . rubricScores , result . rubricWeights ) ) ;
562+ const overallAverage = overallScores . length ? overallScores . reduce ( ( a , b ) => a + b , 0 ) / overallScores . length : 0 ;
563+ const overallStats = {
564+ average : overallAverage ,
565+ standardDeviation : calculateStandardDeviation ( overallScores ) ,
566+ allScores : overallScores ,
567+ } ;
568+
569+ return {
570+ statsByRubric,
571+ overallStats,
572+ inputCount : result . details . length / result . repetitionCount ,
573+ } ;
574+ }
575+
576+ function calculateStats ( result : Result ) : BinaryStats | RougeStats | JudgeStats | null {
577+ switch ( result . type ) {
578+ case 'BINARY' :
579+ return calculateBinaryStats ( result ) ;
580+ case 'ROUGE' :
581+ return calculateRougeStats ( result ) ;
582+ case 'JUDGE' :
583+ return calculateJudgeStats ( result ) ;
584+ default :
585+ return null ;
586+ }
587+ }
588+
556589class ResultStore {
557590 // Map of testName => YYYY-MM-DD => Result
558591 #results = new Map < string , Map < string , Result > > ( ) ;
0 commit comments