example fixed

This commit is contained in:
Roland Grinis 2021-11-01 18:03:46 +00:00
parent a994b8a50c
commit 62e7073ed2

View File

@ -163,7 +163,7 @@ class NeuralNetwork(private val layers: List<Layer>) {
for ((xBatch, yBatch) in iterBatch(xTrain, yTrain)) {
train(xBatch, yBatch)
}
println("Accuracy:${accuracy(yTrain, predict(xTrain).argMax(1, true))}")
println("Accuracy:${accuracy(yTrain, predict(xTrain).argMax(1, true).asDouble())}")
}
}
@ -230,7 +230,7 @@ fun main() = BroadcastDoubleTensorAlgebra {
val prediction = model.predict(xTest)
// process raw prediction via argMax
val predictionLabels = prediction.argMax(1, true)
val predictionLabels = prediction.argMax(1, true).asDouble()
// find out accuracy
val acc = accuracy(yTest, predictionLabels)