Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 25 additions & 15 deletions openml/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,9 +345,10 @@ def _download_data(self) -> None:
# import required here to avoid circular import.
from .functions import _get_dataset_arff, _get_dataset_parquet

self.data_file = str(_get_dataset_arff(self))
if self._parquet_url is not None:
self.parquet_file = str(_get_dataset_parquet(self))
if self.parquet_file is None:
self.data_file = str(_get_dataset_arff(self))

def _get_arff(self, format: str) -> dict: # noqa: A002
"""Read ARFF file and return decoded arff.
Expand Down Expand Up @@ -535,18 +536,7 @@ def _cache_compressed_file_from_file(
feather_attribute_file,
) = self._compressed_cache_file_paths(data_file)

if data_file.suffix == ".arff":
data, categorical, attribute_names = self._parse_data_from_arff(data_file)
elif data_file.suffix == ".pq":
try:
data = pd.read_parquet(data_file)
except Exception as e: # noqa: BLE001
raise Exception(f"File: {data_file}") from e

categorical = [data[c].dtype.name == "category" for c in data.columns]
attribute_names = list(data.columns)
else:
raise ValueError(f"Unknown file type for file '{data_file}'.")
attribute_names, categorical, data = self._parse_data_from_file(data_file)

# Feather format does not work for sparse datasets, so we use pickle for sparse datasets
if scipy.sparse.issparse(data):
Expand All @@ -572,6 +562,24 @@ def _cache_compressed_file_from_file(

return data, categorical, attribute_names

def _parse_data_from_file(self, data_file: Path) -> tuple[list[str], list[bool], pd.DataFrame]:
if data_file.suffix == ".arff":
data, categorical, attribute_names = self._parse_data_from_arff(data_file)
elif data_file.suffix == ".pq":
attribute_names, categorical, data = self._parse_data_from_pq(data_file)
else:
raise ValueError(f"Unknown file type for file '{data_file}'.")
return attribute_names, categorical, data

def _parse_data_from_pq(self, data_file: Path) -> tuple[list[str], list[bool], pd.DataFrame]:
try:
data = pd.read_parquet(data_file)
except Exception as e: # noqa: BLE001
raise Exception(f"File: {data_file}") from e
categorical = [data[c].dtype.name == "category" for c in data.columns]
attribute_names = list(data.columns)
return attribute_names, categorical, data

def _load_data(self) -> tuple[pd.DataFrame | scipy.sparse.csr_matrix, list[bool], list[str]]: # noqa: PLR0912, C901
"""Load data from compressed format or arff. Download data if not present on disk."""
need_to_create_pickle = self.cache_format == "pickle" and self.data_pickle_file is None
Expand Down Expand Up @@ -636,8 +644,10 @@ def _load_data(self) -> tuple[pd.DataFrame | scipy.sparse.csr_matrix, list[bool]
"Please manually delete the cache file if you want OpenML-Python "
"to attempt to reconstruct it.",
)
assert self.data_file is not None
data, categorical, attribute_names = self._parse_data_from_arff(Path(self.data_file))
file_to_load = self.data_file if self.parquet_file is None else self.parquet_file
assert file_to_load is not None
attr, cat, df = self._parse_data_from_file(Path(file_to_load))
return df, cat, attr

data_up_to_date = isinstance(data, pd.DataFrame) or scipy.sparse.issparse(data)
if self.cache_format == "pickle" and not data_up_to_date:
Expand Down
11 changes: 7 additions & 4 deletions openml/datasets/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ def get_datasets(


@openml.utils.thread_safe_if_oslo_installed
def get_dataset( # noqa: C901, PLR0912
def get_dataset( # noqa: C901, PLR0912, PLR0915
dataset_id: int | str,
download_data: bool | None = None, # Optional for deprecation warning; later again only bool
version: int | None = None,
Expand Down Expand Up @@ -589,7 +589,6 @@ def get_dataset( # noqa: C901, PLR0912
if download_qualities:
qualities_file = _get_dataset_qualities_file(did_cache_dir, dataset_id)

arff_file = _get_dataset_arff(description) if download_data else None
if "oml:parquet_url" in description and download_data:
try:
parquet_file = _get_dataset_parquet(
Expand All @@ -598,10 +597,14 @@ def get_dataset( # noqa: C901, PLR0912
)
except urllib3.exceptions.MaxRetryError:
parquet_file = None
if parquet_file is None and arff_file:
logger.warning("Failed to download parquet, fallback on ARFF.")
else:
parquet_file = None

arff_file = None
if parquet_file is None and download_data:
logger.warning("Failed to download parquet, fallback on ARFF.")
arff_file = _get_dataset_arff(description)

remove_dataset_cache = False
except OpenMLServerException as e:
# if there was an exception
Expand Down
1 change: 1 addition & 0 deletions tests/test_datasets/test_dataset_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1574,6 +1574,7 @@ def test_get_dataset_parquet(self):
assert dataset._parquet_url is not None
assert dataset.parquet_file is not None
assert os.path.isfile(dataset.parquet_file)
assert dataset.data_file is None # is alias for arff path

@pytest.mark.production()
def test_list_datasets_with_high_size_parameter(self):
Expand Down