No secondary constructors

This commit is contained in:
Roland Grinis 2021-04-30 14:44:42 +01:00
parent e5e62bc544
commit 6be5caa93f
11 changed files with 103 additions and 109 deletions

View File

@ -8,7 +8,7 @@ import space.kscience.kmath.tensors.core.algebras.TensorLinearStructure
public open class BufferedTensor<T>( public open class BufferedTensor<T>(
override val shape: IntArray, override val shape: IntArray,
internal val buffer: MutableBuffer<T>, internal val mutableBuffer: MutableBuffer<T>,
internal val bufferStart: Int internal val bufferStart: Int
) : TensorStructure<T> { ) : TensorStructure<T> {
public val linearStructure: TensorLinearStructure public val linearStructure: TensorLinearStructure
@ -17,10 +17,10 @@ public open class BufferedTensor<T>(
public val numElements: Int public val numElements: Int
get() = linearStructure.size get() = linearStructure.size
override fun get(index: IntArray): T = buffer[bufferStart + linearStructure.offset(index)] override fun get(index: IntArray): T = mutableBuffer[bufferStart + linearStructure.offset(index)]
override fun set(index: IntArray, value: T) { override fun set(index: IntArray, value: T) {
buffer[bufferStart + linearStructure.offset(index)] = value mutableBuffer[bufferStart + linearStructure.offset(index)] = value
} }
override fun elements(): Sequence<Pair<IntArray, T>> = linearStructure.indices().map { override fun elements(): Sequence<Pair<IntArray, T>> = linearStructure.indices().map {
@ -37,33 +37,28 @@ public class IntTensor internal constructor(
shape: IntArray, shape: IntArray,
buffer: IntArray, buffer: IntArray,
offset: Int = 0 offset: Int = 0
) : BufferedTensor<Int>(shape, IntBuffer(buffer), offset) { ) : BufferedTensor<Int>(shape, IntBuffer(buffer), offset)
internal constructor(bufferedTensor: BufferedTensor<Int>) :
this(bufferedTensor.shape, bufferedTensor.buffer.array(), bufferedTensor.bufferStart)
}
public class DoubleTensor internal constructor( public class DoubleTensor internal constructor(
shape: IntArray, shape: IntArray,
buffer: DoubleArray, buffer: DoubleArray,
offset: Int = 0 offset: Int = 0
) : BufferedTensor<Double>(shape, DoubleBuffer(buffer), offset) { ) : BufferedTensor<Double>(shape, DoubleBuffer(buffer), offset) {
internal constructor(bufferedTensor: BufferedTensor<Double>) :
this(bufferedTensor.shape, bufferedTensor.buffer.array(), bufferedTensor.bufferStart)
override fun toString(): String = toPrettyString() override fun toString(): String = toPrettyString()
} }
internal inline fun BufferedTensor<Int>.asTensor(): IntTensor = IntTensor(this) internal fun BufferedTensor<Int>.asTensor(): IntTensor =
internal inline fun BufferedTensor<Double>.asTensor(): DoubleTensor = DoubleTensor(this) IntTensor(this.shape, this.mutableBuffer.array(), this.bufferStart)
internal fun BufferedTensor<Double>.asTensor(): DoubleTensor =
DoubleTensor(this.shape, this.mutableBuffer.array(), this.bufferStart)
internal inline fun <T> TensorStructure<T>.copyToBufferedTensor(): BufferedTensor<T> = internal fun <T> TensorStructure<T>.copyToBufferedTensor(): BufferedTensor<T> =
BufferedTensor( BufferedTensor(
this.shape, this.shape,
TensorLinearStructure(this.shape).indices().map(this::get).toMutableList().asMutableBuffer(), 0 TensorLinearStructure(this.shape).indices().map(this::get).toMutableList().asMutableBuffer(), 0
) )
internal inline fun <T> TensorStructure<T>.toBufferedTensor(): BufferedTensor<T> = when (this) { internal fun <T> TensorStructure<T>.toBufferedTensor(): BufferedTensor<T> = when (this) {
is BufferedTensor<T> -> this is BufferedTensor<T> -> this
is MutableBufferND<T> -> if (this.strides.strides.toIntArray() contentEquals TensorLinearStructure(this.shape).strides) is MutableBufferND<T> -> if (this.strides.strides.toIntArray() contentEquals TensorLinearStructure(this.shape).strides)
BufferedTensor(this.shape, this.mutableBuffer, 0) else this.copyToBufferedTensor() BufferedTensor(this.shape, this.mutableBuffer, 0) else this.copyToBufferedTensor()

View File

@ -21,7 +21,7 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
val newThis = broadcast[0] val newThis = broadcast[0]
val newOther = broadcast[1] val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.linearStructure.size) { i -> val resBuffer = DoubleArray(newThis.linearStructure.size) { i ->
newThis.buffer.array()[i] + newOther.buffer.array()[i] newThis.mutableBuffer.array()[i] + newOther.mutableBuffer.array()[i]
} }
return DoubleTensor(newThis.shape, resBuffer) return DoubleTensor(newThis.shape, resBuffer)
} }
@ -29,8 +29,8 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
override fun TensorStructure<Double>.plusAssign(other: TensorStructure<Double>) { override fun TensorStructure<Double>.plusAssign(other: TensorStructure<Double>) {
val newOther = broadcastTo(other.tensor, tensor.shape) val newOther = broadcastTo(other.tensor, tensor.shape)
for (i in 0 until tensor.linearStructure.size) { for (i in 0 until tensor.linearStructure.size) {
tensor.buffer.array()[tensor.bufferStart + i] += tensor.mutableBuffer.array()[tensor.bufferStart + i] +=
newOther.buffer.array()[tensor.bufferStart + i] newOther.mutableBuffer.array()[tensor.bufferStart + i]
} }
} }
@ -39,7 +39,7 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
val newThis = broadcast[0] val newThis = broadcast[0]
val newOther = broadcast[1] val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.linearStructure.size) { i -> val resBuffer = DoubleArray(newThis.linearStructure.size) { i ->
newThis.buffer.array()[i] - newOther.buffer.array()[i] newThis.mutableBuffer.array()[i] - newOther.mutableBuffer.array()[i]
} }
return DoubleTensor(newThis.shape, resBuffer) return DoubleTensor(newThis.shape, resBuffer)
} }
@ -47,8 +47,8 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
override fun TensorStructure<Double>.minusAssign(other: TensorStructure<Double>) { override fun TensorStructure<Double>.minusAssign(other: TensorStructure<Double>) {
val newOther = broadcastTo(other.tensor, tensor.shape) val newOther = broadcastTo(other.tensor, tensor.shape)
for (i in 0 until tensor.linearStructure.size) { for (i in 0 until tensor.linearStructure.size) {
tensor.buffer.array()[tensor.bufferStart + i] -= tensor.mutableBuffer.array()[tensor.bufferStart + i] -=
newOther.buffer.array()[tensor.bufferStart + i] newOther.mutableBuffer.array()[tensor.bufferStart + i]
} }
} }
@ -57,8 +57,8 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
val newThis = broadcast[0] val newThis = broadcast[0]
val newOther = broadcast[1] val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.linearStructure.size) { i -> val resBuffer = DoubleArray(newThis.linearStructure.size) { i ->
newThis.buffer.array()[newThis.bufferStart + i] * newThis.mutableBuffer.array()[newThis.bufferStart + i] *
newOther.buffer.array()[newOther.bufferStart + i] newOther.mutableBuffer.array()[newOther.bufferStart + i]
} }
return DoubleTensor(newThis.shape, resBuffer) return DoubleTensor(newThis.shape, resBuffer)
} }
@ -66,8 +66,8 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
override fun TensorStructure<Double>.timesAssign(other: TensorStructure<Double>) { override fun TensorStructure<Double>.timesAssign(other: TensorStructure<Double>) {
val newOther = broadcastTo(other.tensor, tensor.shape) val newOther = broadcastTo(other.tensor, tensor.shape)
for (i in 0 until tensor.linearStructure.size) { for (i in 0 until tensor.linearStructure.size) {
tensor.buffer.array()[tensor.bufferStart + i] *= tensor.mutableBuffer.array()[tensor.bufferStart + i] *=
newOther.buffer.array()[tensor.bufferStart + i] newOther.mutableBuffer.array()[tensor.bufferStart + i]
} }
} }
@ -76,8 +76,8 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
val newThis = broadcast[0] val newThis = broadcast[0]
val newOther = broadcast[1] val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.linearStructure.size) { i -> val resBuffer = DoubleArray(newThis.linearStructure.size) { i ->
newThis.buffer.array()[newOther.bufferStart + i] / newThis.mutableBuffer.array()[newOther.bufferStart + i] /
newOther.buffer.array()[newOther.bufferStart + i] newOther.mutableBuffer.array()[newOther.bufferStart + i]
} }
return DoubleTensor(newThis.shape, resBuffer) return DoubleTensor(newThis.shape, resBuffer)
} }
@ -85,8 +85,8 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
override fun TensorStructure<Double>.divAssign(other: TensorStructure<Double>) { override fun TensorStructure<Double>.divAssign(other: TensorStructure<Double>) {
val newOther = broadcastTo(other.tensor, tensor.shape) val newOther = broadcastTo(other.tensor, tensor.shape)
for (i in 0 until tensor.linearStructure.size) { for (i in 0 until tensor.linearStructure.size) {
tensor.buffer.array()[tensor.bufferStart + i] /= tensor.mutableBuffer.array()[tensor.bufferStart + i] /=
newOther.buffer.array()[tensor.bufferStart + i] newOther.mutableBuffer.array()[tensor.bufferStart + i]
} }
} }
} }

