1
\$\begingroup\$

I am a PhD student working on a machine learning project with binary classification and RESNET architecture in TensorFlow. I believe I have done everything correctly but I am looking for some validation that the code is correct as I have no one to check my work. I am working in Google Colab with two programs: 1 to split the dataset and 1 to run the model. The first code is splitdataset.ipynb and the second code is classifyimages.ipynb. I know their are easier and other ways to do this but this is how I implemented it. Some of the more important things to check over would be how I implemented the brightness augmentations and how I split the dataset but ideally a whole code validation would be nice.

splitdataset.ipynb

import cv2import osimport shutilimport random%cd /content/drive/MyDrive/static_CTC_classification!rm -r data_set!mkdir data_set!mkdir data_set/training!mkdir data_set/validation!mkdir data_set/testing!mkdir data_set/training/DU145!mkdir data_set/training/PC3!mkdir data_set/validation/DU145!mkdir data_set/validation/PC3!mkdir data_set/testing/DU145!mkdir data_set/testing/PC3du145 = []pc3 = []for image in os.listdir("full_ds/DU145"):  du145.append("full_ds/DU145/" + image)for image in os.listdir("full_ds/PC3"):  pc3.append("full_ds/PC3/" + image)images = du145 + pc3random.shuffle(images)num_images = len(images)train_num = int(0.8 * num_images)val_num = int(0.1 * num_images)print("train_num: ",train_num)print("val_num: ",val_num)train = images[0:train_num]val = images[train_num:train_num+val_num]test = images[train_num+val_num:]for image in train:  if(image[8] == "D"):    im = image[14:]    shutil.copyfile(image, "split_ds/training/DU145/"+im)  if(image[8] == "P"):    im = image[12:]    shutil.copyfile(image, "split_ds/training/PC3/"+im)for image in val:  if(image[8] == "D"):    im = image[14:]    shutil.copyfile(image, "split_ds/validation/DU145/"+im)  if(image[8] == "P"):    im = image[12:]    shutil.copyfile(image, "split_ds/validation/PC3/"+im)for image in test:  if(image[8] == "D"):    im = image[14:]    shutil.copyfile(image, "split_ds/testing/DU145/"+im)  if(image[8] == "P"):    im = image[12:]    shutil.copyfile(image, "split_ds/testing/PC3/"+im)

classifyimages.ipynb

import tensorflow as tffrom tensorflow.keras.callbacks import ModelCheckpointimport tensorflow_datasets as tfdsimport numpy as npimport matplotlib.pyplot as pltimage_size = (224,224)batch_size = 32train_ds = tf.keras.preprocessing.image_dataset_from_directory(    "/content/drive/MyDrive/static_CTC_classification/split_ds/training",    seed=1337,    color_mode='rgb',    image_size=image_size,    batch_size=batch_size)val_ds = tf.keras.preprocessing.image_dataset_from_directory(    "/content/drive/MyDrive/static_CTC_classification/split_ds/validation",    seed=1337,    color_mode='rgb',    image_size=image_size,    batch_size=batch_size)test_ds = tf.keras.preprocessing.image_dataset_from_directory(    "/content/drive/MyDrive/static_CTC_classification/split_ds/testing",    seed=1337,    color_mode='rgb',    image_size=image_size,    batch_size=batch_size)# from website: https://towardsdatascience.com/writing-a-custom-data-augmentation-layer-in-keras-2b53e048a98class RandomColorDistortion(tf.keras.layers.Layer):  def __init__(self, contrast_range=[0.5, 1.5],                 brightness_delta=[-0.2, 0.2], **kwargs):      super(RandomColorDistortion, self).__init__(**kwargs)      self.contrast_range = contrast_range      self.brightness_delta = brightness_delta  def call(self, images, training=None):          if not training:              return images                    contrast = np.random.uniform(              self.contrast_range[0], self.contrast_range[1])          brightness = np.random.uniform(              self.brightness_delta[0], self.brightness_delta[1])                    images = tf.image.adjust_contrast(images, contrast)          images = tf.image.adjust_brightness(images, brightness)          images = tf.clip_by_value(images, 0, 1)          return imagesaugment_and_normalize = tf.keras.Sequential([  RandomColorDistortion(contrast_range=[0.5,1.5], brightness_delta=[-0.15, 0.15]),                         tf.keras.layers.RandomFlip("horizontal"),  tf.keras.layers.RandomRotation(0.1),  tf.keras.layers.experimental.preprocessing.Rescaling(1.0 / 255)], name="augment_and_normalize")def make_model(input_shape, num_classes):    input = tf.keras.Input(shape=input_shape)    # Entry block    x = augment_and_normalize(input)    feature_extractor = tf.keras.applications.resnet.ResNet50(input_shape=(224, 224, 3),                                               include_top=False,                                               weights='imagenet')(x)    x = tf.keras.layers.GlobalAveragePooling2D()(feature_extractor)    x = tf.keras.layers.Flatten()(x)    x = tf.keras.layers.Dense(1024, activation="relu")(x)    x = tf.keras.layers.Dropout(0.5)(x)    x = tf.keras.layers.Dense(512, activation="relu")(x)    x = tf.keras.layers.Dropout(0.5)(x)    output = tf.keras.layers.Dense(1, activation="sigmoid", name="classification")(x)    return tf.keras.Model(input, output)model = make_model(input_shape=image_size + (3,), num_classes=2)tf.keras.utils.plot_model(model)epochs = 250model.compile(    optimizer=tf.keras.optimizers.Adam(),    loss="binary_crossentropy",    metrics=["accuracy"],)history = model.fit(x=train_ds, epochs=epochs, validation_data=val_ds,    callbacks = [ModelCheckpoint(filepath="weights.{epoch:02d}.ckpt", monitor='val_accuracy',    verbose=0, save_best_only=True,save_weights_only=True, mode='auto', save_freq='epoch',options=None)])!cp /content/weights.05.ckpt.data-00000-of-00001 /content/drive/MyDrive/static_CTC_classification/!cp /content/weights.05.ckpt.index /content/drive/MyDrive/static_CTC_classification/# Then after saving these weights to my google drive I can load them later with the following commandmodel.load_weights("/content/drive/MyDrive/static_CTC_classification/weights.05.ckpt")
200_success's user avatar
200_success
146k22 gold badges191 silver badges481 bronze badges
askedApr 1, 2022 at 4:05
karl-gardner's user avatar
\$\endgroup\$

0

You mustlog in to answer this question.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.