Generating rgb imagery from digital surface model using Pix2Pix

  • 🔬 Data Science
  • 🥠 Deep Learning and image translation

Introduction

In this notebook, we will focus on using Pix2Pix [1], which is one of the famous and sucessful deep learning models used for paired image-to-image translation. In geospatial sciences, this approach could help in wide range of applications traditionally not possible, where we may want to go from one domain of images to another.

The aim of this notebook is to make use of arcgis.learn Pix2Pix model to translate or convert the gray-scale DSM to a RGB imagery. For more details about model and its working refer How Pix2Pix works ? in guide section.

Necessary imports

import os, zipfile
from pathlib import Path
from os import listdir
from os.path import isfile, join

from arcgis import GIS
from arcgis.learn import Pix2Pix, prepare_data

Connect to your GIS

# gis = GIS('home')
ent_gis = GIS('https://pythonapi.playground.esri.com/portal', 'arcgis_python', 'amazing_arcgis_123')

Export image domain data

For this usecase, we have a high-resolution NAIP airborne imagery in the form of IR-G-B tiles and lidar data converted into DSM, collected over St. George, state of utah by state of utah and partners [5] with same spatial resolution of 0.5 m. We will export that using “Export_Tiles” metadata format available in the Export Training Data For Deep Learning tool. This tool is available in ArcGIS Pro as well as ArcGIS Image Server. The various inputs required by the tool, are described below.

  • Input Raster: DSM
  • Additional Input Raster: NAIP airborne imagery
  • Tile Size X & Tile Size Y: 256
  • Stride X & Stride Y: 128
  • Meta Data Format: 'Export_Tiles' as we are training a Pix2Pix model.
  • Environments: Set optimum Cell Size, Processing Extent.

Raster's used for exporting the training dataset are provided below

naip_domain_b_raster = ent_gis.content.get('319726e4cb4f4d69b471d65cc461b0a8')
naip_domain_b_raster
naip_train_area_domain_b
naip raster or domain bImagery Layer by api_data_owner
Last Modified: March 12, 2021
0 comments, 10 views
dsm_domain_a_raster = ent_gis.content.get('d8f21b09b7774a8f91f3152077eceffd')
dsm_domain_a_raster
dsm_train_area_domain_a
dsm raster or domain aImagery Layer by api_data_owner
Last Modified: January 08, 2021
0 comments, 7 views

Inside the exported data folder, 'Images' and 'Images2' folders contain all the image tiles from two domains exported from DSM and drone imagery respectively. Now we are ready to train the Pix2Pix model.

Model training

Alternatively, we have provided a subset of training data containing a few samples that follows the same directory structure mentioned above and also provided the rasters used for exporting the training dataset. You can use the data directly to run the experiments.

training_data = gis.content.get('2a3dad36569b48ed99858e8579611a80')
training_data
data_for_pix2pix_with_trained_model
data_for_pix2pix_with_trained_modelImage Collection by api_data_owner
Last Modified: January 08, 2021
0 comments, 0 views
filepath = training_data.download(file_name=training_data.name)
#Extract the data from the zipped image collection

with zipfile.ZipFile(filepath, 'r') as zip_ref:
    zip_ref.extractall(Path(filepath).parent)

Prepare data

output_path = Path(os.path.join(os.path.splitext(filepath)[0]))
data = prepare_data(output_path, dataset_type="Pix2Pix", batch_size=5)

Visualize a few samples from your training data

To get a sense of what the training data looks like, arcgis.learn.show_batch() method randomly picks a few training chips and visualize them. On the left are some DSM's (digital surface model) with the corresponding RGB imageries of various locations on the right.

data.show_batch()
<Figure size 864x432 with 4 Axes>

Load Pix2Pix model architecture

model = Pix2Pix(data)

Tuning for optimal learning rate

Learning rate is one of the most important hyperparameters in model training. ArcGIS API for Python provides a learning rate finder that automatically chooses the optimal learning rate for you.

lr = model.lr_find()
<Figure size 432x288 with 1 Axes>

2.5118864315095795e-05

Fit the model

The model is trained for around a few epochs with the suggested learning rate.

model.fit(30, lr)
epochtrain_lossvalid_lossgen_lossl1_lossD_losstime
013.20354713.9802550.5761100.1262740.41244701:01
112.67535313.8917870.5733630.1210200.41113101:02
212.83037713.6523390.5773340.1225300.41022401:00
312.82602813.4789500.5786730.1224740.41002801:01
412.83049613.4645010.5794460.1225100.40703401:01
512.97819013.8087770.5813290.1239690.40515501:01
612.93388714.1885250.5798170.1235410.40228001:01
712.66038313.2734590.5831290.1207730.39804101:01
812.49337813.2347050.5845130.1190890.39592801:02
912.70437314.3149360.5836710.1212070.39375501:01
1012.28365212.8727520.5861150.1169750.39149601:01
1112.00802512.9890320.5858510.1142220.38654201:02
1211.84821412.3562300.5867060.1126150.38512001:01
1311.64824812.3878240.5862940.1106200.38334501:01
1411.22064212.0512900.5863540.1063430.38074701:01
1511.06536311.8160180.5871540.1047820.37941701:01
1611.10709911.5793070.5878860.1051920.37714401:02
1710.68060311.5040060.5873070.1009330.37577901:03
1810.60440811.2342900.5873800.1001700.37391701:03
1910.45902111.1627760.5868170.0987220.37289201:05
2010.25144510.9444000.5876710.0966380.37193301:02
2110.17338210.9668410.5873220.0958610.37182101:01
229.94563410.7838340.5872470.0935840.37138701:01
239.68118210.7164440.5878640.0909330.36966801:01
249.87203910.6006160.5883030.0928370.36956301:00
259.78672010.6039120.5883640.0919840.36971501:02
269.68065810.5063520.5878780.0909280.36986301:02
279.38690410.5025960.5873280.0879960.36850201:01
289.83592310.5058370.5883240.0924760.36969601:01
299.63007110.4986540.5869290.0904310.36885601:00

