From 773ff10dd15cb3d719333a00abdf7267e4a0ca28 Mon Sep 17 00:00:00 2001 From: Roland Grinis Date: Thu, 8 Jul 2021 21:08:20 +0100 Subject: [PATCH] Fix argmax for tensors --- .../kscience/kmath/tensors/NeuralNetwork.kt | 4 +- .../kscience/kmath/nd4j/Nd4jTensorAlgebra.kt | 7 +- .../java/space/kscience/kmath/noa/JNoa.java | 16 ++--- .../space/kscience/kmath/noa/algebras.kt | 49 ++++++++++++++ .../resources/space_kscience_kmath_noa_JNoa.h | 64 ++++++++----------- .../kmath/tensors/api/TensorAlgebra.kt | 14 ++++ .../kmath/tensors/core/DoubleTensorAlgebra.kt | 43 ++++++------- .../kscience/kmath/tensors/core/IntTensor.kt | 6 +- 8 files changed, 128 insertions(+), 75 deletions(-) diff --git a/examples/src/main/kotlin/space/kscience/kmath/tensors/NeuralNetwork.kt b/examples/src/main/kotlin/space/kscience/kmath/tensors/NeuralNetwork.kt index b262bee02..1e961fc7b 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/tensors/NeuralNetwork.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/tensors/NeuralNetwork.kt @@ -163,7 +163,7 @@ class NeuralNetwork(private val layers: List) { for ((xBatch, yBatch) in iterBatch(xTrain, yTrain)) { train(xBatch, yBatch) } - println("Accuracy:${accuracy(yTrain, predict(xTrain).argMax(1, true))}") + println("Accuracy:${accuracy(yTrain, predict(xTrain).argMax(1, true).asDouble())}") } } @@ -230,7 +230,7 @@ fun main() = BroadcastDoubleTensorAlgebra { val prediction = model.predict(xTest) // process raw prediction via argMax - val predictionLabels = prediction.argMax(1, true) + val predictionLabels = prediction.argMax(1, true).asDouble() // find out accuracy val acc = accuracy(yTest, predictionLabels) diff --git a/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jTensorAlgebra.kt b/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jTensorAlgebra.kt index 57bc974ac..bc58eefc3 100644 --- a/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jTensorAlgebra.kt +++ b/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jTensorAlgebra.kt @@ -28,6 +28,8 @@ public sealed interface Nd4jTensorAlgebra : AnalyticTensorAlgebra */ public fun INDArray.wrap(): Nd4jArrayStructure + public fun INDArray.wrapInt(): Nd4jArrayStructure + /** * Unwraps to or acquires [INDArray] from [StructureND]. */ @@ -90,8 +92,8 @@ public sealed interface Nd4jTensorAlgebra : AnalyticTensorAlgebra public override fun Tensor.view(shape: IntArray): Tensor = ndArray.reshape(shape).wrap() public override fun Tensor.viewAs(other: Tensor): Tensor = view(other.shape) - public fun Tensor.argMax(dim: Int, keepDim: Boolean): Tensor = - ndBase.get().argmax(ndArray, keepDim, dim).wrap() + override fun Tensor.argMax(dim: Int, keepDim: Boolean): Tensor = + ndBase.get().argmax(ndArray, keepDim, dim).wrapInt() public override fun Tensor.mean(dim: Int, keepDim: Boolean): Tensor = ndArray.mean(keepDim, dim).wrap() @@ -144,6 +146,7 @@ public sealed interface Nd4jTensorAlgebra : AnalyticTensorAlgebra */ public object DoubleNd4jTensorAlgebra : Nd4jTensorAlgebra { public override fun INDArray.wrap(): Nd4jArrayStructure = asDoubleStructure() + public override fun INDArray.wrapInt(): Nd4jArrayStructure = asIntStructure() @OptIn(PerformancePitfall::class) public override val StructureND.ndArray: INDArray diff --git a/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java b/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java index 934762a1a..20f8c7510 100644 --- a/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java +++ b/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java @@ -55,6 +55,8 @@ class JNoa { public static native long viewTensor(long tensorHandle, int[] shape); + public static native long viewAsTensor(long tensorHandle, long asTensorHandle); + public static native String tensorToString(long tensorHandle); public static native int getDim(long tensorHandle); @@ -75,6 +77,10 @@ class JNoa { public static native int getItemInt(long tensorHandle); + public static native long getIndex(long tensorHandle, int index); + + public static native long getIndexTensor(long tensorHandle, long indexTensorHandle); + public static native double getDouble(long tensorHandle, int[] index); public static native float getFloat(long tensorHandle, int[] index); @@ -175,23 +181,15 @@ class JNoa { public static native long absTensor(long tensorHandle); - public static native void absTensorAssign(long tensorHandle); - public static native long transposeTensor(long tensorHandle, int i, int j); - public static native void transposeTensorAssign(long tensorHandle, int i, int j); - public static native long expTensor(long tensorHandle); - public static native void expTensorAssign(long tensorHandle); - public static native long logTensor(long tensorHandle); - public static native void logTensorAssign(long tensorHandle); - public static native long sumTensor(long tensorHandle); - public static native void sumTensorAssign(long tensorHandle); + public static native long sumDimTensor(long tensorHandle, int dim, boolean keepDim); public static native long matmul(long lhs, long rhs); diff --git a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt index fa164df29..a5d69b6bd 100644 --- a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt +++ b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt @@ -57,6 +57,55 @@ constructor(protected val scope: NoaScope) : TensorAlgebra { override operator fun Tensor.unaryMinus(): TensorType = wrap(JNoa.unaryMinus(this.cast().tensorHandle)) + + override infix fun Tensor.dot(other: Tensor): TensorType { + return wrap(JNoa.matmul(this.cast().tensorHandle, other.cast().tensorHandle)) + } + + public infix fun Tensor.dotAssign(other: Tensor): Unit { + JNoa.matmulAssign(this.cast().tensorHandle, other.cast().tensorHandle) + } + + public infix fun Tensor.dotRightAssign(other: Tensor): Unit { + JNoa.matmulRightAssign(this.cast().tensorHandle, other.cast().tensorHandle) + } + + override operator fun Tensor.get(i: Int): TensorType = + wrap(JNoa.getIndex(this.cast().tensorHandle, i)) + + public operator fun Tensor.get(indexTensor: NoaLongTensor): TensorType = + wrap(JNoa.getIndexTensor(this.cast().tensorHandle, indexTensor.tensorHandle)) + + override fun diagonalEmbedding( + diagonalEntries: Tensor, offset: Int, dim1: Int, dim2: Int + ): TensorType = + wrap(JNoa.diagEmbed(diagonalEntries.cast().tensorHandle, offset, dim1, dim2)) + + override fun Tensor.transpose(i: Int, j: Int): TensorType { + return wrap(JNoa.transposeTensor(this.cast().tensorHandle, i, j)) + } + + override fun Tensor.view(shape: IntArray): TensorType { + return wrap(JNoa.viewTensor(this.cast().tensorHandle, shape)) + } + + override fun Tensor.viewAs(other: Tensor): TensorType { + return wrap(JNoa.viewAsTensor(this.cast().tensorHandle, other.cast().tensorHandle)) + } + + public fun Tensor.abs(): TensorType = wrap(JNoa.absTensor(this.cast().tensorHandle)) + + public fun Tensor.sumAll(): TensorType = wrap(JNoa.sumTensor(this.cast().tensorHandle)) + override fun Tensor.sum(): T = sumAll().item() + override fun Tensor.sum(dim: Int, keepDim: Boolean): TensorType = + wrap(JNoa.sumDimTensor(this.cast().tensorHandle, dim, keepDim)) + + public fun Tensor.copy(): TensorType = + wrap(JNoa.copyTensor(this.cast().tensorHandle)) + + public fun Tensor.copyToDevice(device: Device): TensorType = + wrap(JNoa.copyToDevice(this.cast().tensorHandle, device.toInt())) + } public abstract class NoaPartialDivisionAlgebra> diff --git a/kmath-noa/src/main/resources/space_kscience_kmath_noa_JNoa.h b/kmath-noa/src/main/resources/space_kscience_kmath_noa_JNoa.h index 0353bc24b..3c7342aa5 100644 --- a/kmath-noa/src/main/resources/space_kscience_kmath_noa_JNoa.h +++ b/kmath-noa/src/main/resources/space_kscience_kmath_noa_JNoa.h @@ -143,6 +143,14 @@ JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_copyToInt JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_viewTensor (JNIEnv *, jclass, jlong, jintArray); +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: viewAsTensor + * Signature: (JJ)J + */ +JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_viewAsTensor + (JNIEnv *, jclass, jlong, jlong); + /* * Class: space_kscience_kmath_noa_JNoa * Method: tensorToString @@ -223,6 +231,22 @@ JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_getItemLong JNIEXPORT jint JNICALL Java_space_kscience_kmath_noa_JNoa_getItemInt (JNIEnv *, jclass, jlong); +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: getIndex + * Signature: (JI)J + */ +JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_getIndex + (JNIEnv *, jclass, jlong, jint); + +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: getIndexTensor + * Signature: (JJ)J + */ +JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_getIndexTensor + (JNIEnv *, jclass, jlong, jlong); + /* * Class: space_kscience_kmath_noa_JNoa * Method: getDouble @@ -623,14 +647,6 @@ JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_unaryMinus JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_absTensor (JNIEnv *, jclass, jlong); -/* - * Class: space_kscience_kmath_noa_JNoa - * Method: absTensorAssign - * Signature: (J)V - */ -JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_absTensorAssign - (JNIEnv *, jclass, jlong); - /* * Class: space_kscience_kmath_noa_JNoa * Method: transposeTensor @@ -639,14 +655,6 @@ JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_absTensorAssign JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_transposeTensor (JNIEnv *, jclass, jlong, jint, jint); -/* - * Class: space_kscience_kmath_noa_JNoa - * Method: transposeTensorAssign - * Signature: (JII)V - */ -JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_transposeTensorAssign - (JNIEnv *, jclass, jlong, jint, jint); - /* * Class: space_kscience_kmath_noa_JNoa * Method: expTensor @@ -655,14 +663,6 @@ JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_transposeTensorAssign JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_expTensor (JNIEnv *, jclass, jlong); -/* - * Class: space_kscience_kmath_noa_JNoa - * Method: expTensorAssign - * Signature: (J)V - */ -JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_expTensorAssign - (JNIEnv *, jclass, jlong); - /* * Class: space_kscience_kmath_noa_JNoa * Method: logTensor @@ -671,14 +671,6 @@ JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_expTensorAssign JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_logTensor (JNIEnv *, jclass, jlong); -/* - * Class: space_kscience_kmath_noa_JNoa - * Method: logTensorAssign - * Signature: (J)V - */ -JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_logTensorAssign - (JNIEnv *, jclass, jlong); - /* * Class: space_kscience_kmath_noa_JNoa * Method: sumTensor @@ -689,11 +681,11 @@ JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_sumTensor /* * Class: space_kscience_kmath_noa_JNoa - * Method: sumTensorAssign - * Signature: (J)V + * Method: sumDimTensor + * Signature: (JIZ)J */ -JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_sumTensorAssign - (JNIEnv *, jclass, jlong); +JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_sumDimTensor + (JNIEnv *, jclass, jlong, jint, jboolean); /* * Class: space_kscience_kmath_noa_JNoa diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt index 56aafbc3a..52a00c837 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt @@ -6,6 +6,7 @@ package space.kscience.kmath.tensors.api import space.kscience.kmath.operations.Algebra +import space.kscience.kmath.tensors.core.DoubleTensor /** * Algebra over a ring on [Tensor]. @@ -313,4 +314,17 @@ public interface TensorAlgebra : Algebra> { */ public fun Tensor.max(dim: Int, keepDim: Boolean): Tensor + /** + * Returns the index of maximum value of each row of the input tensor in the given dimension [dim]. + * + * If [keepDim] is true, the output tensor is of the same size as + * input except in the dimension [dim] where it is of size 1. + * Otherwise, [dim] is squeezed, resulting in the output tensor having 1 fewer dimension. + * + * @param dim the dimension to reduce. + * @param keepDim whether the output tensor has [dim] retained or not. + * @return the the index of maximum value of each row of the input tensor in the given dimension [dim]. + */ + public fun Tensor.argMax(dim: Int, keepDim: Boolean): Tensor + } diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt index 0d3dc20b6..6f6cebdc9 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt @@ -5,8 +5,10 @@ package space.kscience.kmath.tensors.core +import space.kscience.kmath.structures.MutableBuffer import space.kscience.kmath.nd.as1D import space.kscience.kmath.nd.as2D +import space.kscience.kmath.structures.asMutableBuffer import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra import space.kscience.kmath.tensors.api.LinearOpsTensorAlgebra import space.kscience.kmath.tensors.api.Tensor @@ -537,11 +539,11 @@ public open class DoubleTensorAlgebra : internal fun Tensor.fold(foldFunction: (DoubleArray) -> Double): Double = foldFunction(tensor.toDoubleArray()) - internal fun Tensor.foldDim( - foldFunction: (DoubleArray) -> Double, + internal fun Tensor.foldDim( + foldFunction: (DoubleArray) -> R, dim: Int, keepDim: Boolean, - ): DoubleTensor { + ): BufferedTensor { check(dim < dimension) { "Dimension $dim out of range $dimension" } val resShape = if (keepDim) { shape.take(dim).toIntArray() + intArrayOf(1) + shape.takeLast(dimension - dim - 1).toIntArray() @@ -549,7 +551,9 @@ public open class DoubleTensorAlgebra : shape.take(dim).toIntArray() + shape.takeLast(dimension - dim - 1).toIntArray() } val resNumElements = resShape.reduce(Int::times) - val resTensor = DoubleTensor(resShape, DoubleArray(resNumElements) { 0.0 }, 0) + val init = foldFunction(DoubleArray(1){0.0}) + val resTensor = BufferedTensor(resShape, + MutableList(resNumElements) { init }.asMutableBuffer(), 0) for (index in resTensor.linearStructure.indices()) { val prefix = index.take(dim).toIntArray() val suffix = index.takeLast(dimension - dim - 1).toIntArray() @@ -557,41 +561,30 @@ public open class DoubleTensorAlgebra : tensor[prefix + intArrayOf(i) + suffix] }) } - return resTensor } + override fun Tensor.sum(): Double = tensor.fold { it.sum() } override fun Tensor.sum(dim: Int, keepDim: Boolean): DoubleTensor = - foldDim({ x -> x.sum() }, dim, keepDim) + foldDim({ x -> x.sum() }, dim, keepDim).toDoubleTensor() override fun Tensor.min(): Double = this.fold { it.minOrNull()!! } override fun Tensor.min(dim: Int, keepDim: Boolean): DoubleTensor = - foldDim({ x -> x.minOrNull()!! }, dim, keepDim) + foldDim({ x -> x.minOrNull()!! }, dim, keepDim).toDoubleTensor() override fun Tensor.max(): Double = this.fold { it.maxOrNull()!! } override fun Tensor.max(dim: Int, keepDim: Boolean): DoubleTensor = - foldDim({ x -> x.maxOrNull()!! }, dim, keepDim) + foldDim({ x -> x.maxOrNull()!! }, dim, keepDim).toDoubleTensor() - /** - * Returns the index of maximum value of each row of the input tensor in the given dimension [dim]. - * - * If [keepDim] is true, the output tensor is of the same size as - * input except in the dimension [dim] where it is of size 1. - * Otherwise, [dim] is squeezed, resulting in the output tensor having 1 fewer dimension. - * - * @param dim the dimension to reduce. - * @param keepDim whether the output tensor has [dim] retained or not. - * @return the the index of maximum value of each row of the input tensor in the given dimension [dim]. - */ - public fun Tensor.argMax(dim: Int, keepDim: Boolean): DoubleTensor = + override fun Tensor.argMax(dim: Int, keepDim: Boolean): IntTensor = foldDim({ x -> - x.withIndex().maxByOrNull { it.value }?.index!!.toDouble() - }, dim, keepDim) + x.withIndex().maxByOrNull { it.value }?.index!! + }, dim, keepDim).toIntTensor() override fun Tensor.mean(): Double = this.fold { it.sum() / tensor.numElements } @@ -604,7 +597,7 @@ public open class DoubleTensorAlgebra : }, dim, keepDim - ) + ).toDoubleTensor() override fun Tensor.std(): Double = this.fold { arr -> val mean = arr.sum() / tensor.numElements @@ -619,7 +612,7 @@ public open class DoubleTensorAlgebra : }, dim, keepDim - ) + ).toDoubleTensor() override fun Tensor.variance(): Double = this.fold { arr -> val mean = arr.sum() / tensor.numElements @@ -634,7 +627,7 @@ public open class DoubleTensorAlgebra : }, dim, keepDim - ) + ).toDoubleTensor() private fun cov(x: DoubleTensor, y: DoubleTensor): Double { val n = x.shape[0] diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/IntTensor.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/IntTensor.kt index ae1e6c8c8..e3d7c3d35 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/IntTensor.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/IntTensor.kt @@ -6,6 +6,7 @@ package space.kscience.kmath.tensors.core import space.kscience.kmath.structures.IntBuffer +import space.kscience.kmath.tensors.core.internal.array /** * Default [BufferedTensor] implementation for [Int] values @@ -14,4 +15,7 @@ public class IntTensor internal constructor( shape: IntArray, buffer: IntArray, offset: Int = 0 -) : BufferedTensor(shape, IntBuffer(buffer), offset) +) : BufferedTensor(shape, IntBuffer(buffer), offset){ + public fun asDouble() : DoubleTensor = + DoubleTensor(shape, mutableBuffer.array().map{ it.toDouble()}.toDoubleArray(), bufferStart) +}