forked from kscience/kmath
fixes
This commit is contained in:
parent
1b1a078dea
commit
14ca7cdd31
@ -23,8 +23,8 @@ public interface LinearOpsTensorAlgebra<T> :
|
||||
|
||||
/**
|
||||
* Computes the multiplicative inverse matrix of a square matrix input, or of each square matrix in a batched input.
|
||||
* Given a square matrix `a`, return the matrix `aInv` satisfying
|
||||
* ``a.dot(aInv) = aInv.dot(a) = eye(a.shape[0])``.
|
||||
* Given a square matrix `A`, return the matrix `AInv` satisfying
|
||||
* `A dot AInv = AInv dot A = eye(a.shape[0])`.
|
||||
* For more information: https://pytorch.org/docs/stable/linalg.html#torch.linalg.inv
|
||||
*
|
||||
* @return the multiplicative inverse of a matrix.
|
||||
@ -37,7 +37,7 @@ public interface LinearOpsTensorAlgebra<T> :
|
||||
* Computes the Cholesky decomposition of a Hermitian (or symmetric for real-valued matrices)
|
||||
* positive-definite matrix or the Cholesky decompositions for a batch of such matrices.
|
||||
* Each decomposition has the form:
|
||||
* Given a tensor `input`, return the tensor `L` satisfying ``input = L * L.H``,
|
||||
* Given a tensor `input`, return the tensor `L` satisfying `input = L dot L.H`,
|
||||
* where L is a lower-triangular matrix and L.H is the conjugate transpose of L,
|
||||
* which is just a transpose for the case of real-valued input matrices.
|
||||
* For more information: https://pytorch.org/docs/stable/linalg.html#torch.linalg.cholesky
|
||||
@ -50,7 +50,7 @@ public interface LinearOpsTensorAlgebra<T> :
|
||||
* QR decomposition.
|
||||
*
|
||||
* Computes the QR decomposition of a matrix or a batch of matrices, and returns a pair `(Q, R)` of tensors.
|
||||
* Given a tensor `input`, return tensors (Q, R) satisfying ``input = Q * R``,
|
||||
* Given a tensor `input`, return tensors (Q, R) satisfying ``input = Q dot R``,
|
||||
* with `Q` being an orthogonal matrix or batch of orthogonal matrices
|
||||
* and `R` being an upper triangular matrix or batch of upper triangular matrices.
|
||||
* For more information: https://pytorch.org/docs/stable/linalg.html#torch.linalg.qr
|
||||
@ -63,7 +63,7 @@ public interface LinearOpsTensorAlgebra<T> :
|
||||
* LUP decomposition
|
||||
*
|
||||
* Computes the LUP decomposition of a matrix or a batch of matrices.
|
||||
* Given a tensor `input`, return tensors (P, L, U) satisfying ``P * input = L * U``,
|
||||
* Given a tensor `input`, return tensors (P, L, U) satisfying `P dot input = L dot U`,
|
||||
* with `P` being a permutation matrix or batch of matrices,
|
||||
* `L` being a lower triangular matrix or batch of matrices,
|
||||
* `U` being an upper triangular matrix or batch of matrices.
|
||||
@ -77,7 +77,8 @@ public interface LinearOpsTensorAlgebra<T> :
|
||||
*
|
||||
* Computes the singular value decomposition of either a matrix or batch of matrices `input`.
|
||||
* The singular value decomposition is represented as a triple `(U, S, V)`,
|
||||
* such that ``input = U.dot(diagonalEmbedding(S).dot(V.T))``.
|
||||
* such that `input = U dot diagonalEmbedding(S) dot V.H`,
|
||||
* where V.H is the conjugate transpose of V.
|
||||
* If input is a batch of tensors, then U, S, and Vh are also batched with the same batch dimensions as input.
|
||||
* For more information: https://pytorch.org/docs/stable/linalg.html#torch.linalg.svd
|
||||
*
|
||||
@ -94,4 +95,4 @@ public interface LinearOpsTensorAlgebra<T> :
|
||||
*/
|
||||
public fun Tensor<T>.symEig(): Pair<Tensor<T>, Tensor<T>>
|
||||
|
||||
}
|
||||
}
|
||||
|
@ -288,7 +288,7 @@ public interface TensorAlgebra<T>: Algebra<Tensor<T>> {
|
||||
public fun Tensor<T>.min(dim: Int, keepDim: Boolean): Tensor<T>
|
||||
|
||||
/**
|
||||
* @return the maximum value of all elements in the input tensor.
|
||||
* Returns the maximum value of all elements in the input tensor.
|
||||
*/
|
||||
public fun Tensor<T>.max(): T
|
||||
|
||||
|
@ -343,7 +343,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
|
||||
val m2 = newOther.shape[newOther.shape.size - 2]
|
||||
val n = newOther.shape[newOther.shape.size - 1]
|
||||
check(m1 == m2) {
|
||||
throw RuntimeException("Tensors dot operation dimension mismatch: ($l, $m1) x ($m2, $n)")
|
||||
"Tensors dot operation dimension mismatch: ($l, $m1) x ($m2, $n)"
|
||||
}
|
||||
|
||||
val resShape = newThis.shape.sliceArray(0..(newThis.shape.size - 2)) + intArrayOf(newOther.shape.last())
|
||||
@ -436,9 +436,8 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
|
||||
* @param epsilon permissible error when comparing two Double values.
|
||||
* @return true if two tensors have the same shape and elements, false otherwise.
|
||||
*/
|
||||
public fun Tensor<Double>.eq(other: Tensor<Double>, epsilon: Double): Boolean {
|
||||
return tensor.eq(other) { x, y -> abs(x - y) < epsilon }
|
||||
}
|
||||
public fun Tensor<Double>.eq(other: Tensor<Double>, epsilon: Double): Boolean =
|
||||
tensor.eq(other) { x, y -> abs(x - y) < epsilon }
|
||||
|
||||
/**
|
||||
* Compares element-wise two tensors.
|
||||
@ -510,7 +509,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
|
||||
}
|
||||
|
||||
/**
|
||||
* Build tensor from rows of input tensor
|
||||
* Builds tensor from rows of input tensor
|
||||
*
|
||||
* @param indices the [IntArray] of 1-dimensional indices
|
||||
* @return tensor with rows corresponding to rows by [indices]
|
||||
|
@ -23,7 +23,6 @@ internal fun stridesFromShape(shape: IntArray): IntArray {
|
||||
current--
|
||||
}
|
||||
return res
|
||||
|
||||
}
|
||||
|
||||
internal fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArray {
|
||||
@ -55,4 +54,4 @@ public class TensorLinearStructure(override val shape: IntArray) : Strides {
|
||||
override val linearSize: Int
|
||||
get() = shape.reduce(Int::times)
|
||||
|
||||
}
|
||||
}
|
||||
|
@ -33,4 +33,4 @@ public fun IntTensor.toIntArray(): IntArray {
|
||||
return IntArray(numElements) { i ->
|
||||
mutableBuffer[bufferStart + i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -39,4 +39,4 @@ internal val Tensor<Int>.tensor: IntTensor
|
||||
get() = when (this) {
|
||||
is IntTensor -> this
|
||||
else -> this.toBufferedTensor().asTensor()
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user