v0.3.0-dev-9 #324
@ -34,26 +34,6 @@ internal inline fun <T, TensorType : TensorStructure<T>,
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
internal inline fun <T, TensorType : TensorStructure<T>,
|
|
||||||
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
|
|
||||||
TorchTensorAlgebraType.checkDot(a: TensorType, b: TensorType): Unit {
|
|
||||||
val sa = a.shape
|
|
||||||
val sb = b.shape
|
|
||||||
val na = sa.size
|
|
||||||
val nb = sb.size
|
|
||||||
var status: Boolean
|
|
||||||
if (nb == 1) {
|
|
||||||
status = sa.last() == sb[0]
|
|
||||||
} else {
|
|
||||||
status = sa.last() == sb[nb - 2]
|
|
||||||
if ((na > 2) and (nb > 2)) {
|
|
||||||
status = status and
|
|
||||||
(sa.take(nb - 2).toIntArray() contentEquals sb.take(nb - 2).toIntArray())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
check(status) { "Incompatible shapes ${sa.toList()} and ${sb.toList()} provided for dot product" }
|
|
||||||
}
|
|
||||||
|
|
||||||
internal inline fun <T, TensorType : TensorStructure<T>,
|
internal inline fun <T, TensorType : TensorStructure<T>,
|
||||||
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
|
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
|
||||||
TorchTensorAlgebraType.checkTranspose(dim: Int, i: Int, j: Int): Unit =
|
TorchTensorAlgebraType.checkTranspose(dim: Int, i: Int, j: Int): Unit =
|
||||||
|
Loading…
Reference in New Issue
Block a user