TensorFlow多班训练和预测

时间:2019-10-01 14:07:13

标签: tensorflow

以下代码(有效)训练了一个模型,以识别猫并根据所选图片做出预测。 (代码为TensorFlowJS,但问题通常是TensorFlow)
到目前为止,它只是预测一个类别(“猫”),因此汽车或狗将成为80%的猫。

问题:
如何添加其他类(如“狗”)?
它看起来像这样(摘要):model.fit([img1,img2,img3],[label1,label2,label3] ...)吗?

我不明白:
标签和训练集之间是什么关系。

以下是代码(请暂时忽略“预测”部分):

<head>
    <script src="https://ajax.googleapis.com/ajax/libs/jquery/3.3.1/jquery.min.js"></script>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@1.2.7"> </script>
    <script src="https://unpkg.com/@tensorflow-models/mobilenet"></script>
</head>
<body>
    <div class="container mt-5">
        <div class="row">
            <input id ="image-selector" class="form-control border-0" type="file"/>
        </div>
        <div class="row">
            <div class="col">
                <h2>Prediction</h2>
                <ol id="prediction-list"></ol>
            </div>
        </div>
        <div class="row">
            <div class="col-12">
                <h2 class="ml-3">Image</h2>
                <canvas id="canvas" width="400" height="300" style="border:1px solid #000000;"></canvas>
            </div>
        </div>
    </div>
    <div  id="training-images">
        <img width="400" height="300" class="train-image cat" src="training-images/cat.jpg" />
        <img width="400" height="300" class="train-image cat" src="training-images/cat2.jpeg" />
        <img width="400" height="300" class="train-image cat" src="training-images/cat3.jpeg" />
        <img width="400" height="300" class="train-image cat" src="training-images/cat4.jpeg" />

        <img width="400" height="300" class="train-image dog" src="training-images/dog.jpeg" />
        <img width="400" height="300" class="train-image dog" src="training-images/dog2.jpeg" />
        <img width="400" height="300" class="train-image dog" src="training-images/dog3.jpeg" />
        <img width="400" height="300" class="train-image dog" src="training-images/dog4.jpeg" />
    </div>
</body>

<script>
    const modelType = "mobilenet";
    const model = tf.sequential();
    const label = ['cat'];
    var ys, setLabel, input, canvas, context;
    input = document.getElementById("image-selector");
    canvas = document.getElementById("canvas");
    context = canvas.getContext('2d');

    //-------------------------- Training: --------------------------------
    window.addEventListener('load', (event) => {
        // Labels
        setLabel = Array.from(new Set(label));
        ys = tf.oneHot(tf.tensor1d(label.map((a) => setLabel.findIndex(e => e === a)), 'int32'), 10);
        console.log('ys:::'+ys);

        // Prepare model :
        model.add(tf.layers.conv2d({
            inputShape: [224, 224 , 3],
            kernelSize: 5,
            filters: 8,
            strides: 2,
            activation: 'relu',
            kernelInitializer: 'VarianceScaling'
        }));
        model.add(tf.layers.maxPooling2d({poolSize: 2, strides: 2}));
        model.add(tf.layers.maxPooling2d({poolSize: 2, strides: 2}));
        model.add(tf.layers.flatten({}));
        model.add(tf.layers.dense({units: 64, activation: 'relu'}));
        model.add(tf.layers.dense({units: 10, activation: 'softmax'}));
        model.compile({
            loss: 'meanSquaredError',
            optimizer : 'sgd'
        });

        // Prepare training images
        var images = [];
        for(var i = 0; i < 40; i++) {
            let img = preprocessImage(document.getElementsByClassName("cat")[i], modelType);
            images.push(tf.reshape(img, [1, 224, 224, 3],'resize'));
        }
        console.log("processed images : ");
        console.log(images);
        trainModel(images);
    });

    async function trainModel(images) {
        for(var i = 0; i < images.length; i++) {
            await model.fit(images[i], ys, {epochs: 100, batchSize: 32}).then((loss) => {
            const t = model.predict(images[i]);
            console.log('Prediction:::'+t);
            pred = t.argMax(1).dataSync(); // get the class of highest probability
            const labelsPred = Array.from(pred).map(e => setLabel[e]);
            console.log('labelsPred:::'+labelsPred);
            }).catch((e) => {
                console.log(e.message);
            })
        }
        console.log("Training done!");
    }

    //-------------------------- Predict: --------------------------------
    input.addEventListener("change", function() {
        var reader = new FileReader();
        reader.addEventListener("loadend", function(arg) {
            var src_image = new Image();
            src_image.onload = function() {
                canvas.height = src_image.height;
                canvas.width = src_image.width;
                context.drawImage(src_image, 0, 0);
                var imageData = canvas.toDataURL("image/png"); 
                runPrediction(src_image)
            }
            src_image.src = this.result;
        });
        var res = reader.readAsDataURL(this.files[0]);
    });

    async function runPrediction(imageData){
        let tensor = preprocessImage(imageData, "mobilenet");
        const resize_image = tf.reshape(tensor, [1, 224, 224, 3],'resize');
        let prediction = await model.predict(tensor).data();
        console.log('prediction:::'+ prediction);

        let top5 = Array.from(prediction)
        .map(function(p,i){
            return {
                probability: p,
                className: prediction[i]
            };
        }).sort(function(a,b){
            return b.probability-a.probability;
        }).slice(0,1);

        $("#prediction-list").empty();
        top5.forEach(function(p){
            $("#prediction-list").append(`<li>${p.className}:${p.probability.toFixed(6)}</li>`);
        });
    }

    //-------------------------- Helpers: --------------------------------
    function preprocessImage(image, modelName)
    {
        let tensor = tf.browser.fromPixels(image)
        .resizeNearestNeighbor([224,224])
        .toFloat();

        let offset=tf.scalar(127.5);

        return tensor.sub(offset)
        .div(offset)
        .expandDims();
    }
