forked from kscience/kmath
Fix argmax for tensors
This commit is contained in:
parent
b2b063196d
commit
773ff10dd1
@ -163,7 +163,7 @@ class NeuralNetwork(private val layers: List<Layer>) {
|
|||||||
for ((xBatch, yBatch) in iterBatch(xTrain, yTrain)) {
|
for ((xBatch, yBatch) in iterBatch(xTrain, yTrain)) {
|
||||||
train(xBatch, yBatch)
|
train(xBatch, yBatch)
|
||||||
}
|
}
|
||||||
println("Accuracy:${accuracy(yTrain, predict(xTrain).argMax(1, true))}")
|
println("Accuracy:${accuracy(yTrain, predict(xTrain).argMax(1, true).asDouble())}")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -230,7 +230,7 @@ fun main() = BroadcastDoubleTensorAlgebra {
|
|||||||
val prediction = model.predict(xTest)
|
val prediction = model.predict(xTest)
|
||||||
|
|
||||||
// process raw prediction via argMax
|
// process raw prediction via argMax
|
||||||
val predictionLabels = prediction.argMax(1, true)
|
val predictionLabels = prediction.argMax(1, true).asDouble()
|
||||||
|
|
||||||
// find out accuracy
|
// find out accuracy
|
||||||
val acc = accuracy(yTest, predictionLabels)
|
val acc = accuracy(yTest, predictionLabels)
|
||||||
|
@ -28,6 +28,8 @@ public sealed interface Nd4jTensorAlgebra<T : Number> : AnalyticTensorAlgebra<T>
|
|||||||
*/
|
*/
|
||||||
public fun INDArray.wrap(): Nd4jArrayStructure<T>
|
public fun INDArray.wrap(): Nd4jArrayStructure<T>
|
||||||
|
|
||||||
|
public fun INDArray.wrapInt(): Nd4jArrayStructure<Int>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Unwraps to or acquires [INDArray] from [StructureND].
|
* Unwraps to or acquires [INDArray] from [StructureND].
|
||||||
*/
|
*/
|
||||||
@ -90,8 +92,8 @@ public sealed interface Nd4jTensorAlgebra<T : Number> : AnalyticTensorAlgebra<T>
|
|||||||
public override fun Tensor<T>.view(shape: IntArray): Tensor<T> = ndArray.reshape(shape).wrap()
|
public override fun Tensor<T>.view(shape: IntArray): Tensor<T> = ndArray.reshape(shape).wrap()
|
||||||
public override fun Tensor<T>.viewAs(other: Tensor<T>): Tensor<T> = view(other.shape)
|
public override fun Tensor<T>.viewAs(other: Tensor<T>): Tensor<T> = view(other.shape)
|
||||||
|
|
||||||
public fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): Tensor<T> =
|
override fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): Tensor<Int> =
|
||||||
ndBase.get().argmax(ndArray, keepDim, dim).wrap()
|
ndBase.get().argmax(ndArray, keepDim, dim).wrapInt()
|
||||||
|
|
||||||
public override fun Tensor<T>.mean(dim: Int, keepDim: Boolean): Tensor<T> = ndArray.mean(keepDim, dim).wrap()
|
public override fun Tensor<T>.mean(dim: Int, keepDim: Boolean): Tensor<T> = ndArray.mean(keepDim, dim).wrap()
|
||||||
|
|
||||||
@ -144,6 +146,7 @@ public sealed interface Nd4jTensorAlgebra<T : Number> : AnalyticTensorAlgebra<T>
|
|||||||
*/
|
*/
|
||||||
public object DoubleNd4jTensorAlgebra : Nd4jTensorAlgebra<Double> {
|
public object DoubleNd4jTensorAlgebra : Nd4jTensorAlgebra<Double> {
|
||||||
public override fun INDArray.wrap(): Nd4jArrayStructure<Double> = asDoubleStructure()
|
public override fun INDArray.wrap(): Nd4jArrayStructure<Double> = asDoubleStructure()
|
||||||
|
public override fun INDArray.wrapInt(): Nd4jArrayStructure<Int> = asIntStructure()
|
||||||
|
|
||||||
@OptIn(PerformancePitfall::class)
|
@OptIn(PerformancePitfall::class)
|
||||||
public override val StructureND<Double>.ndArray: INDArray
|
public override val StructureND<Double>.ndArray: INDArray
|
||||||
|
@ -55,6 +55,8 @@ class JNoa {
|
|||||||
|
|
||||||
public static native long viewTensor(long tensorHandle, int[] shape);
|
public static native long viewTensor(long tensorHandle, int[] shape);
|
||||||
|
|
||||||
|
public static native long viewAsTensor(long tensorHandle, long asTensorHandle);
|
||||||
|
|
||||||
public static native String tensorToString(long tensorHandle);
|
public static native String tensorToString(long tensorHandle);
|
||||||
|
|
||||||
public static native int getDim(long tensorHandle);
|
public static native int getDim(long tensorHandle);
|
||||||
@ -75,6 +77,10 @@ class JNoa {
|
|||||||
|
|
||||||
public static native int getItemInt(long tensorHandle);
|
public static native int getItemInt(long tensorHandle);
|
||||||
|
|
||||||
|
public static native long getIndex(long tensorHandle, int index);
|
||||||
|
|
||||||
|
public static native long getIndexTensor(long tensorHandle, long indexTensorHandle);
|
||||||
|
|
||||||
public static native double getDouble(long tensorHandle, int[] index);
|
public static native double getDouble(long tensorHandle, int[] index);
|
||||||
|
|
||||||
public static native float getFloat(long tensorHandle, int[] index);
|
public static native float getFloat(long tensorHandle, int[] index);
|
||||||
@ -175,23 +181,15 @@ class JNoa {
|
|||||||
|
|
||||||
public static native long absTensor(long tensorHandle);
|
public static native long absTensor(long tensorHandle);
|
||||||
|
|
||||||
public static native void absTensorAssign(long tensorHandle);
|
|
||||||
|
|
||||||
public static native long transposeTensor(long tensorHandle, int i, int j);
|
public static native long transposeTensor(long tensorHandle, int i, int j);
|
||||||
|
|
||||||
public static native void transposeTensorAssign(long tensorHandle, int i, int j);
|
|
||||||
|
|
||||||
public static native long expTensor(long tensorHandle);
|
public static native long expTensor(long tensorHandle);
|
||||||
|
|
||||||
public static native void expTensorAssign(long tensorHandle);
|
|
||||||
|
|
||||||
public static native long logTensor(long tensorHandle);
|
public static native long logTensor(long tensorHandle);
|
||||||
|
|
||||||
public static native void logTensorAssign(long tensorHandle);
|
|
||||||
|
|
||||||
public static native long sumTensor(long tensorHandle);
|
public static native long sumTensor(long tensorHandle);
|
||||||
|
|
||||||
public static native void sumTensorAssign(long tensorHandle);
|
public static native long sumDimTensor(long tensorHandle, int dim, boolean keepDim);
|
||||||
|
|
||||||
public static native long matmul(long lhs, long rhs);
|
public static native long matmul(long lhs, long rhs);
|
||||||
|
|
||||||
|
@ -57,6 +57,55 @@ constructor(protected val scope: NoaScope) : TensorAlgebra<T> {
|
|||||||
|
|
||||||
override operator fun Tensor<T>.unaryMinus(): TensorType =
|
override operator fun Tensor<T>.unaryMinus(): TensorType =
|
||||||
wrap(JNoa.unaryMinus(this.cast().tensorHandle))
|
wrap(JNoa.unaryMinus(this.cast().tensorHandle))
|
||||||
|
|
||||||
|
override infix fun Tensor<T>.dot(other: Tensor<T>): TensorType {
|
||||||
|
return wrap(JNoa.matmul(this.cast().tensorHandle, other.cast().tensorHandle))
|
||||||
|
}
|
||||||
|
|
||||||
|
public infix fun Tensor<T>.dotAssign(other: Tensor<T>): Unit {
|
||||||
|
JNoa.matmulAssign(this.cast().tensorHandle, other.cast().tensorHandle)
|
||||||
|
}
|
||||||
|
|
||||||
|
public infix fun Tensor<T>.dotRightAssign(other: Tensor<T>): Unit {
|
||||||
|
JNoa.matmulRightAssign(this.cast().tensorHandle, other.cast().tensorHandle)
|
||||||
|
}
|
||||||
|
|
||||||
|
override operator fun Tensor<T>.get(i: Int): TensorType =
|
||||||
|
wrap(JNoa.getIndex(this.cast().tensorHandle, i))
|
||||||
|
|
||||||
|
public operator fun Tensor<T>.get(indexTensor: NoaLongTensor): TensorType =
|
||||||
|
wrap(JNoa.getIndexTensor(this.cast().tensorHandle, indexTensor.tensorHandle))
|
||||||
|
|
||||||
|
override fun diagonalEmbedding(
|
||||||
|
diagonalEntries: Tensor<T>, offset: Int, dim1: Int, dim2: Int
|
||||||
|
): TensorType =
|
||||||
|
wrap(JNoa.diagEmbed(diagonalEntries.cast().tensorHandle, offset, dim1, dim2))
|
||||||
|
|
||||||
|
override fun Tensor<T>.transpose(i: Int, j: Int): TensorType {
|
||||||
|
return wrap(JNoa.transposeTensor(this.cast().tensorHandle, i, j))
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun Tensor<T>.view(shape: IntArray): TensorType {
|
||||||
|
return wrap(JNoa.viewTensor(this.cast().tensorHandle, shape))
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun Tensor<T>.viewAs(other: Tensor<T>): TensorType {
|
||||||
|
return wrap(JNoa.viewAsTensor(this.cast().tensorHandle, other.cast().tensorHandle))
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun Tensor<T>.abs(): TensorType = wrap(JNoa.absTensor(this.cast().tensorHandle))
|
||||||
|
|
||||||
|
public fun Tensor<T>.sumAll(): TensorType = wrap(JNoa.sumTensor(this.cast().tensorHandle))
|
||||||
|
override fun Tensor<T>.sum(): T = sumAll().item()
|
||||||
|
override fun Tensor<T>.sum(dim: Int, keepDim: Boolean): TensorType =
|
||||||
|
wrap(JNoa.sumDimTensor(this.cast().tensorHandle, dim, keepDim))
|
||||||
|
|
||||||
|
public fun Tensor<T>.copy(): TensorType =
|
||||||
|
wrap(JNoa.copyTensor(this.cast().tensorHandle))
|
||||||
|
|
||||||
|
public fun Tensor<T>.copyToDevice(device: Device): TensorType =
|
||||||
|
wrap(JNoa.copyToDevice(this.cast().tensorHandle, device.toInt()))
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public abstract class NoaPartialDivisionAlgebra<T, TensorType : NoaTensor<T>>
|
public abstract class NoaPartialDivisionAlgebra<T, TensorType : NoaTensor<T>>
|
||||||
|
@ -143,6 +143,14 @@ JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_copyToInt
|
|||||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_viewTensor
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_viewTensor
|
||||||
(JNIEnv *, jclass, jlong, jintArray);
|
(JNIEnv *, jclass, jlong, jintArray);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_noa_JNoa
|
||||||
|
* Method: viewAsTensor
|
||||||
|
* Signature: (JJ)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_viewAsTensor
|
||||||
|
(JNIEnv *, jclass, jlong, jlong);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: space_kscience_kmath_noa_JNoa
|
* Class: space_kscience_kmath_noa_JNoa
|
||||||
* Method: tensorToString
|
* Method: tensorToString
|
||||||
@ -223,6 +231,22 @@ JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_getItemLong
|
|||||||
JNIEXPORT jint JNICALL Java_space_kscience_kmath_noa_JNoa_getItemInt
|
JNIEXPORT jint JNICALL Java_space_kscience_kmath_noa_JNoa_getItemInt
|
||||||
(JNIEnv *, jclass, jlong);
|
(JNIEnv *, jclass, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_noa_JNoa
|
||||||
|
* Method: getIndex
|
||||||
|
* Signature: (JI)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_getIndex
|
||||||
|
(JNIEnv *, jclass, jlong, jint);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_noa_JNoa
|
||||||
|
* Method: getIndexTensor
|
||||||
|
* Signature: (JJ)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_getIndexTensor
|
||||||
|
(JNIEnv *, jclass, jlong, jlong);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: space_kscience_kmath_noa_JNoa
|
* Class: space_kscience_kmath_noa_JNoa
|
||||||
* Method: getDouble
|
* Method: getDouble
|
||||||
@ -623,14 +647,6 @@ JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_unaryMinus
|
|||||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_absTensor
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_absTensor
|
||||||
(JNIEnv *, jclass, jlong);
|
(JNIEnv *, jclass, jlong);
|
||||||
|
|
||||||
/*
|
|
||||||
* Class: space_kscience_kmath_noa_JNoa
|
|
||||||
* Method: absTensorAssign
|
|
||||||
* Signature: (J)V
|
|
||||||
*/
|
|
||||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_absTensorAssign
|
|
||||||
(JNIEnv *, jclass, jlong);
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: space_kscience_kmath_noa_JNoa
|
* Class: space_kscience_kmath_noa_JNoa
|
||||||
* Method: transposeTensor
|
* Method: transposeTensor
|
||||||
@ -639,14 +655,6 @@ JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_absTensorAssign
|
|||||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_transposeTensor
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_transposeTensor
|
||||||
(JNIEnv *, jclass, jlong, jint, jint);
|
(JNIEnv *, jclass, jlong, jint, jint);
|
||||||
|
|
||||||
/*
|
|
||||||
* Class: space_kscience_kmath_noa_JNoa
|
|
||||||
* Method: transposeTensorAssign
|
|
||||||
* Signature: (JII)V
|
|
||||||
*/
|
|
||||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_transposeTensorAssign
|
|
||||||
(JNIEnv *, jclass, jlong, jint, jint);
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: space_kscience_kmath_noa_JNoa
|
* Class: space_kscience_kmath_noa_JNoa
|
||||||
* Method: expTensor
|
* Method: expTensor
|
||||||
@ -655,14 +663,6 @@ JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_transposeTensorAssign
|
|||||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_expTensor
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_expTensor
|
||||||
(JNIEnv *, jclass, jlong);
|
(JNIEnv *, jclass, jlong);
|
||||||
|
|
||||||
/*
|
|
||||||
* Class: space_kscience_kmath_noa_JNoa
|
|
||||||
* Method: expTensorAssign
|
|
||||||
* Signature: (J)V
|
|
||||||
*/
|
|
||||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_expTensorAssign
|
|
||||||
(JNIEnv *, jclass, jlong);
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: space_kscience_kmath_noa_JNoa
|
* Class: space_kscience_kmath_noa_JNoa
|
||||||
* Method: logTensor
|
* Method: logTensor
|
||||||
@ -671,14 +671,6 @@ JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_expTensorAssign
|
|||||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_logTensor
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_logTensor
|
||||||
(JNIEnv *, jclass, jlong);
|
(JNIEnv *, jclass, jlong);
|
||||||
|
|
||||||
/*
|
|
||||||
* Class: space_kscience_kmath_noa_JNoa
|
|
||||||
* Method: logTensorAssign
|
|
||||||
* Signature: (J)V
|
|
||||||
*/
|
|
||||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_logTensorAssign
|
|
||||||
(JNIEnv *, jclass, jlong);
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: space_kscience_kmath_noa_JNoa
|
* Class: space_kscience_kmath_noa_JNoa
|
||||||
* Method: sumTensor
|
* Method: sumTensor
|
||||||
@ -689,11 +681,11 @@ JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_sumTensor
|
|||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: space_kscience_kmath_noa_JNoa
|
* Class: space_kscience_kmath_noa_JNoa
|
||||||
* Method: sumTensorAssign
|
* Method: sumDimTensor
|
||||||
* Signature: (J)V
|
* Signature: (JIZ)J
|
||||||
*/
|
*/
|
||||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_sumTensorAssign
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_sumDimTensor
|
||||||
(JNIEnv *, jclass, jlong);
|
(JNIEnv *, jclass, jlong, jint, jboolean);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: space_kscience_kmath_noa_JNoa
|
* Class: space_kscience_kmath_noa_JNoa
|
||||||
|
@ -6,6 +6,7 @@
|
|||||||
package space.kscience.kmath.tensors.api
|
package space.kscience.kmath.tensors.api
|
||||||
|
|
||||||
import space.kscience.kmath.operations.Algebra
|
import space.kscience.kmath.operations.Algebra
|
||||||
|
import space.kscience.kmath.tensors.core.DoubleTensor
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Algebra over a ring on [Tensor].
|
* Algebra over a ring on [Tensor].
|
||||||
@ -313,4 +314,17 @@ public interface TensorAlgebra<T> : Algebra<Tensor<T>> {
|
|||||||
*/
|
*/
|
||||||
public fun Tensor<T>.max(dim: Int, keepDim: Boolean): Tensor<T>
|
public fun Tensor<T>.max(dim: Int, keepDim: Boolean): Tensor<T>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the index of maximum value of each row of the input tensor in the given dimension [dim].
|
||||||
|
*
|
||||||
|
* If [keepDim] is true, the output tensor is of the same size as
|
||||||
|
* input except in the dimension [dim] where it is of size 1.
|
||||||
|
* Otherwise, [dim] is squeezed, resulting in the output tensor having 1 fewer dimension.
|
||||||
|
*
|
||||||
|
* @param dim the dimension to reduce.
|
||||||
|
* @param keepDim whether the output tensor has [dim] retained or not.
|
||||||
|
* @return the the index of maximum value of each row of the input tensor in the given dimension [dim].
|
||||||
|
*/
|
||||||
|
public fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): Tensor<Int>
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -5,8 +5,10 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.tensors.core
|
package space.kscience.kmath.tensors.core
|
||||||
|
|
||||||
|
import space.kscience.kmath.structures.MutableBuffer
|
||||||
import space.kscience.kmath.nd.as1D
|
import space.kscience.kmath.nd.as1D
|
||||||
import space.kscience.kmath.nd.as2D
|
import space.kscience.kmath.nd.as2D
|
||||||
|
import space.kscience.kmath.structures.asMutableBuffer
|
||||||
import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra
|
import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra
|
||||||
import space.kscience.kmath.tensors.api.LinearOpsTensorAlgebra
|
import space.kscience.kmath.tensors.api.LinearOpsTensorAlgebra
|
||||||
import space.kscience.kmath.tensors.api.Tensor
|
import space.kscience.kmath.tensors.api.Tensor
|
||||||
@ -537,11 +539,11 @@ public open class DoubleTensorAlgebra :
|
|||||||
internal fun Tensor<Double>.fold(foldFunction: (DoubleArray) -> Double): Double =
|
internal fun Tensor<Double>.fold(foldFunction: (DoubleArray) -> Double): Double =
|
||||||
foldFunction(tensor.toDoubleArray())
|
foldFunction(tensor.toDoubleArray())
|
||||||
|
|
||||||
internal fun Tensor<Double>.foldDim(
|
internal fun <R> Tensor<Double>.foldDim(
|
||||||
foldFunction: (DoubleArray) -> Double,
|
foldFunction: (DoubleArray) -> R,
|
||||||
dim: Int,
|
dim: Int,
|
||||||
keepDim: Boolean,
|
keepDim: Boolean,
|
||||||
): DoubleTensor {
|
): BufferedTensor<R> {
|
||||||
check(dim < dimension) { "Dimension $dim out of range $dimension" }
|
check(dim < dimension) { "Dimension $dim out of range $dimension" }
|
||||||
val resShape = if (keepDim) {
|
val resShape = if (keepDim) {
|
||||||
shape.take(dim).toIntArray() + intArrayOf(1) + shape.takeLast(dimension - dim - 1).toIntArray()
|
shape.take(dim).toIntArray() + intArrayOf(1) + shape.takeLast(dimension - dim - 1).toIntArray()
|
||||||
@ -549,7 +551,9 @@ public open class DoubleTensorAlgebra :
|
|||||||
shape.take(dim).toIntArray() + shape.takeLast(dimension - dim - 1).toIntArray()
|
shape.take(dim).toIntArray() + shape.takeLast(dimension - dim - 1).toIntArray()
|
||||||
}
|
}
|
||||||
val resNumElements = resShape.reduce(Int::times)
|
val resNumElements = resShape.reduce(Int::times)
|
||||||
val resTensor = DoubleTensor(resShape, DoubleArray(resNumElements) { 0.0 }, 0)
|
val init = foldFunction(DoubleArray(1){0.0})
|
||||||
|
val resTensor = BufferedTensor(resShape,
|
||||||
|
MutableList(resNumElements) { init }.asMutableBuffer(), 0)
|
||||||
for (index in resTensor.linearStructure.indices()) {
|
for (index in resTensor.linearStructure.indices()) {
|
||||||
val prefix = index.take(dim).toIntArray()
|
val prefix = index.take(dim).toIntArray()
|
||||||
val suffix = index.takeLast(dimension - dim - 1).toIntArray()
|
val suffix = index.takeLast(dimension - dim - 1).toIntArray()
|
||||||
@ -557,41 +561,30 @@ public open class DoubleTensorAlgebra :
|
|||||||
tensor[prefix + intArrayOf(i) + suffix]
|
tensor[prefix + intArrayOf(i) + suffix]
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return resTensor
|
return resTensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
override fun Tensor<Double>.sum(): Double = tensor.fold { it.sum() }
|
override fun Tensor<Double>.sum(): Double = tensor.fold { it.sum() }
|
||||||
|
|
||||||
override fun Tensor<Double>.sum(dim: Int, keepDim: Boolean): DoubleTensor =
|
override fun Tensor<Double>.sum(dim: Int, keepDim: Boolean): DoubleTensor =
|
||||||
foldDim({ x -> x.sum() }, dim, keepDim)
|
foldDim({ x -> x.sum() }, dim, keepDim).toDoubleTensor()
|
||||||
|
|
||||||
override fun Tensor<Double>.min(): Double = this.fold { it.minOrNull()!! }
|
override fun Tensor<Double>.min(): Double = this.fold { it.minOrNull()!! }
|
||||||
|
|
||||||
override fun Tensor<Double>.min(dim: Int, keepDim: Boolean): DoubleTensor =
|
override fun Tensor<Double>.min(dim: Int, keepDim: Boolean): DoubleTensor =
|
||||||
foldDim({ x -> x.minOrNull()!! }, dim, keepDim)
|
foldDim({ x -> x.minOrNull()!! }, dim, keepDim).toDoubleTensor()
|
||||||
|
|
||||||
override fun Tensor<Double>.max(): Double = this.fold { it.maxOrNull()!! }
|
override fun Tensor<Double>.max(): Double = this.fold { it.maxOrNull()!! }
|
||||||
|
|
||||||
override fun Tensor<Double>.max(dim: Int, keepDim: Boolean): DoubleTensor =
|
override fun Tensor<Double>.max(dim: Int, keepDim: Boolean): DoubleTensor =
|
||||||
foldDim({ x -> x.maxOrNull()!! }, dim, keepDim)
|
foldDim({ x -> x.maxOrNull()!! }, dim, keepDim).toDoubleTensor()
|
||||||
|
|
||||||
|
|
||||||
/**
|
override fun Tensor<Double>.argMax(dim: Int, keepDim: Boolean): IntTensor =
|
||||||
* Returns the index of maximum value of each row of the input tensor in the given dimension [dim].
|
|
||||||
*
|
|
||||||
* If [keepDim] is true, the output tensor is of the same size as
|
|
||||||
* input except in the dimension [dim] where it is of size 1.
|
|
||||||
* Otherwise, [dim] is squeezed, resulting in the output tensor having 1 fewer dimension.
|
|
||||||
*
|
|
||||||
* @param dim the dimension to reduce.
|
|
||||||
* @param keepDim whether the output tensor has [dim] retained or not.
|
|
||||||
* @return the the index of maximum value of each row of the input tensor in the given dimension [dim].
|
|
||||||
*/
|
|
||||||
public fun Tensor<Double>.argMax(dim: Int, keepDim: Boolean): DoubleTensor =
|
|
||||||
foldDim({ x ->
|
foldDim({ x ->
|
||||||
x.withIndex().maxByOrNull { it.value }?.index!!.toDouble()
|
x.withIndex().maxByOrNull { it.value }?.index!!
|
||||||
}, dim, keepDim)
|
}, dim, keepDim).toIntTensor()
|
||||||
|
|
||||||
|
|
||||||
override fun Tensor<Double>.mean(): Double = this.fold { it.sum() / tensor.numElements }
|
override fun Tensor<Double>.mean(): Double = this.fold { it.sum() / tensor.numElements }
|
||||||
@ -604,7 +597,7 @@ public open class DoubleTensorAlgebra :
|
|||||||
},
|
},
|
||||||
dim,
|
dim,
|
||||||
keepDim
|
keepDim
|
||||||
)
|
).toDoubleTensor()
|
||||||
|
|
||||||
override fun Tensor<Double>.std(): Double = this.fold { arr ->
|
override fun Tensor<Double>.std(): Double = this.fold { arr ->
|
||||||
val mean = arr.sum() / tensor.numElements
|
val mean = arr.sum() / tensor.numElements
|
||||||
@ -619,7 +612,7 @@ public open class DoubleTensorAlgebra :
|
|||||||
},
|
},
|
||||||
dim,
|
dim,
|
||||||
keepDim
|
keepDim
|
||||||
)
|
).toDoubleTensor()
|
||||||
|
|
||||||
override fun Tensor<Double>.variance(): Double = this.fold { arr ->
|
override fun Tensor<Double>.variance(): Double = this.fold { arr ->
|
||||||
val mean = arr.sum() / tensor.numElements
|
val mean = arr.sum() / tensor.numElements
|
||||||
@ -634,7 +627,7 @@ public open class DoubleTensorAlgebra :
|
|||||||
},
|
},
|
||||||
dim,
|
dim,
|
||||||
keepDim
|
keepDim
|
||||||
)
|
).toDoubleTensor()
|
||||||
|
|
||||||
private fun cov(x: DoubleTensor, y: DoubleTensor): Double {
|
private fun cov(x: DoubleTensor, y: DoubleTensor): Double {
|
||||||
val n = x.shape[0]
|
val n = x.shape[0]
|
||||||
|
@ -6,6 +6,7 @@
|
|||||||
package space.kscience.kmath.tensors.core
|
package space.kscience.kmath.tensors.core
|
||||||
|
|
||||||
import space.kscience.kmath.structures.IntBuffer
|
import space.kscience.kmath.structures.IntBuffer
|
||||||
|
import space.kscience.kmath.tensors.core.internal.array
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Default [BufferedTensor] implementation for [Int] values
|
* Default [BufferedTensor] implementation for [Int] values
|
||||||
@ -14,4 +15,7 @@ public class IntTensor internal constructor(
|
|||||||
shape: IntArray,
|
shape: IntArray,
|
||||||
buffer: IntArray,
|
buffer: IntArray,
|
||||||
offset: Int = 0
|
offset: Int = 0
|
||||||
) : BufferedTensor<Int>(shape, IntBuffer(buffer), offset)
|
) : BufferedTensor<Int>(shape, IntBuffer(buffer), offset){
|
||||||
|
public fun asDouble() : DoubleTensor =
|
||||||
|
DoubleTensor(shape, mutableBuffer.array().map{ it.toDouble()}.toDoubleArray(), bufferStart)
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user