forked from kscience/kmath
fix bugs in function dot, add tests
This commit is contained in:
commit
c2f11fb6e1
@ -45,7 +45,7 @@ public class BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
|
|||||||
val newThis = broadcast[0]
|
val newThis = broadcast[0]
|
||||||
val newOther = broadcast[1]
|
val newOther = broadcast[1]
|
||||||
val resBuffer = DoubleArray(newThis.linearStructure.size) { i ->
|
val resBuffer = DoubleArray(newThis.linearStructure.size) { i ->
|
||||||
newThis.buffer.array()[newOther.bufferStart + i] *
|
newThis.buffer.array()[newThis.bufferStart + i] *
|
||||||
newOther.buffer.array()[newOther.bufferStart + i]
|
newOther.buffer.array()[newOther.bufferStart + i]
|
||||||
}
|
}
|
||||||
return DoubleTensor(newThis.shape, resBuffer)
|
return DoubleTensor(newThis.shape, resBuffer)
|
||||||
@ -182,17 +182,13 @@ internal inline fun broadcastTensors(vararg tensors: DoubleTensor): List<DoubleT
|
|||||||
}
|
}
|
||||||
|
|
||||||
internal inline fun broadcastOuterTensors(vararg tensors: DoubleTensor): List<DoubleTensor> {
|
internal inline fun broadcastOuterTensors(vararg tensors: DoubleTensor): List<DoubleTensor> {
|
||||||
var onlyTwoDims = true
|
val onlyTwoDims = tensors.asSequence().onEach {
|
||||||
for (tensor in tensors) {
|
require(it.shape.size >= 2) {
|
||||||
if (tensor.shape.size < 2) {
|
|
||||||
throw RuntimeException("Tensors must have at least 2 dimensions")
|
throw RuntimeException("Tensors must have at least 2 dimensions")
|
||||||
}
|
}
|
||||||
if (tensor.shape.size != 2) {
|
}.any { it.shape.size != 2 }
|
||||||
onlyTwoDims = false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (onlyTwoDims) {
|
if (!onlyTwoDims) {
|
||||||
return tensors.asList()
|
return tensors.asList()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -233,7 +229,7 @@ internal inline fun broadcastOuterTensors(vararg tensors: DoubleTensor): List<Do
|
|||||||
newTensor.buffer.array()[newTensor.bufferStart + curLinearIndex]
|
newTensor.buffer.array()[newTensor.bufferStart + curLinearIndex]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
res.add(resTensor)
|
res += resTensor
|
||||||
}
|
}
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
@ -34,7 +34,8 @@ public open class BufferedTensor<T>(
|
|||||||
|
|
||||||
public fun vectorSequence(): Sequence<MutableStructure1D<T>> = sequence {
|
public fun vectorSequence(): Sequence<MutableStructure1D<T>> = sequence {
|
||||||
check(shape.size >= 1) {"todo"}
|
check(shape.size >= 1) {"todo"}
|
||||||
val vectorOffset = linearStructure.strides[0]
|
val n = shape.size
|
||||||
|
val vectorOffset = shape[n - 1]
|
||||||
val vectorShape = intArrayOf(shape.last())
|
val vectorShape = intArrayOf(shape.last())
|
||||||
for (offset in 0 until numel step vectorOffset) {
|
for (offset in 0 until numel step vectorOffset) {
|
||||||
val vector = BufferedTensor<T>(vectorShape, buffer, offset).as1D()
|
val vector = BufferedTensor<T>(vectorShape, buffer, offset).as1D()
|
||||||
@ -44,8 +45,9 @@ public open class BufferedTensor<T>(
|
|||||||
|
|
||||||
public fun matrixSequence(): Sequence<MutableStructure2D<T>> = sequence {
|
public fun matrixSequence(): Sequence<MutableStructure2D<T>> = sequence {
|
||||||
check(shape.size >= 2) {"todo"}
|
check(shape.size >= 2) {"todo"}
|
||||||
val matrixOffset = linearStructure.strides[1]
|
val n = shape.size
|
||||||
val matrixShape = intArrayOf(shape[shape.size - 2], shape.last()) //todo better way?
|
val matrixOffset = shape[n - 1] * shape[n - 2]
|
||||||
|
val matrixShape = intArrayOf(shape[n - 2], shape[n - 1]) //todo better way?
|
||||||
for (offset in 0 until numel step matrixOffset) {
|
for (offset in 0 until numel step matrixOffset) {
|
||||||
val matrix = BufferedTensor<T>(matrixShape, buffer, offset).as2D()
|
val matrix = BufferedTensor<T>(matrixShape, buffer, offset).as2D()
|
||||||
yield(matrix)
|
yield(matrix)
|
||||||
|
@ -8,7 +8,7 @@ public class DoubleLinearOpsTensorAlgebra :
|
|||||||
DoubleTensorAlgebra() {
|
DoubleTensorAlgebra() {
|
||||||
|
|
||||||
override fun DoubleTensor.inv(): DoubleTensor {
|
override fun DoubleTensor.inv(): DoubleTensor {
|
||||||
TODO("Not yet implemented")
|
TODO("ANDREI")
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.lu(tol: Double): Pair<DoubleTensor, IntTensor> {
|
override fun DoubleTensor.lu(tol: Double): Pair<DoubleTensor, IntTensor> {
|
||||||
@ -135,16 +135,16 @@ public class DoubleLinearOpsTensorAlgebra :
|
|||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.qr(): DoubleTensor {
|
override fun DoubleTensor.qr(): DoubleTensor {
|
||||||
TODO("Not yet implemented")
|
TODO("ANDREI")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
override fun DoubleTensor.svd(): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
|
override fun DoubleTensor.svd(): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
|
||||||
TODO("Not yet implemented")
|
TODO("ALYA")
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.symEig(eigenvectors: Boolean): Pair<DoubleTensor, DoubleTensor> {
|
override fun DoubleTensor.symEig(eigenvectors: Boolean): Pair<DoubleTensor, DoubleTensor> {
|
||||||
TODO("Not yet implemented")
|
TODO("ANDREI")
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -241,37 +241,6 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
|||||||
TODO("Not yet implemented")
|
TODO("Not yet implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun DoubleTensor.dotTwoDimensionalTensors(other: DoubleTensor): DoubleTensor {
|
|
||||||
if (this.shape.size > 2 || other.shape.size > 2) {
|
|
||||||
throw RuntimeException("Both tensors must have a maximum of 2 dimensions")
|
|
||||||
}
|
|
||||||
|
|
||||||
if (this.shape[1] != other.shape[0]) {
|
|
||||||
throw RuntimeException("Tensors dot operation dimension mismatch: " +
|
|
||||||
"(${this.shape[0]}, ${this.shape[1]}) x (${other.shape[0]}, ${other.shape[1]})")
|
|
||||||
}
|
|
||||||
|
|
||||||
val l = this.shape[0]
|
|
||||||
val m = this.shape[1]
|
|
||||||
val n = other.shape[1]
|
|
||||||
|
|
||||||
val res = DoubleTensor(intArrayOf(l, n), DoubleArray(l * n))
|
|
||||||
|
|
||||||
for (i in 0 until l) {
|
|
||||||
for (j in 0 until n) {
|
|
||||||
var curr = 0.0
|
|
||||||
for (k in 0 until m) {
|
|
||||||
val ik = this.linearStructure.offset(intArrayOf(i, k))
|
|
||||||
val kj = other.linearStructure.offset(intArrayOf(k, j))
|
|
||||||
curr += this.buffer.array()[ik] * other.buffer.array()[kj]
|
|
||||||
}
|
|
||||||
val linearIndex = res.linearStructure.offset(intArrayOf(i, j))
|
|
||||||
res.buffer.array()[linearIndex] = curr
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun DoubleTensor.dot(other: DoubleTensor): DoubleTensor {
|
override fun DoubleTensor.dot(other: DoubleTensor): DoubleTensor {
|
||||||
if (this.shape.size == 1 && other.shape.size == 1) {
|
if (this.shape.size == 1 && other.shape.size == 1) {
|
||||||
return DoubleTensor(intArrayOf(1), doubleArrayOf(this.times(other).buffer.array().sum()))
|
return DoubleTensor(intArrayOf(1), doubleArrayOf(this.times(other).buffer.array().sum()))
|
||||||
@ -279,10 +248,15 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
|||||||
|
|
||||||
var newThis = this.copy()
|
var newThis = this.copy()
|
||||||
var newOther = other.copy()
|
var newOther = other.copy()
|
||||||
|
|
||||||
|
var penultimateDim = false
|
||||||
|
var lastDim = false
|
||||||
if (this.shape.size == 1) {
|
if (this.shape.size == 1) {
|
||||||
|
penultimateDim = true
|
||||||
newThis = this.view(intArrayOf(1) + this.shape)
|
newThis = this.view(intArrayOf(1) + this.shape)
|
||||||
}
|
}
|
||||||
if (other.shape.size == 1) {
|
if (other.shape.size == 1) {
|
||||||
|
lastDim = true
|
||||||
newOther = other.view(other.shape + intArrayOf(1) )
|
newOther = other.view(other.shape + intArrayOf(1) )
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -299,13 +273,12 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
|||||||
}
|
}
|
||||||
val m = m1
|
val m = m1
|
||||||
|
|
||||||
var resShape = newThis.shape.sliceArray(0..(newThis.shape.size - 2)) + intArrayOf(newOther.shape.last())
|
val resShape = newThis.shape.sliceArray(0..(newThis.shape.size - 2)) + intArrayOf(newOther.shape.last())
|
||||||
val resSize = resShape.reduce { acc, i -> acc * i }
|
val resSize = resShape.reduce { acc, i -> acc * i }
|
||||||
val resTensor = DoubleTensor(resShape, DoubleArray(resSize))
|
val resTensor = DoubleTensor(resShape, DoubleArray(resSize))
|
||||||
|
|
||||||
for ((res, ab) in resTensor.matrixSequence().zip(newThis.matrixSequence().zip(newOther.matrixSequence()))) {
|
for ((res, ab) in resTensor.matrixSequence().zip(newThis.matrixSequence().zip(newOther.matrixSequence()))) {
|
||||||
val a = ab.first
|
val (a, b) = ab
|
||||||
val b = ab.second
|
|
||||||
|
|
||||||
for (i in 0 until l) {
|
for (i in 0 until l) {
|
||||||
for (j in 0 until n) {
|
for (j in 0 until n) {
|
||||||
@ -318,6 +291,13 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (penultimateDim) {
|
||||||
|
return resTensor.view(resTensor.shape.dropLast(2).toIntArray() +
|
||||||
|
intArrayOf(resTensor.shape.last()))
|
||||||
|
}
|
||||||
|
if (lastDim) {
|
||||||
|
return resTensor.view(resTensor.shape.dropLast(1).toIntArray())
|
||||||
|
}
|
||||||
return resTensor
|
return resTensor
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -338,7 +318,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
|||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.det(): DoubleTensor {
|
override fun DoubleTensor.det(): DoubleTensor {
|
||||||
TODO("Not yet implemented")
|
TODO("ANDREI")
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.square(): DoubleTensor {
|
override fun DoubleTensor.square(): DoubleTensor {
|
||||||
|
@ -82,15 +82,37 @@ class TestDoubleTensorAlgebra {
|
|||||||
@Test
|
@Test
|
||||||
fun dot() = DoubleTensorAlgebra {
|
fun dot() = DoubleTensorAlgebra {
|
||||||
val tensor1 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
val tensor1 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||||
|
val tensor11 = fromArray(intArrayOf(3, 2), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||||
val tensor2 = fromArray(intArrayOf(3), doubleArrayOf(10.0, 20.0, 30.0))
|
val tensor2 = fromArray(intArrayOf(3), doubleArrayOf(10.0, 20.0, 30.0))
|
||||||
|
val tensor3 = fromArray(intArrayOf(1, 1, 3), doubleArrayOf(-1.0, -2.0, -3.0))
|
||||||
|
|
||||||
val res12 = tensor1.dot(tensor2)
|
val res12 = tensor1.dot(tensor2)
|
||||||
|
|
||||||
assertTrue(res12.buffer.array() contentEquals doubleArrayOf(140.0, 320.0))
|
assertTrue(res12.buffer.array() contentEquals doubleArrayOf(140.0, 320.0))
|
||||||
assertTrue(res12.shape contentEquals intArrayOf(2, 1))
|
assertTrue(res12.shape contentEquals intArrayOf(2))
|
||||||
|
|
||||||
val tensor4 = fromArray(intArrayOf(10, 3, 4), DoubleArray(10 * 3 * 4) {0.0})
|
val res32 = tensor3.dot(tensor2)
|
||||||
val tensor5 = fromArray(intArrayOf(10, 4, 5), DoubleArray(10 * 4 * 5) {0.0})
|
assertTrue(res32.buffer.array() contentEquals doubleArrayOf(-140.0))
|
||||||
|
assertTrue(res32.shape contentEquals intArrayOf(1, 1))
|
||||||
|
|
||||||
|
val res22 = tensor2.dot(tensor2)
|
||||||
|
assertTrue(res22.buffer.array() contentEquals doubleArrayOf(1400.0))
|
||||||
|
assertTrue(res22.shape contentEquals intArrayOf(1))
|
||||||
|
|
||||||
|
val res11 = tensor1.dot(tensor11)
|
||||||
|
assertTrue(res11.buffer.array() contentEquals doubleArrayOf(22.0, 28.0, 49.0, 64.0))
|
||||||
|
assertTrue(res11.shape contentEquals intArrayOf(2, 2))
|
||||||
|
|
||||||
|
var tensor4 = fromArray(intArrayOf(10, 3, 4), DoubleArray(10 * 3 * 4) {0.0})
|
||||||
|
var tensor5 = fromArray(intArrayOf(10, 4, 5), DoubleArray(10 * 4 * 5) {0.0})
|
||||||
assertTrue(tensor4.dot(tensor5).shape contentEquals intArrayOf(10, 3, 5))
|
assertTrue(tensor4.dot(tensor5).shape contentEquals intArrayOf(10, 3, 5))
|
||||||
|
|
||||||
|
tensor4 = fromArray(intArrayOf(10, 3, 4), DoubleArray(10 * 3 * 4) {0.0})
|
||||||
|
tensor5 = fromArray(intArrayOf(4, 5), DoubleArray(4 * 5) {0.0})
|
||||||
|
assertTrue(tensor4.dot(tensor5).shape contentEquals intArrayOf(10, 3, 5))
|
||||||
|
|
||||||
|
tensor4 = fromArray(intArrayOf(4, 2, 1, 3, 8, 1), DoubleArray(4 * 2 * 1 * 3 * 8 * 1) {0.0})
|
||||||
|
tensor5 = fromArray(intArrayOf(5, 1, 2, 8, 3, 1, 5), DoubleArray(5 * 1 * 2 * 8 * 3 * 1 * 5) {0.0})
|
||||||
|
assertTrue(tensor4.dot(tensor5).shape contentEquals intArrayOf(5, 4, 2, 8, 3, 8, 5))
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
Loading…
Reference in New Issue
Block a user