Initial implementation of TensorFlow connector

This commit is contained in:
Alexander Nozik 2021-10-24 18:33:39 +03:00
parent 6c4741ede6
commit cc114041c4
3 changed files with 71 additions and 32 deletions

View File

@ -0,0 +1,40 @@
package space.kscience.kmath.tensorflow
import org.tensorflow.Graph
import org.tensorflow.Output
import org.tensorflow.ndarray.NdArray
import org.tensorflow.ndarray.Shape
import org.tensorflow.op.core.Constant
import org.tensorflow.types.TFloat64
import space.kscience.kmath.misc.PerformancePitfall
import space.kscience.kmath.tensors.api.Tensor
public class DoubleTensorFlowOutput(
graph: Graph,
output: Output<TFloat64>
) : TensorFlowOutput<Double, TFloat64>(graph, output) {
override fun org.tensorflow.Tensor.actualizeTensor(): NdArray<Double> = output.asTensor()
}
public class DoubleTensorFlowAlgebra internal constructor(
graph: Graph
) : TensorFlowAlgebra<Double, TFloat64>(graph) {
override fun Tensor<Double>.asTensorFlow(): TensorFlowOutput<Double, TFloat64> =
if (this is TensorFlowOutput<Double, *> && output.type() == TFloat64::class.java) {
@Suppress("UNCHECKED_CAST")
this as TensorFlowOutput<Double, TFloat64>
} else {
val res = TFloat64.tensorOf(Shape.of(*shape.toLongArray())) { array ->
@OptIn(PerformancePitfall::class)
elements().forEach { (index, value) ->
array.setDouble(value, *index.toLongArray())
}
}
DoubleTensorFlowOutput(graph, ops.constant(res).asOutput())
}
override fun Output<TFloat64>.wrap(): TensorFlowOutput<Double, TFloat64> = DoubleTensorFlowOutput(graph, this)
override fun const(value: Double): Constant<TFloat64> = ops.constant(value)
}

View File

