Initial implementation of TensorFlow connector
This commit is contained in:
parent
6c4741ede6
commit
cc114041c4
@ -0,0 +1,40 @@
|
||||
package space.kscience.kmath.tensorflow
|
||||
|
||||
import org.tensorflow.Graph
|
||||
import org.tensorflow.Output
|
||||
import org.tensorflow.ndarray.NdArray
|
||||
import org.tensorflow.ndarray.Shape
|
||||
import org.tensorflow.op.core.Constant
|
||||
import org.tensorflow.types.TFloat64
|
||||
import space.kscience.kmath.misc.PerformancePitfall
|
||||
import space.kscience.kmath.tensors.api.Tensor
|
||||
|
||||
public class DoubleTensorFlowOutput(
|
||||
graph: Graph,
|
||||
output: Output<TFloat64>
|
||||
) : TensorFlowOutput<Double, TFloat64>(graph, output) {
|
||||
override fun org.tensorflow.Tensor.actualizeTensor(): NdArray<Double> = output.asTensor()
|
||||
}
|
||||
|
||||
public class DoubleTensorFlowAlgebra internal constructor(
|
||||
graph: Graph
|
||||
) : TensorFlowAlgebra<Double, TFloat64>(graph) {
|
||||
|
||||
override fun Tensor<Double>.asTensorFlow(): TensorFlowOutput<Double, TFloat64> =
|
||||
if (this is TensorFlowOutput<Double, *> && output.type() == TFloat64::class.java) {
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
this as TensorFlowOutput<Double, TFloat64>
|
||||
} else {
|
||||
val res = TFloat64.tensorOf(Shape.of(*shape.toLongArray())) { array ->
|
||||
@OptIn(PerformancePitfall::class)
|
||||
elements().forEach { (index, value) ->
|
||||
array.setDouble(value, *index.toLongArray())
|
||||
}
|
||||
}
|
||||
DoubleTensorFlowOutput(graph, ops.constant(res).asOutput())
|
||||
}
|
||||
|
||||
override fun Output<TFloat64>.wrap(): TensorFlowOutput<Double, TFloat64> = DoubleTensorFlowOutput(graph, this)
|
||||
|
||||
override fun const(value: Double): Constant<TFloat64> = ops.constant(value)
|
||||
}
|
@ -14,11 +14,10 @@ import space.kscience.kmath.nd.Shape
|
||||
import space.kscience.kmath.tensors.api.Tensor
|
||||
import space.kscience.kmath.tensors.api.TensorAlgebra
|
||||
|
||||
private fun IntArray.toLongArray() = LongArray(size) { get(it).toLong() }
|
||||
private fun LongArray.toIntArray() = IntArray(size) { get(it).toInt() }
|
||||
internal fun IntArray.toLongArray() = LongArray(size) { get(it).toLong() }
|
||||
internal fun LongArray.toIntArray() = IntArray(size) { get(it).toInt() }
|
||||
|
||||
private val <T> NdArray<T>.scalar: T
|
||||
get() = getObject()
|
||||
internal val <T> NdArray<T>.scalar: T get() = getObject()
|
||||
|
||||
|
||||
public sealed interface TensorFlowTensor<T> : Tensor<T>
|
||||
@ -72,14 +71,12 @@ public abstract class TensorFlowOutput<T, TT : TType>(
|
||||
|
||||
|
||||
public abstract class TensorFlowAlgebra<T, TT : TType> internal constructor(
|
||||
private val graph: Graph
|
||||
protected val graph: Graph
|
||||
) : TensorAlgebra<T> {
|
||||
|
||||
private val ops by lazy { Ops.create(graph) }
|
||||
protected val ops: Ops by lazy { Ops.create(graph) }
|
||||
|
||||
protected fun Tensor<T>.asTensorFlow(): TensorFlowOutput<T, TT> = if (this is TensorFlowOutput<T, TT>) this else {
|
||||
TODO()
|
||||
}
|
||||
protected abstract fun Tensor<T>.asTensorFlow(): TensorFlowOutput<T, TT>
|
||||
|
||||
protected abstract fun Output<TT>.wrap(): TensorFlowOutput<T, TT>
|
||||
|
||||
@ -138,27 +135,29 @@ public abstract class TensorFlowAlgebra<T, TT : TType> internal constructor(
|
||||
private inline fun unOp(value: Tensor<T>, operation: (Operand<TT>) -> Operand<TT>): TensorFlowOutput<T, TT> =
|
||||
operation(value.asTensorFlow().output).asOutput().wrap()
|
||||
|
||||
override fun T.plus(other: Tensor<T>) = biOp(other, ops.math::add)
|
||||
override fun T.plus(other: Tensor<T>): TensorFlowOutput<T, TT> = biOp(other, ops.math::add)
|
||||
|
||||
override fun Tensor<T>.plus(value: T) = biOp(value, ops.math::add)
|
||||
override fun Tensor<T>.plus(value: T): TensorFlowOutput<T, TT> = biOp(value, ops.math::add)
|
||||
|
||||
override fun Tensor<T>.plus(other: Tensor<T>) = biOp(other, ops.math::add)
|
||||
override fun Tensor<T>.plus(other: Tensor<T>): TensorFlowOutput<T, TT> = biOp(other, ops.math::add)
|
||||
|
||||
override fun Tensor<T>.plusAssign(value: T): Unit = inPlaceOp(value, ops.math::add)
|
||||
|
||||
override fun Tensor<T>.plusAssign(other: Tensor<T>): Unit = inPlaceOp(other, ops.math::add)
|
||||
|
||||
override fun Tensor<T>.minus(value: T) = biOp(value, ops.math::sub)
|
||||
override fun Tensor<T>.minus(value: T): TensorFlowOutput<T, TT> = biOp(value, ops.math::sub)
|
||||
|
||||
override fun Tensor<T>.minus(other: Tensor<T>) = biOp(other, ops.math::sub)
|
||||
override fun Tensor<T>.minus(other: Tensor<T>): TensorFlowOutput<T, TT> = biOp(other, ops.math::sub)
|
||||
|
||||
override fun T.minus(other: Tensor<T>): Tensor<T> = biOp(other, ops.math::sub)
|
||||
|
||||
override fun Tensor<T>.minusAssign(value: T): Unit = inPlaceOp(value, ops.math::sub)
|
||||
|
||||
override fun Tensor<T>.minusAssign(other: Tensor<T>): Unit = inPlaceOp(other, ops.math::sub)
|
||||
|
||||
override fun T.times(other: Tensor<T>) = biOp(other, ops.math::mul)
|
||||
override fun T.times(other: Tensor<T>): TensorFlowOutput<T, TT> = biOp(other, ops.math::mul)
|
||||
|
||||
override fun Tensor<T>.times(value: T) = biOp(value, ops.math::mul)
|
||||
override fun Tensor<T>.times(value: T): TensorFlowOutput<T, TT> = biOp(value, ops.math::mul)
|
||||
|
||||
override fun Tensor<T>.times(other: Tensor<T>): TensorFlowOutput<T, TT> = biOp(other, ops.math::mul)
|
||||
|
||||
@ -166,14 +165,14 @@ public abstract class TensorFlowAlgebra<T, TT : TType> internal constructor(
|
||||
|
||||
override fun Tensor<T>.timesAssign(other: Tensor<T>): Unit = inPlaceOp(other, ops.math::mul)
|
||||
|
||||
override fun Tensor<T>.unaryMinus() = unOp(this, ops.math::neg)
|
||||
override fun Tensor<T>.unaryMinus(): TensorFlowOutput<T, TT> = unOp(this, ops.math::neg)
|
||||
|
||||
override fun Tensor<T>.get(i: Int): Tensor<T>{
|
||||
ops.
|
||||
override fun Tensor<T>.get(i: Int): Tensor<T> {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun Tensor<T>.transpose(i: Int, j: Int): Tensor<T> {
|
||||
TODO("Not yet implemented")
|
||||
override fun Tensor<T>.transpose(i: Int, j: Int): Tensor<T> = unOp(this) {
|
||||
ops.linalg.transpose(it, ops.constant(intArrayOf(i, j)))
|
||||
}
|
||||
|
||||
override fun Tensor<T>.view(shape: IntArray): Tensor<T> {
|
||||
@ -184,15 +183,15 @@ public abstract class TensorFlowAlgebra<T, TT : TType> internal constructor(
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun Tensor<T>.dot(other: Tensor<T>) = biOp(other, ops.math.)
|
||||
override fun Tensor<T>.dot(other: Tensor<T>): TensorFlowOutput<T, TT> = biOp(other) { l, r ->
|
||||
ops.linalg.matMul(l, r)
|
||||
}
|
||||
|
||||
override fun diagonalEmbedding(diagonalEntries: Tensor<T>, offset: Int, dim1: Int, dim2: Int): Tensor<T> = ops.run {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun Tensor<T>.sum(): T {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
override fun Tensor<T>.sum(): T = TODO("Not yet implemented")
|
||||
|
||||
override fun Tensor<T>.sum(dim: Int, keepDim: Boolean): Tensor<T> {
|
||||
TODO("Not yet implemented")
|
||||
|
@ -22,7 +22,7 @@ import kotlin.math.*
|
||||
public open class DoubleTensorAlgebra :
|
||||
TensorPartialDivisionAlgebra<Double>,
|
||||
AnalyticTensorAlgebra<Double>,
|
||||
LinearOpsTensorAlgebra<Double>{
|
||||
LinearOpsTensorAlgebra<Double> {
|
||||
|
||||
public companion object : DoubleTensorAlgebra()
|
||||
|
||||
@ -361,16 +361,16 @@ public open class DoubleTensorAlgebra :
|
||||
dotHelper(a.as2D(), b.as2D(), res.as2D(), l, m1, n)
|
||||
}
|
||||
|
||||
if (penultimateDim) {
|
||||
return resTensor.view(
|
||||
return if (penultimateDim) {
|
||||
resTensor.view(
|
||||
resTensor.shape.dropLast(2).toIntArray() +
|
||||
intArrayOf(resTensor.shape.last())
|
||||
)
|
||||
} else if (lastDim) {
|
||||
resTensor.view(resTensor.shape.dropLast(1).toIntArray())
|
||||
} else {
|
||||
resTensor
|
||||
}
|
||||
if (lastDim) {
|
||||
return resTensor.view(resTensor.shape.dropLast(1).toIntArray())
|
||||
}
|
||||
return resTensor
|
||||
}
|
||||
|
||||
override fun diagonalEmbedding(
|
||||
|
Loading…
Reference in New Issue
Block a user