[WIP] TensorFlow refactoring

This commit is contained in:
Alexander Nozik 2021-10-28 10:52:40 +03:00
parent 7bdc54c818
commit 64629561af
2 changed files with 58 additions and 46 deletions

View File

@ -3,11 +3,13 @@ package space.kscience.kmath.tensorflow
import org.tensorflow.Graph import org.tensorflow.Graph
import org.tensorflow.Output import org.tensorflow.Output
import org.tensorflow.ndarray.NdArray import org.tensorflow.ndarray.NdArray
import org.tensorflow.ndarray.Shape
import org.tensorflow.op.core.Constant import org.tensorflow.op.core.Constant
import org.tensorflow.types.TFloat64 import org.tensorflow.types.TFloat64
import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.misc.PerformancePitfall
import space.kscience.kmath.tensors.api.Tensor import space.kscience.kmath.nd.DefaultStrides
import space.kscience.kmath.nd.Shape
import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.operations.DoubleField
public class DoubleTensorFlowOutput( public class DoubleTensorFlowOutput(
graph: Graph, graph: Graph,
@ -18,14 +20,28 @@ public class DoubleTensorFlowOutput(
public class DoubleTensorFlowAlgebra internal constructor( public class DoubleTensorFlowAlgebra internal constructor(
graph: Graph graph: Graph
) : TensorFlowAlgebra<Double, TFloat64>(graph) { ) : TensorFlowAlgebra<Double, TFloat64, DoubleField>(graph) {
override fun Tensor<Double>.asTensorFlow(): TensorFlowOutput<Double, TFloat64> = override val elementAlgebra: DoubleField get() = DoubleField
override fun structureND(
shape: Shape,
initializer: DoubleField.(IntArray) -> Double
): StructureND<Double> {
val res = TFloat64.tensorOf(org.tensorflow.ndarray.Shape.of(*shape.toLongArray())) { array ->
DefaultStrides(shape).forEach { index ->
array.setDouble(elementAlgebra.initializer(index), *index.toLongArray())
}
}
return DoubleTensorFlowOutput(graph, ops.constant(res).asOutput())
}
override fun StructureND<Double>.asTensorFlow(): TensorFlowOutput<Double, TFloat64> =
if (this is TensorFlowOutput<Double, *> && output.type() == TFloat64::class.java) { if (this is TensorFlowOutput<Double, *> && output.type() == TFloat64::class.java) {
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
this as TensorFlowOutput<Double, TFloat64> this as TensorFlowOutput<Double, TFloat64>
} else { } else {
val res = TFloat64.tensorOf(Shape.of(*shape.toLongArray())) { array -> val res = TFloat64.tensorOf(org.tensorflow.ndarray.Shape.of(*shape.toLongArray())) { array ->
@OptIn(PerformancePitfall::class) @OptIn(PerformancePitfall::class)
elements().forEach { (index, value) -> elements().forEach { (index, value) ->
array.setDouble(value, *index.toLongArray()) array.setDouble(value, *index.toLongArray())

View File

@ -11,6 +11,8 @@ import org.tensorflow.op.core.Constant
import org.tensorflow.types.family.TType import org.tensorflow.types.family.TType
import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.misc.PerformancePitfall
import space.kscience.kmath.nd.Shape import space.kscience.kmath.nd.Shape
import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.operations.Ring
import space.kscience.kmath.tensors.api.Tensor import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.api.TensorAlgebra import space.kscience.kmath.tensors.api.TensorAlgebra
@ -28,13 +30,7 @@ public value class TensorFlowArray<T>(public val tensor: NdArray<T>) : Tensor<T>
override fun get(index: IntArray): T = tensor.getObject(*index.toLongArray()) override fun get(index: IntArray): T = tensor.getObject(*index.toLongArray())
@PerformancePitfall //TODO implement native element sequence
override fun elements(): Sequence<Pair<IntArray, T>> = sequence {
tensor.scalars().forEachIndexed { index: LongArray, ndArray: NdArray<T> ->
//yield(index.toIntArray() to ndArray.scalar)
TODO()
}
}
override fun set(index: IntArray, value: T) { override fun set(index: IntArray, value: T) {
tensor.setObject(value, *index.toLongArray()) tensor.setObject(value, *index.toLongArray())
@ -70,23 +66,23 @@ public abstract class TensorFlowOutput<T, TT : TType>(
} }
public abstract class TensorFlowAlgebra<T, TT : TType> internal constructor( public abstract class TensorFlowAlgebra<T, TT : TType, A: Ring<T>> internal constructor(
protected val graph: Graph protected val graph: Graph
) : TensorAlgebra<T> { ) : TensorAlgebra<T,A> {
protected val ops: Ops by lazy { Ops.create(graph) } protected val ops: Ops by lazy { Ops.create(graph) }
protected abstract fun Tensor<T>.asTensorFlow(): TensorFlowOutput<T, TT> protected abstract fun StructureND<T>.asTensorFlow(): TensorFlowOutput<T, TT>
protected abstract fun Output<TT>.wrap(): TensorFlowOutput<T, TT> protected abstract fun Output<TT>.wrap(): TensorFlowOutput<T, TT>
protected abstract fun const(value: T): Constant<TT> protected abstract fun const(value: T): Constant<TT>
override fun Tensor<T>.valueOrNull(): T? = if (shape contentEquals intArrayOf(1)) override fun StructureND<T>.valueOrNull(): T? = if (shape contentEquals intArrayOf(1))
get(Shape(0)) else null get(Shape(0)) else null
private inline fun Tensor<T>.biOp( private inline fun StructureND<T>.biOp(
other: Tensor<T>, other: StructureND<T>,
operation: (left: Operand<TT>, right: Operand<TT>) -> Operand<TT> operation: (left: Operand<TT>, right: Operand<TT>) -> Operand<TT>
): TensorFlowOutput<T, TT> { ): TensorFlowOutput<T, TT> {
val left = asTensorFlow().output val left = asTensorFlow().output
@ -95,7 +91,7 @@ public abstract class TensorFlowAlgebra<T, TT : TType> internal constructor(
} }
private inline fun T.biOp( private inline fun T.biOp(
other: Tensor<T>, other: StructureND<T>,
operation: (left: Operand<TT>, right: Operand<TT>) -> Operand<TT> operation: (left: Operand<TT>, right: Operand<TT>) -> Operand<TT>
): TensorFlowOutput<T, TT> { ): TensorFlowOutput<T, TT> {
val left = const(this) val left = const(this)
@ -103,7 +99,7 @@ public abstract class TensorFlowAlgebra<T, TT : TType> internal constructor(
return operation(left, right).asOutput().wrap() return operation(left, right).asOutput().wrap()
} }
private inline fun Tensor<T>.biOp( private inline fun StructureND<T>.biOp(
value: T, value: T,
operation: (left: Operand<TT>, right: Operand<TT>) -> Operand<TT> operation: (left: Operand<TT>, right: Operand<TT>) -> Operand<TT>
): TensorFlowOutput<T, TT> { ): TensorFlowOutput<T, TT> {
@ -113,7 +109,7 @@ public abstract class TensorFlowAlgebra<T, TT : TType> internal constructor(
} }
private inline fun Tensor<T>.inPlaceOp( private inline fun Tensor<T>.inPlaceOp(
other: Tensor<T>, other: StructureND<T>,
operation: (left: Operand<TT>, right: Operand<TT>) -> Operand<TT> operation: (left: Operand<TT>, right: Operand<TT>) -> Operand<TT>
): Unit { ): Unit {
val origin = asTensorFlow() val origin = asTensorFlow()
@ -132,46 +128,46 @@ public abstract class TensorFlowAlgebra<T, TT : TType> internal constructor(
origin.output = operation(left, right).asOutput() origin.output = operation(left, right).asOutput()
} }
private inline fun unOp(value: Tensor<T>, operation: (Operand<TT>) -> Operand<TT>): TensorFlowOutput<T, TT> = private inline fun unOp(value: StructureND<T>, operation: (Operand<TT>) -> Operand<TT>): TensorFlowOutput<T, TT> =
operation(value.asTensorFlow().output).asOutput().wrap() operation(value.asTensorFlow().output).asOutput().wrap()
override fun T.plus(other: Tensor<T>): TensorFlowOutput<T, TT> = biOp(other, ops.math::add) override fun T.plus(arg: StructureND<T>): TensorFlowOutput<T, TT> = biOp(arg, ops.math::add)
override fun Tensor<T>.plus(value: T): TensorFlowOutput<T, TT> = biOp(value, ops.math::add) override fun StructureND<T>.plus(arg: T): TensorFlowOutput<T, TT> = biOp(arg, ops.math::add)
override fun Tensor<T>.plus(other: Tensor<T>): TensorFlowOutput<T, TT> = biOp(other, ops.math::add) override fun StructureND<T>.plus(arg: StructureND<T>): TensorFlowOutput<T, TT> = biOp(arg, ops.math::add)
override fun Tensor<T>.plusAssign(value: T): Unit = inPlaceOp(value, ops.math::add) override fun Tensor<T>.plusAssign(value: T): Unit = inPlaceOp(value, ops.math::add)
override fun Tensor<T>.plusAssign(other: Tensor<T>): Unit = inPlaceOp(other, ops.math::add) override fun Tensor<T>.plusAssign(arg: StructureND<T>): Unit = inPlaceOp(arg, ops.math::add)
override fun Tensor<T>.minus(value: T): TensorFlowOutput<T, TT> = biOp(value, ops.math::sub) override fun StructureND<T>.minus(arg: T): TensorFlowOutput<T, TT> = biOp(arg, ops.math::sub)
override fun Tensor<T>.minus(other: Tensor<T>): TensorFlowOutput<T, TT> = biOp(other, ops.math::sub) override fun StructureND<T>.minus(arg: StructureND<T>): TensorFlowOutput<T, TT> = biOp(arg, ops.math::sub)
override fun T.minus(other: Tensor<T>): Tensor<T> = biOp(other, ops.math::sub) override fun T.minus(arg: StructureND<T>): Tensor<T> = biOp(arg, ops.math::sub)
override fun Tensor<T>.minusAssign(value: T): Unit = inPlaceOp(value, ops.math::sub) override fun Tensor<T>.minusAssign(value: T): Unit = inPlaceOp(value, ops.math::sub)
override fun Tensor<T>.minusAssign(other: Tensor<T>): Unit = inPlaceOp(other, ops.math::sub) override fun Tensor<T>.minusAssign(other: StructureND<T>): Unit = inPlaceOp(other, ops.math::sub)
override fun T.times(other: Tensor<T>): TensorFlowOutput<T, TT> = biOp(other, ops.math::mul) override fun T.times(arg: StructureND<T>): TensorFlowOutput<T, TT> = biOp(arg, ops.math::mul)
override fun Tensor<T>.times(value: T): TensorFlowOutput<T, TT> = biOp(value, ops.math::mul) override fun StructureND<T>.times(arg: T): TensorFlowOutput<T, TT> = biOp(arg, ops.math::mul)
override fun Tensor<T>.times(other: Tensor<T>): TensorFlowOutput<T, TT> = biOp(other, ops.math::mul) override fun StructureND<T>.times(other: StructureND<T>): TensorFlowOutput<T, TT> = biOp(other, ops.math::mul)
override fun Tensor<T>.timesAssign(value: T): Unit = inPlaceOp(value, ops.math::mul) override fun Tensor<T>.timesAssign(value: T): Unit = inPlaceOp(value, ops.math::mul)
override fun Tensor<T>.timesAssign(other: Tensor<T>): Unit = inPlaceOp(other, ops.math::mul) override fun Tensor<T>.timesAssign(arg: StructureND<T>): Unit = inPlaceOp(arg, ops.math::mul)
override fun Tensor<T>.unaryMinus(): TensorFlowOutput<T, TT> = unOp(this, ops.math::neg) override fun StructureND<T>.unaryMinus(): TensorFlowOutput<T, TT> = unOp(this, ops.math::neg)
override fun Tensor<T>.get(i: Int): Tensor<T> { override fun StructureND<T>.get(i: Int): Tensor<T> {
TODO("Not yet implemented") TODO("Not yet implemented")
} }
override fun Tensor<T>.transpose(i: Int, j: Int): Tensor<T> = unOp(this) { override fun StructureND<T>.transpose(i: Int, j: Int): Tensor<T> = unOp(this) {
ops.linalg.transpose(it, ops.constant(intArrayOf(i, j))) ops.linalg.transpose(it, ops.constant(intArrayOf(i, j)))
} }
@ -179,11 +175,11 @@ public abstract class TensorFlowAlgebra<T, TT : TType> internal constructor(
TODO("Not yet implemented") TODO("Not yet implemented")
} }
override fun Tensor<T>.viewAs(other: Tensor<T>): Tensor<T> { override fun Tensor<T>.viewAs(other: StructureND<T>): Tensor<T> {
TODO("Not yet implemented") TODO("Not yet implemented")
} }
override fun Tensor<T>.dot(other: Tensor<T>): TensorFlowOutput<T, TT> = biOp(other) { l, r -> override fun StructureND<T>.dot(other: StructureND<T>): TensorFlowOutput<T, TT> = biOp(other) { l, r ->
ops.linalg.matMul(l, r) ops.linalg.matMul(l, r)
} }
@ -191,29 +187,29 @@ public abstract class TensorFlowAlgebra<T, TT : TType> internal constructor(
TODO("Not yet implemented") TODO("Not yet implemented")
} }
override fun Tensor<T>.sum(): T = TODO("Not yet implemented") override fun StructureND<T>.sum(): T = TODO("Not yet implemented")
override fun Tensor<T>.sum(dim: Int, keepDim: Boolean): Tensor<T> { override fun StructureND<T>.sum(dim: Int, keepDim: Boolean): Tensor<T> {
TODO("Not yet implemented") TODO("Not yet implemented")
} }
override fun Tensor<T>.min(): T { override fun StructureND<T>.min(): T {
TODO("Not yet implemented") TODO("Not yet implemented")
} }
override fun Tensor<T>.min(dim: Int, keepDim: Boolean): Tensor<T> { override fun StructureND<T>.min(dim: Int, keepDim: Boolean): Tensor<T> {
TODO("Not yet implemented") TODO("Not yet implemented")
} }
override fun Tensor<T>.max(): T { override fun StructureND<T>.max(): T {
TODO("Not yet implemented") TODO("Not yet implemented")
} }
override fun Tensor<T>.max(dim: Int, keepDim: Boolean): Tensor<T> { override fun StructureND<T>.max(dim: Int, keepDim: Boolean): Tensor<T> {
TODO("Not yet implemented") TODO("Not yet implemented")
} }
override fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): Tensor<T> { override fun StructureND<T>.argMax(dim: Int, keepDim: Boolean): Tensor<T> {
TODO("Not yet implemented") TODO("Not yet implemented")
} }
} }