View File

@ -106,7 +106,7 @@ public object DoubleLinearOpsTensorAlgebra :
val matrixSize = matrix.shape.reduce { acc, i -> acc * i } val matrixSize = matrix.shape.reduce { acc, i -> acc * i }
val curMatrix = DoubleTensor( val curMatrix = DoubleTensor(
matrix.shape, matrix.shape,
matrix.buffer.array().slice(matrix.bufferStart until matrix.bufferStart + matrixSize).toDoubleArray() matrix.mutableBuffer.array().slice(matrix.bufferStart until matrix.bufferStart + matrixSize).toDoubleArray()
) )
svdHelper(curMatrix, USV, m, n, epsilon) svdHelper(curMatrix, USV, m, n, epsilon)
} }

View File

@ -29,7 +29,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
check(tensor.shape contentEquals intArrayOf(1)) { check(tensor.shape contentEquals intArrayOf(1)) {
"Inconsistent value for tensor of shape ${shape.toList()}" "Inconsistent value for tensor of shape ${shape.toList()}"
} }
return tensor.buffer.array()[tensor.bufferStart] return tensor.mutableBuffer.array()[tensor.bufferStart]
} }
public fun fromArray(shape: IntArray, buffer: DoubleArray): DoubleTensor { public fun fromArray(shape: IntArray, buffer: DoubleArray): DoubleTensor {
@ -43,7 +43,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
val lastShape = tensor.shape.drop(1).toIntArray() val lastShape = tensor.shape.drop(1).toIntArray()
val newShape = if (lastShape.isNotEmpty()) lastShape else intArrayOf(1) val newShape = if (lastShape.isNotEmpty()) lastShape else intArrayOf(1)
val newStart = newShape.reduce(Int::times) * i + tensor.bufferStart val newStart = newShape.reduce(Int::times) * i + tensor.bufferStart
return DoubleTensor(newShape, tensor.buffer.array(), newStart) return DoubleTensor(newShape, tensor.mutableBuffer.array(), newStart)
} }
public fun full(value: Double, shape: IntArray): DoubleTensor { public fun full(value: Double, shape: IntArray): DoubleTensor {
@ -77,12 +77,12 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
} }
public fun TensorStructure<Double>.copy(): DoubleTensor { public fun TensorStructure<Double>.copy(): DoubleTensor {
return DoubleTensor(tensor.shape, tensor.buffer.array().copyOf(), tensor.bufferStart) return DoubleTensor(tensor.shape, tensor.mutableBuffer.array().copyOf(), tensor.bufferStart)
} }
override fun Double.plus(other: TensorStructure<Double>): DoubleTensor { override fun Double.plus(other: TensorStructure<Double>): DoubleTensor {
val resBuffer = DoubleArray(other.tensor.numElements) { i -> val resBuffer = DoubleArray(other.tensor.numElements) { i ->
other.tensor.buffer.array()[other.tensor.bufferStart + i] + this other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i] + this
} }
return DoubleTensor(other.shape, resBuffer) return DoubleTensor(other.shape, resBuffer)
} }
@ -92,35 +92,35 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
override fun TensorStructure<Double>.plus(other: TensorStructure<Double>): DoubleTensor { override fun TensorStructure<Double>.plus(other: TensorStructure<Double>): DoubleTensor {
checkShapesCompatible(tensor, other.tensor) checkShapesCompatible(tensor, other.tensor)
val resBuffer = DoubleArray(tensor.numElements) { i -> val resBuffer = DoubleArray(tensor.numElements) { i ->
tensor.buffer.array()[i] + other.tensor.buffer.array()[i] tensor.mutableBuffer.array()[i] + other.tensor.mutableBuffer.array()[i]
} }
return DoubleTensor(tensor.shape, resBuffer) return DoubleTensor(tensor.shape, resBuffer)
} }
override fun TensorStructure<Double>.plusAssign(value: Double) { override fun TensorStructure<Double>.plusAssign(value: Double) {
for (i in 0 until tensor.numElements) { for (i in 0 until tensor.numElements) {
tensor.buffer.array()[tensor.bufferStart + i] += value tensor.mutableBuffer.array()[tensor.bufferStart + i] += value
} }
} }
override fun TensorStructure<Double>.plusAssign(other: TensorStructure<Double>) { override fun TensorStructure<Double>.plusAssign(other: TensorStructure<Double>) {
checkShapesCompatible(tensor, other.tensor) checkShapesCompatible(tensor, other.tensor)
for (i in 0 until tensor.numElements) { for (i in 0 until tensor.numElements) {
tensor.buffer.array()[tensor.bufferStart + i] += tensor.mutableBuffer.array()[tensor.bufferStart + i] +=
other.tensor.buffer.array()[tensor.bufferStart + i] other.tensor.mutableBuffer.array()[tensor.bufferStart + i]
} }
} }
override fun Double.minus(other: TensorStructure<Double>): DoubleTensor { override fun Double.minus(other: TensorStructure<Double>): DoubleTensor {
val resBuffer = DoubleArray(other.tensor.numElements) { i -> val resBuffer = DoubleArray(other.tensor.numElements) { i ->
this - other.tensor.buffer.array()[other.tensor.bufferStart + i] this - other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i]
} }
return DoubleTensor(other.shape, resBuffer) return DoubleTensor(other.shape, resBuffer)
} }
override fun TensorStructure<Double>.minus(value: Double): DoubleTensor { override fun TensorStructure<Double>.minus(value: Double): DoubleTensor {
val resBuffer = DoubleArray(tensor.numElements) { i -> val resBuffer = DoubleArray(tensor.numElements) { i ->
tensor.buffer.array()[tensor.bufferStart + i] - value tensor.mutableBuffer.array()[tensor.bufferStart + i] - value
} }
return DoubleTensor(tensor.shape, resBuffer) return DoubleTensor(tensor.shape, resBuffer)
} }
@ -128,28 +128,28 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
override fun TensorStructure<Double>.minus(other: TensorStructure<Double>): DoubleTensor { override fun TensorStructure<Double>.minus(other: TensorStructure<Double>): DoubleTensor {
checkShapesCompatible(tensor, other) checkShapesCompatible(tensor, other)
val resBuffer = DoubleArray(tensor.numElements) { i -> val resBuffer = DoubleArray(tensor.numElements) { i ->
tensor.buffer.array()[i] - other.tensor.buffer.array()[i] tensor.mutableBuffer.array()[i] - other.tensor.mutableBuffer.array()[i]
} }
return DoubleTensor(tensor.shape, resBuffer) return DoubleTensor(tensor.shape, resBuffer)
} }
override fun TensorStructure<Double>.minusAssign(value: Double) { override fun TensorStructure<Double>.minusAssign(value: Double) {
for (i in 0 until tensor.numElements) { for (i in 0 until tensor.numElements) {
tensor.buffer.array()[tensor.bufferStart + i] -= value tensor.mutableBuffer.array()[tensor.bufferStart + i] -= value
} }
} }
override fun TensorStructure<Double>.minusAssign(other: TensorStructure<Double>) { override fun TensorStructure<Double>.minusAssign(other: TensorStructure<Double>) {
checkShapesCompatible(tensor, other) checkShapesCompatible(tensor, other)
for (i in 0 until tensor.numElements) { for (i in 0 until tensor.numElements) {
tensor.buffer.array()[tensor.bufferStart + i] -= tensor.mutableBuffer.array()[tensor.bufferStart + i] -=
other.tensor.buffer.array()[tensor.bufferStart + i] other.tensor.mutableBuffer.array()[tensor.bufferStart + i]
} }
} }
override fun Double.times(other: TensorStructure<Double>): DoubleTensor { override fun Double.times(other: TensorStructure<Double>): DoubleTensor {
val resBuffer = DoubleArray(other.tensor.numElements) { i -> val resBuffer = DoubleArray(other.tensor.numElements) { i ->
other.tensor.buffer.array()[other.tensor.bufferStart + i] * this other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i] * this
} }
return DoubleTensor(other.shape, resBuffer) return DoubleTensor(other.shape, resBuffer)
} }
@ -159,36 +159,36 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
override fun TensorStructure<Double>.times(other: TensorStructure<Double>): DoubleTensor { override fun TensorStructure<Double>.times(other: TensorStructure<Double>): DoubleTensor {
checkShapesCompatible(tensor, other) checkShapesCompatible(tensor, other)
val resBuffer = DoubleArray(tensor.numElements) { i -> val resBuffer = DoubleArray(tensor.numElements) { i ->
tensor.buffer.array()[tensor.bufferStart + i] * tensor.mutableBuffer.array()[tensor.bufferStart + i] *
other.tensor.buffer.array()[other.tensor.bufferStart + i] other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i]
} }
return DoubleTensor(tensor.shape, resBuffer) return DoubleTensor(tensor.shape, resBuffer)
} }
override fun TensorStructure<Double>.timesAssign(value: Double) { override fun TensorStructure<Double>.timesAssign(value: Double) {
for (i in 0 until tensor.numElements) { for (i in 0 until tensor.numElements) {
tensor.buffer.array()[tensor.bufferStart + i] *= value tensor.mutableBuffer.array()[tensor.bufferStart + i] *= value
} }
} }
override fun TensorStructure<Double>.timesAssign(other: TensorStructure<Double>) { override fun TensorStructure<Double>.timesAssign(other: TensorStructure<Double>) {
checkShapesCompatible(tensor, other) checkShapesCompatible(tensor, other)
for (i in 0 until tensor.numElements) { for (i in 0 until tensor.numElements) {
tensor.buffer.array()[tensor.bufferStart + i] *= tensor.mutableBuffer.array()[tensor.bufferStart + i] *=
other.tensor.buffer.array()[tensor.bufferStart + i] other.tensor.mutableBuffer.array()[tensor.bufferStart + i]
} }
} }
override fun Double.div(other: TensorStructure<Double>): DoubleTensor { override fun Double.div(other: TensorStructure<Double>): DoubleTensor {
val resBuffer = DoubleArray(other.tensor.numElements) { i -> val resBuffer = DoubleArray(other.tensor.numElements) { i ->
this / other.tensor.buffer.array()[other.tensor.bufferStart + i] this / other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i]
} }
return DoubleTensor(other.shape, resBuffer) return DoubleTensor(other.shape, resBuffer)
} }
override fun TensorStructure<Double>.div(value: Double): DoubleTensor { override fun TensorStructure<Double>.div(value: Double): DoubleTensor {
val resBuffer = DoubleArray(tensor.numElements) { i -> val resBuffer = DoubleArray(tensor.numElements) { i ->
tensor.buffer.array()[tensor.bufferStart + i] / value tensor.mutableBuffer.array()[tensor.bufferStart + i] / value
} }
return DoubleTensor(shape, resBuffer) return DoubleTensor(shape, resBuffer)
} }
@ -196,29 +196,29 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
override fun TensorStructure<Double>.div(other: TensorStructure<Double>): DoubleTensor { override fun TensorStructure<Double>.div(other: TensorStructure<Double>): DoubleTensor {
checkShapesCompatible(tensor, other) checkShapesCompatible(tensor, other)
val resBuffer = DoubleArray(tensor.numElements) { i -> val resBuffer = DoubleArray(tensor.numElements) { i ->
tensor.buffer.array()[other.tensor.bufferStart + i] / tensor.mutableBuffer.array()[other.tensor.bufferStart + i] /
other.tensor.buffer.array()[other.tensor.bufferStart + i] other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i]
} }
return DoubleTensor(tensor.shape, resBuffer) return DoubleTensor(tensor.shape, resBuffer)
} }
override fun TensorStructure<Double>.divAssign(value: Double) { override fun TensorStructure<Double>.divAssign(value: Double) {
for (i in 0 until tensor.numElements) { for (i in 0 until tensor.numElements) {
tensor.buffer.array()[tensor.bufferStart + i] /= value tensor.mutableBuffer.array()[tensor.bufferStart + i] /= value
} }
} }
override fun TensorStructure<Double>.divAssign(other: TensorStructure<Double>) { override fun TensorStructure<Double>.divAssign(other: TensorStructure<Double>) {
checkShapesCompatible(tensor, other) checkShapesCompatible(tensor, other)
for (i in 0 until tensor.numElements) { for (i in 0 until tensor.numElements) {
tensor.buffer.array()[tensor.bufferStart + i] /= tensor.mutableBuffer.array()[tensor.bufferStart + i] /=
other.tensor.buffer.array()[tensor.bufferStart + i] other.tensor.mutableBuffer.array()[tensor.bufferStart + i]
} }
} }
override fun TensorStructure<Double>.unaryMinus(): DoubleTensor { override fun TensorStructure<Double>.unaryMinus(): DoubleTensor {
val resBuffer = DoubleArray(tensor.numElements) { i -> val resBuffer = DoubleArray(tensor.numElements) { i ->
tensor.buffer.array()[tensor.bufferStart + i].unaryMinus() tensor.mutableBuffer.array()[tensor.bufferStart + i].unaryMinus()
} }
return DoubleTensor(tensor.shape, resBuffer) return DoubleTensor(tensor.shape, resBuffer)
} }
@ -241,8 +241,8 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
newMultiIndex[ii] = newMultiIndex[jj].also { newMultiIndex[jj] = newMultiIndex[ii] } newMultiIndex[ii] = newMultiIndex[jj].also { newMultiIndex[jj] = newMultiIndex[ii] }
val linearIndex = resTensor.linearStructure.offset(newMultiIndex) val linearIndex = resTensor.linearStructure.offset(newMultiIndex)
resTensor.buffer.array()[linearIndex] = resTensor.mutableBuffer.array()[linearIndex] =
tensor.buffer.array()[tensor.bufferStart + offset] tensor.mutableBuffer.array()[tensor.bufferStart + offset]
} }
return resTensor return resTensor
} }
@ -250,7 +250,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
override fun TensorStructure<Double>.view(shape: IntArray): DoubleTensor { override fun TensorStructure<Double>.view(shape: IntArray): DoubleTensor {
checkView(tensor, shape) checkView(tensor, shape)
return DoubleTensor(shape, tensor.buffer.array(), tensor.bufferStart) return DoubleTensor(shape, tensor.mutableBuffer.array(), tensor.bufferStart)
} }
override fun TensorStructure<Double>.viewAs(other: TensorStructure<Double>): DoubleTensor { override fun TensorStructure<Double>.viewAs(other: TensorStructure<Double>): DoubleTensor {
@ -259,7 +259,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
override infix fun TensorStructure<Double>.dot(other: TensorStructure<Double>): DoubleTensor { override infix fun TensorStructure<Double>.dot(other: TensorStructure<Double>): DoubleTensor {
if (tensor.shape.size == 1 && other.shape.size == 1) { if (tensor.shape.size == 1 && other.shape.size == 1) {
return DoubleTensor(intArrayOf(1), doubleArrayOf(tensor.times(other).tensor.buffer.array().sum())) return DoubleTensor(intArrayOf(1), doubleArrayOf(tensor.times(other).tensor.mutableBuffer.array().sum()))
} }
var newThis = tensor.copy() var newThis = tensor.copy()
@ -361,7 +361,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
public fun TensorStructure<Double>.map(transform: (Double) -> Double): DoubleTensor { public fun TensorStructure<Double>.map(transform: (Double) -> Double): DoubleTensor {
return DoubleTensor( return DoubleTensor(
tensor.shape, tensor.shape,
tensor.buffer.array().map { transform(it) }.toDoubleArray(), tensor.mutableBuffer.array().map { transform(it) }.toDoubleArray(),
tensor.bufferStart tensor.bufferStart
) )
} }
@ -382,7 +382,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
return false return false
} }
for (i in 0 until n) { for (i in 0 until n) {
if (!eqFunction(tensor.buffer[tensor.bufferStart + i], other.tensor.buffer[other.tensor.bufferStart + i])) { if (!eqFunction(tensor.mutableBuffer[tensor.bufferStart + i], other.tensor.mutableBuffer[other.tensor.bufferStart + i])) {
return false return false
} }
} }

View File

@ -18,8 +18,8 @@ internal inline fun multiIndexBroadCasting(tensor: DoubleTensor, resTensor: Doub
} }
val curLinearIndex = tensor.linearStructure.offset(curMultiIndex) val curLinearIndex = tensor.linearStructure.offset(curMultiIndex)
resTensor.buffer.array()[linearIndex] = resTensor.mutableBuffer.array()[linearIndex] =
tensor.buffer.array()[tensor.bufferStart + curLinearIndex] tensor.mutableBuffer.array()[tensor.bufferStart + curLinearIndex]
} }
} }
@ -113,7 +113,7 @@ internal inline fun broadcastOuterTensors(vararg tensors: DoubleTensor): List<Do
var curMultiIndex = tensor.shape.sliceArray(0..tensor.shape.size - 3).copyOf() var curMultiIndex = tensor.shape.sliceArray(0..tensor.shape.size - 3).copyOf()
curMultiIndex = IntArray(totalMultiIndex.size - curMultiIndex.size) { 1 } + curMultiIndex curMultiIndex = IntArray(totalMultiIndex.size - curMultiIndex.size) { 1 } + curMultiIndex
val newTensor = DoubleTensor(curMultiIndex + matrixShape, tensor.buffer.array()) val newTensor = DoubleTensor(curMultiIndex + matrixShape, tensor.mutableBuffer.array())
for (i in curMultiIndex.indices) { for (i in curMultiIndex.indices) {
if (curMultiIndex[i] != 1) { if (curMultiIndex[i] != 1) {
@ -133,8 +133,8 @@ internal inline fun broadcastOuterTensors(vararg tensors: DoubleTensor): List<Do
matrix.linearStructure.index(i) matrix.linearStructure.index(i)
) )
resTensor.buffer.array()[resTensor.bufferStart + newLinearIndex] = resTensor.mutableBuffer.array()[resTensor.bufferStart + newLinearIndex] =
newTensor.buffer.array()[newTensor.bufferStart + curLinearIndex] newTensor.mutableBuffer.array()[newTensor.bufferStart + curLinearIndex]
} }
} }
res += resTensor res += resTensor

