Skip to content

Commit

Permalink
fix gsm8k and reposition average in summary table
Browse files Browse the repository at this point in the history
  • Loading branch information
mlabonne committed Mar 29, 2024
1 parent e3a9999 commit 53f10e5
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 37 deletions.
55 changes: 21 additions & 34 deletions llm_autoeval/table.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import os
import math
import os

from pytablewriter import MarkdownTableWriter


BENCHMARK = os.getenv("BENCHMARK")


Expand Down Expand Up @@ -41,9 +40,7 @@ def calculate_average(data, task):
elif task == "winogrande":
return data["results"]["winogrande"]["acc,none"] * 100
elif task == "gsm8k":
return (
data["results"]["gsm8k"]["exact_match,get-answer"] * 100
)
return data["results"]["gsm8k"]["exact_match,strict-match"] * 100

elif BENCHMARK == "nous":
if task == "agieval":
Expand All @@ -61,51 +58,41 @@ def calculate_average(data, task):

def make_table(result_dict, task):
"""Generate table of results."""
from pytablewriter import MarkdownTableWriter

md_writer = MarkdownTableWriter()
md_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]
md_writer.headers = ["Task", "Average", "Version", "Metric", "Value", "", "Stderr"]

values = []

average = round(calculate_average(result_dict, task), 2)

for k, dic in sorted(result_dict["results"].items()):
# Correctly use get() to safely access the dictionary
version = result_dict["versions"].get(k, "N/A") # Use get() on the versions dictionary

version = result_dict["versions"].get(k, "N/A")
percent = k == "squad2"

for m, v in dic.items():
if m.endswith("_stderr"):
continue

if m + "_stderr" in dic:
se = dic[m + "_stderr"]
if percent or m == "ppl":
values.append([k, version, m, "%.2f" % v, "±", "%.2f" % se])
else:
values.append(
[k, version, m, "%.2f" % (v * 100), "±", "%.2f" % (se * 100)]
)
else:
if percent or m == "ppl":
values.append([k, version, m, "%.2f" % v, "", ""])
else:
try:
# Attempt to convert v to a float
v_converted = float(v)
v_formatted = "%.2f" % v_converted
except ValueError:
# If conversion fails, use the original string value
v_formatted = v

values.append([k, version, m, v_formatted, "", ""])
stderr = dic.get(m + "_stderr", "")
stderr_display = "±%.2f" % stderr if stderr else ""
metric_value = "%.2f" % v if percent or m == "ppl" else "%.2f" % (v * 100)

# Adjusted to skip empty row insertion logic for simplicity
values.append([k, "", version, m, metric_value, "", stderr_display])

# Reset k and version to avoid repetition in the table
k = ""
version = ""

md_writer.value_matrix = values
# Add a row for the average if desired, here shown at the start of values for demonstration
values.insert(0, ["Average", average, "", "", "", "", ""])

# Get average score
average = round(calculate_average(result_dict, task), 2)
md_writer.value_matrix = values
table_output = md_writer.dumps()

return md_writer.dumps(), average
return table_output


def make_final_table(result_dict, model_name):
Expand Down
1 change: 1 addition & 0 deletions llm_autoeval/upload.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os

import requests


Expand Down
8 changes: 5 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import argparse
import json
import logging
import os
import argparse
import time

from llm_autoeval.table import make_table, make_final_table
from llm_autoeval.table import make_final_table, make_table
from llm_autoeval.upload import upload_to_github_gist

logging.basicConfig(level=logging.INFO)
Expand Down Expand Up @@ -76,18 +76,20 @@ def _get_result_dict(directory: str) -> dict:

def _make_lighteval_summary(directory: str, elapsed_time: float) -> str:
from lighteval.evaluator import make_results_table

result_dict = _get_result_dict(directory)
final_table = make_results_table(result_dict)
summary = f"## {MODEL_ID.split('/')[-1]} - {BENCHMARK.capitalize()}\n\n"
summary += final_table
return summary


def _make_eqbench_summary(directory: str, elapsed_time: float) -> str:
result_dict = _get_result_dict(directory)
summary = f"## {MODEL_ID.split('/')[-1]} - {BENCHMARK.capitalize()}\n\n"
return summary


def main(directory: str, elapsed_time: float) -> None:
# Tasks
if BENCHMARK == "openllm" or BENCHMARK == "nous":
Expand Down

0 comments on commit 53f10e5

Please sign in to comment.