This repository is the official implementation of Effective Dataset Distillation for Spatio-Temporal Forecasting with Bi-dimensional Compression, Taehyung Kwon*, Yeonje Choi*, Yeongho Kim, and Kijung Shin, ICDE 2026 (to appear).
Please see the requirements.txt
fast_pytorch_kmeans==0.2.2
matplotlib==3.10.7
numpy==2.3.4
scikit_learn==1.7.2
scipy==1.16.3
torch==2.2.0
tqdm==4.65.0
Please download and check the datasets below for more details.
- There are three npz files (
train.npz, val.npz, test.npz) per dataset. - Each file contains two arrays,
xandy.xis an array of input time series, andyis an array of target time series.
The distillation process of STemDist is implemented in stemdist.py.
-de,--device: GPU id for execution.-d,--data: Location of the dataset folder.-b,--batch_size: Batch size for the distillation process.-lrs,--lr_syn: Learning rate for the surrogate model, which is trained on the synthetic dataset.-lrf,--lr_feat: Learning rate for the synthetic dataset.-nrr,--node_reduce_rate: Compression ratio for the spatial dimension.-srr,--series_reduce_rate: Compression ratio for the temporal dimension.-e,--epoch: Number of outer iterations.-ned,--ne_dim: Hidden dimension of the location embedding model.-s,--seed: Seed of execution.-sp,--save_path: Path for saving the result files.-c,--check_freq: Period in outer iterations for checking the performance of the distilled dataset.
python -m stemdist -de 0 -d ../data/GBA -e 100 -sp results/stemdist_gba_1e-3_1e-3 -lrf 1e-3 -lrs 1e-3 -srr 0.1 -nrr 0.1 -b 256 -ned 32 -s 0 -c 5
stemdist_gba_1e-3_1e-3.pt: Saves the distilled dataset.stemdist_gba_1e-3_1e-3.txt: Saves the performance of distilled datasets for every 'check_freq' outer iteration.
Checking the performance of the distilled dataset is implemented in model/load_mtgnn_stemdist.py.
-de, -d, -s, -ned, -sp: Same with the cases of runningstemdist.py.-b,--batch_size: Batch size for the validation of the trained model.-lr,--lr: Learning rate for training the model.-e,--epochs: Number of training epochs for the model.-c,--check: Period in epochs for checking the performance of the trained model.-lp,--load_path: Path which saves the distilled dataset.-a,--ae: Compute error in Relative MSE when given.
python -m model.load_mtgnn_stemdist -de 0 -d ../data/GBA -lr 0.01 -e 400 -b 128 -lp results/stemdist_gba_1e-3_1e-3.pt -s 0
| Name | M (# time series) | N (# locations) | F (# features) | # Total data points | Source | Downlaod Link |
|---|---|---|---|---|---|---|
| GBA | 1,997 | 2,352 | 1 | 4,649,904 | PatchSTG | Link |
| GLA | 1,997 | 3,834 | 1 | 7,579,818 | PatchSTG | Link |
| ERA5 | 2,137 | 6,561 | 6 | 14,020,857 | Climate Data Store | Link |
| CAMS | 2,556 | 7,070 | 6 | 108,425,520 | ECMWF | Link |
| CA | 1,997 | 8,600 | 1 | 17,002,200 | PatchSTG | Link |