fix NDArray cast

This commit is contained in:
Alexander Nozik 2021-10-20 17:04:00 +03:00
parent cfd3f3b7e1
commit 47aeb36979

View File

@ -171,7 +171,7 @@ public class MultikTensorAlgebra<T : Number> internal constructor(
val mt = asMultik().array val mt = asMultik().array
return if (mt.shape.contentEquals(shape)) { return if (mt.shape.contentEquals(shape)) {
(this as MultikTensor<T>).array mt
} else { } else {
NDArray(mt.data, mt.offset, shape, dim = DN(shape.size), base = mt.base ?: mt) NDArray(mt.data, mt.offset, shape, dim = DN(shape.size), base = mt.base ?: mt)
}.wrap() }.wrap()
@ -186,7 +186,7 @@ public class MultikTensorAlgebra<T : Number> internal constructor(
).asDNArray().wrap() ).asDNArray().wrap()
} else if (this.shape.size == 2 && other.shape.size == 2) { } else if (this.shape.size == 2 && other.shape.size == 2) {
(asMultik().array.asD2Array() dot other.asMultik().array.asD2Array()).asDNArray().wrap() (asMultik().array.asD2Array() dot other.asMultik().array.asD2Array()).asDNArray().wrap()
} else if(this.shape.size == 2 && other.shape.size == 1) { } else if (this.shape.size == 2 && other.shape.size == 1) {
(asMultik().array.asD2Array() dot other.asMultik().array.asD1Array()).asDNArray().wrap() (asMultik().array.asD2Array() dot other.asMultik().array.asD1Array()).asDNArray().wrap()
} else { } else {
TODO("Not implemented for broadcasting") TODO("Not implemented for broadcasting")