v0.3.0-dev-9 #324

Merged
altavir merged 265 commits from dev into master 2021-05-08 17:16:29 +03:00
Showing only changes of commit daa0777182 - Show all commits

View File

@ -34,26 +34,6 @@ internal inline fun <T, TensorType : TensorStructure<T>,
} }
internal inline fun <T, TensorType : TensorStructure<T>,
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
TorchTensorAlgebraType.checkDot(a: TensorType, b: TensorType): Unit {
val sa = a.shape
val sb = b.shape
val na = sa.size
val nb = sb.size
var status: Boolean
if (nb == 1) {
status = sa.last() == sb[0]
} else {
status = sa.last() == sb[nb - 2]
if ((na > 2) and (nb > 2)) {
status = status and
(sa.take(nb - 2).toIntArray() contentEquals sb.take(nb - 2).toIntArray())
}
}
check(status) { "Incompatible shapes ${sa.toList()} and ${sb.toList()} provided for dot product" }
}
internal inline fun <T, TensorType : TensorStructure<T>, internal inline fun <T, TensorType : TensorStructure<T>,
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>> TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
TorchTensorAlgebraType.checkTranspose(dim: Int, i: Int, j: Int): Unit = TorchTensorAlgebraType.checkTranspose(dim: Int, i: Int, j: Int): Unit =