fix NDArray cast

This commit is contained in:
Alexander Nozik 2021-10-20 17:04:00 +03:00
parent cfd3f3b7e1
commit 47aeb36979

View File

@ -171,7 +171,7 @@ public class MultikTensorAlgebra<T : Number> internal constructor(
val mt = asMultik().array
return if (mt.shape.contentEquals(shape)) {
(this as MultikTensor<T>).array
mt
} else {
NDArray(mt.data, mt.offset, shape, dim = DN(shape.size), base = mt.base ?: mt)
}.wrap()
@ -186,7 +186,7 @@ public class MultikTensorAlgebra<T : Number> internal constructor(
).asDNArray().wrap()
} else if (this.shape.size == 2 && other.shape.size == 2) {
(asMultik().array.asD2Array() dot other.asMultik().array.asD2Array()).asDNArray().wrap()
} else if(this.shape.size == 2 && other.shape.size == 1) {
} else if (this.shape.size == 2 && other.shape.size == 1) {
(asMultik().array.asD2Array() dot other.asMultik().array.asD1Array()).asDNArray().wrap()
} else {
TODO("Not implemented for broadcasting")