rework structure + fixes

This commit is contained in:
Andrei Kislitsyn 2021-04-21 23:44:39 +03:00
parent cc11df6174
commit 559e8b24ab
20 changed files with 237 additions and 160 deletions

View File

@ -1,4 +1,9 @@
package space.kscience.kmath.tensors
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.tensors.api
public interface AnalyticTensorAlgebra<T> :

View File

@ -1,4 +1,9 @@
package space.kscience.kmath.tensors
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.tensors.api
public interface LinearOpsTensorAlgebra<T> :

View File

@ -1,4 +1,9 @@
package space.kscience.kmath.tensors
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.tensors.api
// https://proofwiki.org/wiki/Definition:Algebra_over_Ring
public interface TensorAlgebra<T> {
@ -40,7 +45,9 @@ public interface TensorAlgebra<T> {
//https://pytorch.org/docs/stable/generated/torch.diag_embed.html
public fun diagonalEmbedding(
diagonalEntries: TensorStructure<T>,
offset: Int = 0, dim1: Int = -2, dim2: Int = -1
offset: Int = 0,
dim1: Int = -2,
dim2: Int = -1
): TensorStructure<T>
}

View File

@ -1,4 +1,9 @@
package space.kscience.kmath.tensors
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.tensors.api
// https://proofwiki.org/wiki/Definition:Division_Algebra
public interface TensorPartialDivisionAlgebra<T> :

View File

@ -1,4 +1,4 @@
package space.kscience.kmath.tensors
package space.kscience.kmath.tensors.api
import space.kscience.kmath.nd.MutableStructureND

View File

@ -1,7 +1,8 @@
package space.kscience.kmath.tensors.core
import space.kscience.kmath.structures.*
import space.kscience.kmath.tensors.TensorStructure
import space.kscience.kmath.tensors.api.TensorStructure
import space.kscience.kmath.tensors.core.algebras.TensorLinearStructure
public open class BufferedTensor<T>(

View File

@ -0,0 +1,89 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.tensors.core.algebras
import space.kscience.kmath.tensors.api.TensorStructure
import space.kscience.kmath.tensors.core.*
import space.kscience.kmath.tensors.core.broadcastTensors
import space.kscience.kmath.tensors.core.broadcastTo
public class BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
override fun TensorStructure<Double>.plus(other: TensorStructure<Double>): DoubleTensor {
val broadcast = broadcastTensors(tensor, other.tensor)
val newThis = broadcast[0]
val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.linearStructure.size) { i ->
newThis.buffer.array()[i] + newOther.buffer.array()[i]
}
return DoubleTensor(newThis.shape, resBuffer)
}
override fun TensorStructure<Double>.plusAssign(other: TensorStructure<Double>) {
val newOther = broadcastTo(other.tensor, tensor.shape)
for (i in 0 until tensor.linearStructure.size) {
tensor.buffer.array()[tensor.bufferStart + i] +=
newOther.buffer.array()[tensor.bufferStart + i]
}
}
override fun TensorStructure<Double>.minus(other: TensorStructure<Double>): DoubleTensor {
val broadcast = broadcastTensors(tensor, other.tensor)
val newThis = broadcast[0]
val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.linearStructure.size) { i ->
newThis.buffer.array()[i] - newOther.buffer.array()[i]
}
return DoubleTensor(newThis.shape, resBuffer)
}
override fun TensorStructure<Double>.minusAssign(other: TensorStructure<Double>) {
val newOther = broadcastTo(other.tensor, tensor.shape)
for (i in 0 until tensor.linearStructure.size) {
tensor.buffer.array()[tensor.bufferStart + i] -=
newOther.buffer.array()[tensor.bufferStart + i]
}
}
override fun TensorStructure<Double>.times(other: TensorStructure<Double>): DoubleTensor {
val broadcast = broadcastTensors(tensor, other.tensor)
val newThis = broadcast[0]
val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.linearStructure.size) { i ->
newThis.buffer.array()[newThis.bufferStart + i] *
newOther.buffer.array()[newOther.bufferStart + i]
}
return DoubleTensor(newThis.shape, resBuffer)
}
override fun TensorStructure<Double>.timesAssign(other: TensorStructure<Double>) {
val newOther = broadcastTo(other.tensor, tensor.shape)
for (i in 0 until tensor.linearStructure.size) {
tensor.buffer.array()[tensor.bufferStart + i] *=
newOther.buffer.array()[tensor.bufferStart + i]
}
}
override fun TensorStructure<Double>.div(other: TensorStructure<Double>): DoubleTensor {
val broadcast = broadcastTensors(tensor, other.tensor)
val newThis = broadcast[0]
val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.linearStructure.size) { i ->
newThis.buffer.array()[newOther.bufferStart + i] /
newOther.buffer.array()[newOther.bufferStart + i]
}
return DoubleTensor(newThis.shape, resBuffer)
}
override fun TensorStructure<Double>.divAssign(other: TensorStructure<Double>) {
val newOther = broadcastTo(other.tensor, tensor.shape)
for (i in 0 until tensor.linearStructure.size) {
tensor.buffer.array()[tensor.bufferStart + i] /=
newOther.buffer.array()[tensor.bufferStart + i]
}
}
}

