From 64629561afed6465ef74bc8f51b892c94d17adac Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Thu, 28 Oct 2021 10:52:40 +0300 Subject: [PATCH] [WIP] TensorFlow refactoring --- .../tensorflow/DoubleTensorFlowAlgebra.kt | 26 +++++-- .../kmath/tensorflow/TensorFlowAlgebra.kt | 78 +++++++++---------- 2 files changed, 58 insertions(+), 46 deletions(-) diff --git a/kmath-tensorflow/src/main/kotlin/space/kscience/kmath/tensorflow/DoubleTensorFlowAlgebra.kt b/kmath-tensorflow/src/main/kotlin/space/kscience/kmath/tensorflow/DoubleTensorFlowAlgebra.kt index 864205e17..44420cd76 100644 --- a/kmath-tensorflow/src/main/kotlin/space/kscience/kmath/tensorflow/DoubleTensorFlowAlgebra.kt +++ b/kmath-tensorflow/src/main/kotlin/space/kscience/kmath/tensorflow/DoubleTensorFlowAlgebra.kt @@ -3,11 +3,13 @@ package space.kscience.kmath.tensorflow import org.tensorflow.Graph import org.tensorflow.Output import org.tensorflow.ndarray.NdArray -import org.tensorflow.ndarray.Shape import org.tensorflow.op.core.Constant import org.tensorflow.types.TFloat64 import space.kscience.kmath.misc.PerformancePitfall -import space.kscience.kmath.tensors.api.Tensor +import space.kscience.kmath.nd.DefaultStrides +import space.kscience.kmath.nd.Shape +import space.kscience.kmath.nd.StructureND +import space.kscience.kmath.operations.DoubleField public class DoubleTensorFlowOutput( graph: Graph, @@ -18,14 +20,28 @@ public class DoubleTensorFlowOutput( public class DoubleTensorFlowAlgebra internal constructor( graph: Graph -) : TensorFlowAlgebra(graph) { +) : TensorFlowAlgebra(graph) { - override fun Tensor.asTensorFlow(): TensorFlowOutput = + override val elementAlgebra: DoubleField get() = DoubleField + + override fun structureND( + shape: Shape, + initializer: DoubleField.(IntArray) -> Double + ): StructureND { + val res = TFloat64.tensorOf(org.tensorflow.ndarray.Shape.of(*shape.toLongArray())) { array -> + DefaultStrides(shape).forEach { index -> + array.setDouble(elementAlgebra.initializer(index), *index.toLongArray()) + } + } + return DoubleTensorFlowOutput(graph, ops.constant(res).asOutput()) + } + + override fun StructureND.asTensorFlow(): TensorFlowOutput = if (this is TensorFlowOutput && output.type() == TFloat64::class.java) { @Suppress("UNCHECKED_CAST") this as TensorFlowOutput } else { - val res = TFloat64.tensorOf(Shape.of(*shape.toLongArray())) { array -> + val res = TFloat64.tensorOf(org.tensorflow.ndarray.Shape.of(*shape.toLongArray())) { array -> @OptIn(PerformancePitfall::class) elements().forEach { (index, value) -> array.setDouble(value, *index.toLongArray()) diff --git a/kmath-tensorflow/src/main/kotlin/space/kscience/kmath/tensorflow/TensorFlowAlgebra.kt b/kmath-tensorflow/src/main/kotlin/space/kscience/kmath/tensorflow/TensorFlowAlgebra.kt index 12c1211db..e73620d01 100644 --- a/kmath-tensorflow/src/main/kotlin/space/kscience/kmath/tensorflow/TensorFlowAlgebra.kt +++ b/kmath-tensorflow/src/main/kotlin/space/kscience/kmath/tensorflow/TensorFlowAlgebra.kt @@ -11,6 +11,8 @@ import org.tensorflow.op.core.Constant import org.tensorflow.types.family.TType import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.nd.Shape +import space.kscience.kmath.nd.StructureND +import space.kscience.kmath.operations.Ring import space.kscience.kmath.tensors.api.Tensor import space.kscience.kmath.tensors.api.TensorAlgebra @@ -28,13 +30,7 @@ public value class TensorFlowArray(public val tensor: NdArray) : Tensor override fun get(index: IntArray): T = tensor.getObject(*index.toLongArray()) - @PerformancePitfall - override fun elements(): Sequence> = sequence { - tensor.scalars().forEachIndexed { index: LongArray, ndArray: NdArray -> - //yield(index.toIntArray() to ndArray.scalar) - TODO() - } - } + //TODO implement native element sequence override fun set(index: IntArray, value: T) { tensor.setObject(value, *index.toLongArray()) @@ -70,23 +66,23 @@ public abstract class TensorFlowOutput( } -public abstract class TensorFlowAlgebra internal constructor( +public abstract class TensorFlowAlgebra> internal constructor( protected val graph: Graph -) : TensorAlgebra { +) : TensorAlgebra { protected val ops: Ops by lazy { Ops.create(graph) } - protected abstract fun Tensor.asTensorFlow(): TensorFlowOutput + protected abstract fun StructureND.asTensorFlow(): TensorFlowOutput protected abstract fun Output.wrap(): TensorFlowOutput protected abstract fun const(value: T): Constant - override fun Tensor.valueOrNull(): T? = if (shape contentEquals intArrayOf(1)) + override fun StructureND.valueOrNull(): T? = if (shape contentEquals intArrayOf(1)) get(Shape(0)) else null - private inline fun Tensor.biOp( - other: Tensor, + private inline fun StructureND.biOp( + other: StructureND, operation: (left: Operand, right: Operand) -> Operand ): TensorFlowOutput { val left = asTensorFlow().output @@ -95,7 +91,7 @@ public abstract class TensorFlowAlgebra internal constructor( } private inline fun T.biOp( - other: Tensor, + other: StructureND, operation: (left: Operand, right: Operand) -> Operand ): TensorFlowOutput { val left = const(this) @@ -103,7 +99,7 @@ public abstract class TensorFlowAlgebra internal constructor( return operation(left, right).asOutput().wrap() } - private inline fun Tensor.biOp( + private inline fun StructureND.biOp( value: T, operation: (left: Operand, right: Operand) -> Operand ): TensorFlowOutput { @@ -113,7 +109,7 @@ public abstract class TensorFlowAlgebra internal constructor( } private inline fun Tensor.inPlaceOp( - other: Tensor, + other: StructureND, operation: (left: Operand, right: Operand) -> Operand ): Unit { val origin = asTensorFlow() @@ -132,46 +128,46 @@ public abstract class TensorFlowAlgebra internal constructor( origin.output = operation(left, right).asOutput() } - private inline fun unOp(value: Tensor, operation: (Operand) -> Operand): TensorFlowOutput = + private inline fun unOp(value: StructureND, operation: (Operand) -> Operand): TensorFlowOutput = operation(value.asTensorFlow().output).asOutput().wrap() - override fun T.plus(other: Tensor): TensorFlowOutput = biOp(other, ops.math::add) + override fun T.plus(arg: StructureND): TensorFlowOutput = biOp(arg, ops.math::add) - override fun Tensor.plus(value: T): TensorFlowOutput = biOp(value, ops.math::add) + override fun StructureND.plus(arg: T): TensorFlowOutput = biOp(arg, ops.math::add) - override fun Tensor.plus(other: Tensor): TensorFlowOutput = biOp(other, ops.math::add) + override fun StructureND.plus(arg: StructureND): TensorFlowOutput = biOp(arg, ops.math::add) override fun Tensor.plusAssign(value: T): Unit = inPlaceOp(value, ops.math::add) - override fun Tensor.plusAssign(other: Tensor): Unit = inPlaceOp(other, ops.math::add) + override fun Tensor.plusAssign(arg: StructureND): Unit = inPlaceOp(arg, ops.math::add) - override fun Tensor.minus(value: T): TensorFlowOutput = biOp(value, ops.math::sub) + override fun StructureND.minus(arg: T): TensorFlowOutput = biOp(arg, ops.math::sub) - override fun Tensor.minus(other: Tensor): TensorFlowOutput = biOp(other, ops.math::sub) + override fun StructureND.minus(arg: StructureND): TensorFlowOutput = biOp(arg, ops.math::sub) - override fun T.minus(other: Tensor): Tensor = biOp(other, ops.math::sub) + override fun T.minus(arg: StructureND): Tensor = biOp(arg, ops.math::sub) override fun Tensor.minusAssign(value: T): Unit = inPlaceOp(value, ops.math::sub) - override fun Tensor.minusAssign(other: Tensor): Unit = inPlaceOp(other, ops.math::sub) + override fun Tensor.minusAssign(other: StructureND): Unit = inPlaceOp(other, ops.math::sub) - override fun T.times(other: Tensor): TensorFlowOutput = biOp(other, ops.math::mul) + override fun T.times(arg: StructureND): TensorFlowOutput = biOp(arg, ops.math::mul) - override fun Tensor.times(value: T): TensorFlowOutput = biOp(value, ops.math::mul) + override fun StructureND.times(arg: T): TensorFlowOutput = biOp(arg, ops.math::mul) - override fun Tensor.times(other: Tensor): TensorFlowOutput = biOp(other, ops.math::mul) + override fun StructureND.times(other: StructureND): TensorFlowOutput = biOp(other, ops.math::mul) override fun Tensor.timesAssign(value: T): Unit = inPlaceOp(value, ops.math::mul) - override fun Tensor.timesAssign(other: Tensor): Unit = inPlaceOp(other, ops.math::mul) + override fun Tensor.timesAssign(arg: StructureND): Unit = inPlaceOp(arg, ops.math::mul) - override fun Tensor.unaryMinus(): TensorFlowOutput = unOp(this, ops.math::neg) + override fun StructureND.unaryMinus(): TensorFlowOutput = unOp(this, ops.math::neg) - override fun Tensor.get(i: Int): Tensor { + override fun StructureND.get(i: Int): Tensor { TODO("Not yet implemented") } - override fun Tensor.transpose(i: Int, j: Int): Tensor = unOp(this) { + override fun StructureND.transpose(i: Int, j: Int): Tensor = unOp(this) { ops.linalg.transpose(it, ops.constant(intArrayOf(i, j))) } @@ -179,11 +175,11 @@ public abstract class TensorFlowAlgebra internal constructor( TODO("Not yet implemented") } - override fun Tensor.viewAs(other: Tensor): Tensor { + override fun Tensor.viewAs(other: StructureND): Tensor { TODO("Not yet implemented") } - override fun Tensor.dot(other: Tensor): TensorFlowOutput = biOp(other) { l, r -> + override fun StructureND.dot(other: StructureND): TensorFlowOutput = biOp(other) { l, r -> ops.linalg.matMul(l, r) } @@ -191,29 +187,29 @@ public abstract class TensorFlowAlgebra internal constructor( TODO("Not yet implemented") } - override fun Tensor.sum(): T = TODO("Not yet implemented") + override fun StructureND.sum(): T = TODO("Not yet implemented") - override fun Tensor.sum(dim: Int, keepDim: Boolean): Tensor { + override fun StructureND.sum(dim: Int, keepDim: Boolean): Tensor { TODO("Not yet implemented") } - override fun Tensor.min(): T { + override fun StructureND.min(): T { TODO("Not yet implemented") } - override fun Tensor.min(dim: Int, keepDim: Boolean): Tensor { + override fun StructureND.min(dim: Int, keepDim: Boolean): Tensor { TODO("Not yet implemented") } - override fun Tensor.max(): T { + override fun StructureND.max(): T { TODO("Not yet implemented") } - override fun Tensor.max(dim: Int, keepDim: Boolean): Tensor { + override fun StructureND.max(dim: Int, keepDim: Boolean): Tensor { TODO("Not yet implemented") } - override fun Tensor.argMax(dim: Int, keepDim: Boolean): Tensor { + override fun StructureND.argMax(dim: Int, keepDim: Boolean): Tensor { TODO("Not yet implemented") } } \ No newline at end of file