tensors dimensions changed
This commit is contained in:
parent
13d063c1b1
commit
7749b72f24
@ -886,9 +886,9 @@ public open class DoubleTensorAlgebra :
|
|||||||
val size = tensor.dimension
|
val size = tensor.dimension
|
||||||
val commonShape = tensor.shape.sliceArray(0 until size - 2)
|
val commonShape = tensor.shape.sliceArray(0 until size - 2)
|
||||||
val (n, m) = tensor.shape.sliceArray(size - 2 until size)
|
val (n, m) = tensor.shape.sliceArray(size - 2 until size)
|
||||||
val uTensor = zeros(commonShape + intArrayOf(min(n, m), n))
|
val uTensor = zeros(commonShape + intArrayOf(m, n))
|
||||||
val sTensor = zeros(commonShape + intArrayOf(min(n, m)))
|
val sTensor = zeros(commonShape + intArrayOf(m))
|
||||||
val vTensor = zeros(commonShape + intArrayOf(min(n, m), m))
|
val vTensor = zeros(commonShape + intArrayOf(m, m))
|
||||||
|
|
||||||
val matrices = tensor.matrices
|
val matrices = tensor.matrices
|
||||||
val uTensors = uTensor.matrices
|
val uTensors = uTensor.matrices
|
||||||
|
Loading…
Reference in New Issue
Block a user