Add multik dot for tensors
This commit is contained in:
parent
98bbc8349c
commit
40c02f4bd7
@ -3,9 +3,14 @@
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
||||
*/
|
||||
|
||||
@file:Suppress("unused")
|
||||
|
||||
package space.kscience.kmath.multik
|
||||
|
||||
import org.jetbrains.kotlinx.multik.api.Multik
|
||||
import org.jetbrains.kotlinx.multik.api.linalg.dot
|
||||
import org.jetbrains.kotlinx.multik.api.mk
|
||||
import org.jetbrains.kotlinx.multik.api.ndarrayOf
|
||||
import org.jetbrains.kotlinx.multik.api.zeros
|
||||
import org.jetbrains.kotlinx.multik.ndarray.data.*
|
||||
import org.jetbrains.kotlinx.multik.ndarray.operations.*
|
||||
@ -30,8 +35,20 @@ public value class MultikTensor<T>(public val array: MutableMultiArray<T, DN>) :
|
||||
}
|
||||
}
|
||||
|
||||
private fun <T, D : Dimension> MultiArray<T, D>.asD1Array(): D1Array<T> {
|
||||
if (this is NDArray<T, D>)
|
||||
return this.asD1Array()
|
||||
else throw ClassCastException("Cannot cast MultiArray to NDArray.")
|
||||
}
|
||||
|
||||
public class MultikTensorAlgebra<T> internal constructor(
|
||||
|
||||
private fun <T, D : Dimension> MultiArray<T, D>.asD2Array(): D2Array<T> {
|
||||
if (this is NDArray<T, D>)
|
||||
return this.asD2Array()
|
||||
else throw ClassCastException("Cannot cast MultiArray to NDArray.")
|
||||
}
|
||||
|
||||
public class MultikTensorAlgebra<T : Number> internal constructor(
|
||||
public val type: DataType,
|
||||
public val elementAlgebra: Ring<T>,
|
||||
public val comparator: Comparator<T>
|
||||
@ -162,8 +179,17 @@ public class MultikTensorAlgebra<T> internal constructor(
|
||||
|
||||
override fun Tensor<T>.viewAs(other: Tensor<T>): MultikTensor<T> = view(other.shape)
|
||||
|
||||
override fun Tensor<T>.dot(other: Tensor<T>): MultikTensor<T> {
|
||||
TODO("Not yet implemented")
|
||||
override fun Tensor<T>.dot(other: Tensor<T>): MultikTensor<T> =
|
||||
if (this.shape.size == 1 && other.shape.size == 1) {
|
||||
Multik.ndarrayOf(
|
||||
asMultik().array.asD1Array() dot other.asMultik().array.asD1Array()
|
||||
).asDNArray().wrap()
|
||||
} else if (this.shape.size == 2 && other.shape.size == 2) {
|
||||
(asMultik().array.asD2Array() dot other.asMultik().array.asD2Array()).asDNArray().wrap()
|
||||
} else if(this.shape.size == 2 && other.shape.size == 1) {
|
||||
(asMultik().array.asD2Array() dot other.asMultik().array.asD1Array()).asDNArray().wrap()
|
||||
} else {
|
||||
TODO("Not implemented for broadcasting")
|
||||
}
|
||||
|
||||
override fun diagonalEmbedding(diagonalEntries: Tensor<T>, offset: Int, dim1: Int, dim2: Int): MultikTensor<T> {
|
||||
|
Loading…
Reference in New Issue
Block a user