如何转换Tensorflow数据集[EMNIST / balanced]的数据类型(从uint8到float32)

时间:2020-03-15 13:03:34

标签: tensorflow

我正在使用Tensorflow数据集“ emnist / balanced”。默认情况下,功能值的数据类型为uint8。但是,Tensorflow模型仅接受浮点值。

如何将要素和标签数据类型转换为float32。

代码在这里:

#########################################################3
import tensorflow as tf
import tensorflow_datasets as tfds

datasets, info = tfds.load(name="emnist/balanced", with_info=True, as_supervised=True)

emnist_train, emnist_test = datasets['train'], datasets['test']

.
.
.
.
.
.

history = model.fit(emnist_train, epochs = 10)

#validation

test_loss, test_acc = model.evaluate(emnist_test, verbose=2)

print(test_acc)


Error --
      2 
      3 
----> 4 history = model.fit(emnist_train, epochs = 10)
      5 
      6 #validation

TypeError: Value passed to parameter 'features' has DataType uint8 not in list of allowed values: float16, bfloat16, float32, float64

TypeError:传递给参数“功能”的值的数据类型uint8不在允许的值列表中:float16,bfloat16,float32,float64

1 个答案:

答案 0 :(得分:-1)

请参考工作代码为MNIST数据集训练ANN

public class UniquePrimes {


    private static BlockingQueue<Integer> linkedBlockingQueue = new LinkedBlockingQueue<Integer>(); 
    static ConcurrentHashMap<Integer, String> primesproduced = new ConcurrentHashMap<Integer, String>();

    public static void main(String[] args) {

        Scanner reader = new Scanner(System.in);
        System.out.print("Enter number of threads you want to create: ");
        int NOOFTHREADS = reader.nextInt();
        reader.close();

        ExecutorService executorPool = Executors.newFixedThreadPool(NOOFTHREADS);

        AtomicInteger currentPrime = new AtomicInteger();
        Runnable producer = () -> {
            String threadName = Thread.currentThread().getName();

            int p = 0;
            try {

                p = generateNextPrime(currentPrime.incrementAndGet());
                linkedBlockingQueue.put(p);
                primesproduced.put(p, threadName);
                System.out.println("Thread " + threadName + " produced prime number " + p);

            } catch (InterruptedException e) {

                e.printStackTrace();
            }
        };



        List<Runnable> tasks = new ArrayList<Runnable>();

        for (int i = 0; i < NOOFTHREADS; i++) {
            tasks.add(producer);

        }

        CompletableFuture<?>[] futures = tasks.stream().map(task -> CompletableFuture.runAsync(task, executorPool))
                .toArray(CompletableFuture[]::new);

        CompletableFuture.allOf(futures).join();
        executorPool.shutdown();

        System.out.println("\nTotal unique primes produced: " + primesproduced.size() + " and they are: ");


        System.out.print(
        primesproduced.entrySet().stream().filter(map -> map.getKey().intValue()>0).map(k -> "[" + k + "]").collect(Collectors.joining(",")));

        }
    }

    private static int generateNextPrime(int currentPrime) {    

        currentPrime++;
        if (currentPrime < 2) {
            currentPrime = 2;

            return currentPrime;

        }
        for (int i = 2; i < currentPrime; i++) {
            if (currentPrime % i == 0) {
                currentPrime++;
                i = 2;
            } else {
                continue;
            }
        }       
        return currentPrime;
    }
}

输出:

T / F版本:2.1.0

火车准确度:91.06

测试准确度:0.8871

相关问题