{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Video Model Training" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### NOTES: \n", "* It's assumed that there's a pretrained generator from the ColorizeTrainingStable notebook available at the specified path.\n", "* This is \"NoGAN\" based training, described in the DeOldify readme." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#NOTE: This must be the first call in order to work properly!\n", "from deoldify import device\n", "from deoldify.device_id import DeviceId\n", "#choices: CPU, GPU0...GPU7\n", "device.set(device=DeviceId.GPU0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "import fastai\n", "from fastai import *\n", "from fastai.vision import *\n", "from fastai.callbacks.tensorboard import *\n", "from fastai.vision.gan import *\n", "from deoldify.generators import *\n", "from deoldify.critics import *\n", "from deoldify.dataset import *\n", "from deoldify.loss import *\n", "from deoldify.save import *\n", "from PIL import Image, ImageDraw, ImageFont\n", "from PIL import ImageFile" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = Path('data/imagenet/ILSVRC/Data/CLS-LOC')\n", "path_hr = path\n", "path_lr = path/'bandw'\n", "\n", "proj_id = 'VideoModel'\n", "gen_name = proj_id + '_gen'\n", "pre_gen_name = gen_name + '_0'\n", "crit_name = proj_id + '_crit'\n", "\n", "name_gen = proj_id + '_image_gen'\n", "path_gen = path/name_gen\n", "\n", "TENSORBOARD_PATH = Path('data/tensorboard/' + proj_id)\n", "\n", "nf_factor = 2\n", "xtra_tfms=[noisify(p=0.8)]\n", "pct_start = 1e-8" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_data(bs:int, sz:int, keep_pct:float):\n", " return get_colorize_data(sz=sz, bs=bs, crappy_path=path_lr, good_path=path_hr, \n", " random_seed=None, keep_pct=keep_pct, xtra_tfms=xtra_tfms)\n", "\n", "def get_crit_data(classes, bs, sz):\n", " src = ImageList.from_folder(path, include=classes, recurse=True).split_by_rand_pct(0.1, seed=42)\n", " ll = src.label_from_folder(classes=classes)\n", " data = (ll.transform(get_transforms(max_zoom=2.), size=sz)\n", " .databunch(bs=bs).normalize(imagenet_stats))\n", " return data\n", " \n", "def save_preds(dl):\n", " i=0\n", " names = dl.dataset.items\n", " \n", " for b in dl:\n", " preds = learn_gen.pred_batch(batch=b, reconstruct=True)\n", " for o in preds:\n", " o.save(path_gen/names[i].name)\n", " i += 1\n", " \n", "def save_gen_images():\n", " if path_gen.exists(): shutil.rmtree(path_gen)\n", " path_gen.mkdir(exist_ok=True)\n", " data_gen = get_data(bs=bs, sz=sz, keep_pct=0.085)\n", " save_preds(data_gen.fix_dl)\n", " PIL.Image.open(path_gen.ls()[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Finetune Generator With Noise Augmented Images." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "##### This helps the generator better deal with noisy/grainy video (which is pretty normal)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bs=8\n", "sz=192\n", "keep_pct=0.25" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data_gen = get_data(bs=bs, sz=sz, keep_pct=keep_pct)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen.callback_fns.append(partial(ImageGenTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GenPre'))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen = learn_gen.load(pre_gen_name, with_opt=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen.unfreeze()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen.fit_one_cycle(1, pct_start=pct_start, max_lr=slice(5e-8,5e-5))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen.save(pre_gen_name)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Repeatable GAN Cycle" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### NOTE\n", "Best results so far have been based only doing a single run of the cells below (otherwise glitches are introduced that are visible in video). " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "old_checkpoint_num = 0\n", "checkpoint_num = old_checkpoint_num + 1\n", "gen_old_checkpoint_name = gen_name + '_' + str(old_checkpoint_num)\n", "gen_new_checkpoint_name = gen_name + '_' + str(checkpoint_num)\n", "crit_old_checkpoint_name = crit_name + '_' + str(old_checkpoint_num)\n", "crit_new_checkpoint_name= crit_name + '_' + str(checkpoint_num)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Save Generated Images" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bs=8\n", "sz=192" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "save_gen_images()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Pretrain Critic" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bs=16\n", "sz=192" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen=None\n", "gc.collect()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_critic = colorize_crit_learner(data=data_crit, nf=256).load(crit_old_checkpoint_name, with_opt=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_critic.callback_fns.append(partial(LearnerTensorboardWriter, base_dir=TENSORBOARD_PATH, name='CriticPre'))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_critic.fit_one_cycle(4, 1e-4)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_critic.save(crit_new_checkpoint_name)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### GAN" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_crit=None\n", "learn_gen=None\n", "gc.collect()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "lr=5e-6\n", "sz=192\n", "bs=5" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_crit = colorize_crit_learner(data=data_crit, nf=256).load(crit_new_checkpoint_name, with_opt=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65)\n", "learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.0,1.5), show_img=False, switcher=switcher,\n", " opt_func=partial(optim.Adam, betas=(0.,0.9)), wd=1e-3)\n", "learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))\n", "learn.callback_fns.append(partial(GANTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GanLearner', visual_iters=100, stats_iters=10, loss_iters=1))\n", "learn.callback_fns.append(partial(GANSaveCallback, learn_gen=learn_gen, filename=gen_new_checkpoint_name, save_iters=100))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Instructions: \n", "Find the checkpoint just before where glitches start to be introduced. So far this has been found at the point of iterating through 1.4% of the data when using learning rate of 1e-5, and at 2.2% of the data for 5e-6." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.data = get_data(sz=sz, bs=bs, keep_pct=0.03)\n", "learn_gen.freeze_to(-1)\n", "learn.fit(1,lr)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" } }, "nbformat": 4, "nbformat_minor": 4 }