View File

@ -18,7 +18,7 @@ internal inline fun <T> BufferedTensor<T>.vectorSequence(): Sequence<BufferedTen
val vectorOffset = shape[n - 1] val vectorOffset = shape[n - 1]
val vectorShape = intArrayOf(shape.last()) val vectorShape = intArrayOf(shape.last())
for (offset in 0 until numElements step vectorOffset) { for (offset in 0 until numElements step vectorOffset) {
val vector = BufferedTensor(vectorShape, buffer, offset) val vector = BufferedTensor(vectorShape, mutableBuffer, offset)
yield(vector) yield(vector)
} }
} }
@ -29,7 +29,7 @@ internal inline fun <T> BufferedTensor<T>.matrixSequence(): Sequence<BufferedTen
val matrixOffset = shape[n - 1] * shape[n - 2] val matrixOffset = shape[n - 1] * shape[n - 2]
val matrixShape = intArrayOf(shape[n - 2], shape[n - 1]) val matrixShape = intArrayOf(shape[n - 2], shape[n - 1])
for (offset in 0 until numElements step matrixOffset) { for (offset in 0 until numElements step matrixOffset) {
val matrix = BufferedTensor(matrixShape, buffer, offset) val matrix = BufferedTensor(matrixShape, mutableBuffer, offset)
yield(matrix) yield(matrix)
} }
} }
@ -322,16 +322,16 @@ internal inline fun DoubleLinearOpsTensorAlgebra.svdHelper(
} }
val s = res.map { it.first }.toDoubleArray() val s = res.map { it.first }.toDoubleArray()
val uBuffer = res.map { it.second }.flatMap { it.buffer.array().toList() }.toDoubleArray() val uBuffer = res.map { it.second }.flatMap { it.mutableBuffer.array().toList() }.toDoubleArray()
val vBuffer = res.map { it.third }.flatMap { it.buffer.array().toList() }.toDoubleArray() val vBuffer = res.map { it.third }.flatMap { it.mutableBuffer.array().toList() }.toDoubleArray()
for (i in uBuffer.indices) { for (i in uBuffer.indices) {
matrixU.buffer.array()[matrixU.bufferStart + i] = uBuffer[i] matrixU.mutableBuffer.array()[matrixU.bufferStart + i] = uBuffer[i]
} }
for (i in s.indices) { for (i in s.indices) {
matrixS.buffer.array()[matrixS.bufferStart + i] = s[i] matrixS.mutableBuffer.array()[matrixS.bufferStart + i] = s[i]
} }
for (i in vBuffer.indices) { for (i in vBuffer.indices) {
matrixV.buffer.array()[matrixV.bufferStart + i] = vBuffer[i] matrixV.mutableBuffer.array()[matrixV.bufferStart + i] = vBuffer[i]
} }
} }

