fix folding for DoubleTensor

This commit is contained in:
Roland Grinis 2021-07-09 07:36:18 +01:00
parent 280c4e97e2
commit bea6ed4d65

View File

@ -8,7 +8,6 @@ package space.kscience.kmath.tensors.core
import space.kscience.kmath.structures.MutableBuffer import space.kscience.kmath.structures.MutableBuffer
import space.kscience.kmath.nd.as1D import space.kscience.kmath.nd.as1D
import space.kscience.kmath.nd.as2D import space.kscience.kmath.nd.as2D
import space.kscience.kmath.structures.asMutableBuffer
import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra
import space.kscience.kmath.tensors.api.LinearOpsTensorAlgebra import space.kscience.kmath.tensors.api.LinearOpsTensorAlgebra
import space.kscience.kmath.tensors.api.Tensor import space.kscience.kmath.tensors.api.Tensor
@ -539,7 +538,7 @@ public open class DoubleTensorAlgebra :
internal fun Tensor<Double>.fold(foldFunction: (DoubleArray) -> Double): Double = internal fun Tensor<Double>.fold(foldFunction: (DoubleArray) -> Double): Double =
foldFunction(tensor.toDoubleArray()) foldFunction(tensor.toDoubleArray())
internal fun <R> Tensor<Double>.foldDim( internal inline fun <reified R: Any> Tensor<Double>.foldDim(
foldFunction: (DoubleArray) -> R, foldFunction: (DoubleArray) -> R,
dim: Int, dim: Int,
keepDim: Boolean, keepDim: Boolean,
@ -553,7 +552,9 @@ public open class DoubleTensorAlgebra :
val resNumElements = resShape.reduce(Int::times) val resNumElements = resShape.reduce(Int::times)
val init = foldFunction(DoubleArray(1){0.0}) val init = foldFunction(DoubleArray(1){0.0})
val resTensor = BufferedTensor(resShape, val resTensor = BufferedTensor(resShape,
MutableList(resNumElements) { init }.asMutableBuffer(), 0) MutableBuffer.auto(resNumElements) { init },
//MutableList(resNumElements) { init }.asMutableBuffer(),
0)
for (index in resTensor.linearStructure.indices()) { for (index in resTensor.linearStructure.indices()) {
val prefix = index.take(dim).toIntArray() val prefix = index.take(dim).toIntArray()
val suffix = index.takeLast(dimension - dim - 1).toIntArray() val suffix = index.takeLast(dimension - dim - 1).toIntArray()
@ -564,7 +565,6 @@ public open class DoubleTensorAlgebra :
return resTensor return resTensor
} }
override fun Tensor<Double>.sum(): Double = tensor.fold { it.sum() } override fun Tensor<Double>.sum(): Double = tensor.fold { it.sum() }
override fun Tensor<Double>.sum(dim: Int, keepDim: Boolean): DoubleTensor = override fun Tensor<Double>.sum(dim: Int, keepDim: Boolean): DoubleTensor =