v0.3.0-dev-9 #324
@ -34,26 +34,6 @@ internal inline fun <T, TensorType : TensorStructure<T>,
|
||||
}
|
||||
|
||||
|
||||
internal inline fun <T, TensorType : TensorStructure<T>,
|
||||
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
|
||||
TorchTensorAlgebraType.checkDot(a: TensorType, b: TensorType): Unit {
|
||||
val sa = a.shape
|
||||
val sb = b.shape
|
||||
val na = sa.size
|
||||
val nb = sb.size
|
||||
var status: Boolean
|
||||
if (nb == 1) {
|
||||
status = sa.last() == sb[0]
|
||||
} else {
|
||||
status = sa.last() == sb[nb - 2]
|
||||
if ((na > 2) and (nb > 2)) {
|
||||
status = status and
|
||||
(sa.take(nb - 2).toIntArray() contentEquals sb.take(nb - 2).toIntArray())
|
||||
}
|
||||
}
|
||||
check(status) { "Incompatible shapes ${sa.toList()} and ${sb.toList()} provided for dot product" }
|
||||
}
|
||||
|
||||
internal inline fun <T, TensorType : TensorStructure<T>,
|
||||
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
|
||||
TorchTensorAlgebraType.checkTranspose(dim: Int, i: Int, j: Int): Unit =
|
||||
|
Loading…
Reference in New Issue
Block a user