@ -14,11 +14,10 @@ import space.kscience.kmath.nd.Shape
import space.kscience.kmath.tensors.api.Tensor import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.api.TensorAlgebra import space.kscience.kmath.tensors.api.TensorAlgebra
private fun IntArray.toLongArray() = LongArray(size) { get(it).toLong() } internal fun IntArray.toLongArray() = LongArray(size) { get(it).toLong() }
private fun LongArray.toIntArray() = IntArray(size) { get(it).toInt() } internal fun LongArray.toIntArray() = IntArray(size) { get(it).toInt() }
private val <T> NdArray<T>.scalar: T internal val <T> NdArray<T>.scalar: T get() = getObject()
get() = getObject()
public sealed interface TensorFlowTensor<T> : Tensor<T> public sealed interface TensorFlowTensor<T> : Tensor<T>
@ -72,14 +71,12 @@ public abstract class TensorFlowOutput<T, TT : TType>(
public abstract class TensorFlowAlgebra<T, TT : TType> internal constructor( public abstract class TensorFlowAlgebra<T, TT : TType> internal constructor(
private val graph: Graph protected val graph: Graph
) : TensorAlgebra<T> { ) : TensorAlgebra<T> {
private val ops by lazy { Ops.create(graph) } protected val ops: Ops by lazy { Ops.create(graph) }
protected fun Tensor<T>.asTensorFlow(): TensorFlowOutput<T, TT> = if (this is TensorFlowOutput<T, TT>) this else { protected abstract fun Tensor<T>.asTensorFlow(): TensorFlowOutput<T, TT>
TODO()
}
protected abstract fun Output<TT>.wrap(): TensorFlowOutput<T, TT> protected abstract fun Output<TT>.wrap(): TensorFlowOutput<T, TT>
@ -138,27 +135,29 @@ public abstract class TensorFlowAlgebra<T, TT : TType> internal constructor(
private inline fun unOp(value: Tensor<T>, operation: (Operand<TT>) -> Operand<TT>): TensorFlowOutput<T, TT> = private inline fun unOp(value: Tensor<T>, operation: (Operand<TT>) -> Operand<TT>): TensorFlowOutput<T, TT> =
operation(value.asTensorFlow().output).asOutput().wrap() operation(value.asTensorFlow().output).asOutput().wrap()
override fun T.plus(other: Tensor<T>) = biOp(other, ops.math::add) override fun T.plus(other: Tensor<T>): TensorFlowOutput<T, TT> = biOp(other, ops.math::add)
override fun Tensor<T>.plus(value: T) = biOp(value, ops.math::add) override fun Tensor<T>.plus(value: T): TensorFlowOutput<T, TT> = biOp(value, ops.math::add)
override fun Tensor<T>.plus(other: Tensor<T>) = biOp(other, ops.math::add) override fun Tensor<T>.plus(other: Tensor<T>): TensorFlowOutput<T, TT> = biOp(other, ops.math::add)
override fun Tensor<T>.plusAssign(value: T): Unit = inPlaceOp(value, ops.math::add) override fun Tensor<T>.plusAssign(value: T): Unit = inPlaceOp(value, ops.math::add)
override fun Tensor<T>.plusAssign(other: Tensor<T>): Unit = inPlaceOp(other, ops.math::add) override fun Tensor<T>.plusAssign(other: Tensor<T>): Unit = inPlaceOp(other, ops.math::add)
override fun Tensor<T>.minus(value: T) = biOp(value, ops.math::sub) override fun Tensor<T>.minus(value: T): TensorFlowOutput<T, TT> = biOp(value, ops.math::sub)
override fun Tensor<T>.minus(other: Tensor<T>) = biOp(other, ops.math::sub) override fun Tensor<T>.minus(other: Tensor<T>): TensorFlowOutput<T, TT> = biOp(other, ops.math::sub)
override fun T.minus(other: Tensor<T>): Tensor<T> = biOp(other, ops.math::sub)
override fun Tensor<T>.minusAssign(value: T): Unit = inPlaceOp(value, ops.math::sub) override fun Tensor<T>.minusAssign(value: T): Unit = inPlaceOp(value, ops.math::sub)
override fun Tensor<T>.minusAssign(other: Tensor<T>): Unit = inPlaceOp(other, ops.math::sub) override fun Tensor<T>.minusAssign(other: Tensor<T>): Unit = inPlaceOp(other, ops.math::sub)
override fun T.times(other: Tensor<T>) = biOp(other, ops.math::mul) override fun T.times(other: Tensor<T>): TensorFlowOutput<T, TT> = biOp(other, ops.math::mul)
override fun Tensor<T>.times(value: T) = biOp(value, ops.math::mul) override fun Tensor<T>.times(value: T): TensorFlowOutput<T, TT> = biOp(value, ops.math::mul)
override fun Tensor<T>.times(other: Tensor<T>): TensorFlowOutput<T, TT> = biOp(other, ops.math::mul) override fun Tensor<T>.times(other: Tensor<T>): TensorFlowOutput<T, TT> = biOp(other, ops.math::mul)
@ -166,14 +165,14 @@ public abstract class TensorFlowAlgebra<T, TT : TType> internal constructor(
override fun Tensor<T>.timesAssign(other: Tensor<T>): Unit = inPlaceOp(other, ops.math::mul) override fun Tensor<T>.timesAssign(other: Tensor<T>): Unit = inPlaceOp(other, ops.math::mul)
override fun Tensor<T>.unaryMinus() = unOp(this, ops.math::neg) override fun Tensor<T>.unaryMinus(): TensorFlowOutput<T, TT> = unOp(this, ops.math::neg)
override fun Tensor<T>.get(i: Int): Tensor<T> { override fun Tensor<T>.get(i: Int): Tensor<T> {
ops. TODO("Not yet implemented")
} }
override fun Tensor<T>.transpose(i: Int, j: Int): Tensor<T> { override fun Tensor<T>.transpose(i: Int, j: Int): Tensor<T> = unOp(this) {
TODO("Not yet implemented") ops.linalg.transpose(it, ops.constant(intArrayOf(i, j)))
} }
override fun Tensor<T>.view(shape: IntArray): Tensor<T> { override fun Tensor<T>.view(shape: IntArray): Tensor<T> {
@ -184,15 +183,15 @@ public abstract class TensorFlowAlgebra<T, TT : TType> internal constructor(
TODO("Not yet implemented") TODO("Not yet implemented")
} }
override fun Tensor<T>.dot(other: Tensor<T>) = biOp(other, ops.math.) override fun Tensor<T>.dot(other: Tensor<T>): TensorFlowOutput<T, TT> = biOp(other) { l, r ->
ops.linalg.matMul(l, r)
}
override fun diagonalEmbedding(diagonalEntries: Tensor<T>, offset: Int, dim1: Int, dim2: Int): Tensor<T> = ops.run { override fun diagonalEmbedding(diagonalEntries: Tensor<T>, offset: Int, dim1: Int, dim2: Int): Tensor<T> = ops.run {
TODO("Not yet implemented") TODO("Not yet implemented")
} }
override fun Tensor<T>.sum(): T { override fun Tensor<T>.sum(): T = TODO("Not yet implemented")
TODO("Not yet implemented")
}
override fun Tensor<T>.sum(dim: Int, keepDim: Boolean): Tensor<T> { override fun Tensor<T>.sum(dim: Int, keepDim: Boolean): Tensor<T> {
TODO("Not yet implemented") TODO("Not yet implemented")

View File

@ -361,16 +361,16 @@ public open class DoubleTensorAlgebra :
dotHelper(a.as2D(), b.as2D(), res.as2D(), l, m1, n) dotHelper(a.as2D(), b.as2D(), res.as2D(), l, m1, n)
} }
if (penultimateDim) { return if (penultimateDim) {
return resTensor.view( resTensor.view(
resTensor.shape.dropLast(2).toIntArray() + resTensor.shape.dropLast(2).toIntArray() +
intArrayOf(resTensor.shape.last()) intArrayOf(resTensor.shape.last())
) )
} else if (lastDim) {
resTensor.view(resTensor.shape.dropLast(1).toIntArray())
} else {
resTensor
} }
if (lastDim) {
return resTensor.view(resTensor.shape.dropLast(1).toIntArray())
}
return resTensor
} }
override fun diagonalEmbedding( override fun diagonalEmbedding(