v0.3.0-dev-9 #324
@ -1,7 +1,7 @@
|
|||||||
package space.kscience.kmath.tensors
|
package space.kscience.kmath.tensors
|
||||||
|
|
||||||
|
|
||||||
public interface LinearOpsTensorAlgebra<T, TensorType : TensorStructure<T>, IndexTensorType: TensorStructure<Int>> :
|
public interface LinearOpsTensorAlgebra<T, TensorType : TensorStructure<T>, IndexTensorType : TensorStructure<Int>> :
|
||||||
TensorPartialDivisionAlgebra<T, TensorType> {
|
TensorPartialDivisionAlgebra<T, TensorType> {
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/linalg.html#torch.linalg.det
|
//https://pytorch.org/docs/stable/linalg.html#torch.linalg.det
|
||||||
@ -26,6 +26,6 @@ public interface LinearOpsTensorAlgebra<T, TensorType : TensorStructure<T>, Inde
|
|||||||
public fun TensorType.svd(): Triple<TensorType, TensorType, TensorType>
|
public fun TensorType.svd(): Triple<TensorType, TensorType, TensorType>
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.symeig.html
|
//https://pytorch.org/docs/stable/generated/torch.symeig.html
|
||||||
public fun TensorType.symEig(eigenvectors: Boolean = true): Pair<TensorType, TensorType>
|
public fun TensorType.symEig(): Pair<TensorType, TensorType>
|
||||||
|
|
||||||
}
|
}
|
@ -94,8 +94,10 @@ public class DoubleLinearOpsTensorAlgebra :
|
|||||||
return qTensor to rTensor
|
return qTensor to rTensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override fun DoubleTensor.svd(): Triple<DoubleTensor, DoubleTensor, DoubleTensor> =
|
||||||
|
svd(epsilon = 1e-10)
|
||||||
|
|
||||||
override fun DoubleTensor.svd(): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
|
public fun DoubleTensor.svd(epsilon: Double): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
|
||||||
val size = this.linearStructure.dim
|
val size = this.linearStructure.dim
|
||||||
val commonShape = this.shape.sliceArray(0 until size - 2)
|
val commonShape = this.shape.sliceArray(0 until size - 2)
|
||||||
val (n, m) = this.shape.sliceArray(size - 2 until size)
|
val (n, m) = this.shape.sliceArray(size - 2 until size)
|
||||||
@ -110,17 +112,20 @@ public class DoubleLinearOpsTensorAlgebra :
|
|||||||
matrix.shape,
|
matrix.shape,
|
||||||
matrix.buffer.array().slice(matrix.bufferStart until matrix.bufferStart + size).toDoubleArray()
|
matrix.buffer.array().slice(matrix.bufferStart until matrix.bufferStart + size).toDoubleArray()
|
||||||
)
|
)
|
||||||
svdHelper(curMatrix, USV, m, n)
|
svdHelper(curMatrix, USV, m, n, epsilon)
|
||||||
}
|
}
|
||||||
return Triple(resU.transpose(), resS, resV.transpose())
|
return Triple(resU.transpose(), resS, resV.transpose())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override fun DoubleTensor.symEig(): Pair<DoubleTensor, DoubleTensor> =
|
||||||
|
symEig(epsilon = 1e-15)
|
||||||
|
|
||||||
//http://hua-zhou.github.io/teaching/biostatm280-2017spring/slides/16-eigsvd/eigsvd.html
|
//http://hua-zhou.github.io/teaching/biostatm280-2017spring/slides/16-eigsvd/eigsvd.html
|
||||||
override fun DoubleTensor.symEig(eigenvectors: Boolean): Pair<DoubleTensor, DoubleTensor> {
|
public fun DoubleTensor.symEig(epsilon: Double): Pair<DoubleTensor, DoubleTensor> {
|
||||||
checkSymmetric(this)
|
checkSymmetric(this, epsilon)
|
||||||
val (u, s, v) = this.svd()
|
val (u, s, v) = this.svd(epsilon)
|
||||||
val shp = s.shape + intArrayOf(1)
|
val shp = s.shape + intArrayOf(1)
|
||||||
val utv = (u.transpose() dot v).map { if (abs(it) < 0.99) 0.0 else sign(it) }
|
val utv = (u.transpose() dot v).map { if (abs(it) < 0.9) 0.0 else sign(it) }
|
||||||
val eig = (utv dot s.view(shp)).view(s.shape)
|
val eig = (utv dot s.view(shp)).view(s.shape)
|
||||||
return Pair(eig, v)
|
return Pair(eig, v)
|
||||||
}
|
}
|
||||||
|
@ -58,7 +58,7 @@ internal inline fun <T, TensorType : TensorStructure<T>,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
internal inline fun DoubleTensorAlgebra.checkSymmetric(tensor: DoubleTensor): Unit =
|
internal inline fun DoubleTensorAlgebra.checkSymmetric(tensor: DoubleTensor, epsilon: Double = 1e-6): Unit =
|
||||||
check(tensor.eq(tensor.transpose())){
|
check(tensor.eq(tensor.transpose(), epsilon)) {
|
||||||
"Tensor is not symmetric about the last 2 dimensions"
|
"Tensor is not symmetric about the last 2 dimensions at precision $epsilon"
|
||||||
}
|
}
|
@ -246,7 +246,7 @@ internal inline fun DoubleLinearOpsTensorAlgebra.svd1d(a: DoubleTensor, epsilon:
|
|||||||
internal inline fun DoubleLinearOpsTensorAlgebra.svdHelper(
|
internal inline fun DoubleLinearOpsTensorAlgebra.svdHelper(
|
||||||
matrix: DoubleTensor,
|
matrix: DoubleTensor,
|
||||||
USV: Pair<BufferedTensor<Double>, Pair<BufferedTensor<Double>, BufferedTensor<Double>>>,
|
USV: Pair<BufferedTensor<Double>, Pair<BufferedTensor<Double>, BufferedTensor<Double>>>,
|
||||||
m: Int, n: Int
|
m: Int, n: Int, epsilon: Double
|
||||||
): Unit {
|
): Unit {
|
||||||
val res = ArrayList<Triple<Double, DoubleTensor, DoubleTensor>>(0)
|
val res = ArrayList<Triple<Double, DoubleTensor, DoubleTensor>>(0)
|
||||||
val (matrixU, SV) = USV
|
val (matrixU, SV) = USV
|
||||||
@ -267,12 +267,12 @@ internal inline fun DoubleLinearOpsTensorAlgebra.svdHelper(
|
|||||||
var u: DoubleTensor
|
var u: DoubleTensor
|
||||||
var norm: Double
|
var norm: Double
|
||||||
if (n > m) {
|
if (n > m) {
|
||||||
v = svd1d(a)
|
v = svd1d(a, epsilon)
|
||||||
u = matrix.dot(v)
|
u = matrix.dot(v)
|
||||||
norm = DoubleAnalyticTensorAlgebra { (u dot u).sqrt().value() }
|
norm = DoubleAnalyticTensorAlgebra { (u dot u).sqrt().value() }
|
||||||
u = u.times(1.0 / norm)
|
u = u.times(1.0 / norm)
|
||||||
} else {
|
} else {
|
||||||
u = svd1d(a)
|
u = svd1d(a, epsilon)
|
||||||
v = matrix.transpose(0, 1).dot(u)
|
v = matrix.transpose(0, 1).dot(u)
|
||||||
norm = DoubleAnalyticTensorAlgebra { (v dot v).sqrt().value() }
|
norm = DoubleAnalyticTensorAlgebra { (v dot v).sqrt().value() }
|
||||||
v = v.times(1.0 / norm)
|
v = v.times(1.0 / norm)
|
||||||
|
@ -143,7 +143,7 @@ class TestDoubleLinearOpsTensorAlgebra {
|
|||||||
val tensorSigma = tensor + tensor.transpose()
|
val tensorSigma = tensor + tensor.transpose()
|
||||||
val (tensorS, tensorV) = tensorSigma.symEig()
|
val (tensorS, tensorV) = tensorSigma.symEig()
|
||||||
val tensorSigmaCalc = tensorV dot (diagonalEmbedding(tensorS) dot tensorV.transpose())
|
val tensorSigmaCalc = tensorV dot (diagonalEmbedding(tensorS) dot tensorV.transpose())
|
||||||
assertTrue(tensorSigma.eq(tensorSigmaCalc, 0.01))
|
assertTrue(tensorSigma.eq(tensorSigmaCalc))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user