fix sequences + array casting

This commit is contained in:
Andrei Kislitsyn 2021-05-01 13:32:50 +03:00
parent 1b6bd67b90
commit b7cac3a015
2 changed files with 39 additions and 3 deletions

View File

@ -49,6 +49,7 @@ public class DoubleTensor internal constructor(
internal fun BufferedTensor<Int>.asTensor(): IntTensor =
IntTensor(this.shape, this.mutableBuffer.array(), this.bufferStart)
internal fun BufferedTensor<Double>.asTensor(): DoubleTensor =
DoubleTensor(this.shape, this.mutableBuffer.array(), this.bufferStart)
@ -78,4 +79,39 @@ internal val TensorStructure<Int>.tensor: IntTensor
}
public fun TensorStructure<Double>.toDoubleTensor(): DoubleTensor = this.tensor
public fun TensorStructure<Int>.toIntTensor(): IntTensor = this.tensor
public fun TensorStructure<Int>.toIntTensor(): IntTensor = this.tensor
public fun Array<DoubleArray>.toDoubleTensor(): DoubleTensor {
val n = size
check(n > 0) { "An empty array cannot be casted to tensor" }
val m = first().size
check(m > 0) { "Inner arrays must have at least 1 argument" }
check(all { size == m }) { "Inner arrays must be the same size" }
val shape = intArrayOf(n, m)
val buffer = this.flatMap { arr -> arr.map { it } }.toDoubleArray()
return DoubleTensor(shape, buffer, 0)
}
public fun Array<IntArray>.toIntTensor(): IntTensor {
val n = size
check(n > 0) { "An empty array cannot be casted to tensor" }
val m = first().size
check(m > 0) { "Inner arrays must have at least 1 argument" }
check(all { size == m }) { "Inner arrays must be the same size" }
val shape = intArrayOf(n, m)
val buffer = this.flatMap { arr -> arr.map { it } }.toIntArray()
return IntTensor(shape, buffer, 0)
}
public fun DoubleTensor.toDoubleArray(): DoubleArray {
return tensor.mutableBuffer.array().drop(bufferStart).take(numElements).toDoubleArray()
}
public fun IntTensor.toIntArray(): IntArray {
return tensor.mutableBuffer.array().drop(bufferStart).take(numElements).toIntArray()
}

View File

@ -18,7 +18,7 @@ internal fun <T> BufferedTensor<T>.vectorSequence(): Sequence<BufferedTensor<T>>
val vectorOffset = shape[n - 1]
val vectorShape = intArrayOf(shape.last())
for (offset in 0 until numElements step vectorOffset) {
val vector = BufferedTensor(vectorShape, mutableBuffer, offset)
val vector = BufferedTensor(vectorShape, mutableBuffer, bufferStart + offset)
yield(vector)
}
}
@ -29,7 +29,7 @@ internal fun <T> BufferedTensor<T>.matrixSequence(): Sequence<BufferedTensor<T>>
val matrixOffset = shape[n - 1] * shape[n - 2]
val matrixShape = intArrayOf(shape[n - 2], shape[n - 1])
for (offset in 0 until numElements step matrixOffset) {
val matrix = BufferedTensor(matrixShape, mutableBuffer, offset)
val matrix = BufferedTensor(matrixShape, mutableBuffer, bufferStart + offset)
yield(matrix)
}
}