Fixed missing TF basic operations

This commit is contained in:
Alexander Nozik 2022-04-11 14:56:48 +03:00
parent b509dc917d
commit 7e4d223044
No known key found for this signature in database
GPG Key ID: F7FCF2DD25C71357
2 changed files with 22 additions and 18 deletions

View File

@ -6,11 +6,9 @@ import org.tensorflow.Operand
import org.tensorflow.Output
import org.tensorflow.Session
import org.tensorflow.ndarray.NdArray
import org.tensorflow.ndarray.index.Indices
import org.tensorflow.op.Ops
import org.tensorflow.op.core.Constant
import org.tensorflow.op.core.Max
import org.tensorflow.op.core.Min
import org.tensorflow.op.core.Sum
import org.tensorflow.op.core.*
import org.tensorflow.types.TInt32
import org.tensorflow.types.family.TNumber
import org.tensorflow.types.family.TType
@ -182,7 +180,7 @@ public abstract class TensorFlowAlgebra<T, TT : TNumber, A : Ring<T>> internal c
override fun StructureND<T>.unaryMinus(): TensorFlowOutput<T, TT> = operate(ops.math::neg)
override fun Tensor<T>.get(i: Int): Tensor<T> = operate {
TODO("Not yet implemented")
StridedSliceHelper.stridedSlice(ops.scope(), it, Indices.at(i.toLong()))
}
override fun Tensor<T>.transpose(i: Int, j: Int): Tensor<T> = operate {
@ -210,7 +208,13 @@ public abstract class TensorFlowAlgebra<T, TT : TNumber, A : Ring<T>> internal c
dim1: Int,
dim2: Int,
): TensorFlowOutput<T, TT> = diagonalEntries.operate {
TODO("Not yet implemented")
ops.linalg.matrixDiagV3(
/* diagonal = */ it,
/* k = */ ops.constant(offset),
/* numRows = */ ops.constant(dim1),
/* numCols = */ ops.constant(dim2),
/* paddingValue = */ const(elementAlgebra.zero)
)
}
override fun StructureND<T>.sum(): T = operate {

View File

@ -123,27 +123,27 @@ public interface AnalyticTensorAlgebra<T, A : Field<T>> :
//For information: https://pytorch.org/docs/stable/generated/torch.floor.html#torch.floor
public fun StructureND<T>.floor(): Tensor<T>
override fun sin(arg: StructureND<T>): StructureND<T> = arg.sin()
override fun sin(arg: StructureND<T>): Tensor<T> = arg.sin()
override fun cos(arg: StructureND<T>): StructureND<T> = arg.cos()
override fun cos(arg: StructureND<T>): Tensor<T> = arg.cos()
override fun asin(arg: StructureND<T>): StructureND<T> = arg.asin()
override fun asin(arg: StructureND<T>): Tensor<T> = arg.asin()
override fun acos(arg: StructureND<T>): StructureND<T> = arg.acos()
override fun acos(arg: StructureND<T>): Tensor<T> = arg.acos()
override fun atan(arg: StructureND<T>): StructureND<T> = arg.atan()
override fun atan(arg: StructureND<T>): Tensor<T> = arg.atan()
override fun exp(arg: StructureND<T>): StructureND<T> = arg.exp()
override fun exp(arg: StructureND<T>): Tensor<T> = arg.exp()
override fun ln(arg: StructureND<T>): StructureND<T> = arg.ln()
override fun ln(arg: StructureND<T>): Tensor<T> = arg.ln()
override fun sinh(arg: StructureND<T>): StructureND<T> = arg.sinh()
override fun sinh(arg: StructureND<T>): Tensor<T> = arg.sinh()
override fun cosh(arg: StructureND<T>): StructureND<T> = arg.cosh()
override fun cosh(arg: StructureND<T>): Tensor<T> = arg.cosh()
override fun asinh(arg: StructureND<T>): StructureND<T> = arg.asinh()
override fun asinh(arg: StructureND<T>): Tensor<T> = arg.asinh()
override fun acosh(arg: StructureND<T>): StructureND<T> = arg.acosh()
override fun acosh(arg: StructureND<T>): Tensor<T> = arg.acosh()
override fun atanh(arg: StructureND<T>): StructureND<T> = arg.atanh()
override fun atanh(arg: StructureND<T>): Tensor<T> = arg.atanh()
}