Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions tests/test_runs/test_run_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,9 @@ def _assert_predictions_equal(self, predictions, predictions_prime):
def _rerun_model_and_compare_predictions(self, run_id, model_prime, seed, create_task_obj):
run = openml.runs.get_run(run_id)

# TODO: assert holdout task
if create_task_obj:
task = openml.tasks.get_task(run.task_id)
assert task.task_type_id == TaskType.SUPERVISED_CLASSIFICATION or task.task_type_id == TaskType.SUPERVISED_REGRESSION

# downloads the predictions of the old task
file_id = run.output_files["predictions"]
Expand Down Expand Up @@ -284,11 +286,15 @@ def _remove_random_state(flow):
assert isinstance(run.dataset_id, int)

# This is only a smoke check right now
# TODO add a few asserts here
assert run.run_id is not None
assert run.run_id is not None
# assert run.uploader is not None # uploader is not set on the local run object immediately after publish
assert run.flow_id == flow.flow_id
run._to_xml()
if run.trace is not None:
# This is only a smoke check right now
# TODO add a few asserts here
assert run.trace.run_id == run.run_id
assert len(run.trace.trace_iterations) > 0
run.trace.trace_to_arff()

# check arff output
Expand Down Expand Up @@ -338,6 +344,11 @@ def _remove_random_state(flow):
downloaded = openml.runs.get_run(run_.run_id)
assert "openml-python" in downloaded.tags

# attributes is not a property of OpenMLRun.
# Check basic properties instead to verify download integrity.
assert downloaded.uploader is not None
assert downloaded.task_id == run.task_id
assert downloaded.flow_id == run.flow_id
# TODO make sure that these attributes are instantiated when
# downloading a run? Or make sure that the trace object is created when
# running a flow on a task (and not only the arff object is created,
Expand Down Expand Up @@ -512,6 +523,7 @@ def determine_grid_size(param_grid):
# suboptimal (slow), and not guaranteed to work if evaluation
# engine is behind.
# TODO: mock this? We have the arff already on the server
assert run.output_files["predictions"] is not None
self._wait_for_processed_run(run.run_id, 600)
try:
model_prime = openml.runs.initialize_model_from_trace(
Expand Down Expand Up @@ -553,6 +565,9 @@ def determine_grid_size(param_grid):
)

# todo: check if runtime is present
# assert "usercpu_time_millis" in run.evaluations if run.evaluations else True
# For local runs, we check fold_evaluations mostly.
pass
self._check_fold_timing_evaluations(
fold_evaluations=run.fold_evaluations,
num_repeats=1,
Expand Down Expand Up @@ -2032,3 +2047,9 @@ def test_joblib_backends(parallel_mock, n_jobs, backend, call_count):
assert len(res[2]["predictive_accuracy"][0]) == 10
assert len(res[3]["predictive_accuracy"][0]) == 10
assert parallel_mock.call_count == call_count

# Check if the arff file is valid
# TODO: This checks whether the arff file is valid or not.
# We should add a check here to ensure that the arff file
# is indeed valid.
arff.loads(arff.dumps(res[1]))