forked from kscience/kmath
fix folding for DoubleTensor
This commit is contained in:
parent
280c4e97e2
commit
bea6ed4d65
@ -8,7 +8,6 @@ package space.kscience.kmath.tensors.core
|
||||
import space.kscience.kmath.structures.MutableBuffer
|
||||
import space.kscience.kmath.nd.as1D
|
||||
import space.kscience.kmath.nd.as2D
|
||||
import space.kscience.kmath.structures.asMutableBuffer
|
||||
import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra
|
||||
import space.kscience.kmath.tensors.api.LinearOpsTensorAlgebra
|
||||
import space.kscience.kmath.tensors.api.Tensor
|
||||
@ -539,7 +538,7 @@ public open class DoubleTensorAlgebra :
|
||||
internal fun Tensor<Double>.fold(foldFunction: (DoubleArray) -> Double): Double =
|
||||
foldFunction(tensor.toDoubleArray())
|
||||
|
||||
internal fun <R> Tensor<Double>.foldDim(
|
||||
internal inline fun <reified R: Any> Tensor<Double>.foldDim(
|
||||
foldFunction: (DoubleArray) -> R,
|
||||
dim: Int,
|
||||
keepDim: Boolean,
|
||||
@ -553,7 +552,9 @@ public open class DoubleTensorAlgebra :
|
||||
val resNumElements = resShape.reduce(Int::times)
|
||||
val init = foldFunction(DoubleArray(1){0.0})
|
||||
val resTensor = BufferedTensor(resShape,
|
||||
MutableList(resNumElements) { init }.asMutableBuffer(), 0)
|
||||
MutableBuffer.auto(resNumElements) { init },
|
||||
//MutableList(resNumElements) { init }.asMutableBuffer(),
|
||||
0)
|
||||
for (index in resTensor.linearStructure.indices()) {
|
||||
val prefix = index.take(dim).toIntArray()
|
||||
val suffix = index.takeLast(dimension - dim - 1).toIntArray()
|
||||
@ -564,7 +565,6 @@ public open class DoubleTensorAlgebra :
|
||||
return resTensor
|
||||
}
|
||||
|
||||
|
||||
override fun Tensor<Double>.sum(): Double = tensor.fold { it.sum() }
|
||||
|
||||
override fun Tensor<Double>.sum(dim: Int, keepDim: Boolean): DoubleTensor =
|
||||
|
Loading…
Reference in New Issue
Block a user