refactor #307

Merged
AndreiKingsley merged 1 commits from andrew into feature/tensor-algebra 2021-05-01 18:34:30 +03:00
6 changed files with 32 additions and 33 deletions
Showing only changes of commit bfba653904 - Show all commits

View File

@ -21,7 +21,7 @@ fun main() {
// work in context with linear operations // work in context with linear operations
DoubleLinearOpsTensorAlgebra.invoke { DoubleLinearOpsTensorAlgebra.invoke {
// take coefficient vector from normal distribution // take coefficient vector from normal distribution
val alpha = randNormal( val alpha = randomNormal(
intArrayOf(5), intArrayOf(5),
randSeed randSeed
) + fromArray( ) + fromArray(
@ -32,14 +32,14 @@ fun main() {
println("Real alpha:\n$alpha") println("Real alpha:\n$alpha")
// also take sample of size 20 from normal distribution for x // also take sample of size 20 from normal distribution for x
val x = randNormal( val x = randomNormal(
intArrayOf(20, 5), intArrayOf(20, 5),
randSeed randSeed
) )
// calculate y and add gaussian noise (N(0, 0.05)) // calculate y and add gaussian noise (N(0, 0.05))
val y = x dot alpha val y = x dot alpha
y += y.randNormalLike(randSeed) * 0.05 y += y.randomNormalLike(randSeed) * 0.05
// now restore the coefficient vector with OSL estimator with SVD // now restore the coefficient vector with OSL estimator with SVD
val (u, singValues, v) = x.svd() val (u, singValues, v) = x.svd()

View File

@ -31,7 +31,7 @@ public object DoubleLinearOpsTensorAlgebra :
public fun TensorStructure<Double>.luFactor(epsilon: Double): Pair<DoubleTensor, IntTensor> = public fun TensorStructure<Double>.luFactor(epsilon: Double): Pair<DoubleTensor, IntTensor> =
computeLU(tensor, epsilon) computeLU(tensor, epsilon)
?: throw RuntimeException("Tensor contains matrices which are singular at precision $epsilon") ?: throw IllegalArgumentException("Tensor contains matrices which are singular at precision $epsilon")
public fun TensorStructure<Double>.luFactor(): Pair<DoubleTensor, IntTensor> = luFactor(1e-9) public fun TensorStructure<Double>.luFactor(): Pair<DoubleTensor, IntTensor> = luFactor(1e-9)
@ -47,8 +47,10 @@ public object DoubleLinearOpsTensorAlgebra :
val n = luTensor.shape.last() val n = luTensor.shape.last()
val pTensor = luTensor.zeroesLike() val pTensor = luTensor.zeroesLike()
for ((p, pivot) in pTensor.matrixSequence().zip(pivotsTensor.tensor.vectorSequence())) pTensor
pivInit(p.as2D(), pivot.as1D(), n) .matrixSequence()
.zip(pivotsTensor.tensor.vectorSequence())
.forEach { (p, pivot) -> pivInit(p.as2D(), pivot.as1D(), n) }
val lTensor = luTensor.zeroesLike() val lTensor = luTensor.zeroesLike()
val uTensor = luTensor.zeroesLike() val uTensor = luTensor.zeroesLike()

View File

@ -284,7 +284,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
val m1 = newThis.shape[newThis.shape.size - 1] val m1 = newThis.shape[newThis.shape.size - 1]
val m2 = newOther.shape[newOther.shape.size - 2] val m2 = newOther.shape[newOther.shape.size - 2]
val n = newOther.shape[newOther.shape.size - 1] val n = newOther.shape[newOther.shape.size - 1]
if (m1 != m2) { check (m1 == m2) {
throw RuntimeException("Tensors dot operation dimension mismatch: ($l, $m1) x ($m2, $n)") throw RuntimeException("Tensors dot operation dimension mismatch: ($l, $m1) x ($m2, $n)")
} }
@ -315,11 +315,11 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
val d1 = minusIndexFrom(n + 1, dim1) val d1 = minusIndexFrom(n + 1, dim1)
val d2 = minusIndexFrom(n + 1, dim2) val d2 = minusIndexFrom(n + 1, dim2)
if (d1 == d2) { check(d1 != d2) {
throw RuntimeException("Diagonal dimensions cannot be identical $d1, $d2") "Diagonal dimensions cannot be identical $d1, $d2"
} }
if (d1 > n || d2 > n) { check(d1 <= n && d2 <= n) {
throw RuntimeException("Dimension out of range") "Dimension out of range"
} }
var lessDim = d1 var lessDim = d1
@ -366,8 +366,8 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
) )
} }
public fun TensorStructure<Double>.eq(other: TensorStructure<Double>, delta: Double): Boolean { public fun TensorStructure<Double>.eq(other: TensorStructure<Double>, epsilon: Double): Boolean {
return tensor.eq(other) { x, y -> abs(x - y) < delta } return tensor.eq(other) { x, y -> abs(x - y) < epsilon }
} }
public infix fun TensorStructure<Double>.eq(other: TensorStructure<Double>): Boolean = tensor.eq(other, 1e-5) public infix fun TensorStructure<Double>.eq(other: TensorStructure<Double>): Boolean = tensor.eq(other, 1e-5)
@ -393,10 +393,10 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
return true return true
} }
public fun randNormal(shape: IntArray, seed: Long = 0): DoubleTensor = public fun randomNormal(shape: IntArray, seed: Long = 0): DoubleTensor =
DoubleTensor(shape, getRandomNormals(shape.reduce(Int::times), seed)) DoubleTensor(shape, getRandomNormals(shape.reduce(Int::times), seed))
public fun TensorStructure<Double>.randNormalLike(seed: Long = 0): DoubleTensor = public fun TensorStructure<Double>.randomNormalLike(seed: Long = 0): DoubleTensor =
DoubleTensor(tensor.shape, getRandomNormals(tensor.shape.reduce(Int::times), seed)) DoubleTensor(tensor.shape, getRandomNormals(tensor.shape.reduce(Int::times), seed))
// stack tensors by axis 0 // stack tensors by axis 0

