Struggling with multiple programming languages? No worries. OurCode Converter has got you covered. Give it a go!
Satellite image classification is undoubtedly crucial for many applications in agriculture, environmental monitoring, urban planning, and more. Applications such as crop monitoring, land and forest cover mapping are emerging to be utilized by governments and companies, and labs for real-world use.
In this tutorial, you will learn how to build a satellite image classifier using the TensorFlow framework in Python.
We will be using theEuroSAT dataset based on Sentinel-2 satellite images covering 13 spectral bands. It consists of 27,000 labeled samples of 10 different classes: annual and permanent crop, forest, herbaceous vegetation, highway, industrial, pasture, residential, river, and sea lake.
EuroSAT dataset comes in two varieties:
rgb
(default) with RGB that contain only the R, G, B frequency bands encoded as JPEG images.all
: contains all 13 bands in the original value range.Related:Image Captioning using PyTorch and Transformers in Python.
To get started, let's install TensorFlow and some other helper tools:
$ pip install tensorflow tensorflow_addons tensorflow_datasets tensorflow_hub numpy matplotlib seaborn sklearn
We usetensorflow_addons
to calculate theF1 score during the training of the model.
We will use theEfficientNetV2 model which is the current state of the art on most image classification tasks. We usetensorflow_hub
to load this pre-trained CNN model for fine-tuning.
Importing the necessary libraries:
import osimport numpy as npimport matplotlib.pyplot as pltimport seaborn as snsimport tensorflow as tfimport tensorflow_datasets as tfdsimport tensorflow_hub as hubimport tensorflow_addons as tfa
Downloading and loading the dataset:
# load the whole dataset, for data infoall_ds = tfds.load("eurosat", with_info=True)# load training, testing & validation sets, splitting by 60%, 20% and 20% respectivelytrain_ds = tfds.load("eurosat", split="train[:60%]")test_ds = tfds.load("eurosat", split="train[60%:80%]")valid_ds = tfds.load("eurosat", split="train[80%:]")
We split our dataset into 60% training, 20% validation during training, and 20% for testing. The below code is responsible for setting some variables we use for later:
# the class namesclass_names = all_ds[1].features["label"].names# total number of classes (10)num_classes = len(class_names)num_examples = all_ds[1].splits["train"].num_examples
We grab the list of classes from theall_ds
dataset as it was loaded withwith_info
set toTrue
, we also get the number of samples from it.
Next, I'm going to make a bar plot to see the number of samples in each class:
# make a plot for number of samples on each classfig, ax = plt.subplots(1, 1, figsize=(14,10))labels, counts = np.unique(np.fromiter(all_ds[0]["train"].map(lambda x: x["label"]), np.int32), return_counts=True)plt.ylabel('Counts')plt.xlabel('Labels')sns.barplot(x = [class_names[l] for l in labels], y = counts, ax=ax) for i, x_ in enumerate(labels): ax.text(x_-0.2, counts[i]+5, counts[i])# set the titleax.set_title("Bar Plot showing Number of Samples on Each Class")# save the image# plt.savefig("class_samples.png")
Output:
3,000 samples on half of the classes, others have 2,500 samples, while pasture only 2,000 samples.
Now let's take our training and validation sets and prepare them before training:
def prepare_for_training(ds, cache=True, batch_size=64, shuffle_buffer_size=1000): if cache: if isinstance(cache, str): ds = ds.cache(cache) else: ds = ds.cache() ds = ds.map(lambda d: (d["image"], tf.one_hot(d["label"], num_classes))) # shuffle the dataset ds = ds.shuffle(buffer_size=shuffle_buffer_size) # Repeat forever ds = ds.repeat() # split to batches ds = ds.batch(batch_size) # `prefetch` lets the dataset fetch batches in the background while the model # is training. ds = ds.prefetch(buffer_size=tf.data.experimental.AUTOTUNE) return ds
Here is what this function does:
cache()
: This method saves the preprocessed dataset into a local cache file. This will only preprocess it the very first time (in the first epoch during training).map()
: We map our dataset so each sample will be a tuple of an image and its corresponding label one-hot encoded withtf.one_hot()
.shuffle()
: To shuffle the dataset so the samples are in random order.repeat()
Every time we iterate over the dataset, it'll repeatedly generate samples for us; this will help us during the training.batch()
: We batch our dataset into 64 or 32 samples per training step.prefetch()
: This will enable us to fetch batches in the background while the model is training.Let's run it for the training and validation sets:
batch_size = 64# preprocess training & validation setstrain_ds = prepare_for_training(train_ds, batch_size=batch_size)valid_ds = prepare_for_training(valid_ds, batch_size=batch_size)
Let's see what our data looks like:
# validating shapesfor el in valid_ds.take(1): print(el[0].shape, el[1].shape)for el in train_ds.take(1): print(el[0].shape, el[1].shape)
Output:
(64, 64, 64, 3) (64, 10)(64, 64, 64, 3) (64, 10)
Fantastic, both the training and validation have the same shape; where the batch size is 64, and the image shape is(64, 64, 3)
. The targets have the shape of(64, 10)
as it's 64 samples with 10 classes one-hot encoded.
Let's visualize the first batch from the training dataset:
# take the first batch of the training setbatch = next(iter(train_ds))
def show_batch(batch): plt.figure(figsize=(16, 16)) for n in range(min(32, batch_size)): ax = plt.subplot(batch_size//8, 8, n + 1) # show the image plt.imshow(batch[0][n]) # and put the corresponding label as title upper to the image plt.title(class_names[tf.argmax(batch[1][n].numpy())]) plt.axis('off') plt.savefig("sample-images.png")# showing a batch of images along with labelsshow_batch(batch)
Output:
Right. Now that we have our data prepared for training, let's build our model. First, downloading EfficientNetV2 and loading it as ahub.KerasLayer
:
model_url = "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet1k_l/feature_vector/2"# download & load the layer as a feature vectorkeras_layer = hub.KerasLayer(model_url, output_shape=[1280], trainable=True)
We set themodel_url
tohub.KerasLayer
so we get EfficientNetV2 as an image feature extractor. However, we settrainable
toTrue
so we're adjusting the pre-trained weights a bit for our dataset (i.e., fine-tuning).
Building the model:
m = tf.keras.Sequential([ keras_layer, tf.keras.layers.Dense(num_classes, activation="softmax")])# build the model with input image shape as (64, 64, 3)m.build([None, 64, 64, 3])m.compile( loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy", tfa.metrics.F1Score(num_classes)])
m.summary()
We useSequential()
, the first layer is the pre-trained CNN model, and we add a fully connected layer with the size of the number of classes as an output layer.
Finally, the model is built and compiled with categorical cross-entropy, adam optimizer, and accuracy and F1 score as metrics. Output:
Model: "sequential"_________________________________________________________________ Layer (type) Output Shape Param # ================================================================= keras_layer (KerasLayer) (None, 1280) 117746848 dense (Dense) (None, 10) 12810 =================================================================Total params: 117,759,658Trainable params: 117,247,082Non-trainable params: 512,576_________________________________________________________________
We have the data and model right, let's begin fine-tuning our model:
model_name = "satellite-classification"model_path = os.path.join("results", model_name + ".h5")model_checkpoint = tf.keras.callbacks.ModelCheckpoint(model_path, save_best_only=True, verbose=1)
# set the training & validation steps since we're using .repeat() on our dataset# number of training stepsn_training_steps = int(num_examples * 0.6) // batch_size# number of validation stepsn_validation_steps = int(num_examples * 0.2) // batch_size
# train the modelhistory = m.fit( train_ds, validation_data=valid_ds, steps_per_epoch=n_training_steps, validation_steps=n_validation_steps, verbose=1, epochs=5, callbacks=[model_checkpoint])
The training will take several minutes, depending on your GPU. Here is the output:
Epoch 1/5253/253 [==============================] - ETA: 0s - loss: 0.3780 - accuracy: 0.8859 - f1_score: 0.8832Epoch 00001: val_loss improved from inf to 0.16415, saving model to results/satellite-classification.h5253/253 [==============================] - 158s 438ms/step - loss: 0.3780 - accuracy: 0.8859 - f1_score: 0.8832 - val_loss: 0.1641 - val_accuracy: 0.9513 - val_f1_score: 0.9501Epoch 2/5253/253 [==============================] - ETA: 0s - loss: 0.1531 - accuracy: 0.9536 - f1_score: 0.9525Epoch 00002: val_loss improved from 0.16415 to 0.12853, saving model to results/satellite-classification.h5253/253 [==============================] - 106s 421ms/step - loss: 0.1531 - accuracy: 0.9536 - f1_score: 0.9525 - val_loss: 0.1285 - val_accuracy: 0.9568 - val_f1_score: 0.9559Epoch 3/5253/253 [==============================] - ETA: 0s - loss: 0.1092 - accuracy: 0.9660 - f1_score: 0.9654Epoch 00003: val_loss improved from 0.12853 to 0.12095, saving model to results/satellite-classification.h5253/253 [==============================] - 107s 424ms/step - loss: 0.1092 - accuracy: 0.9660 - f1_score: 0.9654 - val_loss: 0.1210 - val_accuracy: 0.9619 - val_f1_score: 0.9605Epoch 4/5253/253 [==============================] - ETA: 0s - loss: 0.1042 - accuracy: 0.9692 - f1_score: 0.9687Epoch 00004: val_loss did not improve from 0.12095253/253 [==============================] - 100s 394ms/step - loss: 0.1042 - accuracy: 0.9692 - f1_score: 0.9687 - val_loss: 0.1435 - val_accuracy: 0.9565 - val_f1_score: 0.9572Epoch 5/5253/253 [==============================] - ETA: 0s - loss: 0.1003 - accuracy: 0.9700 - f1_score: 0.9695Epoch 00005: val_loss improved from 0.12095 to 0.09841, saving model to results/satellite-classification.h5253/253 [==============================] - 107s 423ms/step - loss: 0.1003 - accuracy: 0.9700 - f1_score: 0.9695 - val_loss: 0.0984 - val_accuracy: 0.9702 - val_f1_score: 0.9687
As you can see, the model improved to about 97% accuracy on the validation set on epoch 5. You can increase the number of epochs to see whether it can improve further.
Up until now, we're only validating on the validation set during training. This section uses our model to predict satellite images that the model has never seen before. Loading the best weights:
# load the best weightsm.load_weights(model_path)
Extracting all the testingimages
andlabels
individually fromtest_ds
:
# number of testing stepsn_testing_steps = int(all_ds[1].splits["train"].num_examples * 0.2)# get all testing images as NumPy arrayimages = np.array([ d["image"] for d in test_ds.take(n_testing_steps) ])print("images.shape:", images.shape)# get all testing labels as NumPy arraylabels = np.array([ d["label"] for d in test_ds.take(n_testing_steps) ])print("labels.shape:", labels.shape)
Output:
images.shape: (5400, 64, 64, 3)labels.shape: (5400,)
As expected, 5,400images
andlabels
, let's use the model to predict these images and then compare thepredictions
with the truelabels
:
# feed the images to get predictionspredictions = m.predict(images)# perform argmax to get class indexpredictions = np.argmax(predictions, axis=1)print("predictions.shape:", predictions.shape)
Output:
predictions.shape: (5400,)
from sklearn.metrics import f1_scoreaccuracy = tf.keras.metrics.Accuracy()accuracy.update_state(labels, predictions)print("Accuracy:", accuracy.result().numpy())print("F1 Score:", f1_score(labels, predictions, average="macro"))
Output:
Accuracy: 0.9677778F1 Score: 0.9655686619720163
That's good accuracy! Let's draw the confusion matrix for all the classes:
# compute the confusion matrixcmn = tf.math.confusion_matrix(labels, predictions).numpy()# normalize the matrix to be in percentagescmn = cmn.astype('float') / cmn.sum(axis=0)[:, np.newaxis]# make a plot for the confusion matrixfig, ax = plt.subplots(figsize=(10,10))sns.heatmap(cmn, annot=True, fmt='.2f', xticklabels=[f"pred_{c}" for c in class_names], yticklabels=[f"true_{c}" for c in class_names], # cmap="Blues" cmap="rocket_r" )plt.ylabel('Actual')plt.xlabel('Predicted')# plot the resulting confusion matrixplt.savefig("confusion-matrix.png")# plt.show()
Output:
As you can see, the model is accurate in most of the classes, especially on forest images, as it achieved 100%. However, it's down to 91% for pasture, and the model sometimes predicts the pasture as permanent corp, also on herbaceous vegetation. Most of the confusion is between corp, pasture, and herbaceous vegetation as they all look similar and, most of the time, green from the satellite.
Let's show some examples that the model predicted:
def show_predicted_samples(): plt.figure(figsize=(14, 14)) for n in range(64): ax = plt.subplot(8, 8, n + 1) # show the image plt.imshow(images[n]) # and put the corresponding label as title upper to the image if predictions[n] == labels[n]: # correct prediction ax.set_title(class_names[predictions[n]], color="green") else: # wrong prediction ax.set_title(f"{class_names[predictions[n]]}/T:{class_names[labels[n]]}", color="red") plt.axis('off') plt.savefig("predicted-sample-images.png")# showing a batch of images along with predictions labelsshow_predicted_samples()
Output:
In all 64 images, only one (red label in the above image) failed to predict the actual class. It was predicted as a pasture where it should be a permanent crop.
Alright! That's it for the tutorial. If you want further improvement, I highly advise you to explore onTensorFlow hub, where you find the state-of-the-art pre-trained CNN models and feature extractors.
I also suggest you try out different optimizers and increase the number of epochs to see if you can improve it. You can useTensorBoard to track the accuracy of each change you make. Make sure you include the variables in the model name.
If you want more in-depth information, I encourage you to check theEuroSAT paper, where they achieved 98.57% accuracy with the 13 bands version of the dataset (1.93GB). You can also use this version of the dataset by passing"eurosat/all"
instead of standard"eurosat"
to thetfds.load()
method.
You can get the complete code of this tutorialhere.
Learn also:Skin Cancer Detection using TensorFlow in Python
Happy learning ♥
Liked what you read? You'll love what you can learn from ourAI-powered Code Explainer. Check it out!
View Full Code Convert My CodeGot a coding query or need some guidance before you comment? Check out thisPython Code Assistant for expert advice and handy tips. It's like having a coding tutor right in your fingertips!