diff --git a/src/database/setups.py b/src/database/setups.py new file mode 100644 index 00000000..59601ae4 --- /dev/null +++ b/src/database/setups.py @@ -0,0 +1,42 @@ +from sqlalchemy import Connection, text +from sqlalchemy.engine import Row + + +def get(setup_id: int, connection: Connection) -> Row | None: + row = connection.execute( + text( + """ + SELECT * + FROM algorithm_setup + WHERE sid = :setup_id + """, + ), + parameters={"setup_id": setup_id}, + ) + return row.first() + + +def get_tags(setup_id: int, connection: Connection) -> list[Row]: + rows = connection.execute( + text( + """ + SELECT * + FROM setup_tag + WHERE id = :setup_id + """, + ), + parameters={"setup_id": setup_id}, + ) + return list(rows.all()) + + +def untag(setup_id: int, tag: str, connection: Connection) -> None: + connection.execute( + text( + """ + DELETE FROM setup_tag + WHERE id = :setup_id AND tag = :tag + """, + ), + parameters={"setup_id": setup_id, "tag": tag}, + ) diff --git a/src/main.py b/src/main.py index 560b4c50..07e14510 100644 --- a/src/main.py +++ b/src/main.py @@ -11,6 +11,7 @@ from routers.openml.evaluations import router as evaluationmeasures_router from routers.openml.flows import router as flows_router from routers.openml.qualities import router as qualities_router +from routers.openml.setups import router as setup_router from routers.openml.study import router as study_router from routers.openml.tasks import router as task_router from routers.openml.tasktype import router as ttype_router @@ -55,6 +56,7 @@ def create_api() -> FastAPI: app.include_router(task_router) app.include_router(flows_router) app.include_router(study_router) + app.include_router(setup_router) return app diff --git a/src/routers/dependencies.py b/src/routers/dependencies.py index 2ddccf83..e4de107f 100644 --- a/src/routers/dependencies.py +++ b/src/routers/dependencies.py @@ -1,6 +1,7 @@ +from http import HTTPStatus from typing import Annotated -from fastapi import Depends +from fastapi import Depends, HTTPException from pydantic import BaseModel from sqlalchemy import Connection @@ -29,6 +30,17 @@ def fetch_user( return User.fetch(api_key, user_data) if api_key else None +def fetch_user_or_raise( + user: Annotated[User | None, Depends(fetch_user)] = None, +) -> User: + if user is None: + raise HTTPException( + status_code=HTTPStatus.PRECONDITION_FAILED, + detail={"code": "103", "message": "Authentication failed"}, + ) + return user + + class Pagination(BaseModel): offset: int = 0 limit: int = 100 diff --git a/src/routers/openml/datasets.py b/src/routers/openml/datasets.py index dda25117..2bc1548d 100644 --- a/src/routers/openml/datasets.py +++ b/src/routers/openml/datasets.py @@ -19,7 +19,13 @@ _format_parquet_url, ) from database.users import User, UserGroup -from routers.dependencies import Pagination, expdb_connection, fetch_user, userdb_connection +from routers.dependencies import ( + Pagination, + expdb_connection, + fetch_user, + fetch_user_or_raise, + userdb_connection, +) from routers.types import CasualString128, IntegerRange, SystemString64, integer_range_regex from schemas.datasets.openml import DatasetMetadata, DatasetStatus, Feature, FeatureType @@ -32,29 +38,19 @@ def tag_dataset( data_id: Annotated[int, Body()], tag: Annotated[str, SystemString64], - user: Annotated[User | None, Depends(fetch_user)] = None, + user: Annotated[User, Depends(fetch_user_or_raise)], expdb_db: Annotated[Connection, Depends(expdb_connection)] = None, ) -> dict[str, dict[str, Any]]: tags = database.datasets.get_tags_for(data_id, expdb_db) if tag.casefold() in [t.casefold() for t in tags]: raise create_tag_exists_error(data_id, tag) - if user is None: - raise create_authentication_failed_error() - database.datasets.tag(data_id, tag, user_id=user.user_id, connection=expdb_db) return { "data_tag": {"id": str(data_id), "tag": [*tags, tag]}, } -def create_authentication_failed_error() -> HTTPException: - return HTTPException( - status_code=HTTPStatus.PRECONDITION_FAILED, - detail={"code": "103", "message": "Authentication failed"}, - ) - - def create_tag_exists_error(data_id: int, tag: str) -> HTTPException: return HTTPException( status_code=HTTPStatus.INTERNAL_SERVER_ERROR, diff --git a/src/routers/openml/setups.py b/src/routers/openml/setups.py new file mode 100644 index 00000000..71304e38 --- /dev/null +++ b/src/routers/openml/setups.py @@ -0,0 +1,45 @@ +from http import HTTPStatus +from typing import Annotated + +from fastapi import APIRouter, Body, Depends, HTTPException +from sqlalchemy import Connection + +import database.setups +from database.users import User, UserGroup +from routers.dependencies import expdb_connection, fetch_user_or_raise +from routers.types import SystemString64 + +router = APIRouter(prefix="/setup", tags=["setup"]) + + +@router.post(path="/untag") +def untag_setup( + setup_id: Annotated[int, Body()], + tag: Annotated[str, SystemString64], + user: Annotated[User, Depends(fetch_user_or_raise)], + expdb_db: Annotated[Connection, Depends(expdb_connection)] = None, +) -> dict[str, dict[str, str]]: + if not database.setups.get(setup_id, expdb_db): + raise HTTPException( + status_code=HTTPStatus.PRECONDITION_FAILED, + detail={"code": "472", "message": "Entity not found."}, + ) + + setup_tags = database.setups.get_tags(setup_id, expdb_db) + matched_tag_row = next((t for t in setup_tags if t.tag.casefold() == tag.casefold()), None) + + if not matched_tag_row: + raise HTTPException( + status_code=HTTPStatus.PRECONDITION_FAILED, + detail={"code": "475", "message": "Tag not found."}, + ) + + if matched_tag_row.uploader != user.user_id and UserGroup.ADMIN not in user.groups: + raise HTTPException( + status_code=HTTPStatus.PRECONDITION_FAILED, + detail={"code": "476", "message": "Tag is not owned by you"}, + ) + + database.setups.untag(setup_id, matched_tag_row.tag, expdb_db) + + return {"setup_untag": {"id": str(setup_id)}} diff --git a/tests/routers/openml/migration/setups_migration_test.py b/tests/routers/openml/migration/setups_migration_test.py new file mode 100644 index 00000000..a744dc39 --- /dev/null +++ b/tests/routers/openml/migration/setups_migration_test.py @@ -0,0 +1,62 @@ +from http import HTTPStatus + +import httpx +import pytest +from starlette.testclient import TestClient + +from tests.users import ApiKey + + +@pytest.mark.parametrize( + "setup_id", + [1, 999999], + ids=["existing setup", "unknown setup"], +) +@pytest.mark.parametrize( + "api_key", + [ApiKey.ADMIN, ApiKey.SOME_USER, ApiKey.OWNER_USER], + ids=["Administrator", "regular user", "possible owner"], +) +@pytest.mark.parametrize( + "tag", + ["totally_new_tag_for_migration_testing"], +) +def test_setup_untag_response_is_identical( + setup_id: int, + tag: str, + api_key: str, + py_api: TestClient, + php_api: httpx.Client, +) -> None: + if setup_id == 1: + php_api.post( + "/setup/tag", + data={"api_key": ApiKey.SOME_USER, "tag": tag, "setup_id": setup_id}, + ) + + original = php_api.post( + "/setup/untag", + data={"api_key": api_key, "tag": tag, "setup_id": setup_id}, + ) + + if original.status_code == HTTPStatus.OK: + php_api.post( + "/setup/tag", + data={"api_key": ApiKey.SOME_USER, "tag": tag, "setup_id": setup_id}, + ) + + new = py_api.post( + f"/setup/untag?api_key={api_key}", + json={"setup_id": setup_id, "tag": tag}, + ) + + assert original.status_code == new.status_code + + if new.status_code != HTTPStatus.OK: + assert original.json()["error"] == new.json()["detail"] + return + + original_json = original.json() + new_json = new.json() + + assert original_json == new_json diff --git a/tests/routers/openml/setups_test.py b/tests/routers/openml/setups_test.py new file mode 100644 index 00000000..50b6b843 --- /dev/null +++ b/tests/routers/openml/setups_test.py @@ -0,0 +1,88 @@ +from collections.abc import Iterator +from http import HTTPStatus + +import pytest +from sqlalchemy import Connection, text +from starlette.testclient import TestClient + +from tests.users import ApiKey + + +@pytest.fixture +def mock_setup_tag(expdb_test: Connection) -> Iterator[None]: + expdb_test.execute( + text("DELETE FROM setup_tag WHERE id = 1 AND tag = 'test_unit_tag_123'"), + ) + expdb_test.execute( + text("INSERT INTO setup_tag (id, tag, uploader) VALUES (1, 'test_unit_tag_123', 2)") + ) + expdb_test.commit() + + yield + + expdb_test.execute( + text("DELETE FROM setup_tag WHERE id = 1 AND tag = 'test_unit_tag_123'"), + ) + expdb_test.commit() + + +def test_setup_untag_missing_auth(py_api: TestClient) -> None: + response = py_api.post("/setup/untag", json={"setup_id": 1, "tag": "test_tag"}) + assert response.status_code == HTTPStatus.PRECONDITION_FAILED + assert response.json()["detail"] == {"code": "103", "message": "Authentication failed"} + + +def test_setup_untag_unknown_setup(py_api: TestClient) -> None: + response = py_api.post( + f"/setup/untag?api_key={ApiKey.SOME_USER}", + json={"setup_id": 999999, "tag": "test_tag"}, + ) + assert response.status_code == HTTPStatus.PRECONDITION_FAILED + assert response.json()["detail"] == {"code": "472", "message": "Entity not found."} + + +def test_setup_untag_tag_not_found(py_api: TestClient) -> None: + response = py_api.post( + f"/setup/untag?api_key={ApiKey.SOME_USER}", + json={"setup_id": 1, "tag": "non_existent_tag_12345"}, + ) + assert response.status_code == HTTPStatus.PRECONDITION_FAILED + assert response.json()["detail"] == {"code": "475", "message": "Tag not found."} + + +@pytest.mark.mut +@pytest.mark.usefixtures("mock_setup_tag") +def test_setup_untag_not_owned_by_you(py_api: TestClient) -> None: + response = py_api.post( + f"/setup/untag?api_key={ApiKey.OWNER_USER}", + json={"setup_id": 1, "tag": "test_unit_tag_123"}, + ) + assert response.status_code == HTTPStatus.PRECONDITION_FAILED + assert response.json()["detail"] == {"code": "476", "message": "Tag is not owned by you"} + + +@pytest.mark.mut +@pytest.mark.parametrize( + "api_key", + [ApiKey.SOME_USER, ApiKey.ADMIN], + ids=["Owner", "Administrator"], +) +def test_setup_untag_success(api_key: str, py_api: TestClient, expdb_test: Connection) -> None: + expdb_test.execute(text("DELETE FROM setup_tag WHERE id = 1 AND tag = 'test_success_tag'")) + expdb_test.execute( + text("INSERT INTO setup_tag (id, tag, uploader) VALUES (1, 'test_success_tag', 2)") + ) + expdb_test.commit() + + response = py_api.post( + f"/setup/untag?api_key={api_key}", + json={"setup_id": 1, "tag": "test_success_tag"}, + ) + + assert response.status_code == HTTPStatus.OK + assert response.json() == {"setup_untag": {"id": "1"}} + + rows = expdb_test.execute( + text("SELECT * FROM setup_tag WHERE id = 1 AND tag = 'test_success_tag'") + ).all() + assert len(rows) == 0