Reputation: 7912
I have a Conv2D
layer:
l0 = tf.keras.layers.Conv2D(1, 3, activation=None, input_shape=(36,36,3))
I would like to find out the exact values in the filter/kernel matrix used (not just the number of them). How can I access kernel matrix values?
import tensorflow as tf
import numpy as np
I created a numpy array:
x_core = np.array([[1,0,0,1],
[0,0,0,0],
[0,0,0,0],
[1,0,0,1]],dtype=float)
Cast it to a shape (1,4,4,1)
tensor:
x = tf.expand_dims(tf.expand_dims(tf.convert_to_tensor(x_core),axis=0),axis=3)
Apply a Conv2D
layer to it with strides=(2,2)
. This means that the output will be a 2 by 2 matrix, the top left value in it will be equal to the top left value in the kernel matrix, top right of the result will be equal to top right of the kernel matrix, and so on. (The particular zeros and ones in x_core
achieve this.)
y = tf.keras.layers.Conv2D(1, 2, strides=(2,2), activation=None, input_shape=x.shape[1:])(x)
However, y
changes if I rerun the code, ie the filter is not constant, suggesting that the kernel matrix is drawn from a distribution.
Similar but different question: How to get CNN kernel values in Tensorflow - this method only worked in Tensorflow 1. Problems with it:
gr = tf.get_default_graph()
gives AttributeError: module 'tensorflow' has no attribute 'get_default_graph'
If I replace get_default_graph
with Graph
(as I believe that is the newer equivalent), put name="conv1"
to my layer definition: conv_layer_1 = tf.keras.layers.Conv2D(1, 2, strides=(2,2), activation=None, input_shape=x.shape[1:],name="conv1")
then run conv1_kernel_val = tf.Graph().get_tensor_by_name('conv1/kernel:0').eval()
as suggested, I get:
KeyError: "The name 'conv1/kernel:0' refers to a Tensor which does not exist. The operation, 'conv1/kernel', does not exist in the graph."
Upvotes: 2
Views: 1552
Reputation: 1158
import tensorflow as tf
input_shape = (4, 28, 28, 3)
x = tf.random.normal(input_shape)
model = tf.keras.layers.Conv2D(2, 3, activation='relu', input_shape=input_shape[1:])
y = model(x)
print(model.kernel)
Upvotes: 4