View File

@ -30,7 +30,7 @@ internal class TestBroadcasting {
val res = broadcastTo(tensor2, tensor1.shape) val res = broadcastTo(tensor2, tensor1.shape)
assertTrue(res.shape contentEquals intArrayOf(2, 3)) assertTrue(res.shape contentEquals intArrayOf(2, 3))
assertTrue(res.buffer.array() contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0)) assertTrue(res.mutableBuffer.array() contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0))
} }
@Test @Test
@ -45,9 +45,9 @@ internal class TestBroadcasting {
assertTrue(res[1].shape contentEquals intArrayOf(1, 2, 3)) assertTrue(res[1].shape contentEquals intArrayOf(1, 2, 3))
assertTrue(res[2].shape contentEquals intArrayOf(1, 2, 3)) assertTrue(res[2].shape contentEquals intArrayOf(1, 2, 3))
assertTrue(res[0].buffer.array() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) assertTrue(res[0].mutableBuffer.array() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
assertTrue(res[1].buffer.array() contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0)) assertTrue(res[1].mutableBuffer.array() contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0))
assertTrue(res[2].buffer.array() contentEquals doubleArrayOf(500.0, 500.0, 500.0, 500.0, 500.0, 500.0)) assertTrue(res[2].mutableBuffer.array() contentEquals doubleArrayOf(500.0, 500.0, 500.0, 500.0, 500.0, 500.0))
} }
@Test @Test
@ -62,9 +62,9 @@ internal class TestBroadcasting {
assertTrue(res[1].shape contentEquals intArrayOf(1, 1, 3)) assertTrue(res[1].shape contentEquals intArrayOf(1, 1, 3))
assertTrue(res[2].shape contentEquals intArrayOf(1, 1, 1)) assertTrue(res[2].shape contentEquals intArrayOf(1, 1, 1))
assertTrue(res[0].buffer.array() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) assertTrue(res[0].mutableBuffer.array() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
assertTrue(res[1].buffer.array() contentEquals doubleArrayOf(10.0, 20.0, 30.0)) assertTrue(res[1].mutableBuffer.array() contentEquals doubleArrayOf(10.0, 20.0, 30.0))
assertTrue(res[2].buffer.array() contentEquals doubleArrayOf(500.0)) assertTrue(res[2].mutableBuffer.array() contentEquals doubleArrayOf(500.0))
} }
@Test @Test
@ -91,16 +91,16 @@ internal class TestBroadcasting {
val tensor32 = tensor3 - tensor2 val tensor32 = tensor3 - tensor2
assertTrue(tensor21.shape contentEquals intArrayOf(2, 3)) assertTrue(tensor21.shape contentEquals intArrayOf(2, 3))
assertTrue(tensor21.buffer.array() contentEquals doubleArrayOf(9.0, 18.0, 27.0, 6.0, 15.0, 24.0)) assertTrue(tensor21.mutableBuffer.array() contentEquals doubleArrayOf(9.0, 18.0, 27.0, 6.0, 15.0, 24.0))
assertTrue(tensor31.shape contentEquals intArrayOf(1, 2, 3)) assertTrue(tensor31.shape contentEquals intArrayOf(1, 2, 3))
assertTrue( assertTrue(
tensor31.buffer.array() tensor31.mutableBuffer.array()
contentEquals doubleArrayOf(499.0, 498.0, 497.0, 496.0, 495.0, 494.0) contentEquals doubleArrayOf(499.0, 498.0, 497.0, 496.0, 495.0, 494.0)
) )
assertTrue(tensor32.shape contentEquals intArrayOf(1, 1, 3)) assertTrue(tensor32.shape contentEquals intArrayOf(1, 1, 3))
assertTrue(tensor32.buffer.array() contentEquals doubleArrayOf(490.0, 480.0, 470.0)) assertTrue(tensor32.mutableBuffer.array() contentEquals doubleArrayOf(490.0, 480.0, 470.0))
} }
} }

