example fixed

This commit is contained in:
Roland Grinis 2021-11-01 18:03:46 +00:00
parent a994b8a50c
commit 62e7073ed2

View File

@ -163,7 +163,7 @@ class NeuralNetwork(private val layers: List<Layer>) {
for ((xBatch, yBatch) in iterBatch(xTrain, yTrain)) { for ((xBatch, yBatch) in iterBatch(xTrain, yTrain)) {
train(xBatch, yBatch) train(xBatch, yBatch)
} }
println("Accuracy:${accuracy(yTrain, predict(xTrain).argMax(1, true))}") println("Accuracy:${accuracy(yTrain, predict(xTrain).argMax(1, true).asDouble())}")
} }
} }
@ -230,7 +230,7 @@ fun main() = BroadcastDoubleTensorAlgebra {
val prediction = model.predict(xTest) val prediction = model.predict(xTest)
// process raw prediction via argMax // process raw prediction via argMax
val predictionLabels = prediction.argMax(1, true) val predictionLabels = prediction.argMax(1, true).asDouble()
// find out accuracy // find out accuracy
val acc = accuracy(yTest, predictionLabels) val acc = accuracy(yTest, predictionLabels)