v0.3.0-dev-9 #324

Merged
altavir merged 265 commits from dev into master 2021-05-08 17:16:29 +03:00
2 changed files with 14 additions and 2 deletions
Showing only changes of commit 1f19ac88ae - Show all commits

View File

@ -14,7 +14,7 @@ public interface LinearOpsTensorAlgebra<T, TensorType : TensorStructure<T>, Inde
public fun TensorType.cholesky(): TensorType
//https://pytorch.org/docs/stable/linalg.html#torch.linalg.qr
public fun TensorType.qr(): TensorType
public fun TensorType.qr(): Pair<TensorType, TensorType>
//https://pytorch.org/docs/stable/generated/torch.lu.html
public fun TensorType.lu(): Pair<TensorType, IndexTensorType>

View File

@ -1,8 +1,10 @@
package space.kscience.kmath.tensors.core
import space.kscience.kmath.linear.Matrix
import space.kscience.kmath.nd.MutableStructure2D
import space.kscience.kmath.nd.Structure1D
import space.kscience.kmath.nd.Structure2D
import space.kscience.kmath.nd.asND
import space.kscience.kmath.tensors.LinearOpsTensorAlgebra
import kotlin.math.sqrt
@ -141,7 +143,15 @@ public class DoubleLinearOpsTensorAlgebra :
return lTensor
}
override fun DoubleTensor.qr(): DoubleTensor {
private fun matrixQR(
matrix: Structure2D<Double>,
q: MutableStructure2D<Double>,
r: MutableStructure2D<Double>
) {
}
override fun DoubleTensor.qr(): Pair<DoubleTensor, DoubleTensor> {
TODO("ANDREI")
}
@ -154,6 +164,7 @@ public class DoubleLinearOpsTensorAlgebra :
}
private fun luMatrixDet(lu: Structure2D<Double>, pivots: Structure1D<Int>): Double {
// todo check
val m = lu.shape[0]
val sign = if((pivots[m] - m) % 2 == 0) 1.0 else -1.0
return (0 until m).asSequence().map { lu[it, it] }.fold(sign) { left, right -> left * right }
@ -184,6 +195,7 @@ public class DoubleLinearOpsTensorAlgebra :
pivots: Structure1D<Int>,
invMatrix : MutableStructure2D<Double>
): Unit {
//todo check
val m = lu.shape[0]
for (j in 0 until m) {