fix argmax

This commit is contained in:
Roland Grinis 2021-11-01 17:55:10 +00:00
parent b65197f577
commit a994b8a50c
9 changed files with 102 additions and 92 deletions

View File

@ -277,7 +277,7 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>> : TensorAlgebra<T, A>
TODO("Not yet implemented") TODO("Not yet implemented")
} }
override fun StructureND<T>.argMax(dim: Int, keepDim: Boolean): Tensor<T> { override fun StructureND<T>.argMax(dim: Int, keepDim: Boolean): Tensor<Int> {
TODO("Not yet implemented") TODO("Not yet implemented")
} }
} }

View File

@ -108,8 +108,8 @@ public sealed interface Nd4jTensorAlgebra<T : Number, A : Field<T>> : AnalyticTe
override fun Tensor<T>.view(shape: IntArray): Nd4jArrayStructure<T> = ndArray.reshape(shape).wrap() override fun Tensor<T>.view(shape: IntArray): Nd4jArrayStructure<T> = ndArray.reshape(shape).wrap()
override fun Tensor<T>.viewAs(other: StructureND<T>): Nd4jArrayStructure<T> = view(other.shape) override fun Tensor<T>.viewAs(other: StructureND<T>): Nd4jArrayStructure<T> = view(other.shape)
override fun StructureND<T>.argMax(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> = override fun StructureND<T>.argMax(dim: Int, keepDim: Boolean): Tensor<Int> =
ndBase.get().argmax(ndArray, keepDim, dim).wrap() ndBase.get().argmax(ndArray, keepDim, dim).asIntStructure()
override fun StructureND<T>.mean(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> = override fun StructureND<T>.mean(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> =
ndArray.mean(keepDim, dim).wrap() ndArray.mean(keepDim, dim).wrap()

View File

@ -324,7 +324,7 @@ public interface TensorAlgebra<T, A : Ring<T>> : RingOpsND<T, A> {
* @param keepDim whether the output tensor has [dim] retained or not. * @param keepDim whether the output tensor has [dim] retained or not.
* @return the index of maximum value of each row of the input tensor in the given dimension [dim]. * @return the index of maximum value of each row of the input tensor in the given dimension [dim].
*/ */
public fun StructureND<T>.argMax(dim: Int, keepDim: Boolean): Tensor<T> public fun StructureND<T>.argMax(dim: Int, keepDim: Boolean): Tensor<Int>
override fun add(left: StructureND<T>, right: StructureND<T>): Tensor<T> = left + right override fun add(left: StructureND<T>, right: StructureND<T>): Tensor<T> = left + right

View File

@ -9,7 +9,6 @@ import space.kscience.kmath.misc.PerformancePitfall
import space.kscience.kmath.nd.Strides import space.kscience.kmath.nd.Strides
import space.kscience.kmath.structures.MutableBuffer import space.kscience.kmath.structures.MutableBuffer
import space.kscience.kmath.tensors.api.Tensor import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.core.internal.TensorLinearStructure
/** /**
* Represents [Tensor] over a [MutableBuffer] intended to be used through [DoubleTensor] and [IntTensor] * Represents [Tensor] over a [MutableBuffer] intended to be used through [DoubleTensor] and [IntTensor]

View File

@ -10,6 +10,7 @@ import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.nd.as1D import space.kscience.kmath.nd.as1D
import space.kscience.kmath.nd.as2D import space.kscience.kmath.nd.as2D
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.structures.MutableBuffer
import space.kscience.kmath.structures.indices import space.kscience.kmath.structures.indices
import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra
import space.kscience.kmath.tensors.api.LinearOpsTensorAlgebra import space.kscience.kmath.tensors.api.LinearOpsTensorAlgebra
@ -571,11 +572,11 @@ public open class DoubleTensorAlgebra :
internal inline fun StructureND<Double>.fold(foldFunction: (DoubleArray) -> Double): Double = internal inline fun StructureND<Double>.fold(foldFunction: (DoubleArray) -> Double): Double =
foldFunction(tensor.copyArray()) foldFunction(tensor.copyArray())
internal inline fun StructureND<Double>.foldDim( internal inline fun <reified R: Any> StructureND<Double>.foldDim(
foldFunction: (DoubleArray) -> Double, foldFunction: (DoubleArray) -> R,
dim: Int, dim: Int,
keepDim: Boolean, keepDim: Boolean,
): DoubleTensor { ): BufferedTensor<R> {
check(dim < dimension) { "Dimension $dim out of range $dimension" } check(dim < dimension) { "Dimension $dim out of range $dimension" }
val resShape = if (keepDim) { val resShape = if (keepDim) {
shape.take(dim).toIntArray() + intArrayOf(1) + shape.takeLast(dimension - dim - 1).toIntArray() shape.take(dim).toIntArray() + intArrayOf(1) + shape.takeLast(dimension - dim - 1).toIntArray()
@ -583,37 +584,39 @@ public open class DoubleTensorAlgebra :
shape.take(dim).toIntArray() + shape.takeLast(dimension - dim - 1).toIntArray() shape.take(dim).toIntArray() + shape.takeLast(dimension - dim - 1).toIntArray()
} }
val resNumElements = resShape.reduce(Int::times) val resNumElements = resShape.reduce(Int::times)
val resTensor = DoubleTensor(resShape, DoubleArray(resNumElements) { 0.0 }, 0) val init = foldFunction(DoubleArray(1){0.0})
for (index in resTensor.indices.asSequence()) { val resTensor = BufferedTensor(resShape,
MutableBuffer.auto(resNumElements) { init }, 0)
for (index in resTensor.indices) {
val prefix = index.take(dim).toIntArray() val prefix = index.take(dim).toIntArray()
val suffix = index.takeLast(dimension - dim - 1).toIntArray() val suffix = index.takeLast(dimension - dim - 1).toIntArray()
resTensor[index] = foldFunction(DoubleArray(shape[dim]) { i -> resTensor[index] = foldFunction(DoubleArray(shape[dim]) { i ->
tensor[prefix + intArrayOf(i) + suffix] tensor[prefix + intArrayOf(i) + suffix]
}) })
} }
return resTensor return resTensor
} }
override fun StructureND<Double>.sum(): Double = tensor.fold { it.sum() } override fun StructureND<Double>.sum(): Double = tensor.fold { it.sum() }
override fun StructureND<Double>.sum(dim: Int, keepDim: Boolean): DoubleTensor = override fun StructureND<Double>.sum(dim: Int, keepDim: Boolean): DoubleTensor =
foldDim({ x -> x.sum() }, dim, keepDim) foldDim({ x -> x.sum() }, dim, keepDim).toDoubleTensor()
override fun StructureND<Double>.min(): Double = this.fold { it.minOrNull()!! } override fun StructureND<Double>.min(): Double = this.fold { it.minOrNull()!! }
override fun StructureND<Double>.min(dim: Int, keepDim: Boolean): DoubleTensor = override fun StructureND<Double>.min(dim: Int, keepDim: Boolean): DoubleTensor =
foldDim({ x -> x.minOrNull()!! }, dim, keepDim) foldDim({ x -> x.minOrNull()!! }, dim, keepDim).toDoubleTensor()
override fun StructureND<Double>.max(): Double = this.fold { it.maxOrNull()!! } override fun StructureND<Double>.max(): Double = this.fold { it.maxOrNull()!! }
override fun StructureND<Double>.max(dim: Int, keepDim: Boolean): DoubleTensor = override fun StructureND<Double>.max(dim: Int, keepDim: Boolean): DoubleTensor =
foldDim({ x -> x.maxOrNull()!! }, dim, keepDim) foldDim({ x -> x.maxOrNull()!! }, dim, keepDim).toDoubleTensor()
override fun StructureND<Double>.argMax(dim: Int, keepDim: Boolean): DoubleTensor =
override fun StructureND<Double>.argMax(dim: Int, keepDim: Boolean): IntTensor =
foldDim({ x -> foldDim({ x ->
x.withIndex().maxByOrNull { it.value }?.index!!.toDouble() x.withIndex().maxByOrNull { it.value }?.index!!
}, dim, keepDim) }, dim, keepDim).toIntTensor()
override fun StructureND<Double>.mean(): Double = this.fold { it.sum() / tensor.numElements } override fun StructureND<Double>.mean(): Double = this.fold { it.sum() / tensor.numElements }
@ -626,7 +629,7 @@ public open class DoubleTensorAlgebra :
}, },
dim, dim,
keepDim keepDim
) ).toDoubleTensor()
override fun StructureND<Double>.std(): Double = this.fold { arr -> override fun StructureND<Double>.std(): Double = this.fold { arr ->
val mean = arr.sum() / tensor.numElements val mean = arr.sum() / tensor.numElements
@ -641,7 +644,7 @@ public open class DoubleTensorAlgebra :
}, },
dim, dim,
keepDim keepDim
) ).toDoubleTensor()
override fun StructureND<Double>.variance(): Double = this.fold { arr -> override fun StructureND<Double>.variance(): Double = this.fold { arr ->
val mean = arr.sum() / tensor.numElements val mean = arr.sum() / tensor.numElements
@ -656,7 +659,7 @@ public open class DoubleTensorAlgebra :
}, },
dim, dim,
keepDim keepDim
) ).toDoubleTensor()
private fun cov(x: DoubleTensor, y: DoubleTensor): Double { private fun cov(x: DoubleTensor, y: DoubleTensor): Double {
val n = x.shape[0] val n = x.shape[0]

View File

@ -6,6 +6,7 @@
package space.kscience.kmath.tensors.core package space.kscience.kmath.tensors.core
import space.kscience.kmath.structures.IntBuffer import space.kscience.kmath.structures.IntBuffer
import space.kscience.kmath.tensors.core.internal.array
/** /**
* Default [BufferedTensor] implementation for [Int] values * Default [BufferedTensor] implementation for [Int] values
@ -14,4 +15,7 @@ public class IntTensor internal constructor(
shape: IntArray, shape: IntArray,
buffer: IntArray, buffer: IntArray,
offset: Int = 0 offset: Int = 0
) : BufferedTensor<Int>(shape, IntBuffer(buffer), offset) ) : BufferedTensor<Int>(shape, IntBuffer(buffer), offset){
public fun asDouble() : DoubleTensor =
DoubleTensor(shape, mutableBuffer.array().map{ it.toDouble()}.toDoubleArray(), bufferStart)
}

View File

@ -0,0 +1,74 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.tensors.core
import space.kscience.kmath.nd.Strides
import kotlin.math.max
/**
* This [Strides] implementation follows the last dimension first convention
* For more information: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.strides.html
*
* @param shape the shape of the tensor.
*/
public class TensorLinearStructure(override val shape: IntArray) : Strides() {
override val strides: IntArray
get() = stridesFromShape(shape)
override fun index(offset: Int): IntArray =
indexFromOffset(offset, strides, shape.size)
override val linearSize: Int
get() = shape.reduce(Int::times)
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other == null || this::class != other::class) return false
other as TensorLinearStructure
if (!shape.contentEquals(other.shape)) return false
return true
}
override fun hashCode(): Int {
return shape.contentHashCode()
}
public companion object {
public fun stridesFromShape(shape: IntArray): IntArray {
val nDim = shape.size
val res = IntArray(nDim)
if (nDim == 0)
return res
var current = nDim - 1
res[current] = 1
while (current > 0) {
res[current - 1] = max(1, shape[current]) * res[current]
current--
}
return res
}
public fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArray {
val res = IntArray(nDim)
var current = offset
var strideIndex = 0
while (strideIndex < nDim) {
res[strideIndex] = (current / strides[strideIndex])
current %= strides[strideIndex]
strideIndex++
}
return res
}
}
}

View File

@ -1,71 +0,0 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
*/
package space.kscience.kmath.tensors.core.internal
import space.kscience.kmath.nd.Strides
import kotlin.math.max
internal fun stridesFromShape(shape: IntArray): IntArray {
val nDim = shape.size
val res = IntArray(nDim)
if (nDim == 0)
return res
var current = nDim - 1
res[current] = 1
while (current > 0) {
res[current - 1] = max(1, shape[current]) * res[current]
current--
}
return res
}
internal fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArray {
val res = IntArray(nDim)
var current = offset
var strideIndex = 0
while (strideIndex < nDim) {
res[strideIndex] = (current / strides[strideIndex])
current %= strides[strideIndex]
strideIndex++
}
return res
}
/**
* This [Strides] implementation follows the last dimension first convention
* For more information: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.strides.html
*
* @param shape the shape of the tensor.
*/
internal class TensorLinearStructure(override val shape: IntArray) : Strides() {
override val strides: IntArray
get() = stridesFromShape(shape)
override fun index(offset: Int): IntArray =
indexFromOffset(offset, strides, shape.size)
override val linearSize: Int
get() = shape.reduce(Int::times)
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other == null || this::class != other::class) return false
other as TensorLinearStructure
if (!shape.contentEquals(other.shape)) return false
return true
}
override fun hashCode(): Int {
return shape.contentHashCode()
}
}

View File

@ -12,6 +12,7 @@ import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.core.BufferedTensor import space.kscience.kmath.tensors.core.BufferedTensor
import space.kscience.kmath.tensors.core.DoubleTensor import space.kscience.kmath.tensors.core.DoubleTensor
import space.kscience.kmath.tensors.core.IntTensor import space.kscience.kmath.tensors.core.IntTensor
import space.kscience.kmath.tensors.core.TensorLinearStructure
internal fun BufferedTensor<Int>.asTensor(): IntTensor = internal fun BufferedTensor<Int>.asTensor(): IntTensor =
IntTensor(this.shape, this.mutableBuffer.array(), this.bufferStart) IntTensor(this.shape, this.mutableBuffer.array(), this.bufferStart)