StalkerMuse
StalkerMuse

Reputation: 1029

caffe: Check failed: bottom[0]->num_axes() == first_spatial_axis + num_spatial_axes_ (3 vs. 4) bottom num_axes may not change

I modify the FCN net and design a new net,in which I use two ImageData Layer as input param and hope the net produces a picture as output. here is the train_val.prototxt and the deploy.prototxt

the original picture and the label are both gray scale pics and sizes are 224*224. I've trained a caffemodel and use infer.py to use the caffemodel to do a segmentation,but meet the error:

 Check failed: bottom[0]->num_axes() == first_spatial_axis + num_spatial_axes_ (3 vs. 4) bottom num_axes may not change.

here is the infer.py file:

import numpy as np
from PIL import Image
caffe_root = '/home/zhaimo/' 
import sys
sys.path.insert(0, caffe_root + 'caffe-master/python')

import caffe
im = Image.open('/home/zhaimo/fcn-master/data/vessel/test/13.png')
in_ = np.array(im, dtype=np.float32)
#in_ = in_[:,:,::-1]
#in_ -= np.array((104.00698793,116.66876762,122.67891434))
#in_ = in_.transpose((2,0,1))


net = caffe.Net('/home/zhaimo/fcn-master/mo/deploy.prototxt', '/home/zhaimo/fcn-master/mo/snapshot/train/_iter_200000.caffemodel', caffe.TEST)
net.blobs['data'].reshape(1, *in_.shape)
net.blobs['data'].data[...] = in_
net.forward()
out = net.blobs['score'].data[0].argmax(axis=0)

plt.axis('off')
plt.savefig('/home/zhaimo/fcn-master/mo/result/13.png')

how to solve this problem?

Upvotes: 1

Views: 1245

Answers (1)

Shai
Shai

Reputation: 114786

Your net is expecting a 4D input of shape 1x1xHxW. That is a batch with a single image that has only one channel with height and width HxW. Therefore, the input has two singleton leading dimensions. What you provide your net with is a batch with a single 2D image, that is the shape of your in_ is only HxW - you are missing a singleton dimension for the channel dimension.
To solve your problem you need to explicitly add the singleton dimension:

 net.blobs['data'].reshape(1, 1, *in_.shape)

As for the KeyError you got, your net does not have any blob named 'score', you have 'upscore1' and 'prob':

 out = net.blobs['prob'].data[0].argmax(axis=0)

Upvotes: 1

Related Questions