View File

@ -30,7 +30,7 @@ internal class TestDoubleAnalyticTensorAlgebra {
fun testExp() = DoubleAnalyticTensorAlgebra.invoke { fun testExp() = DoubleAnalyticTensorAlgebra.invoke {
tensor.exp().let { tensor.exp().let {
assertTrue { shape contentEquals it.shape } assertTrue { shape contentEquals it.shape }
assertTrue { buffer.fmap(::exp).epsEqual(it.buffer.array())} assertTrue { buffer.fmap(::exp).epsEqual(it.mutableBuffer.array())}
} }
} }
} }

View File

@ -151,8 +151,8 @@ internal class TestDoubleLinearOpsTensorAlgebra {
val res = svd1d(tensor2) val res = svd1d(tensor2)
assertTrue(res.shape contentEquals intArrayOf(2)) assertTrue(res.shape contentEquals intArrayOf(2))
assertTrue { abs(abs(res.buffer.array()[res.bufferStart]) - 0.386) < 0.01 } assertTrue { abs(abs(res.mutableBuffer.array()[res.bufferStart]) - 0.386) < 0.01 }
assertTrue { abs(abs(res.buffer.array()[res.bufferStart + 1]) - 0.922) < 0.01 } assertTrue { abs(abs(res.mutableBuffer.array()[res.bufferStart + 1]) - 0.922) < 0.01 }
} }
@Test @Test

