Fixed missing TF basic operations
This commit is contained in:
parent
b509dc917d
commit
7e4d223044
@ -6,11 +6,9 @@ import org.tensorflow.Operand
|
||||
import org.tensorflow.Output
|
||||
import org.tensorflow.Session
|
||||
import org.tensorflow.ndarray.NdArray
|
||||
import org.tensorflow.ndarray.index.Indices
|
||||
import org.tensorflow.op.Ops
|
||||
import org.tensorflow.op.core.Constant
|
||||
import org.tensorflow.op.core.Max
|
||||
import org.tensorflow.op.core.Min
|
||||
import org.tensorflow.op.core.Sum
|
||||
import org.tensorflow.op.core.*
|
||||
import org.tensorflow.types.TInt32
|
||||
import org.tensorflow.types.family.TNumber
|
||||
import org.tensorflow.types.family.TType
|
||||
@ -182,7 +180,7 @@ public abstract class TensorFlowAlgebra<T, TT : TNumber, A : Ring<T>> internal c
|
||||
override fun StructureND<T>.unaryMinus(): TensorFlowOutput<T, TT> = operate(ops.math::neg)
|
||||
|
||||
override fun Tensor<T>.get(i: Int): Tensor<T> = operate {
|
||||
TODO("Not yet implemented")
|
||||
StridedSliceHelper.stridedSlice(ops.scope(), it, Indices.at(i.toLong()))
|
||||
}
|
||||
|
||||
override fun Tensor<T>.transpose(i: Int, j: Int): Tensor<T> = operate {
|
||||
@ -210,7 +208,13 @@ public abstract class TensorFlowAlgebra<T, TT : TNumber, A : Ring<T>> internal c
|
||||
dim1: Int,
|
||||
dim2: Int,
|
||||
): TensorFlowOutput<T, TT> = diagonalEntries.operate {
|
||||
TODO("Not yet implemented")
|
||||
ops.linalg.matrixDiagV3(
|
||||
/* diagonal = */ it,
|
||||
/* k = */ ops.constant(offset),
|
||||
/* numRows = */ ops.constant(dim1),
|
||||
/* numCols = */ ops.constant(dim2),
|
||||
/* paddingValue = */ const(elementAlgebra.zero)
|
||||
)
|
||||
}
|
||||
|
||||
override fun StructureND<T>.sum(): T = operate {
|
||||
|
@ -123,27 +123,27 @@ public interface AnalyticTensorAlgebra<T, A : Field<T>> :
|
||||
//For information: https://pytorch.org/docs/stable/generated/torch.floor.html#torch.floor
|
||||
public fun StructureND<T>.floor(): Tensor<T>
|
||||
|
||||
override fun sin(arg: StructureND<T>): StructureND<T> = arg.sin()
|
||||
override fun sin(arg: StructureND<T>): Tensor<T> = arg.sin()
|
||||
|
||||
override fun cos(arg: StructureND<T>): StructureND<T> = arg.cos()
|
||||
override fun cos(arg: StructureND<T>): Tensor<T> = arg.cos()
|
||||
|
||||
override fun asin(arg: StructureND<T>): StructureND<T> = arg.asin()
|
||||
override fun asin(arg: StructureND<T>): Tensor<T> = arg.asin()
|
||||
|
||||
override fun acos(arg: StructureND<T>): StructureND<T> = arg.acos()
|
||||
override fun acos(arg: StructureND<T>): Tensor<T> = arg.acos()
|
||||
|
||||
override fun atan(arg: StructureND<T>): StructureND<T> = arg.atan()
|
||||
override fun atan(arg: StructureND<T>): Tensor<T> = arg.atan()
|
||||
|
||||
override fun exp(arg: StructureND<T>): StructureND<T> = arg.exp()
|
||||
override fun exp(arg: StructureND<T>): Tensor<T> = arg.exp()
|
||||
|
||||
override fun ln(arg: StructureND<T>): StructureND<T> = arg.ln()
|
||||
override fun ln(arg: StructureND<T>): Tensor<T> = arg.ln()
|
||||
|
||||
override fun sinh(arg: StructureND<T>): StructureND<T> = arg.sinh()
|
||||
override fun sinh(arg: StructureND<T>): Tensor<T> = arg.sinh()
|
||||
|
||||
override fun cosh(arg: StructureND<T>): StructureND<T> = arg.cosh()
|
||||
override fun cosh(arg: StructureND<T>): Tensor<T> = arg.cosh()
|
||||
|
||||
override fun asinh(arg: StructureND<T>): StructureND<T> = arg.asinh()
|
||||
override fun asinh(arg: StructureND<T>): Tensor<T> = arg.asinh()
|
||||
|
||||
override fun acosh(arg: StructureND<T>): StructureND<T> = arg.acosh()
|
||||
override fun acosh(arg: StructureND<T>): Tensor<T> = arg.acosh()
|
||||
|
||||
override fun atanh(arg: StructureND<T>): StructureND<T> = arg.atanh()
|
||||
override fun atanh(arg: StructureND<T>): Tensor<T> = arg.atanh()
|
||||
}
|
Loading…
Reference in New Issue
Block a user