v0.3.0-dev-18 #459
@ -163,7 +163,7 @@ class NeuralNetwork(private val layers: List<Layer>) {
|
||||
for ((xBatch, yBatch) in iterBatch(xTrain, yTrain)) {
|
||||
train(xBatch, yBatch)
|
||||
}
|
||||
println("Accuracy:${accuracy(yTrain, predict(xTrain).argMax(1, true))}")
|
||||
println("Accuracy:${accuracy(yTrain, predict(xTrain).argMax(1, true).asDouble())}")
|
||||
}
|
||||
}
|
||||
|
||||
@ -230,7 +230,7 @@ fun main() = BroadcastDoubleTensorAlgebra {
|
||||
val prediction = model.predict(xTest)
|
||||
|
||||
// process raw prediction via argMax
|
||||
val predictionLabels = prediction.argMax(1, true)
|
||||
val predictionLabels = prediction.argMax(1, true).asDouble()
|
||||
|
||||
// find out accuracy
|
||||
val acc = accuracy(yTest, predictionLabels)
|
||||
|
Loading…
Reference in New Issue
Block a user