Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 37 additions & 24 deletions openml/setups/setup.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
# License: BSD 3-Clause
from __future__ import annotations

from collections.abc import Sequence
from dataclasses import asdict, dataclass
from typing import Any

import openml.config
import openml.flows
from openml.base import OpenMLBase


@dataclass
class OpenMLSetup:
@dataclass(repr=False)
class OpenMLSetup(OpenMLBase):
"""Setup object (a.k.a. Configuration).

Parameters
Expand All @@ -36,7 +38,26 @@ def __post_init__(self) -> None:
if self.parameters is not None and not isinstance(self.parameters, dict):
raise ValueError("parameters should be dict")

def _to_dict(self) -> dict[str, Any]:
@property
def id(self) -> int | None:
"""The id of the entity, it is unique for its entity type."""
return self.setup_id

def _get_repr_body_fields(
self,
) -> Sequence[tuple[str, str | int | list[str] | None]]:
"""Collect all information to display in the __repr__ body."""
return [
("Setup ID", self.setup_id),
("Flow ID", self.flow_id),
("Flow URL", openml.flows.OpenMLFlow.url_for_id(self.flow_id)),
(
"# of Parameters",
len(self.parameters) if self.parameters is not None else None,
),
]

def _to_dict(self) -> dict[str, Any]: # type: ignore[override]
return {
"setup_id": self.setup_id,
"flow_id": self.flow_id,
Expand All @@ -45,27 +66,19 @@ def _to_dict(self) -> dict[str, Any]:
else None,
}

def __repr__(self) -> str:
header = "OpenML Setup"
header = f"{header}\n{'=' * len(header)}\n"

fields = {
"Setup ID": self.setup_id,
"Flow ID": self.flow_id,
"Flow URL": openml.flows.OpenMLFlow.url_for_id(self.flow_id),
"# of Parameters": (
len(self.parameters) if self.parameters is not None else float("nan")
),
}

# determines the order in which the information will be printed
order = ["Setup ID", "Flow ID", "Flow URL", "# of Parameters"]
_fields = [(key, fields[key]) for key in order if key in fields]

longest_field_name_length = max(len(name) for name, _ in _fields)
field_line_format = f"{{:.<{longest_field_name_length}}}: {{}}"
body = "\n".join(field_line_format.format(name, value) for name, value in _fields)
return header + body
def _parse_publish_response(self, xml_response: dict[str, str]) -> None:
"""Not supported for setups."""
raise NotImplementedError(
"Setups cannot be published directly. "
"They are created automatically when a run is published."
)

def publish(self) -> OpenMLBase:
"""Not supported for setups."""
raise NotImplementedError(
"Setups cannot be published directly. "
"They are created automatically when a run is published."
)


@dataclass
Expand Down
1 change: 1 addition & 0 deletions openml/utils/_openml.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def _get_rest_api_type_alias(oml_object: OpenMLBase) -> str:
(openml.tasks.OpenMLTask, "task"),
(openml.runs.OpenMLRun, "run"),
((openml.study.OpenMLStudy, openml.study.OpenMLBenchmarkSuite), "study"),
(openml.setups.OpenMLSetup, "setup"),
]
_, api_type_alias = next(
(python_type, api_alias)
Expand Down
90 changes: 90 additions & 0 deletions tests/test_setups/test_setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# License: BSD 3-Clause
from __future__ import annotations

import random

import pytest

import openml
from openml.base import OpenMLBase
from openml.setups.setup import OpenMLSetup
from openml.testing import TestBase
from openml.utils import _tag_entity


class TestOpenMLSetup(TestBase):
"""Tests for OpenMLSetup inheriting OpenMLBase and tagging support."""

def test_setup_is_openml_base(self):
"""OpenMLSetup should be a subclass of OpenMLBase."""
setup = OpenMLSetup(setup_id=1, flow_id=100, parameters=None)
assert isinstance(setup, OpenMLBase)

def test_setup_id_property(self):
"""The id property should return setup_id."""
setup = OpenMLSetup(setup_id=42, flow_id=100, parameters=None)
assert setup.id == 42
assert setup.id == setup.setup_id

def test_setup_repr(self):
"""The repr should use OpenMLBase format and contain expected fields."""
setup = OpenMLSetup(setup_id=1, flow_id=100, parameters=None)
repr_str = repr(setup)
assert "OpenML Setup" in repr_str
assert "Setup ID" in repr_str
assert "Flow ID" in repr_str

def test_setup_repr_with_parameters(self):
"""The repr should show parameter count when parameters are present."""
# Create a minimal mock parameter-like dict
setup = OpenMLSetup(setup_id=1, flow_id=100, parameters={1: "a", 2: "b"})
repr_str = repr(setup)
assert "# of Parameters" in repr_str

def test_setup_publish_raises(self):
"""Calling publish() on a setup should raise NotImplementedError."""
setup = OpenMLSetup(setup_id=1, flow_id=100, parameters=None)
with pytest.raises(NotImplementedError, match="Setups cannot be published"):
setup.publish()

def test_setup_parse_publish_response_raises(self):
"""Calling _parse_publish_response should raise NotImplementedError."""
setup = OpenMLSetup(setup_id=1, flow_id=100, parameters=None)
with pytest.raises(NotImplementedError, match="Setups cannot be published"):
setup._parse_publish_response({})

def test_setup_openml_url(self):
"""The openml_url property should return a valid URL."""
setup = OpenMLSetup(setup_id=1, flow_id=100, parameters=None)
url = setup.openml_url
assert url is not None
assert "/s/1" in url

def test_setup_validation(self):
"""Existing validation in __post_init__ should still work."""
with pytest.raises(ValueError, match="setup id should be int"):
OpenMLSetup(setup_id="not_an_int", flow_id=100, parameters=None)

with pytest.raises(ValueError, match="flow id should be int"):
OpenMLSetup(setup_id=1, flow_id="not_an_int", parameters=None)

with pytest.raises(ValueError, match="parameters should be dict"):
OpenMLSetup(setup_id=1, flow_id=100, parameters="not_a_dict")

@pytest.mark.test_server()
def test_tag_untag_setup_via_entity(self):
"""Test tagging and untagging a setup via _tag_entity."""
# Setup ID 1 should exist on the test server
tag = "test_setup_tag_%d" % random.randint(1, 1_000_000)
all_tags = _tag_entity("setup", 1, tag)
assert tag in all_tags
all_tags = _tag_entity("setup", 1, tag, untag=True)
assert tag not in all_tags

@pytest.mark.test_server()
def test_setup_push_tag_remove_tag(self):
"""Test push_tag and remove_tag on an OpenMLSetup object."""
setup = openml.setups.get_setup(1)
tag = "test_setup_tag_%d" % random.randint(1, 1_000_000)
setup.push_tag(tag)
setup.remove_tag(tag)