forked from kscience/kmath
Fix argmax for tensors
This commit is contained in:
parent
b2b063196d
commit
773ff10dd1
@ -163,7 +163,7 @@ class NeuralNetwork(private val layers: List<Layer>) {
|
||||
for ((xBatch, yBatch) in iterBatch(xTrain, yTrain)) {
|
||||
train(xBatch, yBatch)
|
||||
}
|
||||
println("Accuracy:${accuracy(yTrain, predict(xTrain).argMax(1, true))}")
|
||||
println("Accuracy:${accuracy(yTrain, predict(xTrain).argMax(1, true).asDouble())}")
|
||||
}
|
||||
}
|
||||
|
||||
@ -230,7 +230,7 @@ fun main() = BroadcastDoubleTensorAlgebra {
|
||||
val prediction = model.predict(xTest)
|
||||
|
||||
// process raw prediction via argMax
|
||||
val predictionLabels = prediction.argMax(1, true)
|
||||
val predictionLabels = prediction.argMax(1, true).asDouble()
|
||||
|
||||
// find out accuracy
|
||||
val acc = accuracy(yTest, predictionLabels)
|
||||
|
@ -28,6 +28,8 @@ public sealed interface Nd4jTensorAlgebra<T : Number> : AnalyticTensorAlgebra<T>
|
||||
*/
|
||||
public fun INDArray.wrap(): Nd4jArrayStructure<T>
|
||||
|
||||
public fun INDArray.wrapInt(): Nd4jArrayStructure<Int>
|
||||
|
||||
/**
|
||||
* Unwraps to or acquires [INDArray] from [StructureND].
|
||||
*/
|
||||
@ -90,8 +92,8 @@ public sealed interface Nd4jTensorAlgebra<T : Number> : AnalyticTensorAlgebra<T>
|
||||
public override fun Tensor<T>.view(shape: IntArray): Tensor<T> = ndArray.reshape(shape).wrap()
|
||||
public override fun Tensor<T>.viewAs(other: Tensor<T>): Tensor<T> = view(other.shape)
|
||||
|
||||
public fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): Tensor<T> =
|
||||
ndBase.get().argmax(ndArray, keepDim, dim).wrap()
|
||||
override fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): Tensor<Int> =
|
||||
ndBase.get().argmax(ndArray, keepDim, dim).wrapInt()
|
||||
|
||||
public override fun Tensor<T>.mean(dim: Int, keepDim: Boolean): Tensor<T> = ndArray.mean(keepDim, dim).wrap()
|
||||
|
||||
@ -144,6 +146,7 @@ public sealed interface Nd4jTensorAlgebra<T : Number> : AnalyticTensorAlgebra<T>
|
||||
*/
|
||||
public object DoubleNd4jTensorAlgebra : Nd4jTensorAlgebra<Double> {
|
||||
public override fun INDArray.wrap(): Nd4jArrayStructure<Double> = asDoubleStructure()
|
||||
public override fun INDArray.wrapInt(): Nd4jArrayStructure<Int> = asIntStructure()
|
||||
|
||||
@OptIn(PerformancePitfall::class)
|
||||
public override val StructureND<Double>.ndArray: INDArray
|
||||
|
@ -55,6 +55,8 @@ class JNoa {
|
||||
|
||||
public static native long viewTensor(long tensorHandle, int[] shape);
|
||||
|
||||
public static native long viewAsTensor(long tensorHandle, long asTensorHandle);
|
||||
|
||||
public static native String tensorToString(long tensorHandle);
|
||||
|
||||
public static native int getDim(long tensorHandle);
|
||||
@ -75,6 +77,10 @@ class JNoa {
|
||||
|
||||
public static native int getItemInt(long tensorHandle);
|
||||
|
||||
public static native long getIndex(long tensorHandle, int index);
|
||||
|
||||
public static native long getIndexTensor(long tensorHandle, long indexTensorHandle);
|
||||
|
||||
public static native double getDouble(long tensorHandle, int[] index);
|
||||
|
||||
public static native float getFloat(long tensorHandle, int[] index);
|
||||
@ -175,23 +181,15 @@ class JNoa {
|
||||
|
||||
public static native long absTensor(long tensorHandle);
|
||||
|
||||
public static native void absTensorAssign(long tensorHandle);
|
||||
|
||||
public static native long transposeTensor(long tensorHandle, int i, int j);
|
||||
|
||||
public static native void transposeTensorAssign(long tensorHandle, int i, int j);
|
||||
|
||||
public static native long expTensor(long tensorHandle);
|
||||
|
||||
public static native void expTensorAssign(long tensorHandle);
|
||||
|
||||
public static native long logTensor(long tensorHandle);
|
||||
|
||||
public static native void logTensorAssign(long tensorHandle);
|
||||
|
||||
public static native long sumTensor(long tensorHandle);
|
||||
|
||||
public static native void sumTensorAssign(long tensorHandle);
|
||||
public static native long sumDimTensor(long tensorHandle, int dim, boolean keepDim);
|
||||
|
||||
public static native long matmul(long lhs, long rhs);
|
||||
|
||||
|
@ -57,6 +57,55 @@ constructor(protected val scope: NoaScope) : TensorAlgebra<T> {
|
||||
|
||||
override operator fun Tensor<T>.unaryMinus(): TensorType =
|
||||
wrap(JNoa.unaryMinus(this.cast().tensorHandle))
|
||||
|
||||
override infix fun Tensor<T>.dot(other: Tensor<T>): TensorType {
|
||||
return wrap(JNoa.matmul(this.cast().tensorHandle, other.cast().tensorHandle))
|
||||
}
|
||||
|
||||
public infix fun Tensor<T>.dotAssign(other: Tensor<T>): Unit {
|
||||
JNoa.matmulAssign(this.cast().tensorHandle, other.cast().tensorHandle)
|
||||
}
|
||||
|
||||
public infix fun Tensor<T>.dotRightAssign(other: Tensor<T>): Unit {
|
||||
JNoa.matmulRightAssign(this.cast().tensorHandle, other.cast().tensorHandle)
|
||||
}
|
||||
|
||||
override operator fun Tensor<T>.get(i: Int): TensorType =
|
||||
wrap(JNoa.getIndex(this.cast().tensorHandle, i))
|
||||
|
||||
public operator fun Tensor<T>.get(indexTensor: NoaLongTensor): TensorType =
|
||||
wrap(JNoa.getIndexTensor(this.cast().tensorHandle, indexTensor.tensorHandle))
|
||||
|
||||
override fun diagonalEmbedding(
|
||||
diagonalEntries: Tensor<T>, offset: Int, dim1: Int, dim2: Int
|
||||
): TensorType =
|
||||
wrap(JNoa.diagEmbed(diagonalEntries.cast().tensorHandle, offset, dim1, dim2))
|
||||
|
||||
override fun Tensor<T>.transpose(i: Int, j: Int): TensorType {
|
||||
return wrap(JNoa.transposeTensor(this.cast().tensorHandle, i, j))
|
||||
}
|
||||
|
||||
override fun Tensor<T>.view(shape: IntArray): TensorType {
|
||||
return wrap(JNoa.viewTensor(this.cast().tensorHandle, shape))
|
||||
}
|
||||
|
||||
override fun Tensor<T>.viewAs(other: Tensor<T>): TensorType {
|
||||
return wrap(JNoa.viewAsTensor(this.cast().tensorHandle, other.cast().tensorHandle))
|
||||
}
|
||||
|
||||
public fun Tensor<T>.abs(): TensorType = wrap(JNoa.absTensor(this.cast().tensorHandle))
|
||||
|
||||
public fun Tensor<T>.sumAll(): TensorType = wrap(JNoa.sumTensor(this.cast().tensorHandle))
|
||||
override fun Tensor<T>.sum(): T = sumAll().item()
|
||||
override fun Tensor<T>.sum(dim: Int, keepDim: Boolean): TensorType =
|
||||
wrap(JNoa.sumDimTensor(this.cast().tensorHandle, dim, keepDim))
|
||||
|
||||
public fun Tensor<T>.copy(): TensorType =
|
||||
wrap(JNoa.copyTensor(this.cast().tensorHandle))
|
||||
|
||||
public fun Tensor<T>.copyToDevice(device: Device): TensorType =
|
||||
wrap(JNoa.copyToDevice(this.cast().tensorHandle, device.toInt()))
|
||||
|
||||
}
|
||||
|
||||
public abstract class NoaPartialDivisionAlgebra<T, TensorType : NoaTensor<T>>
|
||||
|
@ -143,6 +143,14 @@ JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_copyToInt
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_viewTensor
|
||||
(JNIEnv *, jclass, jlong, jintArray);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: viewAsTensor
|
||||
* Signature: (JJ)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_viewAsTensor
|
||||
(JNIEnv *, jclass, jlong, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: tensorToString
|
||||
@ -223,6 +231,22 @@ JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_getItemLong
|
||||
JNIEXPORT jint JNICALL Java_space_kscience_kmath_noa_JNoa_getItemInt
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: getIndex
|
||||
* Signature: (JI)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_getIndex
|
||||
(JNIEnv *, jclass, jlong, jint);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: getIndexTensor
|
||||
* Signature: (JJ)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_getIndexTensor
|
||||
(JNIEnv *, jclass, jlong, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: getDouble
|
||||
@ -623,14 +647,6 @@ JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_unaryMinus
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_absTensor
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: absTensorAssign
|
||||
* Signature: (J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_absTensorAssign
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: transposeTensor
|
||||
@ -639,14 +655,6 @@ JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_absTensorAssign
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_transposeTensor
|
||||
(JNIEnv *, jclass, jlong, jint, jint);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: transposeTensorAssign
|
||||
* Signature: (JII)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_transposeTensorAssign
|
||||
(JNIEnv *, jclass, jlong, jint, jint);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: expTensor
|
||||
@ -655,14 +663,6 @@ JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_transposeTensorAssign
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_expTensor
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: expTensorAssign
|
||||
* Signature: (J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_expTensorAssign
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: logTensor
|
||||
@ -671,14 +671,6 @@ JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_expTensorAssign
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_logTensor
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: logTensorAssign
|
||||
* Signature: (J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_logTensorAssign
|
||||
(JNIEnv *, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: sumTensor
|
||||
@ -689,11 +681,11 @@ JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_sumTensor
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: sumTensorAssign
|
||||
* Signature: (J)V
|
||||
* Method: sumDimTensor
|
||||
* Signature: (JIZ)J
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_sumTensorAssign
|
||||
(JNIEnv *, jclass, jlong);
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_sumDimTensor
|
||||
(JNIEnv *, jclass, jlong, jint, jboolean);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
|
@ -6,6 +6,7 @@
|
||||
package space.kscience.kmath.tensors.api
|
||||
|
||||
import space.kscience.kmath.operations.Algebra
|
||||
import space.kscience.kmath.tensors.core.DoubleTensor
|
||||
|
||||
/**
|
||||
* Algebra over a ring on [Tensor].
|
||||
@ -313,4 +314,17 @@ public interface TensorAlgebra<T> : Algebra<Tensor<T>> {
|
||||
*/
|
||||
public fun Tensor<T>.max(dim: Int, keepDim: Boolean): Tensor<T>
|
||||
|
||||
/**
|
||||
* Returns the index of maximum value of each row of the input tensor in the given dimension [dim].
|
||||
*
|
||||
* If [keepDim] is true, the output tensor is of the same size as
|
||||
* input except in the dimension [dim] where it is of size 1.
|
||||
* Otherwise, [dim] is squeezed, resulting in the output tensor having 1 fewer dimension.
|
||||
*
|
||||
* @param dim the dimension to reduce.
|
||||
* @param keepDim whether the output tensor has [dim] retained or not.
|
||||
* @return the the index of maximum value of each row of the input tensor in the given dimension [dim].
|
||||
*/
|
||||
public fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): Tensor<Int>
|
||||
|
||||
}
|
||||
|
@ -5,8 +5,10 @@
|
||||
|
||||
package space.kscience.kmath.tensors.core
|
||||
|
||||
import space.kscience.kmath.structures.MutableBuffer
|
||||
import space.kscience.kmath.nd.as1D
|
||||
import space.kscience.kmath.nd.as2D
|
||||
import space.kscience.kmath.structures.asMutableBuffer
|
||||
import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra
|
||||
import space.kscience.kmath.tensors.api.LinearOpsTensorAlgebra
|
||||
import space.kscience.kmath.tensors.api.Tensor
|
||||
@ -537,11 +539,11 @@ public open class DoubleTensorAlgebra :
|
||||
internal fun Tensor<Double>.fold(foldFunction: (DoubleArray) -> Double): Double =
|
||||
foldFunction(tensor.toDoubleArray())
|
||||
|
||||
internal fun Tensor<Double>.foldDim(
|
||||
foldFunction: (DoubleArray) -> Double,
|
||||
internal fun <R> Tensor<Double>.foldDim(
|
||||
foldFunction: (DoubleArray) -> R,
|
||||
dim: Int,
|
||||
keepDim: Boolean,
|
||||
): DoubleTensor {
|
||||
): BufferedTensor<R> {
|
||||
check(dim < dimension) { "Dimension $dim out of range $dimension" }
|
||||
val resShape = if (keepDim) {
|
||||
shape.take(dim).toIntArray() + intArrayOf(1) + shape.takeLast(dimension - dim - 1).toIntArray()
|
||||
@ -549,7 +551,9 @@ public open class DoubleTensorAlgebra :
|
||||
shape.take(dim).toIntArray() + shape.takeLast(dimension - dim - 1).toIntArray()
|
||||
}
|
||||
val resNumElements = resShape.reduce(Int::times)
|
||||
val resTensor = DoubleTensor(resShape, DoubleArray(resNumElements) { 0.0 }, 0)
|
||||
val init = foldFunction(DoubleArray(1){0.0})
|
||||
val resTensor = BufferedTensor(resShape,
|
||||
MutableList(resNumElements) { init }.asMutableBuffer(), 0)
|
||||
for (index in resTensor.linearStructure.indices()) {
|
||||
val prefix = index.take(dim).toIntArray()
|
||||
val suffix = index.takeLast(dimension - dim - 1).toIntArray()
|
||||
@ -557,41 +561,30 @@ public open class DoubleTensorAlgebra :
|
||||
tensor[prefix + intArrayOf(i) + suffix]
|
||||
})
|
||||
}
|
||||
|
||||
return resTensor
|
||||
}
|
||||
|
||||
|
||||
override fun Tensor<Double>.sum(): Double = tensor.fold { it.sum() }
|
||||
|
||||
override fun Tensor<Double>.sum(dim: Int, keepDim: Boolean): DoubleTensor =
|
||||
foldDim({ x -> x.sum() }, dim, keepDim)
|
||||
foldDim({ x -> x.sum() }, dim, keepDim).toDoubleTensor()
|
||||
|
||||
override fun Tensor<Double>.min(): Double = this.fold { it.minOrNull()!! }
|
||||
|
||||
override fun Tensor<Double>.min(dim: Int, keepDim: Boolean): DoubleTensor =
|
||||
foldDim({ x -> x.minOrNull()!! }, dim, keepDim)
|
||||
foldDim({ x -> x.minOrNull()!! }, dim, keepDim).toDoubleTensor()
|
||||
|
||||
override fun Tensor<Double>.max(): Double = this.fold { it.maxOrNull()!! }
|
||||
|
||||
override fun Tensor<Double>.max(dim: Int, keepDim: Boolean): DoubleTensor =
|
||||
foldDim({ x -> x.maxOrNull()!! }, dim, keepDim)
|
||||
foldDim({ x -> x.maxOrNull()!! }, dim, keepDim).toDoubleTensor()
|
||||
|
||||
|
||||
/**
|
||||
* Returns the index of maximum value of each row of the input tensor in the given dimension [dim].
|
||||
*
|
||||
* If [keepDim] is true, the output tensor is of the same size as
|
||||
* input except in the dimension [dim] where it is of size 1.
|
||||
* Otherwise, [dim] is squeezed, resulting in the output tensor having 1 fewer dimension.
|
||||
*
|
||||
* @param dim the dimension to reduce.
|
||||
* @param keepDim whether the output tensor has [dim] retained or not.
|
||||
* @return the the index of maximum value of each row of the input tensor in the given dimension [dim].
|
||||
*/
|
||||
public fun Tensor<Double>.argMax(dim: Int, keepDim: Boolean): DoubleTensor =
|
||||
override fun Tensor<Double>.argMax(dim: Int, keepDim: Boolean): IntTensor =
|
||||
foldDim({ x ->
|
||||
x.withIndex().maxByOrNull { it.value }?.index!!.toDouble()
|
||||
}, dim, keepDim)
|
||||
x.withIndex().maxByOrNull { it.value }?.index!!
|
||||
}, dim, keepDim).toIntTensor()
|
||||
|
||||
|
||||
override fun Tensor<Double>.mean(): Double = this.fold { it.sum() / tensor.numElements }
|
||||
@ -604,7 +597,7 @@ public open class DoubleTensorAlgebra :
|
||||
},
|
||||
dim,
|
||||
keepDim
|
||||
)
|
||||
).toDoubleTensor()
|
||||
|
||||
override fun Tensor<Double>.std(): Double = this.fold { arr ->
|
||||
val mean = arr.sum() / tensor.numElements
|
||||
@ -619,7 +612,7 @@ public open class DoubleTensorAlgebra :
|
||||
},
|
||||
dim,
|
||||
keepDim
|
||||
)
|
||||
).toDoubleTensor()
|
||||
|
||||
override fun Tensor<Double>.variance(): Double = this.fold { arr ->
|
||||
val mean = arr.sum() / tensor.numElements
|
||||
@ -634,7 +627,7 @@ public open class DoubleTensorAlgebra :
|
||||
},
|
||||
dim,
|
||||
keepDim
|
||||
)
|
||||
).toDoubleTensor()
|
||||
|
||||
private fun cov(x: DoubleTensor, y: DoubleTensor): Double {
|
||||
val n = x.shape[0]
|
||||
|
@ -6,6 +6,7 @@
|
||||
package space.kscience.kmath.tensors.core
|
||||
|
||||
import space.kscience.kmath.structures.IntBuffer
|
||||
import space.kscience.kmath.tensors.core.internal.array
|
||||
|
||||
/**
|
||||
* Default [BufferedTensor] implementation for [Int] values
|
||||
@ -14,4 +15,7 @@ public class IntTensor internal constructor(
|
||||
shape: IntArray,
|
||||
buffer: IntArray,
|
||||
offset: Int = 0
|
||||
) : BufferedTensor<Int>(shape, IntBuffer(buffer), offset)
|
||||
) : BufferedTensor<Int>(shape, IntBuffer(buffer), offset){
|
||||
public fun asDouble() : DoubleTensor =
|
||||
DoubleTensor(shape, mutableBuffer.array().map{ it.toDouble()}.toDoubleArray(), bufferStart)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user