</script>

代码基于TFJS文档和github上的注释:https://github.com/tensorflow/tfjs/issues/1288

更新: 所以我需要X和Y的X:images和Y:labels长度相同,Y1是X1的标签,依此类推...

我尝试过:

ys:::Tensor (with only 2 classes represented in the training data set) :
    [[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
     [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
     [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
     [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
     [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
     [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 1, 0, 0, 0, 0, 0, 0, 0, 0]]



一幅图片+所有标签->使用“ model.fit(images [i],ys,{epochs:100})...”,我得到:
错误:“输入张量应具有与目标张量相同的样本数。找到了1个输入样本和10个目标样本。”

一幅图片+一个标签->使用“ model.fit(images [i],ys [i],{epochs:100})...”,我得到:
错误:“无法读取null的属性'shape'”,我猜ys是张量,而y [i]不是。

所有图片+所有标签->带有“ model.fit(images,ys,{epochs:100})...”,我得到:
错误:“在检查模型输入时:传递给模型的张量数组不是模型期望的大小。 预期会看到1张Tensor,但获得了以下Tensor列表:Tensor ...”

猜猜:我需要将所有图像放置在一个与ys具有相同结构的张量中。

已解决:
在借助Rishabh Sahrawat解决了标签问题之后,我不得不在tf.concat(...)的帮助下将所有张量(图像)合并为一个。

[tensorImg1, tensorImg2, tensorImg3, tensorImg4, ...] x tensor[label1, label2, label3, label4, ...]
-> 
tensor[dataImg1, dataImg2, dataImg3, dataImg4, ...] x tensor[label1, label2, label3, label4, ...]

更新的代码:

<head>
    <script src="https://ajax.googleapis.com/ajax/libs/jquery/3.3.1/jquery.min.js"></script>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@1.2.7"> </script>
    <script src="https://unpkg.com/@tensorflow-models/mobilenet"></script>
</head>
<body>
    <div class="container mt-5">
        <div class="row">
            <input id ="image-selector" class="form-control border-0" type="file"/>
        </div>
        <div class="row">
            <div class="col">
                <h2>Prediction</h2>
                <ol id="prediction-list"></ol>
            </div>
        </div>
        <div class="row">
            <div class="col-12">
                <h2 class="ml-3">Image</h2>
                <canvas id="canvas" width="400" height="300" style="border:1px solid #000000;"></canvas>
            </div>
        </div>
    </div>
    <div  id="training-images">
        <img width="400" height="300" class="train-image cat" src="training-images/cat.jpg" />
        <img width="400" height="300" class="train-image cat" src="training-images/cat2.jpeg" />
        <img width="400" height="300" class="train-image cat" src="training-images/cat3.jpeg" />

        <img width="400" height="300" class="train-image dog" src="training-images/dog.jpeg" />
        <img width="400" height="300" class="train-image dog" src="training-images/dog2.jpeg" />
        <img width="400" height="300" class="train-image dog" src="training-images/dog3.jpeg" />
        <img width="400" height="300" class="train-image dog" src="training-images/dog4.jpeg" />
    </div>
</body>

<script>
    const modelType = "mobilenet";
    const model = tf.sequential();
    var labels = ['cat', 'dog'];
    var ys, setLabel, input, canvas, context;
    input = document.getElementById("image-selector");
    canvas = document.getElementById("canvas");
    context = canvas.getContext('2d');

    //-------------------------- Training: --------------------------------
    window.addEventListener('load', (event) => {        
        // Prepare model :
        prepareModel();

        // Prepare training images
        var images = [];
        var trainLabels = []
        for(var i = 0; i < document.getElementsByClassName('train-image').length; i++) {
            let img = preprocessImage(document.getElementsByClassName('train-image')[i], modelType);
            //images.push(tf.reshape(img, [1, 224, 224, 3],'resize'));
            images.push(img);
            if (document.getElementsByClassName('train-image')[i].classList.contains("cat")){
                trainLabels.push(0)
            } else {
                trainLabels.push(1)
            }
        }

        console.log(labels)
        setLabel = Array.from(labels);
        ys = tf.oneHot(trainLabels, 2);
        console.log('ys:::'+ys);
        console.log(images);
        trainModel(images);
    });

    async function trainModel(images) {
        for(var i = 0; i < images.length; i++) {
            await model.fit(tf.concat(images, 0), ys, {epochs: 100}).then((loss) => {
            const t = model.predict(images[i]);
            console.log('Prediction:::'+t);
            pred = t.argMax().dataSync(); // get the class of highest probability
            //const labelsPred = Array.from(pred).map(e => setLabel[e]);
            //console.log('labelsPred:::'+labelsPred);
            }).catch((e) => {
                console.log(e.message);
            })

        }
        console.log("Training done!");
    }

    //-------------------------- Predict: --------------------------------
    input.addEventListener("change", function() {
        var reader = new FileReader();
        reader.addEventListener("loadend", function(arg) {
            var src_image = new Image();
            src_image.onload = function() {
                canvas.height = src_image.height;
                canvas.width = src_image.width;
                context.drawImage(src_image, 0, 0);
                var imageData = canvas.toDataURL("image/png"); 
                runPrediction(src_image)
            }
            src_image.src = this.result;
        });
        var res = reader.readAsDataURL(this.files[0]);
    });

    async function runPrediction(imageData){
        let tensor = preprocessImage(imageData, "mobilenet");
        const resize_image = tf.reshape(tensor, [1, 224, 224, 3],'resize');
        let prediction = await model.predict(tensor).data();
        console.log('prediction:::'+ prediction);

        let top5 = Array.from(prediction)
        .map(function(p,i){
            return {
                probability: p,
                className: prediction[i]
            };
        }).sort(function(a,b){
            return b.probability-a.probability;
        }).slice(0,1);

        $("#prediction-list").empty();
        top5.forEach(function(p){
            $("#prediction-list").append(`<li>${p.className}:${p.probability.toFixed(6)}</li>`);
        });
    }

    //-------------------------- Helpers: --------------------------------

    function prepareModel(){
        model.add(tf.layers.conv2d({
            inputShape: [224, 224 , 3],
            kernelSize: 5,
            filters: 8,
            strides: 2,
            activation: 'relu',
            kernelInitializer: 'VarianceScaling'
        }));
        model.add(tf.layers.maxPooling2d({poolSize: 2, strides: 2}));
        model.add(tf.layers.maxPooling2d({poolSize: 2, strides: 2}));
        model.add(tf.layers.flatten({}));
        model.add(tf.layers.dense({units: 64, activation: 'relu'}));
        model.add(tf.layers.dense({units: 2, activation: 'softmax'}));
        model.compile({
            loss: 'meanSquaredError',
            optimizer : 'sgd'
        });
        model.summary()
    }

    function preprocessImage(image, modelName)
    {
        let tensor = tf.browser.fromPixels(image)
        .resizeNearestNeighbor([224,224])
        .toFloat();

        let offset=tf.scalar(127.5);

        return tensor.sub(offset)
        .div(offset)
        .expandDims();
    }
</script>

1 个答案:

答案 0 :(得分:0)

  

如何添加其他类(例如“狗”)?

您还可以通过将新课程添加到训练数据集中来使模型在另一个课程上进行预测。假设您添加了Dog类,那么现在您的数据集包含CatDog图片。

  

应该是这样(摘要):model.fit([img1,img2,img3],[label1,label2,label3] ...)

是的,图像x = [img1, img2, img3]和相应图像的标签y = [label1, label2, label3]。在x中,img1img2或任何其他图像可以是猫图像或狗图像。为简单起见,您可以提供以numpy数组表示的图像。 Here是输入训练数据的外观。

  

标签和训练集之间的关系是什么?

标签是训练集的一部分。如果要执行监督分类,则必须将标签与输入要素(图像)一起输入。

更新以获取更新的问题

    [[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
     [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
     [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
     [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
     [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
     [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 1, 0, 0, 0, 0, 0, 0, 0, 0]]

在这种情况下,您的形状不匹配。这里的形状为(10,10),但是模型希望输入的标签为形状(10,)

如果您有两个班级,则无需在[[1, 0, 0, 0, 0, 0, 0, 0, 0, 0]](标签)中用[0, 1, 0, 0, 0, 0, 0, 0, 0, 0]代表一个班级,或用Y代表另一个班级。其余的零是做什么的?只需保持简单并定义如下即可。

如果您有cat,则将其标记为0,对于dog图像,则将其标记为1,反之亦然。 然后像[0,1,0]一样喂它,首先0是img1的标签,1是img2的标签,0是img3的标签。