Movatterモバイル変換


[0]ホーム

URL:


Satellite Image Classification using TensorFlow in Python

Learn how to fine-tune the current state-of-the-art EffecientNet V2 model to perform image classification on satellite data (EuroSAT) using TensorFlow in Python.
  · 10 min read · Updated may 2024 ·Machine Learning ·Computer Vision

Struggling with multiple programming languages? No worries. OurCode Converter has got you covered. Give it a go!

Open In Colab

Satellite image classification is undoubtedly crucial for many applications in agriculture, environmental monitoring, urban planning, and more. Applications such as crop monitoring, land and forest cover mapping are emerging to be utilized by governments and companies, and labs for real-world use.

In this tutorial, you will learn how to build a satellite image classifier using the TensorFlow framework in Python.

We will be using theEuroSAT dataset based on Sentinel-2 satellite images covering 13 spectral bands. It consists of 27,000 labeled samples of 10 different classes: annual and permanent crop, forest, herbaceous vegetation, highway, industrial, pasture, residential, river, and sea lake.

EuroSAT dataset comes in two varieties:

  • rgb (default) with RGB that contain only the R, G, B frequency bands encoded as JPEG images.
  • all: contains all 13 bands in the original value range.

Related:Image Captioning using PyTorch and Transformers in Python.

Getting Started

To get started, let's install TensorFlow and some other helper tools:

$ pip install tensorflow tensorflow_addons tensorflow_datasets tensorflow_hub numpy matplotlib seaborn sklearn

We usetensorflow_addons to calculate theF1 score during the training of the model.

We will use theEfficientNetV2 model which is the current state of the art on most image classification tasks. We usetensorflow_hub to load this pre-trained CNN model for fine-tuning.

Preparing the Dataset

Importing the necessary libraries:

import osimport numpy as npimport matplotlib.pyplot as pltimport seaborn as snsimport tensorflow as tfimport tensorflow_datasets as tfdsimport tensorflow_hub as hubimport tensorflow_addons as tfa

Downloading and loading the dataset:

# load the whole dataset, for data infoall_ds   = tfds.load("eurosat", with_info=True)# load training, testing & validation sets, splitting by 60%, 20% and 20% respectivelytrain_ds = tfds.load("eurosat", split="train[:60%]")test_ds  = tfds.load("eurosat", split="train[60%:80%]")valid_ds = tfds.load("eurosat", split="train[80%:]")

We split our dataset into 60% training, 20% validation during training, and 20% for testing. The below code is responsible for setting some variables we use for later:

# the class namesclass_names = all_ds[1].features["label"].names# total number of classes (10)num_classes = len(class_names)num_examples = all_ds[1].splits["train"].num_examples

We grab the list of classes from theall_ds dataset as it was loaded withwith_info set toTrue, we also get the number of samples from it.

Next, I'm going to make a bar plot to see the number of samples in each class:

# make a plot for number of samples on each classfig, ax = plt.subplots(1, 1, figsize=(14,10))labels, counts = np.unique(np.fromiter(all_ds[0]["train"].map(lambda x: x["label"]), np.int32),                        return_counts=True)plt.ylabel('Counts')plt.xlabel('Labels')sns.barplot(x = [class_names[l] for l in labels], y = counts, ax=ax) for i, x_ in enumerate(labels):  ax.text(x_-0.2, counts[i]+5, counts[i])# set the titleax.set_title("Bar Plot showing Number of Samples on Each Class")# save the image# plt.savefig("class_samples.png")

Output:

Bar plot showing the number of samples on each class in the EuroSAT dataset

3,000 samples on half of the classes, others have 2,500 samples, while pasture only 2,000 samples. 

Now let's take our training and validation sets and prepare them before training:

def prepare_for_training(ds, cache=True, batch_size=64, shuffle_buffer_size=1000):  if cache:    if isinstance(cache, str):      ds = ds.cache(cache)    else:      ds = ds.cache()  ds = ds.map(lambda d: (d["image"], tf.one_hot(d["label"], num_classes)))  # shuffle the dataset  ds = ds.shuffle(buffer_size=shuffle_buffer_size)  # Repeat forever  ds = ds.repeat()  # split to batches  ds = ds.batch(batch_size)  # `prefetch` lets the dataset fetch batches in the background while the model  # is training.  ds = ds.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)  return ds

