Add multik dot for tensors

This commit is contained in:
Alexander Nozik 2021-10-19 10:50:13 +03:00
parent 98bbc8349c
commit 40c02f4bd7

View File

@ -3,9 +3,14 @@
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
*/ */
@file:Suppress("unused")
package space.kscience.kmath.multik package space.kscience.kmath.multik
import org.jetbrains.kotlinx.multik.api.Multik
import org.jetbrains.kotlinx.multik.api.linalg.dot
import org.jetbrains.kotlinx.multik.api.mk import org.jetbrains.kotlinx.multik.api.mk
import org.jetbrains.kotlinx.multik.api.ndarrayOf
import org.jetbrains.kotlinx.multik.api.zeros import org.jetbrains.kotlinx.multik.api.zeros
import org.jetbrains.kotlinx.multik.ndarray.data.* import org.jetbrains.kotlinx.multik.ndarray.data.*
import org.jetbrains.kotlinx.multik.ndarray.operations.* import org.jetbrains.kotlinx.multik.ndarray.operations.*
@ -30,8 +35,20 @@ public value class MultikTensor<T>(public val array: MutableMultiArray<T, DN>) :
} }
} }
private fun <T, D : Dimension> MultiArray<T, D>.asD1Array(): D1Array<T> {
if (this is NDArray<T, D>)
return this.asD1Array()
else throw ClassCastException("Cannot cast MultiArray to NDArray.")
}
public class MultikTensorAlgebra<T> internal constructor(
private fun <T, D : Dimension> MultiArray<T, D>.asD2Array(): D2Array<T> {
if (this is NDArray<T, D>)
return this.asD2Array()
else throw ClassCastException("Cannot cast MultiArray to NDArray.")
}
public class MultikTensorAlgebra<T : Number> internal constructor(
public val type: DataType, public val type: DataType,
public val elementAlgebra: Ring<T>, public val elementAlgebra: Ring<T>,
public val comparator: Comparator<T> public val comparator: Comparator<T>
@ -162,8 +179,17 @@ public class MultikTensorAlgebra<T> internal constructor(
override fun Tensor<T>.viewAs(other: Tensor<T>): MultikTensor<T> = view(other.shape) override fun Tensor<T>.viewAs(other: Tensor<T>): MultikTensor<T> = view(other.shape)
override fun Tensor<T>.dot(other: Tensor<T>): MultikTensor<T> { override fun Tensor<T>.dot(other: Tensor<T>): MultikTensor<T> =
TODO("Not yet implemented") if (this.shape.size == 1 && other.shape.size == 1) {
Multik.ndarrayOf(
asMultik().array.asD1Array() dot other.asMultik().array.asD1Array()
).asDNArray().wrap()
} else if (this.shape.size == 2 && other.shape.size == 2) {
(asMultik().array.asD2Array() dot other.asMultik().array.asD2Array()).asDNArray().wrap()
} else if(this.shape.size == 2 && other.shape.size == 1) {
(asMultik().array.asD2Array() dot other.asMultik().array.asD1Array()).asDNArray().wrap()
} else {
TODO("Not implemented for broadcasting")
} }
override fun diagonalEmbedding(diagonalEntries: Tensor<T>, offset: Int, dim1: Int, dim2: Int): MultikTensor<T> { override fun diagonalEmbedding(diagonalEntries: Tensor<T>, offset: Int, dim1: Int, dim2: Int): MultikTensor<T> {