View File

@ -42,8 +42,8 @@ internal fun broadcastShapes(vararg shapes: IntArray): IntArray {
for (i in shape.indices) { for (i in shape.indices) {
val curDim = shape[i] val curDim = shape[i]
val offset = totalDim - shape.size val offset = totalDim - shape.size
if (curDim != 1 && totalShape[i + offset] != curDim) { check(curDim == 1 || totalShape[i + offset] == curDim) {
throw RuntimeException("Shapes are not compatible and cannot be broadcast") "Shapes are not compatible and cannot be broadcast"
} }
} }
} }
@ -52,8 +52,8 @@ internal fun broadcastShapes(vararg shapes: IntArray): IntArray {
} }
internal fun broadcastTo(tensor: DoubleTensor, newShape: IntArray): DoubleTensor { internal fun broadcastTo(tensor: DoubleTensor, newShape: IntArray): DoubleTensor {
if (tensor.shape.size > newShape.size) { require(tensor.shape.size <= newShape.size) {
throw RuntimeException("Tensor is not compatible with the new shape") "Tensor is not compatible with the new shape"
} }
val n = newShape.reduce { acc, i -> acc * i } val n = newShape.reduce { acc, i -> acc * i }
@ -62,8 +62,8 @@ internal fun broadcastTo(tensor: DoubleTensor, newShape: IntArray): DoubleTensor
for (i in tensor.shape.indices) { for (i in tensor.shape.indices) {
val curDim = tensor.shape[i] val curDim = tensor.shape[i]
val offset = newShape.size - tensor.shape.size val offset = newShape.size - tensor.shape.size
if (curDim != 1 && newShape[i + offset] != curDim) { check(curDim == 1 || newShape[i + offset] == curDim) {
throw RuntimeException("Tensor is not compatible with the new shape and cannot be broadcast") "Tensor is not compatible with the new shape and cannot be broadcast"
} }
} }
@ -75,19 +75,17 @@ internal fun broadcastTensors(vararg tensors: DoubleTensor): List<DoubleTensor>
val totalShape = broadcastShapes(*(tensors.map { it.shape }).toTypedArray()) val totalShape = broadcastShapes(*(tensors.map { it.shape }).toTypedArray())
val n = totalShape.reduce { acc, i -> acc * i } val n = totalShape.reduce { acc, i -> acc * i }
return buildList { return tensors.map { tensor ->
for (tensor in tensors) {
val resTensor = DoubleTensor(totalShape, DoubleArray(n)) val resTensor = DoubleTensor(totalShape, DoubleArray(n))
multiIndexBroadCasting(tensor, resTensor, n) multiIndexBroadCasting(tensor, resTensor, n)
add(resTensor) resTensor
}
} }
} }
internal fun broadcastOuterTensors(vararg tensors: DoubleTensor): List<DoubleTensor> { internal fun broadcastOuterTensors(vararg tensors: DoubleTensor): List<DoubleTensor> {
val onlyTwoDims = tensors.asSequence().onEach { val onlyTwoDims = tensors.asSequence().onEach {
require(it.shape.size >= 2) { require(it.shape.size >= 2) {
throw RuntimeException("Tensors must have at least 2 dimensions") "Tensors must have at least 2 dimensions"
} }
}.any { it.shape.size != 2 } }.any { it.shape.size != 2 }

View File

@ -69,8 +69,7 @@ internal fun DoubleTensor.toPrettyString(): String = buildString {
val shape = this@toPrettyString.shape val shape = this@toPrettyString.shape
val linearStructure = this@toPrettyString.linearStructure val linearStructure = this@toPrettyString.linearStructure
val vectorSize = shape.last() val vectorSize = shape.last()
val initString = "DoubleTensor(\n" append("DoubleTensor(\n")
append(initString)
var charOffset = 3 var charOffset = 3
for (vector in vectorSequence()) { for (vector in vectorSequence()) {
repeat(charOffset) { append(' ') } repeat(charOffset) { append(' ') }

View File

@ -135,7 +135,7 @@ internal class TestDoubleLinearOpsTensorAlgebra {
@Test @Test
fun testCholesky() = DoubleLinearOpsTensorAlgebra.invoke { fun testCholesky() = DoubleLinearOpsTensorAlgebra.invoke {
val tensor = randNormal(intArrayOf(2, 5, 5), 0) val tensor = randomNormal(intArrayOf(2, 5, 5), 0)
val sigma = (tensor dot tensor.transpose()) + diagonalEmbedding( val sigma = (tensor dot tensor.transpose()) + diagonalEmbedding(
fromArray(intArrayOf(2, 5), DoubleArray(10) { 0.1 }) fromArray(intArrayOf(2, 5), DoubleArray(10) { 0.1 })
) )
@ -163,7 +163,7 @@ internal class TestDoubleLinearOpsTensorAlgebra {
@Test @Test
fun testBatchedSVD() = DoubleLinearOpsTensorAlgebra.invoke { fun testBatchedSVD() = DoubleLinearOpsTensorAlgebra.invoke {
val tensor = randNormal(intArrayOf(2, 5, 3), 0) val tensor = randomNormal(intArrayOf(2, 5, 3), 0)
val (tensorU, tensorS, tensorV) = tensor.svd() val (tensorU, tensorS, tensorV) = tensor.svd()
val tensorSVD = tensorU dot (diagonalEmbedding(tensorS) dot tensorV.transpose()) val tensorSVD = tensorU dot (diagonalEmbedding(tensorS) dot tensorV.transpose())
assertTrue(tensor.eq(tensorSVD)) assertTrue(tensor.eq(tensorSVD))
@ -171,7 +171,7 @@ internal class TestDoubleLinearOpsTensorAlgebra {
@Test @Test
fun testBatchedSymEig() = DoubleLinearOpsTensorAlgebra.invoke { fun testBatchedSymEig() = DoubleLinearOpsTensorAlgebra.invoke {
val tensor = randNormal(shape = intArrayOf(2, 3, 3), 0) val tensor = randomNormal(shape = intArrayOf(2, 3, 3), 0)
val tensorSigma = tensor + tensor.transpose() val tensorSigma = tensor + tensor.transpose()
val (tensorS, tensorV) = tensorSigma.symEig() val (tensorS, tensorV) = tensorSigma.symEig()
val tensorSigmaCalc = tensorV dot (diagonalEmbedding(tensorS) dot tensorV.transpose()) val tensorSigmaCalc = tensorV dot (diagonalEmbedding(tensorS) dot tensorV.transpose())