fix folding for DoubleTensor
This commit is contained in:
parent
280c4e97e2
commit
bea6ed4d65
@ -8,7 +8,6 @@ package space.kscience.kmath.tensors.core
|
|||||||
import space.kscience.kmath.structures.MutableBuffer
|
import space.kscience.kmath.structures.MutableBuffer
|
||||||
import space.kscience.kmath.nd.as1D
|
import space.kscience.kmath.nd.as1D
|
||||||
import space.kscience.kmath.nd.as2D
|
import space.kscience.kmath.nd.as2D
|
||||||
import space.kscience.kmath.structures.asMutableBuffer
|
|
||||||
import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra
|
import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra
|
||||||
import space.kscience.kmath.tensors.api.LinearOpsTensorAlgebra
|
import space.kscience.kmath.tensors.api.LinearOpsTensorAlgebra
|
||||||
import space.kscience.kmath.tensors.api.Tensor
|
import space.kscience.kmath.tensors.api.Tensor
|
||||||
@ -539,7 +538,7 @@ public open class DoubleTensorAlgebra :
|
|||||||
internal fun Tensor<Double>.fold(foldFunction: (DoubleArray) -> Double): Double =
|
internal fun Tensor<Double>.fold(foldFunction: (DoubleArray) -> Double): Double =
|
||||||
foldFunction(tensor.toDoubleArray())
|
foldFunction(tensor.toDoubleArray())
|
||||||
|
|
||||||
internal fun <R> Tensor<Double>.foldDim(
|
internal inline fun <reified R: Any> Tensor<Double>.foldDim(
|
||||||
foldFunction: (DoubleArray) -> R,
|
foldFunction: (DoubleArray) -> R,
|
||||||
dim: Int,
|
dim: Int,
|
||||||
keepDim: Boolean,
|
keepDim: Boolean,
|
||||||
@ -553,7 +552,9 @@ public open class DoubleTensorAlgebra :
|
|||||||
val resNumElements = resShape.reduce(Int::times)
|
val resNumElements = resShape.reduce(Int::times)
|
||||||
val init = foldFunction(DoubleArray(1){0.0})
|
val init = foldFunction(DoubleArray(1){0.0})
|
||||||
val resTensor = BufferedTensor(resShape,
|
val resTensor = BufferedTensor(resShape,
|
||||||
MutableList(resNumElements) { init }.asMutableBuffer(), 0)
|
MutableBuffer.auto(resNumElements) { init },
|
||||||
|
//MutableList(resNumElements) { init }.asMutableBuffer(),
|
||||||
|
0)
|
||||||
for (index in resTensor.linearStructure.indices()) {
|
for (index in resTensor.linearStructure.indices()) {
|
||||||
val prefix = index.take(dim).toIntArray()
|
val prefix = index.take(dim).toIntArray()
|
||||||
val suffix = index.takeLast(dimension - dim - 1).toIntArray()
|
val suffix = index.takeLast(dimension - dim - 1).toIntArray()
|
||||||
@ -564,7 +565,6 @@ public open class DoubleTensorAlgebra :
|
|||||||
return resTensor
|
return resTensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
override fun Tensor<Double>.sum(): Double = tensor.fold { it.sum() }
|
override fun Tensor<Double>.sum(): Double = tensor.fold { it.sum() }
|
||||||
|
|
||||||
override fun Tensor<Double>.sum(dim: Int, keepDim: Boolean): DoubleTensor =
|
override fun Tensor<Double>.sum(dim: Int, keepDim: Boolean): DoubleTensor =
|
||||||
|
Loading…
Reference in New Issue
Block a user