Here is what this function does:

  • cache(): This method saves the preprocessed dataset into a local cache file. This will only preprocess it the very first time (in the first epoch during training).
  • map(): We map our dataset so each sample will be a tuple of an image and its corresponding label one-hot encoded withtf.one_hot().
  • shuffle(): To shuffle the dataset so the samples are in random order.
  • repeat()Every time we iterate over the dataset, it'll repeatedly generate samples for us; this will help us during the training.
  • batch(): We batch our dataset into 64 or 32 samples per training step.
  • prefetch(): This will enable us to fetch batches in the background while the model is training.

Let's run it for the training and validation sets:

batch_size = 64# preprocess training & validation setstrain_ds = prepare_for_training(train_ds, batch_size=batch_size)valid_ds = prepare_for_training(valid_ds, batch_size=batch_size)

Let's see what our data looks like:

# validating shapesfor el in valid_ds.take(1):  print(el[0].shape, el[1].shape)for el in train_ds.take(1):  print(el[0].shape, el[1].shape)

Output:

(64, 64, 64, 3) (64, 10)(64, 64, 64, 3) (64, 10)

Fantastic, both the training and validation have the same shape; where the batch size is 64, and the image shape is(64, 64, 3). The targets have the shape of(64, 10) as it's 64 samples with 10 classes one-hot encoded.

Let's visualize the first batch from the training dataset:

