diff --git a/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikTensorAlgebra.kt b/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikTensorAlgebra.kt index c229a44bf..ed5575b93 100644 --- a/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikTensorAlgebra.kt +++ b/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikTensorAlgebra.kt @@ -3,9 +3,14 @@ * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ +@file:Suppress("unused") + package space.kscience.kmath.multik +import org.jetbrains.kotlinx.multik.api.Multik +import org.jetbrains.kotlinx.multik.api.linalg.dot import org.jetbrains.kotlinx.multik.api.mk +import org.jetbrains.kotlinx.multik.api.ndarrayOf import org.jetbrains.kotlinx.multik.api.zeros import org.jetbrains.kotlinx.multik.ndarray.data.* import org.jetbrains.kotlinx.multik.ndarray.operations.* @@ -30,8 +35,20 @@ public value class MultikTensor(public val array: MutableMultiArray) : } } +private fun MultiArray.asD1Array(): D1Array { + if (this is NDArray) + return this.asD1Array() + else throw ClassCastException("Cannot cast MultiArray to NDArray.") +} -public class MultikTensorAlgebra internal constructor( + +private fun MultiArray.asD2Array(): D2Array { + if (this is NDArray) + return this.asD2Array() + else throw ClassCastException("Cannot cast MultiArray to NDArray.") +} + +public class MultikTensorAlgebra internal constructor( public val type: DataType, public val elementAlgebra: Ring, public val comparator: Comparator @@ -162,9 +179,18 @@ public class MultikTensorAlgebra internal constructor( override fun Tensor.viewAs(other: Tensor): MultikTensor = view(other.shape) - override fun Tensor.dot(other: Tensor): MultikTensor { - TODO("Not yet implemented") - } + override fun Tensor.dot(other: Tensor): MultikTensor = + if (this.shape.size == 1 && other.shape.size == 1) { + Multik.ndarrayOf( + asMultik().array.asD1Array() dot other.asMultik().array.asD1Array() + ).asDNArray().wrap() + } else if (this.shape.size == 2 && other.shape.size == 2) { + (asMultik().array.asD2Array() dot other.asMultik().array.asD2Array()).asDNArray().wrap() + } else if(this.shape.size == 2 && other.shape.size == 1) { + (asMultik().array.asD2Array() dot other.asMultik().array.asD1Array()).asDNArray().wrap() + } else { + TODO("Not implemented for broadcasting") + } override fun diagonalEmbedding(diagonalEntries: Tensor, offset: Int, dim1: Int, dim2: Int): MultikTensor { TODO("Diagonal embedding not implemented")