Initial implementation of TensorFlow connector
This commit is contained in:
parent
6c4741ede6
commit
cc114041c4
@ -0,0 +1,40 @@
|
|||||||
|
package space.kscience.kmath.tensorflow
|
||||||
|
|
||||||
|
import org.tensorflow.Graph
|
||||||
|
import org.tensorflow.Output
|
||||||
|
import org.tensorflow.ndarray.NdArray
|
||||||
|
import org.tensorflow.ndarray.Shape
|
||||||
|
import org.tensorflow.op.core.Constant
|
||||||
|
import org.tensorflow.types.TFloat64
|
||||||
|
import space.kscience.kmath.misc.PerformancePitfall
|
||||||
|
import space.kscience.kmath.tensors.api.Tensor
|
||||||
|
|
||||||
|
public class DoubleTensorFlowOutput(
|
||||||
|
graph: Graph,
|
||||||
|
output: Output<TFloat64>
|
||||||
|
) : TensorFlowOutput<Double, TFloat64>(graph, output) {
|
||||||
|
override fun org.tensorflow.Tensor.actualizeTensor(): NdArray<Double> = output.asTensor()
|
||||||
|
}
|
||||||
|
|
||||||
|
public class DoubleTensorFlowAlgebra internal constructor(
|
||||||
|
graph: Graph
|
||||||
|
) : TensorFlowAlgebra<Double, TFloat64>(graph) {
|
||||||
|
|
||||||
|
override fun Tensor<Double>.asTensorFlow(): TensorFlowOutput<Double, TFloat64> =
|
||||||
|
if (this is TensorFlowOutput<Double, *> && output.type() == TFloat64::class.java) {
|
||||||
|
@Suppress("UNCHECKED_CAST")
|
||||||
|
this as TensorFlowOutput<Double, TFloat64>
|
||||||
|
} else {
|
||||||
|
val res = TFloat64.tensorOf(Shape.of(*shape.toLongArray())) { array ->
|
||||||
|
@OptIn(PerformancePitfall::class)
|
||||||
|
elements().forEach { (index, value) ->
|
||||||
|
array.setDouble(value, *index.toLongArray())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
DoubleTensorFlowOutput(graph, ops.constant(res).asOutput())
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun Output<TFloat64>.wrap(): TensorFlowOutput<Double, TFloat64> = DoubleTensorFlowOutput(graph, this)
|
||||||
|
|
||||||
|
override fun const(value: Double): Constant<TFloat64> = ops.constant(value)
|
||||||
|
}
|
@ -14,11 +14,10 @@ import space.kscience.kmath.nd.Shape
|
|||||||
import space.kscience.kmath.tensors.api.Tensor
|
import space.kscience.kmath.tensors.api.Tensor
|
||||||
import space.kscience.kmath.tensors.api.TensorAlgebra
|
import space.kscience.kmath.tensors.api.TensorAlgebra
|
||||||
|
|
||||||
private fun IntArray.toLongArray() = LongArray(size) { get(it).toLong() }
|
internal fun IntArray.toLongArray() = LongArray(size) { get(it).toLong() }
|
||||||
private fun LongArray.toIntArray() = IntArray(size) { get(it).toInt() }
|
internal fun LongArray.toIntArray() = IntArray(size) { get(it).toInt() }
|
||||||
|
|
||||||
private val <T> NdArray<T>.scalar: T
|
internal val <T> NdArray<T>.scalar: T get() = getObject()
|
||||||
get() = getObject()
|
|
||||||
|
|
||||||
|
|
||||||
public sealed interface TensorFlowTensor<T> : Tensor<T>
|
public sealed interface TensorFlowTensor<T> : Tensor<T>
|
||||||
@ -72,14 +71,12 @@ public abstract class TensorFlowOutput<T, TT : TType>(
|
|||||||
|
|
||||||
|
|
||||||
public abstract class TensorFlowAlgebra<T, TT : TType> internal constructor(
|
public abstract class TensorFlowAlgebra<T, TT : TType> internal constructor(
|
||||||
private val graph: Graph
|
protected val graph: Graph
|
||||||
) : TensorAlgebra<T> {
|
) : TensorAlgebra<T> {
|
||||||
|
|
||||||
private val ops by lazy { Ops.create(graph) }
|
protected val ops: Ops by lazy { Ops.create(graph) }
|
||||||
|
|
||||||
protected fun Tensor<T>.asTensorFlow(): TensorFlowOutput<T, TT> = if (this is TensorFlowOutput<T, TT>) this else {
|
protected abstract fun Tensor<T>.asTensorFlow(): TensorFlowOutput<T, TT>
|
||||||
TODO()
|
|
||||||
}
|
|
||||||
|
|
||||||
protected abstract fun Output<TT>.wrap(): TensorFlowOutput<T, TT>
|
protected abstract fun Output<TT>.wrap(): TensorFlowOutput<T, TT>
|
||||||
|
|
||||||
@ -138,27 +135,29 @@ public abstract class TensorFlowAlgebra<T, TT : TType> internal constructor(
|
|||||||
private inline fun unOp(value: Tensor<T>, operation: (Operand<TT>) -> Operand<TT>): TensorFlowOutput<T, TT> =
|
private inline fun unOp(value: Tensor<T>, operation: (Operand<TT>) -> Operand<TT>): TensorFlowOutput<T, TT> =
|
||||||
operation(value.asTensorFlow().output).asOutput().wrap()
|
operation(value.asTensorFlow().output).asOutput().wrap()
|
||||||
|
|
||||||
override fun T.plus(other: Tensor<T>) = biOp(other, ops.math::add)
|
override fun T.plus(other: Tensor<T>): TensorFlowOutput<T, TT> = biOp(other, ops.math::add)
|
||||||
|
|
||||||
override fun Tensor<T>.plus(value: T) = biOp(value, ops.math::add)
|
override fun Tensor<T>.plus(value: T): TensorFlowOutput<T, TT> = biOp(value, ops.math::add)
|
||||||
|
|
||||||
override fun Tensor<T>.plus(other: Tensor<T>) = biOp(other, ops.math::add)
|
override fun Tensor<T>.plus(other: Tensor<T>): TensorFlowOutput<T, TT> = biOp(other, ops.math::add)
|
||||||
|
|
||||||
override fun Tensor<T>.plusAssign(value: T): Unit = inPlaceOp(value, ops.math::add)
|
override fun Tensor<T>.plusAssign(value: T): Unit = inPlaceOp(value, ops.math::add)
|
||||||
|
|
||||||
override fun Tensor<T>.plusAssign(other: Tensor<T>): Unit = inPlaceOp(other, ops.math::add)
|
override fun Tensor<T>.plusAssign(other: Tensor<T>): Unit = inPlaceOp(other, ops.math::add)
|
||||||
|
|
||||||
override fun Tensor<T>.minus(value: T) = biOp(value, ops.math::sub)
|
override fun Tensor<T>.minus(value: T): TensorFlowOutput<T, TT> = biOp(value, ops.math::sub)
|
||||||
|
|
||||||
override fun Tensor<T>.minus(other: Tensor<T>) = biOp(other, ops.math::sub)
|
override fun Tensor<T>.minus(other: Tensor<T>): TensorFlowOutput<T, TT> = biOp(other, ops.math::sub)
|
||||||
|
|
||||||
|
override fun T.minus(other: Tensor<T>): Tensor<T> = biOp(other, ops.math::sub)
|
||||||
|
|
||||||
override fun Tensor<T>.minusAssign(value: T): Unit = inPlaceOp(value, ops.math::sub)
|
override fun Tensor<T>.minusAssign(value: T): Unit = inPlaceOp(value, ops.math::sub)
|
||||||
|
|
||||||
override fun Tensor<T>.minusAssign(other: Tensor<T>): Unit = inPlaceOp(other, ops.math::sub)
|
override fun Tensor<T>.minusAssign(other: Tensor<T>): Unit = inPlaceOp(other, ops.math::sub)
|
||||||
|
|
||||||
override fun T.times(other: Tensor<T>) = biOp(other, ops.math::mul)
|
override fun T.times(other: Tensor<T>): TensorFlowOutput<T, TT> = biOp(other, ops.math::mul)
|
||||||
|
|
||||||
override fun Tensor<T>.times(value: T) = biOp(value, ops.math::mul)
|
override fun Tensor<T>.times(value: T): TensorFlowOutput<T, TT> = biOp(value, ops.math::mul)
|
||||||
|
|
||||||
override fun Tensor<T>.times(other: Tensor<T>): TensorFlowOutput<T, TT> = biOp(other, ops.math::mul)
|
override fun Tensor<T>.times(other: Tensor<T>): TensorFlowOutput<T, TT> = biOp(other, ops.math::mul)
|
||||||
|
|
||||||
@ -166,14 +165,14 @@ public abstract class TensorFlowAlgebra<T, TT : TType> internal constructor(
|
|||||||
|
|
||||||
override fun Tensor<T>.timesAssign(other: Tensor<T>): Unit = inPlaceOp(other, ops.math::mul)
|
override fun Tensor<T>.timesAssign(other: Tensor<T>): Unit = inPlaceOp(other, ops.math::mul)
|
||||||
|
|
||||||
override fun Tensor<T>.unaryMinus() = unOp(this, ops.math::neg)
|
override fun Tensor<T>.unaryMinus(): TensorFlowOutput<T, TT> = unOp(this, ops.math::neg)
|
||||||
|
|
||||||
override fun Tensor<T>.get(i: Int): Tensor<T>{
|
override fun Tensor<T>.get(i: Int): Tensor<T> {
|
||||||
ops.
|
TODO("Not yet implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun Tensor<T>.transpose(i: Int, j: Int): Tensor<T> {
|
override fun Tensor<T>.transpose(i: Int, j: Int): Tensor<T> = unOp(this) {
|
||||||
TODO("Not yet implemented")
|
ops.linalg.transpose(it, ops.constant(intArrayOf(i, j)))
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun Tensor<T>.view(shape: IntArray): Tensor<T> {
|
override fun Tensor<T>.view(shape: IntArray): Tensor<T> {
|
||||||
@ -184,15 +183,15 @@ public abstract class TensorFlowAlgebra<T, TT : TType> internal constructor(
|
|||||||
TODO("Not yet implemented")
|
TODO("Not yet implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun Tensor<T>.dot(other: Tensor<T>) = biOp(other, ops.math.)
|
override fun Tensor<T>.dot(other: Tensor<T>): TensorFlowOutput<T, TT> = biOp(other) { l, r ->
|
||||||
|
ops.linalg.matMul(l, r)
|
||||||
|
}
|
||||||
|
|
||||||
override fun diagonalEmbedding(diagonalEntries: Tensor<T>, offset: Int, dim1: Int, dim2: Int): Tensor<T> = ops.run {
|
override fun diagonalEmbedding(diagonalEntries: Tensor<T>, offset: Int, dim1: Int, dim2: Int): Tensor<T> = ops.run {
|
||||||
TODO("Not yet implemented")
|
TODO("Not yet implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun Tensor<T>.sum(): T {
|
override fun Tensor<T>.sum(): T = TODO("Not yet implemented")
|
||||||
TODO("Not yet implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun Tensor<T>.sum(dim: Int, keepDim: Boolean): Tensor<T> {
|
override fun Tensor<T>.sum(dim: Int, keepDim: Boolean): Tensor<T> {
|
||||||
TODO("Not yet implemented")
|
TODO("Not yet implemented")
|
||||||
|
@ -22,7 +22,7 @@ import kotlin.math.*
|
|||||||
public open class DoubleTensorAlgebra :
|
public open class DoubleTensorAlgebra :
|
||||||
TensorPartialDivisionAlgebra<Double>,
|
TensorPartialDivisionAlgebra<Double>,
|
||||||
AnalyticTensorAlgebra<Double>,
|
AnalyticTensorAlgebra<Double>,
|
||||||
LinearOpsTensorAlgebra<Double>{
|
LinearOpsTensorAlgebra<Double> {
|
||||||
|
|
||||||
public companion object : DoubleTensorAlgebra()
|
public companion object : DoubleTensorAlgebra()
|
||||||
|
|
||||||
@ -361,16 +361,16 @@ public open class DoubleTensorAlgebra :
|
|||||||
dotHelper(a.as2D(), b.as2D(), res.as2D(), l, m1, n)
|
dotHelper(a.as2D(), b.as2D(), res.as2D(), l, m1, n)
|
||||||
}
|
}
|
||||||
|
|
||||||
if (penultimateDim) {
|
return if (penultimateDim) {
|
||||||
return resTensor.view(
|
resTensor.view(
|
||||||
resTensor.shape.dropLast(2).toIntArray() +
|
resTensor.shape.dropLast(2).toIntArray() +
|
||||||
intArrayOf(resTensor.shape.last())
|
intArrayOf(resTensor.shape.last())
|
||||||
)
|
)
|
||||||
|
} else if (lastDim) {
|
||||||
|
resTensor.view(resTensor.shape.dropLast(1).toIntArray())
|
||||||
|
} else {
|
||||||
|
resTensor
|
||||||
}
|
}
|
||||||
if (lastDim) {
|
|
||||||
return resTensor.view(resTensor.shape.dropLast(1).toIntArray())
|
|
||||||
}
|
|
||||||
return resTensor
|
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun diagonalEmbedding(
|
override fun diagonalEmbedding(
|
||||||
|
Loading…
Reference in New Issue
Block a user