fix argmax
This commit is contained in:
parent
b65197f577
commit
a994b8a50c
@ -277,7 +277,7 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>> : TensorAlgebra<T, A>
|
|||||||
TODO("Not yet implemented")
|
TODO("Not yet implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun StructureND<T>.argMax(dim: Int, keepDim: Boolean): Tensor<T> {
|
override fun StructureND<T>.argMax(dim: Int, keepDim: Boolean): Tensor<Int> {
|
||||||
TODO("Not yet implemented")
|
TODO("Not yet implemented")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -108,8 +108,8 @@ public sealed interface Nd4jTensorAlgebra<T : Number, A : Field<T>> : AnalyticTe
|
|||||||
override fun Tensor<T>.view(shape: IntArray): Nd4jArrayStructure<T> = ndArray.reshape(shape).wrap()
|
override fun Tensor<T>.view(shape: IntArray): Nd4jArrayStructure<T> = ndArray.reshape(shape).wrap()
|
||||||
override fun Tensor<T>.viewAs(other: StructureND<T>): Nd4jArrayStructure<T> = view(other.shape)
|
override fun Tensor<T>.viewAs(other: StructureND<T>): Nd4jArrayStructure<T> = view(other.shape)
|
||||||
|
|
||||||
override fun StructureND<T>.argMax(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> =
|
override fun StructureND<T>.argMax(dim: Int, keepDim: Boolean): Tensor<Int> =
|
||||||
ndBase.get().argmax(ndArray, keepDim, dim).wrap()
|
ndBase.get().argmax(ndArray, keepDim, dim).asIntStructure()
|
||||||
|
|
||||||
override fun StructureND<T>.mean(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> =
|
override fun StructureND<T>.mean(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> =
|
||||||
ndArray.mean(keepDim, dim).wrap()
|
ndArray.mean(keepDim, dim).wrap()
|
||||||
|
@ -324,7 +324,7 @@ public interface TensorAlgebra<T, A : Ring<T>> : RingOpsND<T, A> {
|
|||||||
* @param keepDim whether the output tensor has [dim] retained or not.
|
* @param keepDim whether the output tensor has [dim] retained or not.
|
||||||
* @return the index of maximum value of each row of the input tensor in the given dimension [dim].
|
* @return the index of maximum value of each row of the input tensor in the given dimension [dim].
|
||||||
*/
|
*/
|
||||||
public fun StructureND<T>.argMax(dim: Int, keepDim: Boolean): Tensor<T>
|
public fun StructureND<T>.argMax(dim: Int, keepDim: Boolean): Tensor<Int>
|
||||||
|
|
||||||
override fun add(left: StructureND<T>, right: StructureND<T>): Tensor<T> = left + right
|
override fun add(left: StructureND<T>, right: StructureND<T>): Tensor<T> = left + right
|
||||||
|
|
||||||
|
@ -9,7 +9,6 @@ import space.kscience.kmath.misc.PerformancePitfall
|
|||||||
import space.kscience.kmath.nd.Strides
|
import space.kscience.kmath.nd.Strides
|
||||||
import space.kscience.kmath.structures.MutableBuffer
|
import space.kscience.kmath.structures.MutableBuffer
|
||||||
import space.kscience.kmath.tensors.api.Tensor
|
import space.kscience.kmath.tensors.api.Tensor
|
||||||
import space.kscience.kmath.tensors.core.internal.TensorLinearStructure
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Represents [Tensor] over a [MutableBuffer] intended to be used through [DoubleTensor] and [IntTensor]
|
* Represents [Tensor] over a [MutableBuffer] intended to be used through [DoubleTensor] and [IntTensor]
|
||||||
|
@ -10,6 +10,7 @@ import space.kscience.kmath.nd.StructureND
|
|||||||
import space.kscience.kmath.nd.as1D
|
import space.kscience.kmath.nd.as1D
|
||||||
import space.kscience.kmath.nd.as2D
|
import space.kscience.kmath.nd.as2D
|
||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.operations.DoubleField
|
||||||
|
import space.kscience.kmath.structures.MutableBuffer
|
||||||
import space.kscience.kmath.structures.indices
|
import space.kscience.kmath.structures.indices
|
||||||
import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra
|
import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra
|
||||||
import space.kscience.kmath.tensors.api.LinearOpsTensorAlgebra
|
import space.kscience.kmath.tensors.api.LinearOpsTensorAlgebra
|
||||||
@ -571,11 +572,11 @@ public open class DoubleTensorAlgebra :
|
|||||||
internal inline fun StructureND<Double>.fold(foldFunction: (DoubleArray) -> Double): Double =
|
internal inline fun StructureND<Double>.fold(foldFunction: (DoubleArray) -> Double): Double =
|
||||||
foldFunction(tensor.copyArray())
|
foldFunction(tensor.copyArray())
|
||||||
|
|
||||||
internal inline fun StructureND<Double>.foldDim(
|
internal inline fun <reified R: Any> StructureND<Double>.foldDim(
|
||||||
foldFunction: (DoubleArray) -> Double,
|
foldFunction: (DoubleArray) -> R,
|
||||||
dim: Int,
|
dim: Int,
|
||||||
keepDim: Boolean,
|
keepDim: Boolean,
|
||||||
): DoubleTensor {
|
): BufferedTensor<R> {
|
||||||
check(dim < dimension) { "Dimension $dim out of range $dimension" }
|
check(dim < dimension) { "Dimension $dim out of range $dimension" }
|
||||||
val resShape = if (keepDim) {
|
val resShape = if (keepDim) {
|
||||||
shape.take(dim).toIntArray() + intArrayOf(1) + shape.takeLast(dimension - dim - 1).toIntArray()
|
shape.take(dim).toIntArray() + intArrayOf(1) + shape.takeLast(dimension - dim - 1).toIntArray()
|
||||||
@ -583,37 +584,39 @@ public open class DoubleTensorAlgebra :
|
|||||||
shape.take(dim).toIntArray() + shape.takeLast(dimension - dim - 1).toIntArray()
|
shape.take(dim).toIntArray() + shape.takeLast(dimension - dim - 1).toIntArray()
|
||||||
}
|
}
|
||||||
val resNumElements = resShape.reduce(Int::times)
|
val resNumElements = resShape.reduce(Int::times)
|
||||||
val resTensor = DoubleTensor(resShape, DoubleArray(resNumElements) { 0.0 }, 0)
|
val init = foldFunction(DoubleArray(1){0.0})
|
||||||
for (index in resTensor.indices.asSequence()) {
|
val resTensor = BufferedTensor(resShape,
|
||||||
|
MutableBuffer.auto(resNumElements) { init }, 0)
|
||||||
|
for (index in resTensor.indices) {
|
||||||
val prefix = index.take(dim).toIntArray()
|
val prefix = index.take(dim).toIntArray()
|
||||||
val suffix = index.takeLast(dimension - dim - 1).toIntArray()
|
val suffix = index.takeLast(dimension - dim - 1).toIntArray()
|
||||||
resTensor[index] = foldFunction(DoubleArray(shape[dim]) { i ->
|
resTensor[index] = foldFunction(DoubleArray(shape[dim]) { i ->
|
||||||
tensor[prefix + intArrayOf(i) + suffix]
|
tensor[prefix + intArrayOf(i) + suffix]
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return resTensor
|
return resTensor
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun StructureND<Double>.sum(): Double = tensor.fold { it.sum() }
|
override fun StructureND<Double>.sum(): Double = tensor.fold { it.sum() }
|
||||||
|
|
||||||
override fun StructureND<Double>.sum(dim: Int, keepDim: Boolean): DoubleTensor =
|
override fun StructureND<Double>.sum(dim: Int, keepDim: Boolean): DoubleTensor =
|
||||||
foldDim({ x -> x.sum() }, dim, keepDim)
|
foldDim({ x -> x.sum() }, dim, keepDim).toDoubleTensor()
|
||||||
|
|
||||||
override fun StructureND<Double>.min(): Double = this.fold { it.minOrNull()!! }
|
override fun StructureND<Double>.min(): Double = this.fold { it.minOrNull()!! }
|
||||||
|
|
||||||
override fun StructureND<Double>.min(dim: Int, keepDim: Boolean): DoubleTensor =
|
override fun StructureND<Double>.min(dim: Int, keepDim: Boolean): DoubleTensor =
|
||||||
foldDim({ x -> x.minOrNull()!! }, dim, keepDim)
|
foldDim({ x -> x.minOrNull()!! }, dim, keepDim).toDoubleTensor()
|
||||||
|
|
||||||
override fun StructureND<Double>.max(): Double = this.fold { it.maxOrNull()!! }
|
override fun StructureND<Double>.max(): Double = this.fold { it.maxOrNull()!! }
|
||||||
|
|
||||||
override fun StructureND<Double>.max(dim: Int, keepDim: Boolean): DoubleTensor =
|
override fun StructureND<Double>.max(dim: Int, keepDim: Boolean): DoubleTensor =
|
||||||
foldDim({ x -> x.maxOrNull()!! }, dim, keepDim)
|
foldDim({ x -> x.maxOrNull()!! }, dim, keepDim).toDoubleTensor()
|
||||||
|
|
||||||
override fun StructureND<Double>.argMax(dim: Int, keepDim: Boolean): DoubleTensor =
|
|
||||||
|
override fun StructureND<Double>.argMax(dim: Int, keepDim: Boolean): IntTensor =
|
||||||
foldDim({ x ->
|
foldDim({ x ->
|
||||||
x.withIndex().maxByOrNull { it.value }?.index!!.toDouble()
|
x.withIndex().maxByOrNull { it.value }?.index!!
|
||||||
}, dim, keepDim)
|
}, dim, keepDim).toIntTensor()
|
||||||
|
|
||||||
|
|
||||||
override fun StructureND<Double>.mean(): Double = this.fold { it.sum() / tensor.numElements }
|
override fun StructureND<Double>.mean(): Double = this.fold { it.sum() / tensor.numElements }
|
||||||
@ -626,7 +629,7 @@ public open class DoubleTensorAlgebra :
|
|||||||
},
|
},
|
||||||
dim,
|
dim,
|
||||||
keepDim
|
keepDim
|
||||||
)
|
).toDoubleTensor()
|
||||||
|
|
||||||
override fun StructureND<Double>.std(): Double = this.fold { arr ->
|
override fun StructureND<Double>.std(): Double = this.fold { arr ->
|
||||||
val mean = arr.sum() / tensor.numElements
|
val mean = arr.sum() / tensor.numElements
|
||||||
@ -641,7 +644,7 @@ public open class DoubleTensorAlgebra :
|
|||||||
},
|
},
|
||||||
dim,
|
dim,
|
||||||
keepDim
|
keepDim
|
||||||
)
|
).toDoubleTensor()
|
||||||
|
|
||||||
override fun StructureND<Double>.variance(): Double = this.fold { arr ->
|
override fun StructureND<Double>.variance(): Double = this.fold { arr ->
|
||||||
val mean = arr.sum() / tensor.numElements
|
val mean = arr.sum() / tensor.numElements
|
||||||
@ -656,7 +659,7 @@ public open class DoubleTensorAlgebra :
|
|||||||
},
|
},
|
||||||
dim,
|
dim,
|
||||||
keepDim
|
keepDim
|
||||||
)
|
).toDoubleTensor()
|
||||||
|
|
||||||
private fun cov(x: DoubleTensor, y: DoubleTensor): Double {
|
private fun cov(x: DoubleTensor, y: DoubleTensor): Double {
|
||||||
val n = x.shape[0]
|
val n = x.shape[0]
|
||||||
|
@ -6,6 +6,7 @@
|
|||||||
package space.kscience.kmath.tensors.core
|
package space.kscience.kmath.tensors.core
|
||||||
|
|
||||||
import space.kscience.kmath.structures.IntBuffer
|
import space.kscience.kmath.structures.IntBuffer
|
||||||
|
import space.kscience.kmath.tensors.core.internal.array
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Default [BufferedTensor] implementation for [Int] values
|
* Default [BufferedTensor] implementation for [Int] values
|
||||||
@ -14,4 +15,7 @@ public class IntTensor internal constructor(
|
|||||||
shape: IntArray,
|
shape: IntArray,
|
||||||
buffer: IntArray,
|
buffer: IntArray,
|
||||||
offset: Int = 0
|
offset: Int = 0
|
||||||
) : BufferedTensor<Int>(shape, IntBuffer(buffer), offset)
|
) : BufferedTensor<Int>(shape, IntBuffer(buffer), offset){
|
||||||
|
public fun asDouble() : DoubleTensor =
|
||||||
|
DoubleTensor(shape, mutableBuffer.array().map{ it.toDouble()}.toDoubleArray(), bufferStart)
|
||||||
|
}
|
||||||
|
@ -0,0 +1,74 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2021 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package space.kscience.kmath.tensors.core
|
||||||
|
|
||||||
|
import space.kscience.kmath.nd.Strides
|
||||||
|
import kotlin.math.max
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This [Strides] implementation follows the last dimension first convention
|
||||||
|
* For more information: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.strides.html
|
||||||
|
*
|
||||||
|
* @param shape the shape of the tensor.
|
||||||
|
*/
|
||||||
|
public class TensorLinearStructure(override val shape: IntArray) : Strides() {
|
||||||
|
override val strides: IntArray
|
||||||
|
get() = stridesFromShape(shape)
|
||||||
|
|
||||||
|
override fun index(offset: Int): IntArray =
|
||||||
|
indexFromOffset(offset, strides, shape.size)
|
||||||
|
|
||||||
|
override val linearSize: Int
|
||||||
|
get() = shape.reduce(Int::times)
|
||||||
|
|
||||||
|
override fun equals(other: Any?): Boolean {
|
||||||
|
if (this === other) return true
|
||||||
|
if (other == null || this::class != other::class) return false
|
||||||
|
|
||||||
|
other as TensorLinearStructure
|
||||||
|
|
||||||
|
if (!shape.contentEquals(other.shape)) return false
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun hashCode(): Int {
|
||||||
|
return shape.contentHashCode()
|
||||||
|
}
|
||||||
|
|
||||||
|
public companion object {
|
||||||
|
|
||||||
|
public fun stridesFromShape(shape: IntArray): IntArray {
|
||||||
|
val nDim = shape.size
|
||||||
|
val res = IntArray(nDim)
|
||||||
|
if (nDim == 0)
|
||||||
|
return res
|
||||||
|
|
||||||
|
var current = nDim - 1
|
||||||
|
res[current] = 1
|
||||||
|
|
||||||
|
while (current > 0) {
|
||||||
|
res[current - 1] = max(1, shape[current]) * res[current]
|
||||||
|
current--
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArray {
|
||||||
|
val res = IntArray(nDim)
|
||||||
|
var current = offset
|
||||||
|
var strideIndex = 0
|
||||||
|
|
||||||
|
while (strideIndex < nDim) {
|
||||||
|
res[strideIndex] = (current / strides[strideIndex])
|
||||||
|
current %= strides[strideIndex]
|
||||||
|
strideIndex++
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -1,71 +0,0 @@
|
|||||||
/*
|
|
||||||
* Copyright 2018-2021 KMath contributors.
|
|
||||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package space.kscience.kmath.tensors.core.internal
|
|
||||||
|
|
||||||
import space.kscience.kmath.nd.Strides
|
|
||||||
import kotlin.math.max
|
|
||||||
|
|
||||||
|
|
||||||
internal fun stridesFromShape(shape: IntArray): IntArray {
|
|
||||||
val nDim = shape.size
|
|
||||||
val res = IntArray(nDim)
|
|
||||||
if (nDim == 0)
|
|
||||||
return res
|
|
||||||
|
|
||||||
var current = nDim - 1
|
|
||||||
res[current] = 1
|
|
||||||
|
|
||||||
while (current > 0) {
|
|
||||||
res[current - 1] = max(1, shape[current]) * res[current]
|
|
||||||
current--
|
|
||||||
}
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
internal fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArray {
|
|
||||||
val res = IntArray(nDim)
|
|
||||||
var current = offset
|
|
||||||
var strideIndex = 0
|
|
||||||
|
|
||||||
while (strideIndex < nDim) {
|
|
||||||
res[strideIndex] = (current / strides[strideIndex])
|
|
||||||
current %= strides[strideIndex]
|
|
||||||
strideIndex++
|
|
||||||
}
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* This [Strides] implementation follows the last dimension first convention
|
|
||||||
* For more information: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.strides.html
|
|
||||||
*
|
|
||||||
* @param shape the shape of the tensor.
|
|
||||||
*/
|
|
||||||
internal class TensorLinearStructure(override val shape: IntArray) : Strides() {
|
|
||||||
override val strides: IntArray
|
|
||||||
get() = stridesFromShape(shape)
|
|
||||||
|
|
||||||
override fun index(offset: Int): IntArray =
|
|
||||||
indexFromOffset(offset, strides, shape.size)
|
|
||||||
|
|
||||||
override val linearSize: Int
|
|
||||||
get() = shape.reduce(Int::times)
|
|
||||||
|
|
||||||
override fun equals(other: Any?): Boolean {
|
|
||||||
if (this === other) return true
|
|
||||||
if (other == null || this::class != other::class) return false
|
|
||||||
|
|
||||||
other as TensorLinearStructure
|
|
||||||
|
|
||||||
if (!shape.contentEquals(other.shape)) return false
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun hashCode(): Int {
|
|
||||||
return shape.contentHashCode()
|
|
||||||
}
|
|
||||||
}
|
|
@ -12,6 +12,7 @@ import space.kscience.kmath.tensors.api.Tensor
|
|||||||
import space.kscience.kmath.tensors.core.BufferedTensor
|
import space.kscience.kmath.tensors.core.BufferedTensor
|
||||||
import space.kscience.kmath.tensors.core.DoubleTensor
|
import space.kscience.kmath.tensors.core.DoubleTensor
|
||||||
import space.kscience.kmath.tensors.core.IntTensor
|
import space.kscience.kmath.tensors.core.IntTensor
|
||||||
|
import space.kscience.kmath.tensors.core.TensorLinearStructure
|
||||||
|
|
||||||
internal fun BufferedTensor<Int>.asTensor(): IntTensor =
|
internal fun BufferedTensor<Int>.asTensor(): IntTensor =
|
||||||
IntTensor(this.shape, this.mutableBuffer.array(), this.bufferStart)
|
IntTensor(this.shape, this.mutableBuffer.array(), this.bufferStart)
|
||||||
|
Loading…
Reference in New Issue
Block a user