Fixing 2D and 1D casts

This commit is contained in:
Roland Grinis 2021-03-29 21:58:56 +01:00
parent 22b68e5ca4
commit 92710097f0
6 changed files with 173 additions and 248 deletions

View File

@ -2731,8 +2731,6 @@ public class space/kscience/kmath/tensors/core/BufferedTensor : space/kscience/k
public fun <init> ([ILspace/kscience/kmath/structures/MutableBuffer;I)V public fun <init> ([ILspace/kscience/kmath/structures/MutableBuffer;I)V
public fun elements ()Lkotlin/sequences/Sequence; public fun elements ()Lkotlin/sequences/Sequence;
public fun equals (Ljava/lang/Object;)Z public fun equals (Ljava/lang/Object;)Z
public final fun forEachMatrix (Lkotlin/jvm/functions/Function1;)V
public final fun forEachVector (Lkotlin/jvm/functions/Function1;)V
public fun get ([I)Ljava/lang/Object; public fun get ([I)Ljava/lang/Object;
public final fun getBuffer ()Lspace/kscience/kmath/structures/MutableBuffer; public final fun getBuffer ()Lspace/kscience/kmath/structures/MutableBuffer;
public fun getDimension ()I public fun getDimension ()I
@ -2740,37 +2738,7 @@ public class space/kscience/kmath/tensors/core/BufferedTensor : space/kscience/k
public final fun getNumel ()I public final fun getNumel ()I
public fun getShape ()[I public fun getShape ()[I
public fun hashCode ()I public fun hashCode ()I
public final fun matrixSequence ()Lkotlin/sequences/Sequence;
public fun set ([ILjava/lang/Object;)V public fun set ([ILjava/lang/Object;)V
public final fun vectorSequence ()Lkotlin/sequences/Sequence;
}
public final class space/kscience/kmath/tensors/core/BufferedTensor1D : space/kscience/kmath/tensors/core/BufferedTensor, space/kscience/kmath/nd/MutableStructure1D {
public fun copy ()Lspace/kscience/kmath/structures/MutableBuffer;
public fun get (I)Ljava/lang/Object;
public fun get ([I)Ljava/lang/Object;
public fun getDimension ()I
public fun getSize ()I
public fun iterator ()Ljava/util/Iterator;
public fun set (ILjava/lang/Object;)V
public fun set ([ILjava/lang/Object;)V
}
public final class space/kscience/kmath/tensors/core/BufferedTensor2D : space/kscience/kmath/tensors/core/BufferedTensor, space/kscience/kmath/nd/MutableStructure2D {
public fun elements ()Lkotlin/sequences/Sequence;
public fun get (II)Ljava/lang/Object;
public fun get ([I)Ljava/lang/Object;
public fun getColNum ()I
public fun getColumns ()Ljava/util/List;
public fun getRowNum ()I
public fun getRows ()Ljava/util/List;
public fun getShape ()[I
public fun set (IILjava/lang/Object;)V
}
public final class space/kscience/kmath/tensors/core/BufferedTensorKt {
public static final fun as1D (Lspace/kscience/kmath/tensors/core/BufferedTensor;)Lspace/kscience/kmath/tensors/core/BufferedTensor1D;
public static final fun as2D (Lspace/kscience/kmath/tensors/core/BufferedTensor;)Lspace/kscience/kmath/tensors/core/BufferedTensor2D;
} }
public final class space/kscience/kmath/tensors/core/DoubleAnalyticTensorAlgebra : space/kscience/kmath/tensors/core/DoubleTensorAlgebra, space/kscience/kmath/tensors/AnalyticTensorAlgebra { public final class space/kscience/kmath/tensors/core/DoubleAnalyticTensorAlgebra : space/kscience/kmath/tensors/core/DoubleTensorAlgebra, space/kscience/kmath/tensors/AnalyticTensorAlgebra {
@ -2876,7 +2844,6 @@ public class space/kscience/kmath/tensors/core/DoubleTensorAlgebra : space/kscie
public synthetic fun eq (Lspace/kscience/kmath/nd/MutableStructureND;Lspace/kscience/kmath/nd/MutableStructureND;Ljava/lang/Object;)Z public synthetic fun eq (Lspace/kscience/kmath/nd/MutableStructureND;Lspace/kscience/kmath/nd/MutableStructureND;Ljava/lang/Object;)Z
public final fun eq (Lspace/kscience/kmath/tensors/core/DoubleTensor;Lspace/kscience/kmath/tensors/core/DoubleTensor;)Z public final fun eq (Lspace/kscience/kmath/tensors/core/DoubleTensor;Lspace/kscience/kmath/tensors/core/DoubleTensor;)Z
public fun eq (Lspace/kscience/kmath/tensors/core/DoubleTensor;Lspace/kscience/kmath/tensors/core/DoubleTensor;D)Z public fun eq (Lspace/kscience/kmath/tensors/core/DoubleTensor;Lspace/kscience/kmath/tensors/core/DoubleTensor;D)Z
public final fun eq (Lspace/kscience/kmath/tensors/core/DoubleTensor;Lspace/kscience/kmath/tensors/core/DoubleTensor;Lkotlin/jvm/functions/Function2;)Z
public synthetic fun eye (I)Lspace/kscience/kmath/nd/MutableStructureND; public synthetic fun eye (I)Lspace/kscience/kmath/nd/MutableStructureND;
public fun eye (I)Lspace/kscience/kmath/tensors/core/DoubleTensor; public fun eye (I)Lspace/kscience/kmath/tensors/core/DoubleTensor;
public final fun fromArray ([I[D)Lspace/kscience/kmath/tensors/core/DoubleTensor; public final fun fromArray ([I[D)Lspace/kscience/kmath/tensors/core/DoubleTensor;

View File

@ -35,35 +35,34 @@ public open class BufferedTensor<T>(
override fun hashCode(): Int = 0 override fun hashCode(): Int = 0
public fun vectorSequence(): Sequence<BufferedTensor1D<T>> = sequence { internal fun vectorSequence(): Sequence<BufferedTensor<T>> = sequence {
check(shape.size >= 1) { "todo" }
val n = shape.size val n = shape.size
val vectorOffset = shape[n - 1] val vectorOffset = shape[n - 1]
val vectorShape = intArrayOf(shape.last()) val vectorShape = intArrayOf(shape.last())
for (offset in 0 until numel step vectorOffset) { for (offset in 0 until numel step vectorOffset) {
val vector = BufferedTensor<T>(vectorShape, buffer, offset).as1D() val vector = BufferedTensor(vectorShape, buffer, offset)
yield(vector) yield(vector)
} }
} }
public fun matrixSequence(): Sequence<BufferedTensor2D<T>> = sequence { internal fun matrixSequence(): Sequence<BufferedTensor<T>> = sequence {
check(shape.size >= 2) { "todo" } check(shape.size >= 2) { "todo" }
val n = shape.size val n = shape.size
val matrixOffset = shape[n - 1] * shape[n - 2] val matrixOffset = shape[n - 1] * shape[n - 2]
val matrixShape = intArrayOf(shape[n - 2], shape[n - 1]) //todo better way? val matrixShape = intArrayOf(shape[n - 2], shape[n - 1])
for (offset in 0 until numel step matrixOffset) { for (offset in 0 until numel step matrixOffset) {
val matrix = BufferedTensor<T>(matrixShape, buffer, offset).as2D() val matrix = BufferedTensor(matrixShape, buffer, offset)
yield(matrix) yield(matrix)
} }
} }
public inline fun forEachVector(vectorAction: (BufferedTensor1D<T>) -> Unit): Unit { internal inline fun forEachVector(vectorAction: (BufferedTensor<T>) -> Unit): Unit {
for (vector in vectorSequence()) { for (vector in vectorSequence()) {
vectorAction(vector) vectorAction(vector)
} }
} }
public inline fun forEachMatrix(matrixAction: (BufferedTensor2D<T>) -> Unit): Unit { internal inline fun forEachMatrix(matrixAction: (BufferedTensor<T>) -> Unit): Unit {
for (matrix in matrixSequence()) { for (matrix in matrixSequence()) {
matrixAction(matrix) matrixAction(matrix)
} }
@ -71,7 +70,6 @@ public open class BufferedTensor<T>(
} }
public class IntTensor internal constructor( public class IntTensor internal constructor(
shape: IntArray, shape: IntArray,
buffer: IntArray, buffer: IntArray,
@ -112,90 +110,7 @@ public class DoubleTensor internal constructor(
this(bufferedTensor.shape, bufferedTensor.buffer.array(), bufferedTensor.bufferStart) this(bufferedTensor.shape, bufferedTensor.buffer.array(), bufferedTensor.bufferStart)
} }
internal fun BufferedTensor<Int>.asTensor(): IntTensor = IntTensor(this)
public class BufferedTensor2D<T> internal constructor( internal fun BufferedTensor<Long>.asTensor(): LongTensor = LongTensor(this)
private val tensor: BufferedTensor<T>, internal fun BufferedTensor<Float>.asTensor(): FloatTensor = FloatTensor(this)
) : BufferedTensor<T>(tensor), MutableStructure2D<T> { internal fun BufferedTensor<Double>.asTensor(): DoubleTensor = DoubleTensor(this)
init {
check(shape.size == 2) {
"Shape ${shape.toList()} not compatible with DoubleTensor2D"
}
}
override val shape: IntArray
get() = tensor.shape
override val rowNum: Int
get() = shape[0]
override val colNum: Int
get() = shape[1]
override fun get(i: Int, j: Int): T = tensor[intArrayOf(i, j)]
override fun get(index: IntArray): T = tensor[index]
override fun elements(): Sequence<Pair<IntArray, T>> = tensor.elements()
override fun set(i: Int, j: Int, value: T) {
tensor[intArrayOf(i, j)] = value
}
override val rows: List<BufferedTensor1D<T>>
get() = List(rowNum) { i ->
BufferedTensor1D(
BufferedTensor(
shape = intArrayOf(colNum),
buffer = VirtualMutableBuffer(colNum) { j -> get(i, j) },
bufferStart = 0
)
)
}
override val columns: List<BufferedTensor1D<T>>
get() = List(colNum) { j ->
BufferedTensor1D(
BufferedTensor(
shape = intArrayOf(rowNum),
buffer = VirtualMutableBuffer(rowNum) { i -> get(i, j) },
bufferStart = 0
)
)
}
}
public class BufferedTensor1D<T> internal constructor(
private val tensor: BufferedTensor<T>
) : BufferedTensor<T>(tensor), MutableStructure1D<T> {
init {
check(shape.size == 1) {
"Shape ${shape.toList()} not compatible with DoubleTensor1D"
}
}
override fun get(index: IntArray): T = tensor[index]
override fun set(index: IntArray, value: T) {
tensor[index] = value
}
override val size: Int
get() = tensor.linearStructure.size
override fun get(index: Int): T = tensor[intArrayOf(index)]
override fun set(index: Int, value: T) {
tensor[intArrayOf(index)] = value
}
override fun copy(): MutableBuffer<T> = tensor.buffer.copy()
}
internal fun BufferedTensor<Int>.asIntTensor(): IntTensor = IntTensor(this)
internal fun BufferedTensor<Long>.asLongTensor(): LongTensor = LongTensor(this)
internal fun BufferedTensor<Float>.asFloatTensor(): FloatTensor = FloatTensor(this)
internal fun BufferedTensor<Double>.asDoubleTensor(): DoubleTensor = DoubleTensor(this)
public fun <T> BufferedTensor<T>.as2D(): BufferedTensor2D<T> = BufferedTensor2D(this)
public fun <T> BufferedTensor<T>.as1D(): BufferedTensor1D<T> = BufferedTensor1D(this)

View File

@ -1,5 +1,9 @@
package space.kscience.kmath.tensors.core package space.kscience.kmath.tensors.core
import space.kscience.kmath.nd.MutableStructure1D
import space.kscience.kmath.nd.MutableStructure2D
import space.kscience.kmath.nd.as1D
import space.kscience.kmath.nd.as2D
import space.kscience.kmath.tensors.LinearOpsTensorAlgebra import space.kscience.kmath.tensors.LinearOpsTensorAlgebra
import kotlin.math.sqrt import kotlin.math.sqrt
@ -11,6 +15,48 @@ public class DoubleLinearOpsTensorAlgebra :
override fun DoubleTensor.det(): DoubleTensor = detLU() override fun DoubleTensor.det(): DoubleTensor = detLU()
private inline fun luHelper(lu: MutableStructure2D<Double>, pivots: MutableStructure1D<Int>, m: Int) {
for (row in 0 until m) pivots[row] = row
for (i in 0 until m) {
var maxVal = -1.0
var maxInd = i
for (k in i until m) {
val absA = kotlin.math.abs(lu[k, i])
if (absA > maxVal) {
maxVal = absA
maxInd = k
}
}
//todo check singularity
if (maxInd != i) {
val j = pivots[i]
pivots[i] = pivots[maxInd]
pivots[maxInd] = j
for (k in 0 until m) {
val tmp = lu[i, k]
lu[i, k] = lu[maxInd, k]
lu[maxInd, k] = tmp
}
pivots[m] += 1
}
for (j in i + 1 until m) {
lu[j, i] /= lu[i, i]
for (k in i + 1 until m) {
lu[j, k] -= lu[j, i] * lu[i, k]
}
}
}
}
override fun DoubleTensor.lu(): Pair<DoubleTensor, IntTensor> { override fun DoubleTensor.lu(): Pair<DoubleTensor, IntTensor> {
checkSquareMatrix(shape) checkSquareMatrix(shape)
@ -27,90 +73,93 @@ public class DoubleLinearOpsTensorAlgebra :
IntArray(pivotsShape.reduce(Int::times)) { 0 } IntArray(pivotsShape.reduce(Int::times)) { 0 }
) )
for ((lu, pivots) in luTensor.matrixSequence().zip(pivotsTensor.vectorSequence())){ for ((lu, pivots) in luTensor.matrixSequence().zip(pivotsTensor.vectorSequence()))
for (row in 0 until m) pivots[row] = row luHelper(lu.as2D(), pivots.as1D(), m)
for (i in 0 until m) {
var maxVal = -1.0
var maxInd = i
for (k in i until m) {
val absA = kotlin.math.abs(lu[k, i])
if (absA > maxVal) {
maxVal = absA
maxInd = k
}
}
//todo check singularity
if (maxInd != i) {
val j = pivots[i]
pivots[i] = pivots[maxInd]
pivots[maxInd] = j
for (k in 0 until m) {
val tmp = lu[i, k]
lu[i, k] = lu[maxInd, k]
lu[maxInd, k] = tmp
}
pivots[m] += 1
}
for (j in i + 1 until m) {
lu[j, i] /= lu[i, i]
for (k in i + 1 until m) {
lu[j, k] -= lu[j, i] * lu[i, k]
}
}
}
}
return Pair(luTensor, pivotsTensor) return Pair(luTensor, pivotsTensor)
} }
override fun luPivot(luTensor: DoubleTensor, pivotsTensor: IntTensor): Triple<DoubleTensor, DoubleTensor, DoubleTensor> { private inline fun pivInit(
p: MutableStructure2D<Double>,
pivot: MutableStructure1D<Int>,
n: Int
) {
for (i in 0 until n) {
p[i, pivot[i]] = 1.0
}
}
private inline fun luPivotHelper(
l: MutableStructure2D<Double>,
u: MutableStructure2D<Double>,
lu: MutableStructure2D<Double>,
n: Int
) {
for (i in 0 until n) {
for (j in 0 until n) {
if (i == j) {
l[i, j] = 1.0
}
if (j < i) {
l[i, j] = lu[i, j]
}
if (j >= i) {
u[i, j] = lu[i, j]
}
}
}
}
override fun luPivot(
luTensor: DoubleTensor,
pivotsTensor: IntTensor
): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
//todo checks //todo checks
checkSquareMatrix(luTensor.shape) checkSquareMatrix(luTensor.shape)
check(luTensor.shape.dropLast(1).toIntArray() contentEquals pivotsTensor.shape) { "Bed shapes (("} //todo rewrite check(
luTensor.shape.dropLast(1).toIntArray() contentEquals pivotsTensor.shape
) { "Bed shapes ((" } //todo rewrite
val n = luTensor.shape.last() val n = luTensor.shape.last()
val pTensor = luTensor.zeroesLike() val pTensor = luTensor.zeroesLike()
for ((p, pivot) in pTensor.matrixSequence().zip(pivotsTensor.vectorSequence())){ for ((p, pivot) in pTensor.matrixSequence().zip(pivotsTensor.vectorSequence()))
for (i in 0 until n){ pivInit(p.as2D(), pivot.as1D(), n)
p[i, pivot[i]] = 1.0
}
}
val lTensor = luTensor.zeroesLike() val lTensor = luTensor.zeroesLike()
val uTensor = luTensor.zeroesLike() val uTensor = luTensor.zeroesLike()
for ((pairLU, lu) in lTensor.matrixSequence().zip(uTensor.matrixSequence()).zip(luTensor.matrixSequence())){ for ((pairLU, lu) in lTensor.matrixSequence().zip(uTensor.matrixSequence())
.zip(luTensor.matrixSequence())) {
val (l, u) = pairLU val (l, u) = pairLU
for (i in 0 until n) { luPivotHelper(l.as2D(), u.as2D(), lu.as2D(), n)
for (j in 0 until n) {
if (i == j) {
l[i, j] = 1.0
}
if (j < i) {
l[i, j] = lu[i, j]
}
if (j >= i) {
u[i, j] = lu[i, j]
}
}
}
} }
return Triple(pTensor, lTensor, uTensor) return Triple(pTensor, lTensor, uTensor)
} }
private inline fun choleskyHelper(
a: MutableStructure2D<Double>,
l: MutableStructure2D<Double>,
n: Int
) {
for (i in 0 until n) {
for (j in 0 until i) {
var h = a[i, j]
for (k in 0 until j) {
h -= l[i, k] * l[j, k]
}
l[i, j] = h / l[j, j]
}
var h = a[i, i]
for (j in 0 until i) {
h -= l[i, j] * l[i, j]
}
l[i, i] = sqrt(h)
}
}
override fun DoubleTensor.cholesky(): DoubleTensor { override fun DoubleTensor.cholesky(): DoubleTensor {
// todo checks // todo checks
checkSquareMatrix(shape) checkSquareMatrix(shape)
@ -118,22 +167,8 @@ public class DoubleLinearOpsTensorAlgebra :
val n = shape.last() val n = shape.last()
val lTensor = zeroesLike() val lTensor = zeroesLike()
for ((a, l) in this.matrixSequence().zip(lTensor.matrixSequence())) { for ((a, l) in this.matrixSequence().zip(lTensor.matrixSequence()))
for (i in 0 until n) { for (i in 0 until n) choleskyHelper(a.as2D(), l.as2D(), n)
for (j in 0 until i) {
var h = a[i, j]
for (k in 0 until j) {
h -= l[i, k] * l[j, k]
}
l[i, j] = h / l[j, j]
}
var h = a[i, i]
for (j in 0 until i) {
h -= l[i, j] * l[i, j]
}
l[i, i] = sqrt(h)
}
}
return lTensor return lTensor
} }
@ -150,9 +185,11 @@ public class DoubleLinearOpsTensorAlgebra :
TODO("ANDREI") TODO("ANDREI")
} }
private fun luMatrixDet(lu: BufferedTensor2D<Double>, pivots: BufferedTensor1D<Int>): Double { private fun luMatrixDet(luTensor: MutableStructure2D<Double>, pivotsTensor: MutableStructure1D<Int>): Double {
val lu = luTensor.as2D()
val pivots = pivotsTensor.as1D()
val m = lu.shape[0] val m = lu.shape[0]
val sign = if((pivots[m] - m) % 2 == 0) 1.0 else -1.0 val sign = if ((pivots[m] - m) % 2 == 0) 1.0 else -1.0
return (0 until m).asSequence().map { lu[it, it] }.fold(sign) { left, right -> left * right } return (0 until m).asSequence().map { lu[it, it] }.fold(sign) { left, right -> left * right }
} }
@ -162,34 +199,34 @@ public class DoubleLinearOpsTensorAlgebra :
val detTensorShape = IntArray(n - 1) { i -> shape[i] } val detTensorShape = IntArray(n - 1) { i -> shape[i] }
detTensorShape[n - 2] = 1 detTensorShape[n - 2] = 1
val resBuffer = DoubleArray(detTensorShape.reduce(Int::times)) { 0.0 } val resBuffer = DoubleArray(detTensorShape.reduce(Int::times)) { 0.0 }
val detTensor = DoubleTensor( val detTensor = DoubleTensor(
detTensorShape, detTensorShape,
resBuffer resBuffer
) )
luTensor.matrixSequence().zip(pivotsTensor.vectorSequence()).forEachIndexed { index, (luMatrix, pivots) -> luTensor.matrixSequence().zip(pivotsTensor.vectorSequence()).forEachIndexed { index, (lu, pivots) ->
resBuffer[index] = luMatrixDet(luMatrix, pivots) resBuffer[index] = luMatrixDet(lu.as2D(), pivots.as1D())
} }
return detTensor return detTensor
} }
private fun luMatrixInv( private fun luMatrixInv(
lu: BufferedTensor2D<Double>, lu: MutableStructure2D<Double>,
pivots: BufferedTensor1D<Int>, pivots: MutableStructure1D<Int>,
invMatrix : BufferedTensor2D<Double> invMatrix: MutableStructure2D<Double>
): Unit { ) {
val m = lu.shape[0] val m = lu.shape[0]
for (j in 0 until m) { for (j in 0 until m) {
for (i in 0 until m) { for (i in 0 until m) {
if (pivots[i] == j){ if (pivots[i] == j) {
invMatrix[i, j] = 1.0 invMatrix[i, j] = 1.0
} }
for (k in 0 until i){ for (k in 0 until i) {
invMatrix[i, j] -= lu[i, k] * invMatrix[k, j] invMatrix[i, j] -= lu[i, k] * invMatrix[k, j]
} }
} }
@ -205,13 +242,12 @@ public class DoubleLinearOpsTensorAlgebra :
public fun DoubleTensor.invLU(): DoubleTensor { public fun DoubleTensor.invLU(): DoubleTensor {
val (luTensor, pivotsTensor) = lu() val (luTensor, pivotsTensor) = lu()
val n = shape.size
val invTensor = luTensor.zeroesLike() val invTensor = luTensor.zeroesLike()
val seq = luTensor.matrixSequence().zip(pivotsTensor.vectorSequence()).zip(invTensor.matrixSequence()) val seq = luTensor.matrixSequence().zip(pivotsTensor.vectorSequence()).zip(invTensor.matrixSequence())
for ((luP, invMatrix) in seq) { for ((luP, invMatrix) in seq) {
val (lu, pivots) = luP val (lu, pivots) = luP
luMatrixInv(lu, pivots, invMatrix) luMatrixInv(lu.as2D(), pivots.as1D(), invMatrix.as2D())
} }
return invTensor return invTensor

View File

@ -1,8 +1,7 @@
package space.kscience.kmath.tensors.core package space.kscience.kmath.tensors.core
import space.kscience.kmath.linear.Matrix
import space.kscience.kmath.nd.MutableStructure2D import space.kscience.kmath.nd.MutableStructure2D
import space.kscience.kmath.nd.Structure2D import space.kscience.kmath.nd.as2D
import space.kscience.kmath.tensors.TensorPartialDivisionAlgebra import space.kscience.kmath.tensors.TensorPartialDivisionAlgebra
import kotlin.math.abs import kotlin.math.abs
@ -224,6 +223,23 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
return this.view(other.shape) return this.view(other.shape)
} }
private inline fun dotHelper(
a: MutableStructure2D<Double>,
b: MutableStructure2D<Double>,
res: MutableStructure2D<Double>,
l: Int, m: Int, n: Int
) {
for (i in 0 until l) {
for (j in 0 until n) {
var curr = 0.0
for (k in 0 until m) {
curr += a[i, k] * b[k, j]
}
res[i, j] = curr
}
}
}
override fun DoubleTensor.dot(other: DoubleTensor): DoubleTensor { override fun DoubleTensor.dot(other: DoubleTensor): DoubleTensor {
if (this.shape.size == 1 && other.shape.size == 1) { if (this.shape.size == 1 && other.shape.size == 1) {
return DoubleTensor(intArrayOf(1), doubleArrayOf(this.times(other).buffer.array().sum())) return DoubleTensor(intArrayOf(1), doubleArrayOf(this.times(other).buffer.array().sum()))
@ -240,7 +256,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
} }
if (other.shape.size == 1) { if (other.shape.size == 1) {
lastDim = true lastDim = true
newOther = other.view(other.shape + intArrayOf(1) ) newOther = other.view(other.shape + intArrayOf(1))
} }
val broadcastTensors = broadcastOuterTensors(newThis, newOther) val broadcastTensors = broadcastOuterTensors(newThis, newOther)
@ -248,7 +264,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
newOther = broadcastTensors[1] newOther = broadcastTensors[1]
val l = newThis.shape[newThis.shape.size - 2] val l = newThis.shape[newThis.shape.size - 2]
val m1= newThis.shape[newThis.shape.size - 1] val m1 = newThis.shape[newThis.shape.size - 1]
val m2 = newOther.shape[newOther.shape.size - 2] val m2 = newOther.shape[newOther.shape.size - 2]
val n = newOther.shape[newOther.shape.size - 1] val n = newOther.shape[newOther.shape.size - 1]
if (m1 != m2) { if (m1 != m2) {
@ -262,21 +278,14 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
for ((res, ab) in resTensor.matrixSequence().zip(newThis.matrixSequence().zip(newOther.matrixSequence()))) { for ((res, ab) in resTensor.matrixSequence().zip(newThis.matrixSequence().zip(newOther.matrixSequence()))) {
val (a, b) = ab val (a, b) = ab
dotHelper(a.as2D(), b.as2D(), res.as2D(), l, m, n)
for (i in 0 until l) {
for (j in 0 until n) {
var curr = 0.0
for (k in 0 until m) {
curr += a[i, k] * b[k, j]
}
res[i, j] = curr
}
}
} }
if (penultimateDim) { if (penultimateDim) {
return resTensor.view(resTensor.shape.dropLast(2).toIntArray() + return resTensor.view(
intArrayOf(resTensor.shape.last())) resTensor.shape.dropLast(2).toIntArray() +
intArrayOf(resTensor.shape.last())
)
} }
if (lastDim) { if (lastDim) {
return resTensor.view(resTensor.shape.dropLast(1).toIntArray()) return resTensor.view(resTensor.shape.dropLast(1).toIntArray())
@ -307,15 +316,11 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
public fun DoubleTensor.eq(other: DoubleTensor): Boolean = this.eq(other, 1e-5) public fun DoubleTensor.eq(other: DoubleTensor): Boolean = this.eq(other, 1e-5)
public fun DoubleTensor.contentEquals(other: DoubleTensor, eqFunction: (Double, Double) -> Boolean): Boolean { public fun DoubleTensor.contentEquals(other: DoubleTensor, eqFunction: (Double, Double) -> Boolean): Boolean =
if (!(this.shape contentEquals other.shape)) { this.eq(other, eqFunction)
return false
}
return this.eq(other, eqFunction)
}
public fun DoubleTensor.eq(other: DoubleTensor, eqFunction: (Double, Double) -> Boolean): Boolean { private fun DoubleTensor.eq(other: DoubleTensor, eqFunction: (Double, Double) -> Boolean): Boolean {
// todo broadcasting checking checkShapesCompatible(this, other)
val n = this.linearStructure.size val n = this.linearStructure.size
if (n != other.linearStructure.size) { if (n != other.linearStructure.size) {
return false return false

View File

@ -11,14 +11,14 @@ internal inline fun <T, TensorType : TensorStructure<T>,
"Illegal empty shape provided" "Illegal empty shape provided"
} }
internal inline fun < TensorType : TensorStructure<Double>, internal inline fun <TensorType : TensorStructure<Double>,
TorchTensorAlgebraType : TensorAlgebra<Double, TensorType>> TorchTensorAlgebraType : TensorAlgebra<Double, TensorType>>
TorchTensorAlgebraType.checkEmptyDoubleBuffer(buffer: DoubleArray): Unit = TorchTensorAlgebraType.checkEmptyDoubleBuffer(buffer: DoubleArray): Unit =
check(buffer.isNotEmpty()) { check(buffer.isNotEmpty()) {
"Illegal empty buffer provided" "Illegal empty buffer provided"
} }
internal inline fun < TensorType : TensorStructure<Double>, internal inline fun <TensorType : TensorStructure<Double>,
TorchTensorAlgebraType : TensorAlgebra<Double, TensorType>> TorchTensorAlgebraType : TensorAlgebra<Double, TensorType>>
TorchTensorAlgebraType.checkBufferShapeConsistency(shape: IntArray, buffer: DoubleArray): Unit = TorchTensorAlgebraType.checkBufferShapeConsistency(shape: IntArray, buffer: DoubleArray): Unit =
check(buffer.size == shape.reduce(Int::times)) { check(buffer.size == shape.reduce(Int::times)) {
@ -56,4 +56,4 @@ internal inline fun <T, TensorType : TensorStructure<T>,
check(shape[n - 1] == shape[n - 2]) { check(shape[n - 1] == shape[n - 2]) {
"Tensor must be batches of square matrices, but they are ${shape[n - 1]} by ${shape[n - 1]} matrices" "Tensor must be batches of square matrices, but they are ${shape[n - 1]} by ${shape[n - 1]} matrices"
} }
} }

View File

@ -1,5 +1,7 @@
package space.kscience.kmath.tensors.core package space.kscience.kmath.tensors.core
import space.kscience.kmath.nd.as1D
import space.kscience.kmath.nd.as2D
import space.kscience.kmath.structures.toDoubleArray import space.kscience.kmath.structures.toDoubleArray
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals