dhruv Gangwani
dhruv Gangwani

Reputation: 1

the data augmentation doesn't work properly

according to me, the given code should create 5700 images, which is 10 times the number of original images, instead, when I check the shape of it, it only gives 1140, which is twice the number of images. What am I missing in this code?

# Custom augmentation functions
def random_translate(image, max_translate):
    shift = np.random.uniform(-max_translate, max_translate, 3)
    translated_image = scipy.ndimage.shift(image, shift, mode='nearest')
    return translated_image

def random_rotate(image, max_angle):
    angles = np.random.uniform(-max_angle, max_angle, 3)
    rotated_image = scipy.ndimage.rotate(image, angles[0], axes=(1, 2), reshape=False)
    rotated_image = scipy.ndimage.rotate(rotated_image, angles[1], axes=(0, 2), reshape=False)
    rotated_image = scipy.ndimage.rotate(rotated_image, angles[2], axes=(0, 1), reshape=False)
    return rotated_image

def random_flip(image):
    if random.random() > 0.5:
        image = np.flip(image, axis=0)
    if random.random() > 0.5:
        image = np.flip(image, axis=1)
    if random.random() > 0.5:
        image = np.flip(image, axis=2)
    return image

def random_noise(image, noise_level=0.01):
    noise = np.random.normal(0, noise_level, image.shape)
    noisy_image = image + noise
    return noisy_image

def random_brightness(image, max_delta=0.2):
    delta = np.random.uniform(-max_delta, max_delta)
    bright_image = np.clip(image + delta, 0, 1)  # Clip to maintain valid pixel range
    return bright_image

def random_contrast(image, lower=0.8, upper=1.2):
    factor = np.random.uniform(lower, upper)
    mean = np.mean(image, axis=(0, 1, 2), keepdims=True)
    contrast_image = np.clip((image - mean) * factor + mean, 0, 1)
    return contrast_image

def random_scale(image, min_scale=0.9, max_scale=1.1):
    scale = np.random.uniform(min_scale, max_scale)
    height, width = image.shape[:2]
    scaled_image = scipy.ndimage.zoom(image, (scale, scale, 1), order=1)
    if scale < 1.0:
        pad_height = (height - scaled_image.shape[0]) // 2
        pad_width = (width - scaled_image.shape[1]) // 2
        scaled_image = np.pad(scaled_image, ((pad_height, pad_height), (pad_width, pad_width), (0, 0)), mode='constant')
    else:
        start_height = (scaled_image.shape[0] - height) // 2
        start_width = (scaled_image.shape[1] - width) // 2
        scaled_image = scaled_image[start_height:start_height + height, start_width:start_width + width]
    return scaled_image

def random_shear(image, max_shear=0.2):
    shear = np.random.uniform(-max_shear, max_shear)
    afine_tf = tf.keras.preprocessing.image.random_shear(shear)
    shear_image = tf.keras.preprocessing.image.apply_affine_transform(image, shear=afine_tf)
    return shear_image

def elastic_transform(image, alpha=1000, sigma=30):
    random_state = np.random.RandomState(None)
    shape = image.shape
    dx = scipy.ndimage.gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0) * alpha
    dy = scipy.ndimage.gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0) * alpha
    dz = scipy.ndimage.gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0) * alpha
    x, y, z = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), np.arange(shape[2]), indexing='ij')
    indices = np.reshape(x + dx, (-1, 1)), np.reshape(y + dy, (-1, 1)), np.reshape(z + dz, (-1, 1))
    distored_image = scipy.ndimage.map_coordinates(image, indices, order=1, mode='reflect')
    return distored_image.reshape(image.shape)


def custom_data_generator(images, labels, batch_size,
                          max_translate=10, max_angle=15, noise_level=0.01,
                          max_delta=0.2, lower_contrast=0.8, upper_contrast=1.2,
                          min_scale=0.9, max_scale=1.1, max_shear=0.2, alpha=1000, sigma=30):
    num_images = images.shape[0]
    while True:
        batch_indices = np.random.choice(num_images, batch_size)
        batch_images = []
        batch_labels = []
        for idx in batch_indices:
            image = images[idx]
            label = labels[idx]

            # Apply all augmentation functions
            image = random_translate(image, max_translate)
            image = random_rotate(image, max_angle)
            image = random_flip(image)
            image = random_noise(image, noise_level)
            image = random_brightness(image, max_delta)
            image = random_contrast(image, lower_contrast, upper_contrast)
            image = random_scale(image, min_scale, max_scale)
            image = tf.keras.preprocessing.image.random_shear(image, max_shear)
            image = elastic_transform(image, alpha, sigma)

            batch_images.append(image)
            batch_labels.append(label)

        batch_images = np.array(batch_images)
        batch_labels = np.array(batch_labels)
        yield batch_images, batch_labels


# Directory to save the images
save_dir = 'mri_augmented_images'
os.makedirs(save_dir, exist_ok=True)

for i in range(mri_resized.shape[0]):
    for j in range(mri_resized.shape[3]):  # Loop over channels
        img_array = mri_resized[i, :, :, j]
        img = Image.fromarray((img_array * 255).astype('uint8'))  # Scale to [0, 255] and convert to uint8
        img.save(os.path.join(save_dir, f'image_{i}_channel_{j}.png'))

# Using the custom data generator to augment images and retain labels
target_num_images = 5700
batch_size = 3  # Define your batch size
num_batches_needed = (target_num_images + batch_size - 1) // batch_size
data_gen = custom_data_generator(mri_resized, labels, batch_size=batch_size)

# Generate and save augmented images with labels
for batch_images, batch_labels in data_gen:  # Directly iterate over the generator
    for j, img_array in enumerate(batch_images):
        img = Image.fromarray((img_array[:, :, 0] * 255).astype('uint8'))  # Save only the first channel
        img_index = i * batch_size + j
        img.save(os.path.join(save_dir, f'aug_image_{img_index}_label_{batch_labels[j]}.png'))

print(f"Generated and saved {target_num_images} aug

mented images.")

Upvotes: 0

Views: 46

Answers (0)

Related Questions