# take the first batch of the training setbatch = next(iter(train_ds))
def show_batch(batch):  plt.figure(figsize=(16, 16))  for n in range(min(32, batch_size)):      ax = plt.subplot(batch_size//8, 8, n + 1)      # show the image      plt.imshow(batch[0][n])      # and put the corresponding label as title upper to the image      plt.title(class_names[tf.argmax(batch[1][n].numpy())])      plt.axis('off')      plt.savefig("sample-images.png")# showing a batch of images along with labelsshow_batch(batch)

Output:

Sample images

Building the Model

Right. Now that we have our data prepared for training, let's build our model. First, downloading EfficientNetV2 and loading it as ahub.KerasLayer:

model_url = "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet1k_l/feature_vector/2"# download & load the layer as a feature vectorkeras_layer = hub.KerasLayer(model_url, output_shape=[1280], trainable=True)

We set themodel_url tohub.KerasLayer so we get EfficientNetV2 as an image feature extractor. However, we settrainable toTrue so we're adjusting the pre-trained weights a bit for our dataset (i.e., fine-tuning).

Building the model:

m = tf.keras.Sequential([  keras_layer,  tf.keras.layers.Dense(num_classes, activation="softmax")])# build the model with input image shape as (64, 64, 3)m.build([None, 64, 64, 3])m.compile(    loss="categorical_crossentropy",     optimizer="adam",     metrics=["accuracy", tfa.metrics.F1Score(num_classes)])
m.summary()

We useSequential(), the first layer is the pre-trained CNN model, and we add a fully connected layer with the size of the number of classes as an output layer.

Finally, the model is built and compiled with categorical cross-entropy, adam optimizer, and accuracy and F1 score as metrics. Output:

Model: "sequential"_________________________________________________________________ Layer (type)                Output Shape              Param #   ================================================================= keras_layer (KerasLayer)    (None, 1280)              117746848                                                                   dense (Dense)               (None, 10)                12810                                                                      =================================================================Total params: 117,759,658Trainable params: 117,247,082Non-trainable params: 512,576_________________________________________________________________

Fine-tuning the Model

We have the data and model right, let's begin fine-tuning our model:

model_name = "satellite-classification"model_path = os.path.join("results", model_name + ".h5")model_checkpoint = tf.keras.callbacks.ModelCheckpoint(model_path, save_best_only=True, verbose=1)
# set the training & validation steps since we're using .repeat() on our dataset# number of training stepsn_training_steps   = int(num_examples * 0.6) // batch_size# number of validation stepsn_validation_steps = int(num_examples * 0.2) // batch_size
# train the modelhistory = m.fit(    train_ds, validation_data=valid_ds,    steps_per_epoch=n_training_steps,    validation_steps=n_validation_steps,    verbose=1, epochs=5,     callbacks=[model_checkpoint])

The training will take several minutes, depending on your GPU. Here is the output:

Epoch 1/5253/253 [==============================] - ETA: 0s - loss: 0.3780 - accuracy: 0.8859 - f1_score: 0.8832Epoch 00001: val_loss improved from inf to 0.16415, saving model to results/satellite-classification.h5253/253 [==============================] - 158s 438ms/step - loss: 0.3780 - accuracy: 0.8859 - f1_score: 0.8832 - val_loss: 0.1641 - val_accuracy: 0.9513 - val_f1_score: 0.9501Epoch 2/5253/253 [==============================] - ETA: 0s - loss: 0.1531 - accuracy: 0.9536 - f1_score: 0.9525Epoch 00002: val_loss improved from 0.16415 to 0.12853, saving model to results/satellite-classification.h5253/253 [==============================] - 106s 421ms/step - loss: 0.1531 - accuracy: 0.9536 - f1_score: 0.9525 - val_loss: 0.1285 - val_accuracy: 0.9568 - val_f1_score: 0.9559Epoch 3/5253/253 [==============================] - ETA: 0s - loss: 0.1092 - accuracy: 0.9660 - f1_score: 0.9654Epoch 00003: val_loss improved from 0.12853 to 0.12095, saving model to results/satellite-classification.h5253/253 [==============================] - 107s 424ms/step - loss: 0.1092 - accuracy: 0.9660 - f1_score: 0.9654 - val_loss: 0.1210 - val_accuracy: 0.9619 - val_f1_score: 0.9605Epoch 4/5253/253 [==============================] - ETA: 0s - loss: 0.1042 - accuracy: 0.9692 - f1_score: 0.9687Epoch 00004: val_loss did not improve from 0.12095253/253 [==============================] - 100s 394ms/step - loss: 0.1042 - accuracy: 0.9692 - f1_score: 0.9687 - val_loss: 0.1435 - val_accuracy: 0.9565 - val_f1_score: 0.9572Epoch 5/5253/253 [==============================] - ETA: 0s - loss: 0.1003 - accuracy: 0.9700 - f1_score: 0.9695Epoch 00005: val_loss improved from 0.12095 to 0.09841, saving model to results/satellite-classification.h5253/253 [==============================] - 107s 423ms/step - loss: 0.1003 - accuracy: 0.9700 - f1_score: 0.9695 - val_loss: 0.0984 - val_accuracy: 0.9702 - val_f1_score: 0.9687

As you can see, the model improved to about 97% accuracy on the validation set on epoch 5. You can increase the number of epochs to see whether it can improve further.

Model Evaluation

Up until now, we're only validating on the validation set during training. This section uses our model to predict satellite images that the model has never seen before. Loading the best weights:

# load the best weightsm.load_weights(model_path)

Extracting all the testingimages andlabels individually fromtest_ds:

# number of testing stepsn_testing_steps = int(all_ds[1].splits["train"].num_examples * 0.2)# get all testing images as NumPy arrayimages = np.array([ d["image"] for d in test_ds.take(n_testing_steps) ])print("images.shape:", images.shape)# get all testing labels as NumPy arraylabels = np.array([ d["label"] for d in test_ds.take(n_testing_steps) ])print("labels.shape:", labels.shape)

Output:

images.shape: (5400, 64, 64, 3)labels.shape: (5400,)

As expected, 5,400images andlabels, let's use the model to predict these images and then compare thepredictions with the truelabels:

# feed the images to get predictionspredictions = m.predict(images)# perform argmax to get class indexpredictions = np.argmax(predictions, axis=1)print("predictions.shape:", predictions.shape)

Output:

predictions.shape: (5400,)
from sklearn.metrics import f1_scoreaccuracy = tf.keras.metrics.Accuracy()accuracy.update_state(labels, predictions)print("Accuracy:", accuracy.result().numpy())print("F1 Score:", f1_score(labels, predictions, average="macro"))

Output:

Accuracy: 0.9677778F1 Score: 0.9655686619720163

That's good accuracy! Let's draw the confusion matrix for all the classes:

# compute the confusion matrixcmn = tf.math.confusion_matrix(labels, predictions).numpy()# normalize the matrix to be in percentagescmn = cmn.astype('float') / cmn.sum(axis=0)[:, np.newaxis]# make a plot for the confusion matrixfig, ax = plt.subplots(figsize=(10,10))sns.heatmap(cmn, annot=True, fmt='.2f',             xticklabels=[f"pred_{c}" for c in class_names],             yticklabels=[f"true_{c}" for c in class_names],            # cmap="Blues"            cmap="rocket_r"            )plt.ylabel('Actual')plt.xlabel('Predicted')# plot the resulting confusion matrixplt.savefig("confusion-matrix.png")# plt.show()

Output:

Confusion MatrixAs you can see, the model is accurate in most of the classes, especially on forest images, as it achieved 100%. However, it's down to 91% for pasture, and the model sometimes predicts the pasture as permanent corp, also on herbaceous vegetation. Most of the confusion is between corp, pasture, and herbaceous vegetation as they all look similar and, most of the time, green from the satellite.

Let's show some examples that the model predicted:

def show_predicted_samples():  plt.figure(figsize=(14, 14))  for n in range(64):      ax = plt.subplot(8, 8, n + 1)      # show the image      plt.imshow(images[n])      # and put the corresponding label as title upper to the image      if predictions[n] == labels[n]:        # correct prediction        ax.set_title(class_names[predictions[n]], color="green")      else:        # wrong prediction        ax.set_title(f"{class_names[predictions[n]]}/T:{class_names[labels[n]]}", color="red")      plt.axis('off')      plt.savefig("predicted-sample-images.png")# showing a batch of images along with predictions labelsshow_predicted_samples()

Output:

Example samples inferred from the modelIn all 64 images, only one (red label in the above image) failed to predict the actual class. It was predicted as a pasture where it should be a permanent crop.

Final Thoughts

Alright! That's it for the tutorial. If you want further improvement, I highly advise you to explore onTensorFlow hub, where you find the state-of-the-art pre-trained CNN models and feature extractors.

I also suggest you try out different optimizers and increase the number of epochs to see if you can improve it. You can useTensorBoard to track the accuracy of each change you make. Make sure you include the variables in the model name.

If you want more in-depth information, I encourage you to check theEuroSAT paper, where they achieved 98.57% accuracy with the 13 bands version of the dataset (1.93GB). You can also use this version of the dataset by passing"eurosat/all" instead of standard"eurosat" to thetfds.load() method.

You can get the complete code of this tutorialhere.

Learn also:Skin Cancer Detection using TensorFlow in Python

Happy learning ♥

Open In Colab

Liked what you read? You'll love what you can learn from ourAI-powered Code Explainer. Check it out!

View Full Code Convert My Code
Sharing is caring!



Read Also


Skin Cancer Detection using TensorFlow in Python
How to Fine Tune BERT for Text Classification using Transformers in Python
Speech Recognition using Transformers in Python

Comment panel

    Got a coding query or need some guidance before you comment? Check out thisPython Code Assistant for expert advice and handy tips. It's like having a coding tutor right in your fingertips!





    Mastering YOLO - Topic - Top


    Join 50,000+ Python Programmers & Enthusiasts like you!



    Tags


    New Tutorials

    Popular Tutorials


    Mastering YOLO - Topic - Bottom






    Claim your Free YOLO EBook

    Download a Completely Free EBook about YOLO Object Detection

    Build a custom object detector with YOLO from scratch using OpenCV in Python following this free EBook!



    [8]ページ先頭

    ©2009-2025 Movatter.jp