user7860075
user7860075

Reputation: 1

Android Tensorflow IllegalArgumentException Error

I'm working about image recognition using android studio and tensorflow, android version. It is not tracking and recognition continuously, just recognition for one image. I already have graph pb and label txt files, and set needed settings. But there is a big problem. I've got a same error repeatedly about image, dimensional error. Here is error log and my source code.

java.lang.IllegalArgumentException: input must be 4-dimensional[1,1,299,299,3]
                                                                         [[Node: ResizeBilinear = ResizeBilinear[T=DT_FLOAT, align_corners=false, _device="/job:localhost/replica:0/task:0/cpu:0"](ExpandDims, ResizeBilinear/size)]]
                                                                         at org.tensorflow.Session.run(Native Method)
                                                                         at org.tensorflow.Session.access$100(Session.java:48)
                                                                         at org.tensorflow.Session$Runner.runHelper(Session.java:295)
                                                                         at org.tensorflow.Session$Runner.run(Session.java:245)
                                                                         at org.tensorflow.contrib.android.TensorFlowInferenceInterface.run(TensorFlowInferenceInterface.java:144)
                                                                         at com.example.yuuuuu.tensorTest.TensorFlowImageClassifier.recognizeImage(TensorFlowImageClassifier.java:119)
                                                                         at com.example.yuuuuu.tensorTest.MainActivity.runTensor(MainActivity.java:69)
                                                                         at com.example.yuuuuu.tensorTest.MainActivity$1.onClick(MainActivity.java:42)
                                                                         at android.view.View.performClick(View.java:6205)
                                                                         at android.widget.TextView.performClick(TextView.java:11103)
                                                                         at android.view.View$PerformClick.run(View.java:23653)
                                                                         at android.os.Handler.handleCallback(Handler.java:751)
                                                                         at android.os.Handler.dispatchMessage(Handler.java:95)
                                                                         at android.os.Looper.loop(Looper.java:154)
                                                                         at android.app.ActivityThread.main(ActivityThread.java:6682)
                                                                         at java.lang.reflect.Method.invoke(Native Method)
                                                                         at com.android.internal.os.ZygoteInit$MethodAndArgsCaller.run(ZygoteInit.java:1520)
                                                                         at com.android.internal.os.ZygoteInit.main(ZygoteInit.java:1410)

I don't know where the problem is, first line, [1,1,299,299,3]. I think two 299 are ImageSize, one 1 is ImageStd, but I don't know what another 1 and 3 are... I typed the code same with official codes in tensorflow github and just changed a few parts. This is MainActivity.

public class MainActivity extends AppCompatActivity {

private static final String MODEL_FILE = "file:///android_asset/optimized_graph.pb";
private static final String LABEL_FILE = "file:///android_asset/output_labels.txt";
private static final String INPUT_NAME = "Cast";
private static final String OUTPUT_NAME = "final_result";
private static final int INPUT_SIZE = 299;
private static final int IMAGE_MEAN = 117;
private static final float IMAGE_STD = 1;

private Classifier classifier;
private TextView textView;
private ImageView img;
private Button button;

@Override
protected void onCreate(Bundle savedInstanceState) {
    super.onCreate(savedInstanceState);
    setContentView(R.layout.activity_main);

    textView = (TextView)findViewById(R.id.textView);
    button = (Button)findViewById(R.id.btn);
    img = (ImageView)findViewById(R.id.img);

    button.setOnClickListener(new View.OnClickListener(){
        public void onClick(View v){
            runTensor();
        }
    });

    initTensor();
}

public void initTensor(){
    classifier = TensorFlowImageClassifier.create(
            getAssets(),
            MODEL_FILE,
            LABEL_FILE,
            INPUT_SIZE,
            IMAGE_MEAN,
            IMAGE_STD,
            INPUT_NAME,
            OUTPUT_NAME
    );
}

public void runTensor(){
    Bitmap bitmap = BitmapFactory.decodeResource(getResources(), R.drawable.test);
    bitmap = Bitmap.createScaledBitmap(bitmap, INPUT_SIZE, INPUT_SIZE, false);

    img = (ImageView)findViewById(R.id.img);
    img.setImageBitmap(bitmap);

    final List<Classifier.Recognition> results = classifier.recognizeImage(bitmap);
    textView.setText(results.toString());
}

protected void onDestroy(){
    super.onDestroy();
    classifier.close();
}

}

This is Classifier, same with official code.

public interface Classifier {

public class Recognition{
    private final String id;
    private final String title;
    private final Float confidence;
    private RectF location;

    public Recognition(
            final String id, final String title, final Float confidence, final RectF location){
        this.id = id;
        this.title = title;
        this.confidence = confidence;
        this.location = location;
    }

    public String getId(){return id;}
    public String getTitle(){return title;}
    public Float getConfidence(){return confidence;}
    public RectF getLocation(){return location;}
    public void setLocation(RectF location){this.location = location;}

    public String toString(){
        String resultString = "";
        if (id != null) {
            resultString += "[" + id + "] ";
        }

        if (title != null) {
            resultString += title + " ";
        }

        if (confidence != null) {
            resultString += String.format("(%.1f%%) ", confidence * 100.0f);
        }

        if (location != null) {
            resultString += location + " ";
        }

        return resultString.trim();
    }
}

List<Recognition> recognizeImage(Bitmap bitmap);
void enableStatLogging(final boolean debug);
String getStatString();
void close();
}

