Fix argmax for tensors

This commit is contained in:
Roland Grinis 2021-07-08 21:08:20 +01:00
parent b2b063196d
commit 773ff10dd1
8 changed files with 128 additions and 75 deletions

View File

@ -163,7 +163,7 @@ class NeuralNetwork(private val layers: List<Layer>) {
for ((xBatch, yBatch) in iterBatch(xTrain, yTrain)) { for ((xBatch, yBatch) in iterBatch(xTrain, yTrain)) {
train(xBatch, yBatch) train(xBatch, yBatch)
} }
println("Accuracy:${accuracy(yTrain, predict(xTrain).argMax(1, true))}") println("Accuracy:${accuracy(yTrain, predict(xTrain).argMax(1, true).asDouble())}")
} }
} }
@ -230,7 +230,7 @@ fun main() = BroadcastDoubleTensorAlgebra {
val prediction = model.predict(xTest) val prediction = model.predict(xTest)
// process raw prediction via argMax // process raw prediction via argMax
val predictionLabels = prediction.argMax(1, true) val predictionLabels = prediction.argMax(1, true).asDouble()
// find out accuracy // find out accuracy
val acc = accuracy(yTest, predictionLabels) val acc = accuracy(yTest, predictionLabels)

View File

@ -28,6 +28,8 @@ public sealed interface Nd4jTensorAlgebra<T : Number> : AnalyticTensorAlgebra<T>
*/ */
public fun INDArray.wrap(): Nd4jArrayStructure<T> public fun INDArray.wrap(): Nd4jArrayStructure<T>
public fun INDArray.wrapInt(): Nd4jArrayStructure<Int>
/** /**
* Unwraps to or acquires [INDArray] from [StructureND]. * Unwraps to or acquires [INDArray] from [StructureND].
*/ */
@ -90,8 +92,8 @@ public sealed interface Nd4jTensorAlgebra<T : Number> : AnalyticTensorAlgebra<T>
public override fun Tensor<T>.view(shape: IntArray): Tensor<T> = ndArray.reshape(shape).wrap() public override fun Tensor<T>.view(shape: IntArray): Tensor<T> = ndArray.reshape(shape).wrap()
public override fun Tensor<T>.viewAs(other: Tensor<T>): Tensor<T> = view(other.shape) public override fun Tensor<T>.viewAs(other: Tensor<T>): Tensor<T> = view(other.shape)
public fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): Tensor<T> = override fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): Tensor<Int> =
ndBase.get().argmax(ndArray, keepDim, dim).wrap() ndBase.get().argmax(ndArray, keepDim, dim).wrapInt()
public override fun Tensor<T>.mean(dim: Int, keepDim: Boolean): Tensor<T> = ndArray.mean(keepDim, dim).wrap() public override fun Tensor<T>.mean(dim: Int, keepDim: Boolean): Tensor<T> = ndArray.mean(keepDim, dim).wrap()
@ -144,6 +146,7 @@ public sealed interface Nd4jTensorAlgebra<T : Number> : AnalyticTensorAlgebra<T>
*/ */
public object DoubleNd4jTensorAlgebra : Nd4jTensorAlgebra<Double> { public object DoubleNd4jTensorAlgebra : Nd4jTensorAlgebra<Double> {
public override fun INDArray.wrap(): Nd4jArrayStructure<Double> = asDoubleStructure() public override fun INDArray.wrap(): Nd4jArrayStructure<Double> = asDoubleStructure()
public override fun INDArray.wrapInt(): Nd4jArrayStructure<Int> = asIntStructure()
@OptIn(PerformancePitfall::class) @OptIn(PerformancePitfall::class)
public override val StructureND<Double>.ndArray: INDArray public override val StructureND<Double>.ndArray: INDArray

View File

@ -55,6 +55,8 @@ class JNoa {
public static native long viewTensor(long tensorHandle, int[] shape); public static native long viewTensor(long tensorHandle, int[] shape);
public static native long viewAsTensor(long tensorHandle, long asTensorHandle);
public static native String tensorToString(long tensorHandle); public static native String tensorToString(long tensorHandle);
public static native int getDim(long tensorHandle); public static native int getDim(long tensorHandle);
@ -75,6 +77,10 @@ class JNoa {
public static native int getItemInt(long tensorHandle); public static native int getItemInt(long tensorHandle);
public static native long getIndex(long tensorHandle, int index);
public static native long getIndexTensor(long tensorHandle, long indexTensorHandle);
public static native double getDouble(long tensorHandle, int[] index); public static native double getDouble(long tensorHandle, int[] index);
public static native float getFloat(long tensorHandle, int[] index); public static native float getFloat(long tensorHandle, int[] index);
@ -175,23 +181,15 @@ class JNoa {
public static native long absTensor(long tensorHandle); public static native long absTensor(long tensorHandle);
public static native void absTensorAssign(long tensorHandle);
public static native long transposeTensor(long tensorHandle, int i, int j); public static native long transposeTensor(long tensorHandle, int i, int j);
public static native void transposeTensorAssign(long tensorHandle, int i, int j);
public static native long expTensor(long tensorHandle); public static native long expTensor(long tensorHandle);
public static native void expTensorAssign(long tensorHandle);
public static native long logTensor(long tensorHandle); public static native long logTensor(long tensorHandle);
public static native void logTensorAssign(long tensorHandle);
public static native long sumTensor(long tensorHandle); public static native long sumTensor(long tensorHandle);
public static native void sumTensorAssign(long tensorHandle); public static native long sumDimTensor(long tensorHandle, int dim, boolean keepDim);
public static native long matmul(long lhs, long rhs); public static native long matmul(long lhs, long rhs);

View File

@ -57,6 +57,55 @@ constructor(protected val scope: NoaScope) : TensorAlgebra<T> {
override operator fun Tensor<T>.unaryMinus(): TensorType = override operator fun Tensor<T>.unaryMinus(): TensorType =
wrap(JNoa.unaryMinus(this.cast().tensorHandle)) wrap(JNoa.unaryMinus(this.cast().tensorHandle))
override infix fun Tensor<T>.dot(other: Tensor<T>): TensorType {
return wrap(JNoa.matmul(this.cast().tensorHandle, other.cast().tensorHandle))
}
public infix fun Tensor<T>.dotAssign(other: Tensor<T>): Unit {
JNoa.matmulAssign(this.cast().tensorHandle, other.cast().tensorHandle)
}
public infix fun Tensor<T>.dotRightAssign(other: Tensor<T>): Unit {
JNoa.matmulRightAssign(this.cast().tensorHandle, other.cast().tensorHandle)
}
override operator fun Tensor<T>.get(i: Int): TensorType =
wrap(JNoa.getIndex(this.cast().tensorHandle, i))
public operator fun Tensor<T>.get(indexTensor: NoaLongTensor): TensorType =
wrap(JNoa.getIndexTensor(this.cast().tensorHandle, indexTensor.tensorHandle))
override fun diagonalEmbedding(
diagonalEntries: Tensor<T>, offset: Int, dim1: Int, dim2: Int
): TensorType =
wrap(JNoa.diagEmbed(diagonalEntries.cast().tensorHandle, offset, dim1, dim2))
override fun Tensor<T>.transpose(i: Int, j: Int): TensorType {
return wrap(JNoa.transposeTensor(this.cast().tensorHandle, i, j))
}
override fun Tensor<T>.view(shape: IntArray): TensorType {
return wrap(JNoa.viewTensor(this.cast().tensorHandle, shape))
}
override fun Tensor<T>.viewAs(other: Tensor<T>): TensorType {
return wrap(JNoa.viewAsTensor(this.cast().tensorHandle, other.cast().tensorHandle))
}
public fun Tensor<T>.abs(): TensorType = wrap(JNoa.absTensor(this.cast().tensorHandle))
public fun Tensor<T>.sumAll(): TensorType = wrap(JNoa.sumTensor(this.cast().tensorHandle))
override fun Tensor<T>.sum(): T = sumAll().item()
override fun Tensor<T>.sum(dim: Int, keepDim: Boolean): TensorType =
wrap(JNoa.sumDimTensor(this.cast().tensorHandle, dim, keepDim))
public fun Tensor<T>.copy(): TensorType =
wrap(JNoa.copyTensor(this.cast().tensorHandle))
public fun Tensor<T>.copyToDevice(device: Device): TensorType =
wrap(JNoa.copyToDevice(this.cast().tensorHandle, device.toInt()))
} }
public abstract class NoaPartialDivisionAlgebra<T, TensorType : NoaTensor<T>> public abstract class NoaPartialDivisionAlgebra<T, TensorType : NoaTensor<T>>

View File

@ -143,6 +143,14 @@ JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_copyToInt
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_viewTensor JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_viewTensor
(JNIEnv *, jclass, jlong, jintArray); (JNIEnv *, jclass, jlong, jintArray);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: viewAsTensor
* Signature: (JJ)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_viewAsTensor
(JNIEnv *, jclass, jlong, jlong);
/* /*
* Class: space_kscience_kmath_noa_JNoa * Class: space_kscience_kmath_noa_JNoa
* Method: tensorToString * Method: tensorToString
@ -223,6 +231,22 @@ JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_getItemLong
JNIEXPORT jint JNICALL Java_space_kscience_kmath_noa_JNoa_getItemInt JNIEXPORT jint JNICALL Java_space_kscience_kmath_noa_JNoa_getItemInt
(JNIEnv *, jclass, jlong); (JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: getIndex
* Signature: (JI)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_getIndex
(JNIEnv *, jclass, jlong, jint);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: getIndexTensor
* Signature: (JJ)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_getIndexTensor
(JNIEnv *, jclass, jlong, jlong);
/* /*
* Class: space_kscience_kmath_noa_JNoa * Class: space_kscience_kmath_noa_JNoa
* Method: getDouble * Method: getDouble
@ -623,14 +647,6 @@ JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_unaryMinus
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_absTensor JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_absTensor
(JNIEnv *, jclass, jlong); (JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: absTensorAssign
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_absTensorAssign
(JNIEnv *, jclass, jlong);
/* /*
* Class: space_kscience_kmath_noa_JNoa * Class: space_kscience_kmath_noa_JNoa
* Method: transposeTensor * Method: transposeTensor
@ -639,14 +655,6 @@ JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_absTensorAssign
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_transposeTensor JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_transposeTensor
(JNIEnv *, jclass, jlong, jint, jint); (JNIEnv *, jclass, jlong, jint, jint);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: transposeTensorAssign
* Signature: (JII)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_transposeTensorAssign
(JNIEnv *, jclass, jlong, jint, jint);
/* /*
* Class: space_kscience_kmath_noa_JNoa * Class: space_kscience_kmath_noa_JNoa
* Method: expTensor * Method: expTensor
@ -655,14 +663,6 @@ JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_transposeTensorAssign
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_expTensor JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_expTensor
(JNIEnv *, jclass, jlong); (JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: expTensorAssign
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_expTensorAssign
(JNIEnv *, jclass, jlong);
/* /*
* Class: space_kscience_kmath_noa_JNoa * Class: space_kscience_kmath_noa_JNoa
* Method: logTensor * Method: logTensor
@ -671,14 +671,6 @@ JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_expTensorAssign
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_logTensor JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_logTensor
(JNIEnv *, jclass, jlong); (JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: logTensorAssign
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_logTensorAssign
(JNIEnv *, jclass, jlong);
/* /*
* Class: space_kscience_kmath_noa_JNoa * Class: space_kscience_kmath_noa_JNoa
* Method: sumTensor * Method: sumTensor
@ -689,11 +681,11 @@ JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_sumTensor
/* /*
* Class: space_kscience_kmath_noa_JNoa * Class: space_kscience_kmath_noa_JNoa
* Method: sumTensorAssign * Method: sumDimTensor
* Signature: (J)V * Signature: (JIZ)J
*/ */
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_sumTensorAssign JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_sumDimTensor
(JNIEnv *, jclass, jlong); (JNIEnv *, jclass, jlong, jint, jboolean);
/* /*
* Class: space_kscience_kmath_noa_JNoa * Class: space_kscience_kmath_noa_JNoa

View File

@ -6,6 +6,7 @@
package space.kscience.kmath.tensors.api package space.kscience.kmath.tensors.api
import space.kscience.kmath.operations.Algebra import space.kscience.kmath.operations.Algebra
import space.kscience.kmath.tensors.core.DoubleTensor
/** /**
* Algebra over a ring on [Tensor]. * Algebra over a ring on [Tensor].
@ -313,4 +314,17 @@ public interface TensorAlgebra<T> : Algebra<Tensor<T>> {
*/ */
public fun Tensor<T>.max(dim: Int, keepDim: Boolean): Tensor<T> public fun Tensor<T>.max(dim: Int, keepDim: Boolean): Tensor<T>
/**
* Returns the index of maximum value of each row of the input tensor in the given dimension [dim].
*
* If [keepDim] is true, the output tensor is of the same size as
* input except in the dimension [dim] where it is of size 1.
* Otherwise, [dim] is squeezed, resulting in the output tensor having 1 fewer dimension.
*
* @param dim the dimension to reduce.
* @param keepDim whether the output tensor has [dim] retained or not.
* @return the the index of maximum value of each row of the input tensor in the given dimension [dim].
*/
public fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): Tensor<Int>
} }

View File

@ -5,8 +5,10 @@
package space.kscience.kmath.tensors.core package space.kscience.kmath.tensors.core
import space.kscience.kmath.structures.MutableBuffer
import space.kscience.kmath.nd.as1D import space.kscience.kmath.nd.as1D
import space.kscience.kmath.nd.as2D import space.kscience.kmath.nd.as2D
import space.kscience.kmath.structures.asMutableBuffer
import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra
import space.kscience.kmath.tensors.api.LinearOpsTensorAlgebra import space.kscience.kmath.tensors.api.LinearOpsTensorAlgebra
import space.kscience.kmath.tensors.api.Tensor import space.kscience.kmath.tensors.api.Tensor
@ -537,11 +539,11 @@ public open class DoubleTensorAlgebra :
internal fun Tensor<Double>.fold(foldFunction: (DoubleArray) -> Double): Double = internal fun Tensor<Double>.fold(foldFunction: (DoubleArray) -> Double): Double =
foldFunction(tensor.toDoubleArray()) foldFunction(tensor.toDoubleArray())
internal fun Tensor<Double>.foldDim( internal fun <R> Tensor<Double>.foldDim(
foldFunction: (DoubleArray) -> Double, foldFunction: (DoubleArray) -> R,
dim: Int, dim: Int,
keepDim: Boolean, keepDim: Boolean,
): DoubleTensor { ): BufferedTensor<R> {
check(dim < dimension) { "Dimension $dim out of range $dimension" } check(dim < dimension) { "Dimension $dim out of range $dimension" }
val resShape = if (keepDim) { val resShape = if (keepDim) {
shape.take(dim).toIntArray() + intArrayOf(1) + shape.takeLast(dimension - dim - 1).toIntArray() shape.take(dim).toIntArray() + intArrayOf(1) + shape.takeLast(dimension - dim - 1).toIntArray()
@ -549,7 +551,9 @@ public open class DoubleTensorAlgebra :
shape.take(dim).toIntArray() + shape.takeLast(dimension - dim - 1).toIntArray() shape.take(dim).toIntArray() + shape.takeLast(dimension - dim - 1).toIntArray()
} }
val resNumElements = resShape.reduce(Int::times) val resNumElements = resShape.reduce(Int::times)
val resTensor = DoubleTensor(resShape, DoubleArray(resNumElements) { 0.0 }, 0) val init = foldFunction(DoubleArray(1){0.0})
val resTensor = BufferedTensor(resShape,
MutableList(resNumElements) { init }.asMutableBuffer(), 0)
for (index in resTensor.linearStructure.indices()) { for (index in resTensor.linearStructure.indices()) {
val prefix = index.take(dim).toIntArray() val prefix = index.take(dim).toIntArray()
val suffix = index.takeLast(dimension - dim - 1).toIntArray() val suffix = index.takeLast(dimension - dim - 1).toIntArray()
@ -557,41 +561,30 @@ public open class DoubleTensorAlgebra :
tensor[prefix + intArrayOf(i) + suffix] tensor[prefix + intArrayOf(i) + suffix]
}) })
} }
return resTensor return resTensor
} }
override fun Tensor<Double>.sum(): Double = tensor.fold { it.sum() } override fun Tensor<Double>.sum(): Double = tensor.fold { it.sum() }
override fun Tensor<Double>.sum(dim: Int, keepDim: Boolean): DoubleTensor = override fun Tensor<Double>.sum(dim: Int, keepDim: Boolean): DoubleTensor =
foldDim({ x -> x.sum() }, dim, keepDim) foldDim({ x -> x.sum() }, dim, keepDim).toDoubleTensor()
override fun Tensor<Double>.min(): Double = this.fold { it.minOrNull()!! } override fun Tensor<Double>.min(): Double = this.fold { it.minOrNull()!! }
override fun Tensor<Double>.min(dim: Int, keepDim: Boolean): DoubleTensor = override fun Tensor<Double>.min(dim: Int, keepDim: Boolean): DoubleTensor =
foldDim({ x -> x.minOrNull()!! }, dim, keepDim) foldDim({ x -> x.minOrNull()!! }, dim, keepDim).toDoubleTensor()
override fun Tensor<Double>.max(): Double = this.fold { it.maxOrNull()!! } override fun Tensor<Double>.max(): Double = this.fold { it.maxOrNull()!! }
override fun Tensor<Double>.max(dim: Int, keepDim: Boolean): DoubleTensor = override fun Tensor<Double>.max(dim: Int, keepDim: Boolean): DoubleTensor =
foldDim({ x -> x.maxOrNull()!! }, dim, keepDim) foldDim({ x -> x.maxOrNull()!! }, dim, keepDim).toDoubleTensor()
/** override fun Tensor<Double>.argMax(dim: Int, keepDim: Boolean): IntTensor =
* Returns the index of maximum value of each row of the input tensor in the given dimension [dim].
*
* If [keepDim] is true, the output tensor is of the same size as
* input except in the dimension [dim] where it is of size 1.
* Otherwise, [dim] is squeezed, resulting in the output tensor having 1 fewer dimension.
*
* @param dim the dimension to reduce.
* @param keepDim whether the output tensor has [dim] retained or not.
* @return the the index of maximum value of each row of the input tensor in the given dimension [dim].
*/
public fun Tensor<Double>.argMax(dim: Int, keepDim: Boolean): DoubleTensor =
foldDim({ x -> foldDim({ x ->
x.withIndex().maxByOrNull { it.value }?.index!!.toDouble() x.withIndex().maxByOrNull { it.value }?.index!!
}, dim, keepDim) }, dim, keepDim).toIntTensor()
override fun Tensor<Double>.mean(): Double = this.fold { it.sum() / tensor.numElements } override fun Tensor<Double>.mean(): Double = this.fold { it.sum() / tensor.numElements }
@ -604,7 +597,7 @@ public open class DoubleTensorAlgebra :
}, },
dim, dim,
keepDim keepDim
) ).toDoubleTensor()
override fun Tensor<Double>.std(): Double = this.fold { arr -> override fun Tensor<Double>.std(): Double = this.fold { arr ->
val mean = arr.sum() / tensor.numElements val mean = arr.sum() / tensor.numElements
@ -619,7 +612,7 @@ public open class DoubleTensorAlgebra :
}, },
dim, dim,
keepDim keepDim
) ).toDoubleTensor()
override fun Tensor<Double>.variance(): Double = this.fold { arr -> override fun Tensor<Double>.variance(): Double = this.fold { arr ->
val mean = arr.sum() / tensor.numElements val mean = arr.sum() / tensor.numElements
@ -634,7 +627,7 @@ public open class DoubleTensorAlgebra :
}, },
dim, dim,
keepDim keepDim
) ).toDoubleTensor()
private fun cov(x: DoubleTensor, y: DoubleTensor): Double { private fun cov(x: DoubleTensor, y: DoubleTensor): Double {
val n = x.shape[0] val n = x.shape[0]

View File

@ -6,6 +6,7 @@
package space.kscience.kmath.tensors.core package space.kscience.kmath.tensors.core
import space.kscience.kmath.structures.IntBuffer import space.kscience.kmath.structures.IntBuffer
import space.kscience.kmath.tensors.core.internal.array
/** /**
* Default [BufferedTensor] implementation for [Int] values * Default [BufferedTensor] implementation for [Int] values
@ -14,4 +15,7 @@ public class IntTensor internal constructor(
shape: IntArray, shape: IntArray,
buffer: IntArray, buffer: IntArray,
offset: Int = 0 offset: Int = 0
) : BufferedTensor<Int>(shape, IntBuffer(buffer), offset) ) : BufferedTensor<Int>(shape, IntBuffer(buffer), offset){
public fun asDouble() : DoubleTensor =
DoubleTensor(shape, mutableBuffer.array().map{ it.toDouble()}.toDoubleArray(), bufferStart)
}