Here, with 30 epochs, we can see reasonable results — both training and validation losses have gone down considerably, indicating that the model is learning to translate between domain of imageries.

Save the model

We will save the model which we trained as a 'Deep Learning Package' ('.dlpk' format). Deep Learning package is the standard format used to deploy deep learning models on the ArcGIS platform.

We will use the save() method to save the trained model. By default, it will be saved to the 'models' sub-folder within our training data folder.

model.save("pix2pix_model_e30", publish =True)

Visualize results in validation set

It is a good practice to see results of the model viz-a-viz ground truth. The code below picks random samples and shows us ground truth and model predictions, side by side. This enables us to preview the results of the model within the notebook.

model.show_results()
<Figure size 864x432 with 4 Axes>

Compute evaluation metrics

The Frechet Inception Distance score, or FID for short, is a metric that calculates the distance between feature vectors calculated for real and generated images. Lower scores indicate the two groups of images are more similar, or have more similar statistics, with a perfect score being 0.0 indicating that the two groups of images are identical.

model.compute_metrics()
263.63128885232044

Model inferencing

Inference on a single imagery chip

We can translate DSM to RGB imagery with the help of predict() method.

Using predict function, we can apply the trained model on the image chip kept for validation, which we want to translate.

  • img_path: path to the image file.
valid_data = gis.content.get('f682b16bcc6d40419a775ea2cad8f861')
valid_data
dsm raster chip for inferencing
dsm raster chip for inferencing Image by api_data_owner
Last Modified: January 08, 2021
0 comments, 8 views
filepath2 = valid_data.download(file_name=valid_data.name)
# Visualize the image chip used for inferencing 
from fastai.vision import open_image
open_image(filepath2)
Image (3, 256, 256)
#Inference single imagery chip
model.predict(filepath2)
<PIL.Image.Image image mode=RGB size=256x256 at 0x23E136F9308>

Generate raster using classify pixels using deep learning tool

After we trained the Pix2Pix model and saved the weights for translating image and we could use the classify pixels using deep learning tool avialable in both ArcGIS pro and ArcGIS Enterprise for inferencing at scale.

test_data = ent_gis.content.get('ee2729b5b0f845d291a9866696cdd33a')
test_data
dsm_test_area
Test area dsm for large scale inferencingImagery Layer by api_data_owner
Last Modified: January 22, 2021
0 comments, 0 views

out_classified_raster = arcpy.ia.ClassifyPixelsUsingDeepLearning("Imagery", r"C:\path\to\model.emd", "padding 64;batch_size 2"); out_classified_raster.save(r"C:\sample\sample.gdb\predicted_img2dsm")

Results visualization

The RGB output raster is generated using ArcGIS Pro. The output raster is published on the portal for visualization.

inferenced_results = ent_gis.content.get('4c0a3d149ece42559b06d82fdb204898')
inferenced_results
predicted_rgb_imagery
Inferenced rgb imagery Imagery Layer by api_data_owner
Last Modified: January 22, 2021
0 comments, 4 views

Create map widgets

Two map widgets are created showing DSM and Inferenced RGB raster.

map1 = ent_gis.map('Washington Fields', 13)
map1.add_layer(test_data)
map2 = ent_gis.map('Washington Fields', 13)
map2.add_layer(inferenced_results)

Synchronize web maps

The maps are synchronized with each other using MapView.sync_navigation functionality. It helps in comparing the inferenced results with the DSM. Detailed description about advanced map widget options can be referred here.

map2.sync_navigation(map1)

Set the map layout

from ipywidgets import HBox, VBox, Label, Layout

Hbox and Vbox were used to set the layout of map widgets.

hbox_layout = Layout()
hbox_layout.justify_content = 'space-around'

hb1=HBox([Label('DSM'),Label('RGB results')])
hb1.layout=hbox_layout

Results

The predictions are provided as a map for better visualization.

VBox([hb1,HBox([map1,map2])])
map2.zoom_to_layer(inferenced_results)

Conclusion

In this notebook, we demonstrated how to use Pix2Pix model using ArcGIS API for Python in order to translate imagery of one domain to the another domain.

References

  • [1]. Isola, Phillip, Jun-Yan Zhu, Tinghui Zhou, and Alexei A. Efros. "Image-to-image translation with conditional adversarial networks." In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 1125-1134. 2017.
  • [2]. Goodfellow, Ian, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, and Yoshua Bengio. "Generative adversarial nets." In Advances in neural information processing systems, pp. 2672-2680. 2014.
  • [3]. https://stephan-osterburg.gitbook.io/coding/coding/ml-dl/tensorfow/chapter-4-conditional-generative-adversarial-network/acgan-architectural-design
  • [4]. Kang, Yuhao, Song Gao, and Robert E. Roth. "Transferring multiscale map styles using generative adversarial networks." International Journal of Cartography 5, no. 2-3 (2019): 115-141.
  • [5]. State of Utah and Partners, 2019, Regional Utah high-resolution lidar data 2015 - 2017: Collected by Quantum Spatial, Inc., Digital Mapping, Inc., and Aero-Graphics, Inc. and distributed by OpenTopography, https://doi.org/10.5069/G9RV0KSQ. Accessed: 2020-12-08

Your browser is no longer supported. Please upgrade your browser for the best experience. See our browser deprecation post for more details.