From 7e4d2230444ff8902ef50c667bbeb75c7b1eb900 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Mon, 11 Apr 2022 14:56:48 +0300 Subject: [PATCH] Fixed missing TF basic operations --- .../kmath/tensorflow/TensorFlowAlgebra.kt | 16 ++++++++----- .../tensors/api/AnalyticTensorAlgebra.kt | 24 +++++++++---------- 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/kmath-tensorflow/src/main/kotlin/space/kscience/kmath/tensorflow/TensorFlowAlgebra.kt b/kmath-tensorflow/src/main/kotlin/space/kscience/kmath/tensorflow/TensorFlowAlgebra.kt index b40739ee0..7185b84d6 100644 --- a/kmath-tensorflow/src/main/kotlin/space/kscience/kmath/tensorflow/TensorFlowAlgebra.kt +++ b/kmath-tensorflow/src/main/kotlin/space/kscience/kmath/tensorflow/TensorFlowAlgebra.kt @@ -6,11 +6,9 @@ import org.tensorflow.Operand import org.tensorflow.Output import org.tensorflow.Session import org.tensorflow.ndarray.NdArray +import org.tensorflow.ndarray.index.Indices import org.tensorflow.op.Ops -import org.tensorflow.op.core.Constant -import org.tensorflow.op.core.Max -import org.tensorflow.op.core.Min -import org.tensorflow.op.core.Sum +import org.tensorflow.op.core.* import org.tensorflow.types.TInt32 import org.tensorflow.types.family.TNumber import org.tensorflow.types.family.TType @@ -182,7 +180,7 @@ public abstract class TensorFlowAlgebra> internal c override fun StructureND.unaryMinus(): TensorFlowOutput = operate(ops.math::neg) override fun Tensor.get(i: Int): Tensor = operate { - TODO("Not yet implemented") + StridedSliceHelper.stridedSlice(ops.scope(), it, Indices.at(i.toLong())) } override fun Tensor.transpose(i: Int, j: Int): Tensor = operate { @@ -210,7 +208,13 @@ public abstract class TensorFlowAlgebra> internal c dim1: Int, dim2: Int, ): TensorFlowOutput = diagonalEntries.operate { - TODO("Not yet implemented") + ops.linalg.matrixDiagV3( + /* diagonal = */ it, + /* k = */ ops.constant(offset), + /* numRows = */ ops.constant(dim1), + /* numCols = */ ops.constant(dim2), + /* paddingValue = */ const(elementAlgebra.zero) + ) } override fun StructureND.sum(): T = operate { diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/AnalyticTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/AnalyticTensorAlgebra.kt index 3ed34ae5e..b32fcf608 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/AnalyticTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/AnalyticTensorAlgebra.kt @@ -123,27 +123,27 @@ public interface AnalyticTensorAlgebra> : //For information: https://pytorch.org/docs/stable/generated/torch.floor.html#torch.floor public fun StructureND.floor(): Tensor - override fun sin(arg: StructureND): StructureND = arg.sin() + override fun sin(arg: StructureND): Tensor = arg.sin() - override fun cos(arg: StructureND): StructureND = arg.cos() + override fun cos(arg: StructureND): Tensor = arg.cos() - override fun asin(arg: StructureND): StructureND = arg.asin() + override fun asin(arg: StructureND): Tensor = arg.asin() - override fun acos(arg: StructureND): StructureND = arg.acos() + override fun acos(arg: StructureND): Tensor = arg.acos() - override fun atan(arg: StructureND): StructureND = arg.atan() + override fun atan(arg: StructureND): Tensor = arg.atan() - override fun exp(arg: StructureND): StructureND = arg.exp() + override fun exp(arg: StructureND): Tensor = arg.exp() - override fun ln(arg: StructureND): StructureND = arg.ln() + override fun ln(arg: StructureND): Tensor = arg.ln() - override fun sinh(arg: StructureND): StructureND = arg.sinh() + override fun sinh(arg: StructureND): Tensor = arg.sinh() - override fun cosh(arg: StructureND): StructureND = arg.cosh() + override fun cosh(arg: StructureND): Tensor = arg.cosh() - override fun asinh(arg: StructureND): StructureND = arg.asinh() + override fun asinh(arg: StructureND): Tensor = arg.asinh() - override fun acosh(arg: StructureND): StructureND = arg.acosh() + override fun acosh(arg: StructureND): Tensor = arg.acosh() - override fun atanh(arg: StructureND): StructureND = arg.atanh() + override fun atanh(arg: StructureND): Tensor = arg.atanh() } \ No newline at end of file