View File

@ -6,7 +6,6 @@ import space.kscience.kmath.nd.as1D
import space.kscience.kmath.nd.as2D import space.kscience.kmath.nd.as2D
import space.kscience.kmath.operations.invoke import space.kscience.kmath.operations.invoke
import space.kscience.kmath.structures.DoubleBuffer import space.kscience.kmath.structures.DoubleBuffer
import space.kscience.kmath.structures.asMutableBuffer
import space.kscience.kmath.structures.toDoubleArray import space.kscience.kmath.structures.toDoubleArray
import space.kscience.kmath.tensors.core.algebras.DoubleTensorAlgebra import space.kscience.kmath.tensors.core.algebras.DoubleTensorAlgebra
import kotlin.test.Test import kotlin.test.Test
@ -27,7 +26,7 @@ internal class TestDoubleTensor {
val tensor = fromArray(intArrayOf(2, 2), doubleArrayOf(3.5, 5.8, 58.4, 2.4)) val tensor = fromArray(intArrayOf(2, 2), doubleArrayOf(3.5, 5.8, 58.4, 2.4))
assertEquals(tensor[intArrayOf(0, 1)], 5.8) assertEquals(tensor[intArrayOf(0, 1)], 5.8)
assertTrue( assertTrue(
tensor.elements().map { it.second }.toList().toDoubleArray() contentEquals tensor.buffer.toDoubleArray() tensor.elements().map { it.second }.toList().toDoubleArray() contentEquals tensor.mutableBuffer.toDoubleArray()
) )
} }
@ -71,7 +70,7 @@ internal class TestDoubleTensor {
val tensorArrayPublic = ndArray.toTypedTensor() // public API, data copied twice val tensorArrayPublic = ndArray.toTypedTensor() // public API, data copied twice
val sharedTensorArray = tensorArrayPublic.toTypedTensor() // no data copied by matching type val sharedTensorArray = tensorArrayPublic.toTypedTensor() // no data copied by matching type
assertTrue(tensorArray.buffer.array() contentEquals sharedTensorArray.buffer.array()) assertTrue(tensorArray.mutableBuffer.array() contentEquals sharedTensorArray.mutableBuffer.array())
tensorArray[intArrayOf(0)] = 55.9 tensorArray[intArrayOf(0)] = 55.9
assertEquals(tensorArrayPublic[intArrayOf(0)], 1.0) assertEquals(tensorArrayPublic[intArrayOf(0)], 1.0)

View File

@ -13,21 +13,21 @@ internal class TestDoubleTensorAlgebra {
fun doublePlus() = DoubleTensorAlgebra.invoke { fun doublePlus() = DoubleTensorAlgebra.invoke {
val tensor = fromArray(intArrayOf(2), doubleArrayOf(1.0, 2.0)) val tensor = fromArray(intArrayOf(2), doubleArrayOf(1.0, 2.0))
val res = 10.0 + tensor val res = 10.0 + tensor
assertTrue(res.buffer.array() contentEquals doubleArrayOf(11.0, 12.0)) assertTrue(res.mutableBuffer.array() contentEquals doubleArrayOf(11.0, 12.0))
} }
@Test @Test
fun doubleDiv() = DoubleTensorAlgebra.invoke { fun doubleDiv() = DoubleTensorAlgebra.invoke {
val tensor = fromArray(intArrayOf(2), doubleArrayOf(2.0, 4.0)) val tensor = fromArray(intArrayOf(2), doubleArrayOf(2.0, 4.0))
val res = 2.0/tensor val res = 2.0/tensor
assertTrue(res.buffer.array() contentEquals doubleArrayOf(1.0, 0.5)) assertTrue(res.mutableBuffer.array() contentEquals doubleArrayOf(1.0, 0.5))
} }
@Test @Test
fun divDouble() = DoubleTensorAlgebra.invoke { fun divDouble() = DoubleTensorAlgebra.invoke {
val tensor = fromArray(intArrayOf(2), doubleArrayOf(10.0, 5.0)) val tensor = fromArray(intArrayOf(2), doubleArrayOf(10.0, 5.0))
val res = tensor / 2.5 val res = tensor / 2.5
assertTrue(res.buffer.array() contentEquals doubleArrayOf(4.0, 2.0)) assertTrue(res.mutableBuffer.array() contentEquals doubleArrayOf(4.0, 2.0))
} }
@Test @Test
@ -35,7 +35,7 @@ internal class TestDoubleTensorAlgebra {
val tensor = fromArray(intArrayOf(1), doubleArrayOf(0.0)) val tensor = fromArray(intArrayOf(1), doubleArrayOf(0.0))
val res = tensor.transpose(0, 0) val res = tensor.transpose(0, 0)
assertTrue(res.buffer.array() contentEquals doubleArrayOf(0.0)) assertTrue(res.mutableBuffer.array() contentEquals doubleArrayOf(0.0))
assertTrue(res.shape contentEquals intArrayOf(1)) assertTrue(res.shape contentEquals intArrayOf(1))
} }
@ -44,7 +44,7 @@ internal class TestDoubleTensorAlgebra {
val tensor = fromArray(intArrayOf(3, 2), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) val tensor = fromArray(intArrayOf(3, 2), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
val res = tensor.transpose(1, 0) val res = tensor.transpose(1, 0)
assertTrue(res.buffer.array() contentEquals doubleArrayOf(1.0, 3.0, 5.0, 2.0, 4.0, 6.0)) assertTrue(res.mutableBuffer.array() contentEquals doubleArrayOf(1.0, 3.0, 5.0, 2.0, 4.0, 6.0))
assertTrue(res.shape contentEquals intArrayOf(2, 3)) assertTrue(res.shape contentEquals intArrayOf(2, 3))
} }
@ -59,9 +59,9 @@ internal class TestDoubleTensorAlgebra {
assertTrue(res02.shape contentEquals intArrayOf(3, 2, 1)) assertTrue(res02.shape contentEquals intArrayOf(3, 2, 1))
assertTrue(res12.shape contentEquals intArrayOf(1, 3, 2)) assertTrue(res12.shape contentEquals intArrayOf(1, 3, 2))
assertTrue(res01.buffer.array() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) assertTrue(res01.mutableBuffer.array() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
assertTrue(res02.buffer.array() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0)) assertTrue(res02.mutableBuffer.array() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
assertTrue(res12.buffer.array() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0)) assertTrue(res12.mutableBuffer.array() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
} }
@Test @Test
@ -92,8 +92,8 @@ internal class TestDoubleTensorAlgebra {
assignResult += tensorC assignResult += tensorC
assignResult += -39.4 assignResult += -39.4
assertTrue(expected.buffer.array() contentEquals result.buffer.array()) assertTrue(expected.mutableBuffer.array() contentEquals result.mutableBuffer.array())
assertTrue(expected.buffer.array() contentEquals assignResult.buffer.array()) assertTrue(expected.mutableBuffer.array() contentEquals assignResult.mutableBuffer.array())
} }
@Test @Test
@ -104,19 +104,19 @@ internal class TestDoubleTensorAlgebra {
val tensor3 = fromArray(intArrayOf(1, 1, 3), doubleArrayOf(-1.0, -2.0, -3.0)) val tensor3 = fromArray(intArrayOf(1, 1, 3), doubleArrayOf(-1.0, -2.0, -3.0))
val res12 = tensor1.dot(tensor2) val res12 = tensor1.dot(tensor2)
assertTrue(res12.buffer.array() contentEquals doubleArrayOf(140.0, 320.0)) assertTrue(res12.mutableBuffer.array() contentEquals doubleArrayOf(140.0, 320.0))
assertTrue(res12.shape contentEquals intArrayOf(2)) assertTrue(res12.shape contentEquals intArrayOf(2))
val res32 = tensor3.dot(tensor2) val res32 = tensor3.dot(tensor2)
assertTrue(res32.buffer.array() contentEquals doubleArrayOf(-140.0)) assertTrue(res32.mutableBuffer.array() contentEquals doubleArrayOf(-140.0))
assertTrue(res32.shape contentEquals intArrayOf(1, 1)) assertTrue(res32.shape contentEquals intArrayOf(1, 1))
val res22 = tensor2.dot(tensor2) val res22 = tensor2.dot(tensor2)
assertTrue(res22.buffer.array() contentEquals doubleArrayOf(1400.0)) assertTrue(res22.mutableBuffer.array() contentEquals doubleArrayOf(1400.0))
assertTrue(res22.shape contentEquals intArrayOf(1)) assertTrue(res22.shape contentEquals intArrayOf(1))
val res11 = tensor1.dot(tensor11) val res11 = tensor1.dot(tensor11)
assertTrue(res11.buffer.array() contentEquals doubleArrayOf(22.0, 28.0, 49.0, 64.0)) assertTrue(res11.mutableBuffer.array() contentEquals doubleArrayOf(22.0, 28.0, 49.0, 64.0))
assertTrue(res11.shape contentEquals intArrayOf(2, 2)) assertTrue(res11.shape contentEquals intArrayOf(2, 2))
var tensor4 = fromArray(intArrayOf(10, 3, 4), DoubleArray(10 * 3 * 4) {0.0}) var tensor4 = fromArray(intArrayOf(10, 3, 4), DoubleArray(10 * 3 * 4) {0.0})
@ -147,17 +147,17 @@ internal class TestDoubleTensorAlgebra {
val diagonal1 = diagonalEmbedding(tensor1, 0, 1, 0) val diagonal1 = diagonalEmbedding(tensor1, 0, 1, 0)
assertTrue(diagonal1.shape contentEquals intArrayOf(3, 3)) assertTrue(diagonal1.shape contentEquals intArrayOf(3, 3))
assertTrue(diagonal1.buffer.array() contentEquals assertTrue(diagonal1.mutableBuffer.array() contentEquals
doubleArrayOf(10.0, 0.0, 0.0, 0.0, 20.0, 0.0, 0.0, 0.0, 30.0)) doubleArrayOf(10.0, 0.0, 0.0, 0.0, 20.0, 0.0, 0.0, 0.0, 30.0))
val diagonal1Offset = diagonalEmbedding(tensor1, 1, 1, 0) val diagonal1Offset = diagonalEmbedding(tensor1, 1, 1, 0)
assertTrue(diagonal1Offset.shape contentEquals intArrayOf(4, 4)) assertTrue(diagonal1Offset.shape contentEquals intArrayOf(4, 4))
assertTrue(diagonal1Offset.buffer.array() contentEquals assertTrue(diagonal1Offset.mutableBuffer.array() contentEquals
doubleArrayOf(0.0, 0.0, 0.0, 0.0, 10.0, 0.0, 0.0, 0.0, 0.0, 20.0, 0.0, 0.0, 0.0, 0.0, 30.0, 0.0)) doubleArrayOf(0.0, 0.0, 0.0, 0.0, 10.0, 0.0, 0.0, 0.0, 0.0, 20.0, 0.0, 0.0, 0.0, 0.0, 30.0, 0.0))
val diagonal2 = diagonalEmbedding(tensor2, 1, 0, 2) val diagonal2 = diagonalEmbedding(tensor2, 1, 0, 2)
assertTrue(diagonal2.shape contentEquals intArrayOf(4, 2, 4)) assertTrue(diagonal2.shape contentEquals intArrayOf(4, 2, 4))
assertTrue(diagonal2.buffer.array() contentEquals assertTrue(diagonal2.mutableBuffer.array() contentEquals
doubleArrayOf( doubleArrayOf(
0.0, 1.0, 0.0, 0.0, 0.0, 4.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 4.0, 0.0, 0.0,
0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 5.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 5.0, 0.0,