v0.3.0-dev-18 #459

Merged
altavir merged 64 commits from dev into master 2022-02-13 17:50:34 +03:00
Showing only changes of commit 47aeb36979 - Show all commits

View File

@ -171,7 +171,7 @@ public class MultikTensorAlgebra<T : Number> internal constructor(
val mt = asMultik().array val mt = asMultik().array
return if (mt.shape.contentEquals(shape)) { return if (mt.shape.contentEquals(shape)) {
(this as MultikTensor<T>).array mt
} else { } else {
NDArray(mt.data, mt.offset, shape, dim = DN(shape.size), base = mt.base ?: mt) NDArray(mt.data, mt.offset, shape, dim = DN(shape.size), base = mt.base ?: mt)
}.wrap() }.wrap()
@ -186,7 +186,7 @@ public class MultikTensorAlgebra<T : Number> internal constructor(
).asDNArray().wrap() ).asDNArray().wrap()
} else if (this.shape.size == 2 && other.shape.size == 2) { } else if (this.shape.size == 2 && other.shape.size == 2) {
(asMultik().array.asD2Array() dot other.asMultik().array.asD2Array()).asDNArray().wrap() (asMultik().array.asD2Array() dot other.asMultik().array.asD2Array()).asDNArray().wrap()
} else if(this.shape.size == 2 && other.shape.size == 1) { } else if (this.shape.size == 2 && other.shape.size == 1) {
(asMultik().array.asD2Array() dot other.asMultik().array.asD1Array()).asDNArray().wrap() (asMultik().array.asD2Array() dot other.asMultik().array.asD1Array()).asDNArray().wrap()
} else { } else {
TODO("Not implemented for broadcasting") TODO("Not implemented for broadcasting")