forked from kscience/kmath
Add multik dot for tensors
This commit is contained in:
parent
98bbc8349c
commit
40c02f4bd7
@ -3,9 +3,14 @@
|
|||||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
@file:Suppress("unused")
|
||||||
|
|
||||||
package space.kscience.kmath.multik
|
package space.kscience.kmath.multik
|
||||||
|
|
||||||
|
import org.jetbrains.kotlinx.multik.api.Multik
|
||||||
|
import org.jetbrains.kotlinx.multik.api.linalg.dot
|
||||||
import org.jetbrains.kotlinx.multik.api.mk
|
import org.jetbrains.kotlinx.multik.api.mk
|
||||||
|
import org.jetbrains.kotlinx.multik.api.ndarrayOf
|
||||||
import org.jetbrains.kotlinx.multik.api.zeros
|
import org.jetbrains.kotlinx.multik.api.zeros
|
||||||
import org.jetbrains.kotlinx.multik.ndarray.data.*
|
import org.jetbrains.kotlinx.multik.ndarray.data.*
|
||||||
import org.jetbrains.kotlinx.multik.ndarray.operations.*
|
import org.jetbrains.kotlinx.multik.ndarray.operations.*
|
||||||
@ -30,8 +35,20 @@ public value class MultikTensor<T>(public val array: MutableMultiArray<T, DN>) :
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private fun <T, D : Dimension> MultiArray<T, D>.asD1Array(): D1Array<T> {
|
||||||
|
if (this is NDArray<T, D>)
|
||||||
|
return this.asD1Array()
|
||||||
|
else throw ClassCastException("Cannot cast MultiArray to NDArray.")
|
||||||
|
}
|
||||||
|
|
||||||
public class MultikTensorAlgebra<T> internal constructor(
|
|
||||||
|
private fun <T, D : Dimension> MultiArray<T, D>.asD2Array(): D2Array<T> {
|
||||||
|
if (this is NDArray<T, D>)
|
||||||
|
return this.asD2Array()
|
||||||
|
else throw ClassCastException("Cannot cast MultiArray to NDArray.")
|
||||||
|
}
|
||||||
|
|
||||||
|
public class MultikTensorAlgebra<T : Number> internal constructor(
|
||||||
public val type: DataType,
|
public val type: DataType,
|
||||||
public val elementAlgebra: Ring<T>,
|
public val elementAlgebra: Ring<T>,
|
||||||
public val comparator: Comparator<T>
|
public val comparator: Comparator<T>
|
||||||
@ -162,9 +179,18 @@ public class MultikTensorAlgebra<T> internal constructor(
|
|||||||
|
|
||||||
override fun Tensor<T>.viewAs(other: Tensor<T>): MultikTensor<T> = view(other.shape)
|
override fun Tensor<T>.viewAs(other: Tensor<T>): MultikTensor<T> = view(other.shape)
|
||||||
|
|
||||||
override fun Tensor<T>.dot(other: Tensor<T>): MultikTensor<T> {
|
override fun Tensor<T>.dot(other: Tensor<T>): MultikTensor<T> =
|
||||||
TODO("Not yet implemented")
|
if (this.shape.size == 1 && other.shape.size == 1) {
|
||||||
}
|
Multik.ndarrayOf(
|
||||||
|
asMultik().array.asD1Array() dot other.asMultik().array.asD1Array()
|
||||||
|
).asDNArray().wrap()
|
||||||
|
} else if (this.shape.size == 2 && other.shape.size == 2) {
|
||||||
|
(asMultik().array.asD2Array() dot other.asMultik().array.asD2Array()).asDNArray().wrap()
|
||||||
|
} else if(this.shape.size == 2 && other.shape.size == 1) {
|
||||||
|
(asMultik().array.asD2Array() dot other.asMultik().array.asD1Array()).asDNArray().wrap()
|
||||||
|
} else {
|
||||||
|
TODO("Not implemented for broadcasting")
|
||||||
|
}
|
||||||
|
|
||||||
override fun diagonalEmbedding(diagonalEntries: Tensor<T>, offset: Int, dim1: Int, dim2: Int): MultikTensor<T> {
|
override fun diagonalEmbedding(diagonalEntries: Tensor<T>, offset: Int, dim1: Int, dim2: Int): MultikTensor<T> {
|
||||||
TODO("Diagonal embedding not implemented")
|
TODO("Diagonal embedding not implemented")
|
||||||
|
Loading…
Reference in New Issue
Block a user