Android上的Tensorflow,输入和输出名称

时间:2019-03-31 20:40:48

标签: android python tensorflow

我已经成功地训练了数字准直器。现在,我试图在android中使用它。我从未使用过tensorflow,因此我遵循了一堆教程并说到需要在android app中使用我创建的.pb文件的地步。我正在尝试将其加载,但是它需要inputName和outputName。我不知道那会是什么。从python脚本中,我认为outputName等于final_result,但对于其余我不知道。这就是我在Android

中所拥有的
    mClassifiers.add(
         TensorFlowClassifier.create(
              context.getAssets(),
              "?????",  // <- what goes here ?
               "clasifier.pb",
               "labels.txt",
                100,
                "????", // <- what goes here ?
                "???", // <- what goes here ?
                true)
            );

    import android.content.res.AssetManager;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.List;
import org.tensorflow.contrib.android.TensorFlowInferenceInterface;



public class TensorFlowClassifier implements Classifier {

    // Only returns if at least this confidence
    //must be a classification percetnage greater than this
    private static final float THRESHOLD = 0.1f;

    private TensorFlowInferenceInterface tfHelper;

    private String name;
    private String inputName;
    private String outputName;
    private int inputSize;
    private boolean feedKeepProb;

    private List<String> labels;
    private float[] output;
    private String[] outputNames;

    //given a saved drawn model, lets read all the classification labels that are
    //stored and write them to our in memory labels list
    private static List<String> readLabels(AssetManager am, String fileName) throws IOException {
        List<String> labels = new ArrayList<>();
        BufferedReader br = null;
        try {
            br = new BufferedReader(new InputStreamReader(am.open(fileName)));
            String line;
            while ((line = br.readLine()) != null) {
                labels.add(line);
            }

        } catch (Exception e) {

        } finally {
            if (br != null) {
                br.close();
            }
        }


        return labels;
    }

    //given a model, its label file, and its metadata
    //fill out a classifier object with all the necessary
    //metadata including output prediction
    public static TensorFlowClassifier create(AssetManager assetManager,
                                              String name,
                                              String modelPath,
                                              String labelFile,
                                              int inputSize,
                                              String inputName,
                                              String outputName,
                                              boolean feedKeepProb) throws IOException {
        //intialize a classifier
        TensorFlowClassifier c = new TensorFlowClassifier();

        //store its name, input and output labels
        c.name = name;

        c.inputName = inputName;
        c.outputName = outputName;

        //read labels for label file
        c.labels = readLabels(assetManager, labelFile);

        //set its model path and where the raw asset files are
        c.tfHelper = new TensorFlowInferenceInterface(assetManager, modelPath);
        int numClasses = 10;

        //how big is the input?
        c.inputSize = inputSize;

        // Pre-allocate buffer.
        c.outputNames = new String[] { outputName };

        c.outputName = outputName;
        c.output = new float[numClasses];

        c.feedKeepProb = feedKeepProb;

        return c;
    }

    @Override
    public String name() {
        return name;
    }

    @Override
    public Classification recognize(final float[] pixels, final int width, final int height) {

        //using the interface
        //give it the input name, raw pixels from the drawing,
        //input size
        tfHelper.feed(inputName, pixels, 1, width, height, 1);

        //probabilities
        if (feedKeepProb) {
            tfHelper.feed("keep_prob", new float[] { 1 });
        }
        //get the possible outputs
        tfHelper.run(outputNames);

        //get the output
        tfHelper.fetch(outputName, output);

        // Find the best classification
        //for each output prediction
        //if its above the threshold for accuracy we predefined
        //write it out to the view
        Classification ans = new Classification();
        for (int i = 0; i < output.length; ++i) {
            /*System.out.println(output[i]);
            System.out.println(labels.get(i));*/
            if (!labels.get(i).equals("0") && output[i] > THRESHOLD && output[i] > ans.getConf()) {
                ans.update(output[i], labels.get(i));
            }
        }

        return ans;
    }
}

在这里可以找到python脚本,因为我无法包含它 https://github.com/MicrocontrollersAndMore/TensorFlow_Tut_2_Classification_Walk-through/blob/master/retrain.py

1 个答案:

答案 0 :(得分:0)

        TensorFlowInferenceInterface tensorflow = new TensorFlowInferenceInterface(getAssets(), MODEL_FILE);

        Iterator<Operation> operationIterator = tensorflow.graph().operations();
        while (operationIterator.hasNext()){
            Operation operation = operationIterator.next();
            System.out.print(operation.name());
        }

在加载模型文件后尝试执行此操作以查看图层名称。希望对您有所帮助!