diff --git a/.gitignore b/.gitignore index f232dc1..9865364 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,133 @@ ckt_logs backup .DS_Store + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don’t work, or not +# install all needed dependencies. +#Pipfile.lock + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +.vscode/* +!.vscode/settings.json +!.vscode/tasks.json +!.vscode/launch.json +!.vscode/extensions.json diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..66624b1 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "python.pythonPath": ".env/bin/python2.7" +} \ No newline at end of file diff --git a/README.md b/README.md index 4800bd2..37438b9 100755 --- a/README.md +++ b/README.md @@ -42,8 +42,8 @@ In addition, please add the project folder to PYTHONPATH and `pip install` the f **Training** - The steps to train a StackGAN model on the CUB dataset using our preprocessed data for birds. - - Step 1: train Stage-I GAN (e.g., for 600 epochs) `python stageI/run_exp.py --cfg stageI/cfg/birds.yml --gpu 0` - - Step 2: train Stage-II GAN (e.g., for another 600 epochs) `python stageII/run_exp.py --cfg stageII/cfg/birds.yml --gpu 1` + - Step 1: train Stage-I GAN (e.g., for 600 epochs) `python -m stageI/run_exp --cfg stageI/cfg/birds.yml --gpu 0` + - Step 2: train Stage-II GAN (e.g., for another 600 epochs) `python -m stageII/run_exp --cfg stageII/cfg/birds.yml --gpu 1` - Change `birds.yml` to `flowers.yml` to train a StackGAN model on Oxford-102 dataset using our preprocessed data for flowers. - `*.yml` files are example configuration files for training/testing our models. - If you want to try your own datasets, [here](https://github.com/soumith/ganhacks) are some good tips about how to train GAN. Also, we encourage to try different hyper-parameters and architectures, especially for more complex datasets. @@ -100,3 +100,6 @@ booktitle = {{ICCV}}, - Generative Adversarial Text-to-Image Synthesis [Paper](https://arxiv.org/abs/1605.05396) [Code](https://github.com/reedscot/icml2016) - Learning Deep Representations of Fine-grained Visual Descriptions [Paper](https://arxiv.org/abs/1605.05395) [Code](https://github.com/reedscot/cvpr2016) + +**Docker** +nvidia-docker run -it -v $(pwd):/notebooks -v ~:/root --name tf-gpu-0.12 tensorflow/tensorflow:0.12.0-gpu bash \ No newline at end of file diff --git a/misc/datasets.py b/misc/datasets.py index 624e6af..31ecfa0 100644 --- a/misc/datasets.py +++ b/misc/datasets.py @@ -75,7 +75,7 @@ def transform(self, images): h1 = np.floor((ori_size - self._imsize) * np.random.random()) w1 = np.floor((ori_size - self._imsize) * np.random.random()) cropped_image =\ - images[i][w1: w1 + self._imsize, h1: h1 + self._imsize, :] + images[i][int(w1): int(w1) + self._imsize, int(h1): int(h1) + self._imsize, :] if random.random() > 0.5: transformed_images[i] = np.fliplr(cropped_image) else: diff --git a/misc/preprocess_birds.py b/misc/preprocess_birds.py index c93736a..a66f55a 100644 --- a/misc/preprocess_birds.py +++ b/misc/preprocess_birds.py @@ -7,7 +7,7 @@ import numpy as np import os import pickle -from misc.utils import get_image +from utils import get_image import scipy.misc import pandas as pd diff --git a/misc/utils.py b/misc/utils.py index 961954d..52823db 100644 --- a/misc/utils.py +++ b/misc/utils.py @@ -6,11 +6,10 @@ from __future__ import print_function import numpy as np -import scipy.misc +import scipy import os import errno - def get_image(image_path, image_size, is_crop=False, bbox=None): global index out = transform(imread(image_path), image_size, is_crop, bbox) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..7cde56f --- /dev/null +++ b/requirements.txt @@ -0,0 +1,66 @@ +backports-abc==0.5 +backports.shutil-get-terminal-size==1.0.0 +backports.ssl-match-hostname==3.5.0.1 +bleach==1.5.0 +certifi==2016.9.26 +configparser==3.5.0 +cycler==0.10.0 +decorator==4.0.10 +easydict==1.9 +entrypoints==0.2.2 +enum34==1.1.6 +funcsigs==1.0.2 +functools32==3.2.3.post2 +html5lib==0.9999999 +ipykernel==4.5.2 +ipython==5.1.0 +ipython-genutils==0.1.0 +ipywidgets==5.2.2 +Jinja2==2.8 +jsonschema==2.5.1 +jupyter==1.0.0 +jupyter-client==4.4.0 +jupyter-console==5.0.0 +jupyter-core==4.2.1 +MarkupSafe==0.23 +matplotlib==1.5.3 +mistune==0.7.3 +mock==2.0.0 +nbconvert==5.0.0 +nbformat==4.2.0 +notebook==4.3.0 +numpy==1.16.2 +pandas==0.24.2 +pandocfilters==1.4.1 +pathlib2==2.1.0 +pbr==5.1.3 +pexpect==4.2.1 +pickleshare==0.7.4 +Pillow==6.0.0 +prettytensor==0.7.2 +progressbar==2.5 +prompt-toolkit==1.0.9 +protobuf==3.1.0 +ptyprocess==0.5.1 +Pygments==2.1.3 +pyparsing==2.1.10 +python-dateutil==2.8.0 +pytz==2019.1 +PyYAML==5.1 +pyzmq==16.0.2 +qtconsole==4.2.1 +scikit-learn==0.18.1 +scipy==1.2.1 +simplegeneric==0.8.1 +singledispatch==3.4.0.3 +six==1.12.0 +sklearn==0.0 +tensorflow-gpu==0.12.1 +terminado==0.6 +testpath==0.3 +torchfile==0.1.0 +tornado==4.4.2 +traitlets==4.3.1 +virtualenv==16.4.3 +wcwidth==0.1.7 +widgetsnbextension==1.2.6 diff --git a/stageI/trainer.py b/stageI/trainer.py index 001666a..950fbfc 100644 --- a/stageI/trainer.py +++ b/stageI/trainer.py @@ -301,6 +301,8 @@ def build_model(self, sess): def train(self): config = tf.ConfigProto(allow_soft_placement=True) + config.gpu_options.allow_growth=True + # config.gpu_options.per_process_gpu_memory_fraction = 0.5 with tf.Session(config=config) as sess: with tf.device("/gpu:%d" % cfg.GPU_ID): counter = self.build_model(sess) @@ -437,6 +439,7 @@ def eval_one_dataset(self, sess, dataset, save_dir, subset='train'): def evaluate(self): config = tf.ConfigProto(allow_soft_placement=True) + config.gpu_options.per_process_gpu_memory_fraction = 0.3 with tf.Session(config=config) as sess: with tf.device("/gpu:%d" % cfg.GPU_ID): if self.model_path.find('.ckpt') != -1: