Czorio
Czorio

Reputation: 115

Apply a function to each element of a 3D, 2 channel Keras tensor

I have a model that processes 2 input 3D images of the same size, A and B, for use in a more classical function to attempt to increase performance of this function. In order to properly train the model I need to apply the function to the result of each run. The function itself takes 2 values, which correspond to the values in A and B at the same coordinate p. This result is to be stored in a 3D image, C, of the same size as A and B at point p. Classical implementations of this would perform a for loop over all the coordinates and apply the function for each pair. Unfortunately, this approach does not work for training a Keras model as the output of the function has to feed back to the weights of the previous layers.

Input -(A, B)-> Model -(A', B')-> Function(A'[p], B'[p]) -(C[p])-> Result

I have attempted to write a custom Keras layer for this. This layer accepts a 4D tensor (channel, z, y, x) and should return a tensor with shape (1, z, y, x).

Currently this looks like:

# imports

def function(x: [float, float]) -> float:
    # a -> x[0], b -> x[1]
    # Calculate
    return c

class CustomLayer(Layer):
    # ... __init__ and build
    def call(self, inputs, **kwargs):
        # All samples, channel n ([::][n])
        # We stack the tensors in such a way because map_fn() maps the top most axis to the function.
        # This way a tensor of shape (n_voxels, 2) is created and the values are delivered in pairs to the function
        map_input = K.stack([K.flatten(inputs[::][0]), K.flatten(inputs[::][1]), axis=1])
        result = K.map_fn(lambda x: function(x), map_input)
        result = K.reshape(result, K.constant([-1, 1, inputs.shape[2], inputs.shape[3], inputs.shape[4]], dtype=tf.int32))
        return result

Unfortunately, this method has severely slowed down the training. Whereas the model without the custom layer at the end took around 45 minutes to train per epoch, the model with the custom layer takes about 120 hours per epoch.

The model and the function on their own can perform the required task with a significant error, however I wanted to see if I could combine the two for better results.


The real-life use in my case is the decomposition of materials from Dual Energy CT. You can calculate the fraction of materials for each voxel by assuming 3 known materials mat1, mat2, mat3 and an unknown sample sample.

With some calculations you can then decompose this unknown sample into fractions of each known material f1, f2, f3 (f1 + f2 + f3 == 1.0). We are only interested in f3, so the calculations for the other fractions have been omitted.

Actual code:

"""
DECT Decomposition Layer
"""

import numpy as np
import tensorflow as tf

from keras import backend as K
from keras.layers import Layer
from keras.constraints import MinMaxNorm


def _intersect(line1_start, line1_end, line2_start, line2_end, ):
    """
    Find the intersection point between 2 lines
    """
    # https://en.wikipedia.org/wiki/Line%E2%80%93line_intersection

    # Tensorflow's tensors need a little more lines to unpack
    x1 = line1_start[0]
    y1 = line1_start[1]
    x2 = line1_end[0]
    y2 = line1_end[1]
    x3 = line2_start[0]
    y3 = line2_start[1]
    x4 = line2_end[0]
    y4 = line2_end[1]

    px = (x1 * y2 - y1 * x2) * (x3 - x4) - (x1 - x2) * (x3 * y4 - y3 * x4)
    px /= (x1 - x2) * (y3 - y4) - (y1 - y2) * (x3 - x4)
    py = (x1 * y2 - y1 * x2) * (y3 - y4) - (y1 - y2) * (x3 * y4 - y3 * x4)
    py /= (x1 - x2) * (y3 - y4) - (y1 - y2) * (x3 - x4)

    return K.stack([px, py])


def _decomp(sample, mat1, mat2, mat3):
    """
    Decomposition of a sample into 1 fraction of material 3
    """
    # Calculate the sample lines' ends
    sample3 = sample + (mat2 - mat3)

    # Calculate the intersection points between the sample lines and triangle sides
    intersect3 = _intersect(sample, sample3, mat1, mat2)

    # Find out how far along the sample line the intersection is
    f3 = tf.norm(sample - intersect3) / tf.norm(sample - sample3)

    return f3


class DectDecompoLayer(Layer):
    def __init__(self, mat1, mat2, mat3, **kwargs):
        self.mat1 = K.constant(mat1)
        self.mat2 = K.constant(mat2)
        self.mat3 = K.constant(mat3)

        super(DectDecompoLayer, self).__init__(**kwargs)

    def build(self, input_shape):
        super(DectDecompoLayer, self).build(input_shape)

    def call(self, inputs, **kwargs):
        map_input = K.stack([K.flatten(inputs[::][0]), K.flatten(inputs[::][1])], axis=1)
        result = K.map_fn(lambda x: _decomp(x, self.mat1, self.mat2, self.mat3), map_input)
        result = K.reshape(result, K.constant([-1, 1, inputs.shape[2], inputs.shape[3], inputs.shape[4]], dtype=tf.int32))
        return result

    def compute_output_shape(self, input_shape):
        return input_shape[0], 1, input_shape[2], input_shape[3], input_shape[4]

Upvotes: 1

Views: 1182

Answers (1)

Daniel Möller
Daniel Möller

Reputation: 86630

Ok. The first thing that is very strange is that you are taking "samples", not "channels".

The command inputs[::] returns exactly inputs, and inputs[::][0] is equal to inputs[0].

So, you are training only two samples, no matter how big your batch size is.

That said, all you need is something like:

  • Assuming inputs with shape (batch, 2, size, size, size)
  • Assuming matN with shape (1, 2, size, size, size), exactly, or (batch, 2, size, size, size)
def call(self, inputs, **kwargs): #shape (batch, 2, size, size, size)
    sample3 = inputs + (self.mat2 - self.mat3) 

    #all shapes (batch, size, size, size)
    x1 = inputs[:,0]
    y1 = inputs[:,1]
    x2 = sample3[:,0]
    y2 = sample3[:,1]

    #all shapes (1, size, size, size) or (batch, size, size, size)
    x3 = self.mat1[:,0]
    y3 = self.mat2[:,1]
    x4 = self.mat2[:,0]
    y4 = self.mat2[:,1]


    #all shapes (batch, size, size, size)
    px = (x1 * y2 - y1 * x2) * (x3 - x4) - (x1 - x2) * (x3 * y4 - y3 * x4)
    px /= (x1 - x2) * (y3 - y4) - (y1 - y2) * (x3 - x4)
    py = (x1 * y2 - y1 * x2) * (y3 - y4) - (y1 - y2) * (x3 * y4 - y3 * x4)
    py /= (x1 - x2) * (y3 - y4) - (y1 - y2) * (x3 - x4)

    #proceed with the rest

Warning, you may have division by zero for parallel lines.

I recommend some kind of K.switch like:

denominator = (x1 - x2) * (y3 - y4) - (y1 - y2) * (x3 - x4)
K.switch(
    K.less(K.abs(denominator), K.epsilon()), 
    denominator + K.sign(denominator)*K.epsilon(), 
    denominator)
px /= denominator

Upvotes: 1

Related Questions