Last is TensorFlowImageClassifier, same with official too.

public class TensorFlowImageClassifier implements Classifier {
private static final String TAG = "TensorFlowImageClassifier";

private static final int MAX_RESULTS = 3;
private static final float THRESHOLD = 0.1f;

private String inputName;
private String outputName;
private int inputSize;
private int imageMean;
private float imageStd;

private Vector<String> labels = new Vector<String>();
private int[] intValues;
private float[] floatValues;
private float[] outputs;
private String[] outputNames;

private boolean logStats = false;
private TensorFlowInferenceInterface inferenceInterface;
private TensorFlowImageClassifier() {}

/*
assetManager : assets 로드하는데 사용
modelFilename : pb 파일
labelFilename : txt 파일
inputSize : 정사각형 길이, inputSize * inputSize
imageMean : image values 평균값
imageStd : image values 표준값?
inputName : image input 노드 레이블
outputName : output 노드 레이블
 */

public static Classifier create(
        AssetManager assetManager, String modelFilename, String labelFilename, int inputSize, int imageMean, float imageStd, String inputName, String outputName){
    TensorFlowImageClassifier c = new TensorFlowImageClassifier();
    c.inputName = inputName;
    c.outputName = outputName;

    String actualFilename = labelFilename.split("file:///android_asset/")[1];
    Log.d(TAG, "reading labels from : " + actualFilename);
    BufferedReader br = null;

    try {
        br = new BufferedReader(new InputStreamReader(assetManager.open(actualFilename)));
        String line;
        while((line = br.readLine()) != null){
            c.labels.add(line);
        }
        br.close();
    } catch (IOException e) {
        throw new RuntimeException("failed reading labels" , e);
    }

    c.inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename);

    final Operation operation = c.inferenceInterface.graphOperation(outputName);
    final int numClasses = (int)operation.output(0).shape().size(1);
    Log.d(TAG, "reading " + c.labels.size() + " labels, size of output layers : " + numClasses);

    c.inputSize = inputSize;
    c.imageMean = imageMean;
    c.imageStd = imageStd;

    c.outputNames = new String[]{outputName};
    c.intValues = new int[inputSize * inputSize];
    c.floatValues = new float[inputSize * inputSize * 3];
    c.outputs = new float[numClasses];

    return c;
}

@RequiresApi(api = Build.VERSION_CODES.JELLY_BEAN_MR2)
public List<Recognition> recognizeImage(final Bitmap bitmap){
    beginSection("recognizeImage");
    beginSection("preprocessBitmap");

    bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
    for(int i = 0; i < intValues.length; i++){
        final int val = intValues[i];
        floatValues[i*3+0] = (((val >> 16) & 0xFF) - imageMean) / imageStd;
        floatValues[i*3+1] = (((val >> 8) & 0xFF) - imageMean) / imageStd;
        floatValues[i*3+2] = ((val & 0xFF) - imageMean) / imageStd;
    }
    endSection();

    beginSection("feed");
    inferenceInterface.feed(inputName, floatValues, 1, inputSize, inputSize, 3);
    endSection();

    beginSection("run");
    inferenceInterface.run(outputNames, logStats);
    endSection();

    beginSection("fetch");
    inferenceInterface.fetch(outputName, outputs);
    endSection();


    PriorityQueue<Recognition> pq = new PriorityQueue<Recognition>(
            3,
            new Comparator<Recognition>(){
                public int compare(Recognition lhs, Recognition rhs){
                    return Float.compare(rhs.getConfidence(), lhs.getConfidence());
                }
            }
    );

    for(int i = 0; i < outputs.length; ++i){
        if(outputs[i] > THRESHOLD){
            pq.add(
                    new Recognition("" + i, labels.size() > i ? labels.get(i) : "unknown", outputs[i], null));
        }
    }

    final ArrayList<Recognition> recognitions = new ArrayList<Recognition>();
    int recognitionSize = Math.min(pq.size(), MAX_RESULTS);
    for(int i = 0; i < recognitionSize; ++i){
        recognitions.add(pq.poll());
    }
    endSection();

    return recognitions;
}

public void enableStatLogging(boolean logStats){this.logStats = logStats;}
public String getStatString(){return inferenceInterface.getStatString();}
public void close(){inferenceInterface.close();}
}

If you know how to fix these codes, please tell me how.

Upvotes: 0

Views: 599

Answers (2)

Dan J
Dan J

Reputation: 25673

java.lang.IllegalArgumentException: input must be 4-dimensional[1,1,299,299,3]

The error message explains the problem: you are accidentally passing a 5 item array instead of a 4 item array. i.e. you should probably be passing something like [1,299,299,1] instead of [1,1,299,299,3].

It's hard to tell from your question which code changes you actually made. If you could make your changes as a single Git commit then it might be easier for somebody to identify which change is causing the problem?

You could try viewing your TensorFlow model in TensorBoard to inspect the input and output nodes to check that they match the values you have configured:
https://medium.com/@daj/how-to-inspect-a-pre-trained-tensorflow-model-5fd2ee79ced0

Upvotes: 1

Marcos Vasconcelos
Marcos Vasconcelos

Reputation: 18276

Well, when I was working with native libs, I noticed that usually they dont get files from assets by itself, you need to copy it to a acessible file storage path and pass the absolute path to the library.

Your error may be from loading the resources.

Upvotes: 0

Related Questions