forked from jantic/DeOldify
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
48 lines (41 loc) · 1.26 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import fastai
from fastai import *
from fastai.core import *
from fastai.vision.transform import get_transforms
from fastai.vision.data import ImageImageList, ImageDataBunch, imagenet_stats
from .augs import noisify
def get_colorize_data(
sz: int,
bs: int,
crappy_path: Path,
good_path: Path,
random_seed: int = None,
keep_pct: float = 1.0,
num_workers: int = 8,
stats: tuple = imagenet_stats,
xtra_tfms=[],
) -> ImageDataBunch:
src = (
ImageImageList.from_folder(crappy_path, convert_mode='RGB')
.use_partial_data(sample_pct=keep_pct, seed=random_seed)
.split_by_rand_pct(0.1, seed=random_seed)
)
data = (
src.label_from_func(lambda x: good_path / x.relative_to(crappy_path))
.transform(
get_transforms(
max_zoom=1.2, max_lighting=0.5, max_warp=0.25, xtra_tfms=xtra_tfms
),
size=sz,
tfm_y=True,
)
.databunch(bs=bs, num_workers=num_workers, no_check=True)
.normalize(stats, do_y=True)
)
data.c = 3
return data
def get_dummy_databunch() -> ImageDataBunch:
path = Path('./dummy/')
return get_colorize_data(
sz=1, bs=1, crappy_path=path, good_path=path, keep_pct=0.001
)