From 62e7073ed2637d888e5f43075af8677d483dc55c Mon Sep 17 00:00:00 2001 From: Roland Grinis Date: Mon, 1 Nov 2021 18:03:46 +0000 Subject: [PATCH] example fixed --- .../main/kotlin/space/kscience/kmath/tensors/neuralNetwork.kt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/src/main/kotlin/space/kscience/kmath/tensors/neuralNetwork.kt b/examples/src/main/kotlin/space/kscience/kmath/tensors/neuralNetwork.kt index cdfc06922..1a2a94534 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/tensors/neuralNetwork.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/tensors/neuralNetwork.kt @@ -163,7 +163,7 @@ class NeuralNetwork(private val layers: List) { for ((xBatch, yBatch) in iterBatch(xTrain, yTrain)) { train(xBatch, yBatch) } - println("Accuracy:${accuracy(yTrain, predict(xTrain).argMax(1, true))}") + println("Accuracy:${accuracy(yTrain, predict(xTrain).argMax(1, true).asDouble())}") } } @@ -230,7 +230,7 @@ fun main() = BroadcastDoubleTensorAlgebra { val prediction = model.predict(xTest) // process raw prediction via argMax - val predictionLabels = prediction.argMax(1, true) + val predictionLabels = prediction.argMax(1, true).asDouble() // find out accuracy val acc = accuracy(yTest, predictionLabels)