View File

@ -1,7 +1,14 @@
package space.kscience.kmath.tensors.core
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
import space.kscience.kmath.tensors.AnalyticTensorAlgebra
import space.kscience.kmath.tensors.TensorStructure
package space.kscience.kmath.tensors.core.algebras
import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra
import space.kscience.kmath.tensors.api.TensorStructure
import space.kscience.kmath.tensors.core.DoubleTensor
import space.kscience.kmath.tensors.core.tensor
import kotlin.math.*
public class DoubleAnalyticTensorAlgebra:

View File

@ -1,9 +1,23 @@
package space.kscience.kmath.tensors.core
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
import space.kscience.kmath.tensors.LinearOpsTensorAlgebra
package space.kscience.kmath.tensors.core.algebras
import space.kscience.kmath.tensors.api.LinearOpsTensorAlgebra
import space.kscience.kmath.nd.as1D
import space.kscience.kmath.nd.as2D
import space.kscience.kmath.tensors.TensorStructure
import space.kscience.kmath.tensors.api.TensorStructure
import space.kscience.kmath.tensors.core.*
import space.kscience.kmath.tensors.core.checkSquareMatrix
import space.kscience.kmath.tensors.core.choleskyHelper
import space.kscience.kmath.tensors.core.cleanSymHelper
import space.kscience.kmath.tensors.core.luHelper
import space.kscience.kmath.tensors.core.luMatrixDet
import space.kscience.kmath.tensors.core.luMatrixInv
import space.kscience.kmath.tensors.core.luPivotHelper
import space.kscience.kmath.tensors.core.pivInit
import kotlin.math.min
@ -25,12 +39,11 @@ public class DoubleLinearOpsTensorAlgebra :
luTensor: TensorStructure<Double>,
pivotsTensor: TensorStructure<Int>
): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
//todo checks
checkSquareMatrix(luTensor.shape)
check(
luTensor.shape.dropLast(2).toIntArray() contentEquals pivotsTensor.shape.dropLast(1).toIntArray() ||
luTensor.shape.last() == pivotsTensor.shape.last() - 1
) { "Bad shapes ((" } //todo rewrite
) { "Inappropriate shapes of input tensors" }
val n = luTensor.shape.last()
val pTensor = luTensor.zeroesLike()
@ -90,10 +103,10 @@ public class DoubleLinearOpsTensorAlgebra :
for ((matrix, USV) in tensor.matrixSequence()
.zip(resU.matrixSequence().zip(resS.vectorSequence().zip(resV.matrixSequence())))) {
val size = matrix.shape.reduce { acc, i -> acc * i }
val matrixSize = matrix.shape.reduce { acc, i -> acc * i }
val curMatrix = DoubleTensor(
matrix.shape,
matrix.buffer.array().slice(matrix.bufferStart until matrix.bufferStart + size).toDoubleArray()
matrix.buffer.array().slice(matrix.bufferStart until matrix.bufferStart + matrixSize).toDoubleArray()
)
svdHelper(curMatrix, USV, m, n, epsilon)
}

