|
|
|
@ -8,8 +8,13 @@ import org.tensorflow.Session
|
|
|
|
|
import org.tensorflow.ndarray.NdArray
|
|
|
|
|
import org.tensorflow.op.Ops
|
|
|
|
|
import org.tensorflow.op.core.Constant
|
|
|
|
|
import org.tensorflow.op.core.Max
|
|
|
|
|
import org.tensorflow.op.core.Min
|
|
|
|
|
import org.tensorflow.op.core.Sum
|
|
|
|
|
import org.tensorflow.types.TInt32
|
|
|
|
|
import org.tensorflow.types.family.TType
|
|
|
|
|
import space.kscience.kmath.misc.PerformancePitfall
|
|
|
|
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
|
|
|
|
import space.kscience.kmath.nd.Shape
|
|
|
|
|
import space.kscience.kmath.nd.StructureND
|
|
|
|
|
import space.kscience.kmath.operations.Ring
|
|
|
|
@ -38,8 +43,8 @@ public value class TensorFlowArray<T>(public val tensor: NdArray<T>) : Tensor<T>
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
public abstract class TensorFlowOutput<T, TT : TType>(
|
|
|
|
|
private val graph: Graph,
|
|
|
|
|
output: Output<TT>
|
|
|
|
|
protected val graph: Graph,
|
|
|
|
|
output: Output<TT>,
|
|
|
|
|
) : TensorFlowTensor<T> {
|
|
|
|
|
|
|
|
|
|
public var output: Output<TT> = output
|
|
|
|
@ -49,10 +54,11 @@ public abstract class TensorFlowOutput<T, TT : TType>(
|
|
|
|
|
|
|
|
|
|
protected abstract fun org.tensorflow.Tensor.actualizeTensor(): NdArray<T>
|
|
|
|
|
|
|
|
|
|
private val actualTensor by lazy {
|
|
|
|
|
val session = Session(graph)
|
|
|
|
|
internal val actualTensor by lazy {
|
|
|
|
|
Session(graph).use { session ->
|
|
|
|
|
TensorFlowArray(session.runner().fetch(output).run().first().actualizeTensor())
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
override fun get(index: IntArray): T = actualTensor[index]
|
|
|
|
|
|
|
|
|
@ -66,9 +72,9 @@ public abstract class TensorFlowOutput<T, TT : TType>(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public abstract class TensorFlowAlgebra<T, TT : TType, A: Ring<T>> internal constructor(
|
|
|
|
|
protected val graph: Graph
|
|
|
|
|
) : TensorAlgebra<T,A> {
|
|
|
|
|
public abstract class TensorFlowAlgebra<T, TT : TType, A : Ring<T>> internal constructor(
|
|
|
|
|
protected val graph: Graph,
|
|
|
|
|
) : TensorAlgebra<T, A> {
|
|
|
|
|
|
|
|
|
|
protected val ops: Ops by lazy { Ops.create(graph) }
|
|
|
|
|
|
|
|
|
@ -83,7 +89,7 @@ public abstract class TensorFlowAlgebra<T, TT : TType, A: Ring<T>> internal cons
|
|
|
|
|
|
|
|
|
|
private inline fun StructureND<T>.biOp(
|
|
|
|
|
other: StructureND<T>,
|
|
|
|
|
operation: (left: Operand<TT>, right: Operand<TT>) -> Operand<TT>
|
|
|
|
|
operation: (left: Operand<TT>, right: Operand<TT>) -> Operand<TT>,
|
|
|
|
|
): TensorFlowOutput<T, TT> {
|
|
|
|
|
val left = asTensorFlow().output
|
|
|
|
|
val right = other.asTensorFlow().output
|
|
|
|
@ -92,7 +98,7 @@ public abstract class TensorFlowAlgebra<T, TT : TType, A: Ring<T>> internal cons
|
|
|
|
|
|
|
|
|
|
private inline fun T.biOp(
|
|
|
|
|
other: StructureND<T>,
|
|
|
|
|
operation: (left: Operand<TT>, right: Operand<TT>) -> Operand<TT>
|
|
|
|
|
operation: (left: Operand<TT>, right: Operand<TT>) -> Operand<TT>,
|
|
|
|
|
): TensorFlowOutput<T, TT> {
|
|
|
|
|
val left = const(this)
|
|
|
|
|
val right = other.asTensorFlow().output
|
|
|
|
@ -101,7 +107,7 @@ public abstract class TensorFlowAlgebra<T, TT : TType, A: Ring<T>> internal cons
|
|
|
|
|
|
|
|
|
|
private inline fun StructureND<T>.biOp(
|
|
|
|
|
value: T,
|
|
|
|
|
operation: (left: Operand<TT>, right: Operand<TT>) -> Operand<TT>
|
|
|
|
|
operation: (left: Operand<TT>, right: Operand<TT>) -> Operand<TT>,
|
|
|
|
|
): TensorFlowOutput<T, TT> {
|
|
|
|
|
val left = asTensorFlow().output
|
|
|
|
|
val right = const(value)
|
|
|
|
@ -110,7 +116,7 @@ public abstract class TensorFlowAlgebra<T, TT : TType, A: Ring<T>> internal cons
|
|
|
|
|
|
|
|
|
|
private inline fun Tensor<T>.inPlaceOp(
|
|
|
|
|
other: StructureND<T>,
|
|
|
|
|
operation: (left: Operand<TT>, right: Operand<TT>) -> Operand<TT>
|
|
|
|
|
operation: (left: Operand<TT>, right: Operand<TT>) -> Operand<TT>,
|
|
|
|
|
): Unit {
|
|
|
|
|
val origin = asTensorFlow()
|
|
|
|
|
val left = origin.output
|
|
|
|
@ -120,7 +126,7 @@ public abstract class TensorFlowAlgebra<T, TT : TType, A: Ring<T>> internal cons
|
|
|
|
|
|
|
|
|
|
private inline fun Tensor<T>.inPlaceOp(
|
|
|
|
|
value: T,
|
|
|
|
|
operation: (left: Operand<TT>, right: Operand<TT>) -> Operand<TT>
|
|
|
|
|
operation: (left: Operand<TT>, right: Operand<TT>) -> Operand<TT>,
|
|
|
|
|
): Unit {
|
|
|
|
|
val origin = asTensorFlow()
|
|
|
|
|
val left = origin.output
|
|
|
|
@ -128,8 +134,8 @@ public abstract class TensorFlowAlgebra<T, TT : TType, A: Ring<T>> internal cons
|
|
|
|
|
origin.output = operation(left, right).asOutput()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private inline fun unOp(value: StructureND<T>, operation: (Operand<TT>) -> Operand<TT>): TensorFlowOutput<T, TT> =
|
|
|
|
|
operation(value.asTensorFlow().output).asOutput().wrap()
|
|
|
|
|
private inline fun StructureND<T>.unOp(operation: (Operand<TT>) -> Operand<TT>): TensorFlowOutput<T, TT> =
|
|
|
|
|
operation(asTensorFlow().output).asOutput().wrap()
|
|
|
|
|
|
|
|
|
|
override fun T.plus(arg: StructureND<T>): TensorFlowOutput<T, TT> = biOp(arg, ops.math::add)
|
|
|
|
|
|
|
|
|
@ -149,67 +155,79 @@ public abstract class TensorFlowAlgebra<T, TT : TType, A: Ring<T>> internal cons
|
|
|
|
|
|
|
|
|
|
override fun Tensor<T>.minusAssign(value: T): Unit = inPlaceOp(value, ops.math::sub)
|
|
|
|
|
|
|
|
|
|
override fun Tensor<T>.minusAssign(other: StructureND<T>): Unit = inPlaceOp(other, ops.math::sub)
|
|
|
|
|
override fun Tensor<T>.minusAssign(arg: StructureND<T>): Unit = inPlaceOp(arg, ops.math::sub)
|
|
|
|
|
|
|
|
|
|
override fun T.times(arg: StructureND<T>): TensorFlowOutput<T, TT> = biOp(arg, ops.math::mul)
|
|
|
|
|
|
|
|
|
|
override fun StructureND<T>.times(arg: T): TensorFlowOutput<T, TT> = biOp(arg, ops.math::mul)
|
|
|
|
|
|
|
|
|
|
override fun StructureND<T>.times(other: StructureND<T>): TensorFlowOutput<T, TT> = biOp(other, ops.math::mul)
|
|
|
|
|
override fun StructureND<T>.times(arg: StructureND<T>): TensorFlowOutput<T, TT> = biOp(arg, ops.math::mul)
|
|
|
|
|
|
|
|
|
|
override fun Tensor<T>.timesAssign(value: T): Unit = inPlaceOp(value, ops.math::mul)
|
|
|
|
|
|
|
|
|
|
override fun Tensor<T>.timesAssign(arg: StructureND<T>): Unit = inPlaceOp(arg, ops.math::mul)
|
|
|
|
|
|
|
|
|
|
override fun StructureND<T>.unaryMinus(): TensorFlowOutput<T, TT> = unOp(this, ops.math::neg)
|
|
|
|
|
override fun StructureND<T>.unaryMinus(): TensorFlowOutput<T, TT> = unOp(ops.math::neg)
|
|
|
|
|
|
|
|
|
|
override fun StructureND<T>.get(i: Int): Tensor<T> {
|
|
|
|
|
override fun Tensor<T>.get(i: Int): Tensor<T> = unOp {
|
|
|
|
|
TODO("Not yet implemented")
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
override fun StructureND<T>.transpose(i: Int, j: Int): Tensor<T> = unOp(this) {
|
|
|
|
|
override fun Tensor<T>.transpose(i: Int, j: Int): Tensor<T> = unOp {
|
|
|
|
|
ops.linalg.transpose(it, ops.constant(intArrayOf(i, j)))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
override fun Tensor<T>.view(shape: IntArray): Tensor<T> {
|
|
|
|
|
TODO("Not yet implemented")
|
|
|
|
|
override fun Tensor<T>.view(shape: IntArray): Tensor<T> = unOp {
|
|
|
|
|
ops.reshape(it, ops.constant(shape))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
override fun Tensor<T>.viewAs(other: StructureND<T>): Tensor<T> {
|
|
|
|
|
TODO("Not yet implemented")
|
|
|
|
|
override fun Tensor<T>.viewAs(other: StructureND<T>): Tensor<T> = biOp(other) { l, r ->
|
|
|
|
|
ops.reshape(l, ops.shape(r))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
override fun StructureND<T>.dot(other: StructureND<T>): TensorFlowOutput<T, TT> = biOp(other) { l, r ->
|
|
|
|
|
ops.linalg.matMul(l, r)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
override fun diagonalEmbedding(diagonalEntries: Tensor<T>, offset: Int, dim1: Int, dim2: Int): Tensor<T> = ops.run {
|
|
|
|
|
TODO("Not yet implemented")
|
|
|
|
|
override fun diagonalEmbedding(
|
|
|
|
|
diagonalEntries: Tensor<T>,
|
|
|
|
|
offset: Int,
|
|
|
|
|
dim1: Int,
|
|
|
|
|
dim2: Int,
|
|
|
|
|
): TensorFlowOutput<T, TT> = diagonalEntries.unOp {
|
|
|
|
|
TODO()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
override fun StructureND<T>.sum(): T = TODO("Not yet implemented")
|
|
|
|
|
override fun StructureND<T>.sum(): T = unOp {
|
|
|
|
|
ops.sum(it, ops.constant(intArrayOf()))
|
|
|
|
|
}.value()
|
|
|
|
|
|
|
|
|
|
override fun StructureND<T>.sum(dim: Int, keepDim: Boolean): Tensor<T> {
|
|
|
|
|
TODO("Not yet implemented")
|
|
|
|
|
override fun StructureND<T>.sum(dim: Int, keepDim: Boolean): TensorFlowOutput<T, TT> = unOp {
|
|
|
|
|
ops.sum(it, ops.constant(dim), Sum.keepDims(keepDim))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
override fun StructureND<T>.min(): T {
|
|
|
|
|
TODO("Not yet implemented")
|
|
|
|
|
override fun StructureND<T>.min(): T = unOp {
|
|
|
|
|
ops.min(it, ops.constant(intArrayOf()))
|
|
|
|
|
}.value()
|
|
|
|
|
|
|
|
|
|
override fun StructureND<T>.min(dim: Int, keepDim: Boolean): Tensor<T> = unOp {
|
|
|
|
|
ops.min(it, ops.constant(dim), Min.keepDims(keepDim))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
override fun StructureND<T>.min(dim: Int, keepDim: Boolean): Tensor<T> {
|
|
|
|
|
TODO("Not yet implemented")
|
|
|
|
|
override fun StructureND<T>.max(): T = unOp {
|
|
|
|
|
ops.max(it, ops.constant(intArrayOf()))
|
|
|
|
|
}.value()
|
|
|
|
|
|
|
|
|
|
override fun StructureND<T>.max(dim: Int, keepDim: Boolean): Tensor<T> = unOp {
|
|
|
|
|
ops.max(it, ops.constant(dim), Max.keepDims(keepDim))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
override fun StructureND<T>.max(): T {
|
|
|
|
|
TODO("Not yet implemented")
|
|
|
|
|
}
|
|
|
|
|
override fun StructureND<T>.argMax(dim: Int, keepDim: Boolean): Tensor<Int> = IntTensorFlowOutput(
|
|
|
|
|
graph,
|
|
|
|
|
ops.math.argMax(asTensorFlow().output, ops.constant(dim), TInt32::class.java).output()
|
|
|
|
|
).actualTensor
|
|
|
|
|
|
|
|
|
|
override fun StructureND<T>.max(dim: Int, keepDim: Boolean): Tensor<T> {
|
|
|
|
|
TODO("Not yet implemented")
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
override fun StructureND<T>.argMax(dim: Int, keepDim: Boolean): Tensor<T> {
|
|
|
|
|
TODO("Not yet implemented")
|
|
|
|
|
}
|
|
|
|
|
@OptIn(UnstableKMathAPI::class)
|
|
|
|
|
override fun export(arg: StructureND<T>): StructureND<T> =
|
|
|
|
|
if (arg is TensorFlowOutput<T, *>) arg.actualTensor else arg
|
|
|
|
|
}
|