diff --git a/openml/setups/setup.py b/openml/setups/setup.py index 170838138..3145e660f 100644 --- a/openml/setups/setup.py +++ b/openml/setups/setup.py @@ -1,15 +1,17 @@ # License: BSD 3-Clause from __future__ import annotations +from collections.abc import Sequence from dataclasses import asdict, dataclass from typing import Any import openml.config import openml.flows +from openml.base import OpenMLBase -@dataclass -class OpenMLSetup: +@dataclass(repr=False) +class OpenMLSetup(OpenMLBase): """Setup object (a.k.a. Configuration). Parameters @@ -36,7 +38,26 @@ def __post_init__(self) -> None: if self.parameters is not None and not isinstance(self.parameters, dict): raise ValueError("parameters should be dict") - def _to_dict(self) -> dict[str, Any]: + @property + def id(self) -> int | None: + """The id of the entity, it is unique for its entity type.""" + return self.setup_id + + def _get_repr_body_fields( + self, + ) -> Sequence[tuple[str, str | int | list[str] | None]]: + """Collect all information to display in the __repr__ body.""" + return [ + ("Setup ID", self.setup_id), + ("Flow ID", self.flow_id), + ("Flow URL", openml.flows.OpenMLFlow.url_for_id(self.flow_id)), + ( + "# of Parameters", + len(self.parameters) if self.parameters is not None else None, + ), + ] + + def _to_dict(self) -> dict[str, Any]: # type: ignore[override] return { "setup_id": self.setup_id, "flow_id": self.flow_id, @@ -45,27 +66,19 @@ def _to_dict(self) -> dict[str, Any]: else None, } - def __repr__(self) -> str: - header = "OpenML Setup" - header = f"{header}\n{'=' * len(header)}\n" - - fields = { - "Setup ID": self.setup_id, - "Flow ID": self.flow_id, - "Flow URL": openml.flows.OpenMLFlow.url_for_id(self.flow_id), - "# of Parameters": ( - len(self.parameters) if self.parameters is not None else float("nan") - ), - } - - # determines the order in which the information will be printed - order = ["Setup ID", "Flow ID", "Flow URL", "# of Parameters"] - _fields = [(key, fields[key]) for key in order if key in fields] - - longest_field_name_length = max(len(name) for name, _ in _fields) - field_line_format = f"{{:.<{longest_field_name_length}}}: {{}}" - body = "\n".join(field_line_format.format(name, value) for name, value in _fields) - return header + body + def _parse_publish_response(self, xml_response: dict[str, str]) -> None: + """Not supported for setups.""" + raise NotImplementedError( + "Setups cannot be published directly. " + "They are created automatically when a run is published." + ) + + def publish(self) -> OpenMLBase: + """Not supported for setups.""" + raise NotImplementedError( + "Setups cannot be published directly. " + "They are created automatically when a run is published." + ) @dataclass diff --git a/openml/utils/_openml.py b/openml/utils/_openml.py index f18dbe3e0..59c4fe2db 100644 --- a/openml/utils/_openml.py +++ b/openml/utils/_openml.py @@ -102,6 +102,7 @@ def _get_rest_api_type_alias(oml_object: OpenMLBase) -> str: (openml.tasks.OpenMLTask, "task"), (openml.runs.OpenMLRun, "run"), ((openml.study.OpenMLStudy, openml.study.OpenMLBenchmarkSuite), "study"), + (openml.setups.OpenMLSetup, "setup"), ] _, api_type_alias = next( (python_type, api_alias) diff --git a/tests/test_setups/test_setup.py b/tests/test_setups/test_setup.py new file mode 100644 index 000000000..8d4657730 --- /dev/null +++ b/tests/test_setups/test_setup.py @@ -0,0 +1,90 @@ +# License: BSD 3-Clause +from __future__ import annotations + +import random + +import pytest + +import openml +from openml.base import OpenMLBase +from openml.setups.setup import OpenMLSetup +from openml.testing import TestBase +from openml.utils import _tag_entity + + +class TestOpenMLSetup(TestBase): + """Tests for OpenMLSetup inheriting OpenMLBase and tagging support.""" + + def test_setup_is_openml_base(self): + """OpenMLSetup should be a subclass of OpenMLBase.""" + setup = OpenMLSetup(setup_id=1, flow_id=100, parameters=None) + assert isinstance(setup, OpenMLBase) + + def test_setup_id_property(self): + """The id property should return setup_id.""" + setup = OpenMLSetup(setup_id=42, flow_id=100, parameters=None) + assert setup.id == 42 + assert setup.id == setup.setup_id + + def test_setup_repr(self): + """The repr should use OpenMLBase format and contain expected fields.""" + setup = OpenMLSetup(setup_id=1, flow_id=100, parameters=None) + repr_str = repr(setup) + assert "OpenML Setup" in repr_str + assert "Setup ID" in repr_str + assert "Flow ID" in repr_str + + def test_setup_repr_with_parameters(self): + """The repr should show parameter count when parameters are present.""" + # Create a minimal mock parameter-like dict + setup = OpenMLSetup(setup_id=1, flow_id=100, parameters={1: "a", 2: "b"}) + repr_str = repr(setup) + assert "# of Parameters" in repr_str + + def test_setup_publish_raises(self): + """Calling publish() on a setup should raise NotImplementedError.""" + setup = OpenMLSetup(setup_id=1, flow_id=100, parameters=None) + with pytest.raises(NotImplementedError, match="Setups cannot be published"): + setup.publish() + + def test_setup_parse_publish_response_raises(self): + """Calling _parse_publish_response should raise NotImplementedError.""" + setup = OpenMLSetup(setup_id=1, flow_id=100, parameters=None) + with pytest.raises(NotImplementedError, match="Setups cannot be published"): + setup._parse_publish_response({}) + + def test_setup_openml_url(self): + """The openml_url property should return a valid URL.""" + setup = OpenMLSetup(setup_id=1, flow_id=100, parameters=None) + url = setup.openml_url + assert url is not None + assert "/s/1" in url + + def test_setup_validation(self): + """Existing validation in __post_init__ should still work.""" + with pytest.raises(ValueError, match="setup id should be int"): + OpenMLSetup(setup_id="not_an_int", flow_id=100, parameters=None) + + with pytest.raises(ValueError, match="flow id should be int"): + OpenMLSetup(setup_id=1, flow_id="not_an_int", parameters=None) + + with pytest.raises(ValueError, match="parameters should be dict"): + OpenMLSetup(setup_id=1, flow_id=100, parameters="not_a_dict") + + @pytest.mark.test_server() + def test_tag_untag_setup_via_entity(self): + """Test tagging and untagging a setup via _tag_entity.""" + # Setup ID 1 should exist on the test server + tag = "test_setup_tag_%d" % random.randint(1, 1_000_000) + all_tags = _tag_entity("setup", 1, tag) + assert tag in all_tags + all_tags = _tag_entity("setup", 1, tag, untag=True) + assert tag not in all_tags + + @pytest.mark.test_server() + def test_setup_push_tag_remove_tag(self): + """Test push_tag and remove_tag on an OpenMLSetup object.""" + setup = openml.setups.get_setup(1) + tag = "test_setup_tag_%d" % random.randint(1, 1_000_000) + setup.push_tag(tag) + setup.remove_tag(tag)