View File

@ -1,8 +1,24 @@
package space.kscience.kmath.tensors.core
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.tensors.core.algebras
import space.kscience.kmath.nd.as2D
import space.kscience.kmath.tensors.TensorPartialDivisionAlgebra
import space.kscience.kmath.tensors.TensorStructure
import space.kscience.kmath.tensors.api.TensorPartialDivisionAlgebra
import space.kscience.kmath.tensors.api.TensorStructure
import space.kscience.kmath.tensors.core.*
import space.kscience.kmath.tensors.core.broadcastOuterTensors
import space.kscience.kmath.tensors.core.checkBufferShapeConsistency
import space.kscience.kmath.tensors.core.checkEmptyDoubleBuffer
import space.kscience.kmath.tensors.core.checkEmptyShape
import space.kscience.kmath.tensors.core.checkShapesCompatible
import space.kscience.kmath.tensors.core.checkTranspose
import space.kscience.kmath.tensors.core.checkView
import space.kscience.kmath.tensors.core.dotHelper
import space.kscience.kmath.tensors.core.getRandomNormals
import space.kscience.kmath.tensors.core.minusIndexFrom
import kotlin.math.abs
public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
@ -263,7 +279,6 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
if (m1 != m2) {
throw RuntimeException("Tensors dot operation dimension mismatch: ($l, $m1) x ($m2, $n)")
}
val m = m1
val resShape = newThis.shape.sliceArray(0..(newThis.shape.size - 2)) + intArrayOf(newOther.shape.last())
val resSize = resShape.reduce { acc, i -> acc * i }
@ -271,7 +286,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
for ((res, ab) in resTensor.matrixSequence().zip(newThis.matrixSequence().zip(newOther.matrixSequence()))) {
val (a, b) = ab
dotHelper(a.as2D(), b.as2D(), res.as2D(), l, m, n)
dotHelper(a.as2D(), b.as2D(), res.as2D(), l, m1, n)
}
if (penultimateDim) {
@ -347,7 +362,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
return tensor.eq(other) { x, y -> abs(x - y) < delta }
}
public fun TensorStructure<Double>.eq(other: TensorStructure<Double>): Boolean = tensor.eq(other, 1e-5)
public infix fun TensorStructure<Double>.eq(other: TensorStructure<Double>): Boolean = tensor.eq(other, 1e-5)
private fun TensorStructure<Double>.eq(
other: TensorStructure<Double>,

View File

@ -1,4 +1,9 @@
package space.kscience.kmath.tensors.core
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.tensors.core.algebras
import kotlin.math.max

View File

@ -1,89 +1,31 @@
package space.kscience.kmath.tensors.core
import space.kscience.kmath.tensors.TensorStructure
import space.kscience.kmath.tensors.core.algebras.BroadcastDoubleTensorAlgebra
import kotlin.math.max
public class BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
override fun TensorStructure<Double>.plus(other: TensorStructure<Double>): DoubleTensor {
val broadcast = broadcastTensors(tensor, other.tensor)
val newThis = broadcast[0]
val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.linearStructure.size) { i ->
newThis.buffer.array()[i] + newOther.buffer.array()[i]
}
return DoubleTensor(newThis.shape, resBuffer)
}
override fun TensorStructure<Double>.plusAssign(other: TensorStructure<Double>) {
val newOther = broadcastTo(other.tensor, tensor.shape)
for (i in 0 until tensor.linearStructure.size) {
tensor.buffer.array()[tensor.bufferStart + i] +=
newOther.buffer.array()[tensor.bufferStart + i]
}
}
override fun TensorStructure<Double>.minus(other: TensorStructure<Double>): DoubleTensor {
val broadcast = broadcastTensors(tensor, other.tensor)
val newThis = broadcast[0]
val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.linearStructure.size) { i ->
newThis.buffer.array()[i] - newOther.buffer.array()[i]
}
return DoubleTensor(newThis.shape, resBuffer)
}
override fun TensorStructure<Double>.minusAssign(other: TensorStructure<Double>) {
val newOther = broadcastTo(other.tensor, tensor.shape)
for (i in 0 until tensor.linearStructure.size) {
tensor.buffer.array()[tensor.bufferStart + i] -=
newOther.buffer.array()[tensor.bufferStart + i]
}
}
override fun TensorStructure<Double>.times(other: TensorStructure<Double>): DoubleTensor {
val broadcast = broadcastTensors(tensor, other.tensor)
val newThis = broadcast[0]
val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.linearStructure.size) { i ->
newThis.buffer.array()[newThis.bufferStart + i] *
newOther.buffer.array()[newOther.bufferStart + i]
}
return DoubleTensor(newThis.shape, resBuffer)
}
override fun TensorStructure<Double>.timesAssign(other: TensorStructure<Double>) {
val newOther = broadcastTo(other.tensor, tensor.shape)
for (i in 0 until tensor.linearStructure.size) {
tensor.buffer.array()[tensor.bufferStart + i] *=
newOther.buffer.array()[tensor.bufferStart + i]
}
}
override fun TensorStructure<Double>.div(other: TensorStructure<Double>): DoubleTensor {
val broadcast = broadcastTensors(tensor, other.tensor)
val newThis = broadcast[0]
val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.linearStructure.size) { i ->
newThis.buffer.array()[newOther.bufferStart + i] /
newOther.buffer.array()[newOther.bufferStart + i]
}
return DoubleTensor(newThis.shape, resBuffer)
}
override fun TensorStructure<Double>.divAssign(other: TensorStructure<Double>) {
val newOther = broadcastTo(other.tensor, tensor.shape)
for (i in 0 until tensor.linearStructure.size) {
tensor.buffer.array()[tensor.bufferStart + i] /=
newOther.buffer.array()[tensor.bufferStart + i]
}
}
}
public inline fun <R> BroadcastDoubleTensorAlgebra(block: BroadcastDoubleTensorAlgebra.() -> R): R =
BroadcastDoubleTensorAlgebra().block()
internal inline fun multiIndexBroadCasting(tensor: DoubleTensor, resTensor: DoubleTensor, linearSize: Int) {
for (linearIndex in 0 until linearSize) {
val totalMultiIndex = resTensor.linearStructure.index(linearIndex)
val curMultiIndex = tensor.shape.copyOf()
val offset = totalMultiIndex.size - curMultiIndex.size
for (i in curMultiIndex.indices) {
if (curMultiIndex[i] != 1) {
curMultiIndex[i] = totalMultiIndex[i + offset]
} else {
curMultiIndex[i] = 0
}
}
val curLinearIndex = tensor.linearStructure.offset(curMultiIndex)
resTensor.buffer.array()[linearIndex] =
tensor.buffer.array()[tensor.bufferStart + curLinearIndex]
}
}
internal inline fun broadcastShapes(vararg shapes: IntArray): IntArray {
var totalDim = 0
@ -129,24 +71,7 @@ internal inline fun broadcastTo(tensor: DoubleTensor, newShape: IntArray): Doubl
}
}
for (linearIndex in 0 until n) {
val totalMultiIndex = resTensor.linearStructure.index(linearIndex)
val curMultiIndex = tensor.shape.copyOf()
val offset = totalMultiIndex.size - curMultiIndex.size
for (i in curMultiIndex.indices) {
if (curMultiIndex[i] != 1) {
curMultiIndex[i] = totalMultiIndex[i + offset]
} else {
curMultiIndex[i] = 0
}
}
val curLinearIndex = tensor.linearStructure.offset(curMultiIndex)
resTensor.buffer.array()[linearIndex] =
tensor.buffer.array()[tensor.bufferStart + curLinearIndex]
}
multiIndexBroadCasting(tensor, resTensor, n)
return resTensor
}
@ -157,25 +82,7 @@ internal inline fun broadcastTensors(vararg tensors: DoubleTensor): List<DoubleT
val res = ArrayList<DoubleTensor>(0)
for (tensor in tensors) {
val resTensor = DoubleTensor(totalShape, DoubleArray(n))
for (linearIndex in 0 until n) {
val totalMultiIndex = resTensor.linearStructure.index(linearIndex)
val curMultiIndex = tensor.shape.copyOf()
val offset = totalMultiIndex.size - curMultiIndex.size
for (i in curMultiIndex.indices) {
if (curMultiIndex[i] != 1) {
curMultiIndex[i] = totalMultiIndex[i + offset]
} else {
curMultiIndex[i] = 0
}
}
val curLinearIndex = tensor.linearStructure.offset(curMultiIndex)
resTensor.buffer.array()[linearIndex] =
tensor.buffer.array()[tensor.bufferStart + curLinearIndex]
}
multiIndexBroadCasting(tensor, resTensor, n)
res.add(resTensor)
}
@ -221,10 +128,14 @@ internal inline fun broadcastOuterTensors(vararg tensors: DoubleTensor): List<Do
}
for (i in 0 until matrixSize) {
val curLinearIndex = newTensor.linearStructure.offset(curMultiIndex +
matrix.linearStructure.index(i))
val newLinearIndex = resTensor.linearStructure.offset(totalMultiIndex +
matrix.linearStructure.index(i))
val curLinearIndex = newTensor.linearStructure.offset(
curMultiIndex +
matrix.linearStructure.index(i)
)
val newLinearIndex = resTensor.linearStructure.offset(
totalMultiIndex +
matrix.linearStructure.index(i)
)
resTensor.buffer.array()[resTensor.bufferStart + newLinearIndex] =
newTensor.buffer.array()[newTensor.bufferStart + curLinearIndex]

View File

@ -1,6 +1,8 @@
package space.kscience.kmath.tensors.core
import space.kscience.kmath.tensors.TensorStructure
import space.kscience.kmath.tensors.api.TensorStructure
import space.kscience.kmath.tensors.core.algebras.DoubleLinearOpsTensorAlgebra
import space.kscience.kmath.tensors.core.algebras.DoubleTensorAlgebra
internal inline fun checkEmptyShape(shape: IntArray): Unit =

View File

@ -4,6 +4,8 @@ import space.kscience.kmath.nd.MutableStructure1D
import space.kscience.kmath.nd.MutableStructure2D
import space.kscience.kmath.nd.as1D
import space.kscience.kmath.nd.as2D
import space.kscience.kmath.tensors.core.algebras.DoubleAnalyticTensorAlgebra
import space.kscience.kmath.tensors.core.algebras.DoubleLinearOpsTensorAlgebra
import kotlin.math.abs
import kotlin.math.min
import kotlin.math.sign
@ -31,13 +33,13 @@ internal inline fun <T> BufferedTensor<T>.matrixSequence(): Sequence<BufferedTen
}
}
internal inline fun <T> BufferedTensor<T>.forEachVector(vectorAction: (BufferedTensor<T>) -> Unit): Unit {
internal inline fun <T> BufferedTensor<T>.forEachVector(vectorAction: (BufferedTensor<T>) -> Unit) {
for (vector in vectorSequence()) {
vectorAction(vector)
}
}
internal inline fun <T> BufferedTensor<T>.forEachMatrix(matrixAction: (BufferedTensor<T>) -> Unit): Unit {
internal inline fun <T> BufferedTensor<T>.forEachMatrix(matrixAction: (BufferedTensor<T>) -> Unit) {
for (matrix in matrixSequence()) {
matrixAction(matrix)
}
@ -284,7 +286,7 @@ internal inline fun DoubleLinearOpsTensorAlgebra.svdHelper(
matrix: DoubleTensor,
USV: Pair<BufferedTensor<Double>, Pair<BufferedTensor<Double>, BufferedTensor<Double>>>,
m: Int, n: Int, epsilon: Double
): Unit {
) {
val res = ArrayList<Triple<Double, DoubleTensor, DoubleTensor>>(0)
val (matrixU, SV) = USV
val (matrixS, matrixV) = SV
@ -332,7 +334,7 @@ internal inline fun DoubleLinearOpsTensorAlgebra.svdHelper(
}
}
internal inline fun cleanSymHelper(matrix: MutableStructure2D<Double>, n: Int): Unit {
internal inline fun cleanSymHelper(matrix: MutableStructure2D<Double>, n: Int) {
for (i in 0 until n)
for (j in 0 until n) {
if (i == j) {

View File

@ -110,7 +110,6 @@ internal inline fun DoubleTensor.toPrettyString(): String = buildString {
charOffset -=1
}
offset += vectorSize
// todo refactor
if (this@toPrettyString.numElements == offset) {
break
}

View File

@ -1,5 +1,6 @@
package space.kscience.kmath.tensors.core
import space.kscience.kmath.tensors.core.algebras.DoubleTensorAlgebra
import kotlin.test.Test
import kotlin.test.assertTrue

View File

@ -1,5 +1,6 @@
package space.kscience.kmath.tensors.core
import space.kscience.kmath.tensors.core.algebras.DoubleAnalyticTensorAlgebra
import kotlin.math.abs
import kotlin.math.exp
import kotlin.test.Test

View File

@ -1,5 +1,6 @@
package space.kscience.kmath.tensors.core
import space.kscience.kmath.tensors.core.algebras.DoubleLinearOpsTensorAlgebra
import kotlin.math.abs
import kotlin.test.Test
import kotlin.test.assertEquals
@ -125,8 +126,6 @@ class TestDoubleLinearOpsTensorAlgebra {
val (lu, pivots) = tensor.lu()
// todo check lu
val (p, l, u) = luPivot(lu, pivots)
assertTrue { p.shape contentEquals shape }

View File

@ -3,7 +3,7 @@ package space.kscience.kmath.tensors.core
import space.kscience.kmath.nd.as1D
import space.kscience.kmath.nd.as2D
import space.kscience.kmath.structures.toDoubleArray
import kotlin.test.Ignore
import space.kscience.kmath.tensors.core.algebras.DoubleTensorAlgebra
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertTrue

View File

@ -1,7 +1,9 @@
package space.kscience.kmath.tensors.core
import space.kscience.kmath.tensors.core.algebras.DoubleTensorAlgebra
import kotlin.test.Test
import kotlin.test.assertFalse
import kotlin.test.assertTrue
class TestDoubleTensorAlgebra {
@ -133,9 +135,9 @@ class TestDoubleTensorAlgebra {
assertTrue(diagonal1.buffer.array() contentEquals
doubleArrayOf(10.0, 0.0, 0.0, 0.0, 20.0, 0.0, 0.0, 0.0, 30.0))
val diagonal1_offset = diagonalEmbedding(tensor1, 1, 1, 0)
assertTrue(diagonal1_offset.shape contentEquals intArrayOf(4, 4))
assertTrue(diagonal1_offset.buffer.array() contentEquals
val diagonal1Offset = diagonalEmbedding(tensor1, 1, 1, 0)
assertTrue(diagonal1Offset.shape contentEquals intArrayOf(4, 4))
assertTrue(diagonal1Offset.buffer.array() contentEquals
doubleArrayOf(0.0, 0.0, 0.0, 0.0, 10.0, 0.0, 0.0, 0.0, 0.0, 20.0, 0.0, 0.0, 0.0, 0.0, 30.0, 0.0))
val diagonal2 = diagonalEmbedding(tensor2, 1, 0, 2)
@ -149,7 +151,15 @@ class TestDoubleTensorAlgebra {
}
@Test
fun testContentEqual() = DoubleTensorAlgebra {
//TODO()
fun testEq() = DoubleTensorAlgebra {
val tensor1 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
val tensor2 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
val tensor3 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 5.0))
val tensor4 = fromArray(intArrayOf(6, 1), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
assertTrue(tensor1 eq tensor1)
assertTrue(tensor1 eq tensor2)
assertFalse(tensor1.eq(tensor3))
}
}