Skip to content

Writing custom data generator #6

@PrashastiSachan

Description

@PrashastiSachan

Hi All,

Thank you for sharing work. The code is perfectly working and I was able to train and run the inference.
But I have one simple quesiton,
I don't know if it is the right platform for this query but it will be great help if you could respond to it.

I was customizing data generator script but then the model fails to converge. Any idea what could be the reason

Here is my code,
`

            from flowUtils import read_flow
            import tensorflow as tf
            from tensorflow.keras.utils import Sequence
            import numpy as np
            import cv2
          class DataGenerator(Sequence):
              def __init__(self, im1PairsList, im2PairsList, flowList, batch_size=6, crop_size=[256,448], shuffle=True, isTrain=False) -> None:
                  self.im1PairsList = im1PairsList
                  self.im2PairsList = im2PairsList
                  self.flowList = flowList
                  self.batch_size = batch_size
                  self.crop_size = crop_size
                  self.shuffle = shuffle
                  self.isTrain = isTrain
                  self.on_epoch_end()
              
              def __len__(self):
                  return int(np.floor(len(self.im1PairsList)/ self.batch_size))
              
              def __getitem__(self, index):
                  indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
                  img1PairsTemp = [self.im1PairsList[k] for k in indexes]
                  img2PairsTemp = [self.im2PairsList[k] for k in indexes]
                  flowPairsTemp = [self.flowList[k] for k in indexes]
                  imPair, flow = self._data_generation(img1PairsTemp, img2PairsTemp, flowPairsTemp)
                  return imPair, flow
          
              def normalizeImages(self, img):
                  return np.asarray(img/255., dtype=np.float32)
          
              def tf_image_crop(self, img_concat):
                  im_cropped = tf.image.random_crop(img_concat, [self.crop_size[0], self.crop_size[1], 8]) # RGB + RGB + UV = 8 channels
                  im1 = im_cropped[:, :, :3]
                  im2 = im_cropped[:, :, 3:6]
                  flo = im_cropped[:, :, 6:]
                  return im1, im2, flo
          
              def _data_generation(self, img1PairsTemp, img2PairsTemp, flowPairsTemp):
                  imgPair = []
                  flowPair = []
                  for img1Path, img2Path, flowPath in zip(img1PairsTemp, img2PairsTemp, flowPairsTemp):
                      im1 = cv2.imread(img1Path)
                      im2 = cv2.imread(img2Path)
                      flo = read_flow(flowPath)
                      norm_im1 = self.normalizeImages(im1)
                      norm_im2 = self.normalizeImages(im2)
                      norm_im1_tf = tf.convert_to_tensor(norm_im1, dtype=tf.float32) 
                      norm_im2_tf = tf.convert_to_tensor(norm_im2, dtype=tf.float32) 
                      im_concat = tf.concat([norm_im1_tf, norm_im2_tf, flo], axis=2)
                      
                      if self.isTrain:
                          im1, im2, flo = self.tf_image_crop(im_concat)
                          imgconc = tf.concat([im1, im2], axis=2)
                          imgPair.append(np.expand_dims(imgconc, axis=0))
                          flowPair.append(np.expand_dims(flo, axis=0))
                      else:
                          imgconc = tf.concat([norm_im1_tf, norm_im2_tf], axis=2)
                          imgPair.append(np.expand_dims(imgconc, axis=0))
                          flowPair.append(np.expand_dims(flo, axis=0))
          
                  imgPair = np.concatenate(imgPair, axis=0)
                  flowPair = np.concatenate(flowPair, axis=0)
                  return tf.convert_to_tensor(imgPair,dtype=tf.float32), tf.convert_to_tensor(flowPair, dtype=tf.float32)
              
              def on_epoch_end(self):
                  self.indexes = np.arange(len(self.im1PairsList))
                  if self.shuffle==True:
                      np.random.shuffle(self.indexes)

`

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions