fix NDArray cast
This commit is contained in:
parent
cfd3f3b7e1
commit
47aeb36979
@ -171,7 +171,7 @@ public class MultikTensorAlgebra<T : Number> internal constructor(
|
|||||||
|
|
||||||
val mt = asMultik().array
|
val mt = asMultik().array
|
||||||
return if (mt.shape.contentEquals(shape)) {
|
return if (mt.shape.contentEquals(shape)) {
|
||||||
(this as MultikTensor<T>).array
|
mt
|
||||||
} else {
|
} else {
|
||||||
NDArray(mt.data, mt.offset, shape, dim = DN(shape.size), base = mt.base ?: mt)
|
NDArray(mt.data, mt.offset, shape, dim = DN(shape.size), base = mt.base ?: mt)
|
||||||
}.wrap()
|
}.wrap()
|
||||||
@ -186,7 +186,7 @@ public class MultikTensorAlgebra<T : Number> internal constructor(
|
|||||||
).asDNArray().wrap()
|
).asDNArray().wrap()
|
||||||
} else if (this.shape.size == 2 && other.shape.size == 2) {
|
} else if (this.shape.size == 2 && other.shape.size == 2) {
|
||||||
(asMultik().array.asD2Array() dot other.asMultik().array.asD2Array()).asDNArray().wrap()
|
(asMultik().array.asD2Array() dot other.asMultik().array.asD2Array()).asDNArray().wrap()
|
||||||
} else if(this.shape.size == 2 && other.shape.size == 1) {
|
} else if (this.shape.size == 2 && other.shape.size == 1) {
|
||||||
(asMultik().array.asD2Array() dot other.asMultik().array.asD1Array()).asDNArray().wrap()
|
(asMultik().array.asD2Array() dot other.asMultik().array.asD1Array()).asDNArray().wrap()
|
||||||
} else {
|
} else {
|
||||||
TODO("Not implemented for broadcasting")
|
TODO("Not implemented for broadcasting")
|
||||||
|
Loading…
Reference in New Issue
Block a user