v0.3.0-dev-18 #459

Merged
altavir merged 64 commits from dev into master 2022-02-13 17:50:34 +03:00
Showing only changes of commit 62e7073ed2 - Show all commits

View File

@ -163,7 +163,7 @@ class NeuralNetwork(private val layers: List<Layer>) {
for ((xBatch, yBatch) in iterBatch(xTrain, yTrain)) { for ((xBatch, yBatch) in iterBatch(xTrain, yTrain)) {
train(xBatch, yBatch) train(xBatch, yBatch)
} }
println("Accuracy:${accuracy(yTrain, predict(xTrain).argMax(1, true))}") println("Accuracy:${accuracy(yTrain, predict(xTrain).argMax(1, true).asDouble())}")
} }
} }
@ -230,7 +230,7 @@ fun main() = BroadcastDoubleTensorAlgebra {
val prediction = model.predict(xTest) val prediction = model.predict(xTest)
// process raw prediction via argMax // process raw prediction via argMax
val predictionLabels = prediction.argMax(1, true) val predictionLabels = prediction.argMax(1, true).asDouble()
// find out accuracy // find out accuracy
val acc = accuracy(yTest, predictionLabels) val acc = accuracy(yTest, predictionLabels)