forked from kscience/kmath
Fixed missing TF basic operations
This commit is contained in:
parent
b509dc917d
commit
7e4d223044
@ -6,11 +6,9 @@ import org.tensorflow.Operand
|
|||||||
import org.tensorflow.Output
|
import org.tensorflow.Output
|
||||||
import org.tensorflow.Session
|
import org.tensorflow.Session
|
||||||
import org.tensorflow.ndarray.NdArray
|
import org.tensorflow.ndarray.NdArray
|
||||||
|
import org.tensorflow.ndarray.index.Indices
|
||||||
import org.tensorflow.op.Ops
|
import org.tensorflow.op.Ops
|
||||||
import org.tensorflow.op.core.Constant
|
import org.tensorflow.op.core.*
|
||||||
import org.tensorflow.op.core.Max
|
|
||||||
import org.tensorflow.op.core.Min
|
|
||||||
import org.tensorflow.op.core.Sum
|
|
||||||
import org.tensorflow.types.TInt32
|
import org.tensorflow.types.TInt32
|
||||||
import org.tensorflow.types.family.TNumber
|
import org.tensorflow.types.family.TNumber
|
||||||
import org.tensorflow.types.family.TType
|
import org.tensorflow.types.family.TType
|
||||||
@ -182,7 +180,7 @@ public abstract class TensorFlowAlgebra<T, TT : TNumber, A : Ring<T>> internal c
|
|||||||
override fun StructureND<T>.unaryMinus(): TensorFlowOutput<T, TT> = operate(ops.math::neg)
|
override fun StructureND<T>.unaryMinus(): TensorFlowOutput<T, TT> = operate(ops.math::neg)
|
||||||
|
|
||||||
override fun Tensor<T>.get(i: Int): Tensor<T> = operate {
|
override fun Tensor<T>.get(i: Int): Tensor<T> = operate {
|
||||||
TODO("Not yet implemented")
|
StridedSliceHelper.stridedSlice(ops.scope(), it, Indices.at(i.toLong()))
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun Tensor<T>.transpose(i: Int, j: Int): Tensor<T> = operate {
|
override fun Tensor<T>.transpose(i: Int, j: Int): Tensor<T> = operate {
|
||||||
@ -210,7 +208,13 @@ public abstract class TensorFlowAlgebra<T, TT : TNumber, A : Ring<T>> internal c
|
|||||||
dim1: Int,
|
dim1: Int,
|
||||||
dim2: Int,
|
dim2: Int,
|
||||||
): TensorFlowOutput<T, TT> = diagonalEntries.operate {
|
): TensorFlowOutput<T, TT> = diagonalEntries.operate {
|
||||||
TODO("Not yet implemented")
|
ops.linalg.matrixDiagV3(
|
||||||
|
/* diagonal = */ it,
|
||||||
|
/* k = */ ops.constant(offset),
|
||||||
|
/* numRows = */ ops.constant(dim1),
|
||||||
|
/* numCols = */ ops.constant(dim2),
|
||||||
|
/* paddingValue = */ const(elementAlgebra.zero)
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun StructureND<T>.sum(): T = operate {
|
override fun StructureND<T>.sum(): T = operate {
|
||||||
|
@ -123,27 +123,27 @@ public interface AnalyticTensorAlgebra<T, A : Field<T>> :
|
|||||||
//For information: https://pytorch.org/docs/stable/generated/torch.floor.html#torch.floor
|
//For information: https://pytorch.org/docs/stable/generated/torch.floor.html#torch.floor
|
||||||
public fun StructureND<T>.floor(): Tensor<T>
|
public fun StructureND<T>.floor(): Tensor<T>
|
||||||
|
|
||||||
override fun sin(arg: StructureND<T>): StructureND<T> = arg.sin()
|
override fun sin(arg: StructureND<T>): Tensor<T> = arg.sin()
|
||||||
|
|
||||||
override fun cos(arg: StructureND<T>): StructureND<T> = arg.cos()
|
override fun cos(arg: StructureND<T>): Tensor<T> = arg.cos()
|
||||||
|
|
||||||
override fun asin(arg: StructureND<T>): StructureND<T> = arg.asin()
|
override fun asin(arg: StructureND<T>): Tensor<T> = arg.asin()
|
||||||
|
|
||||||
override fun acos(arg: StructureND<T>): StructureND<T> = arg.acos()
|
override fun acos(arg: StructureND<T>): Tensor<T> = arg.acos()
|
||||||
|
|
||||||
override fun atan(arg: StructureND<T>): StructureND<T> = arg.atan()
|
override fun atan(arg: StructureND<T>): Tensor<T> = arg.atan()
|
||||||
|
|
||||||
override fun exp(arg: StructureND<T>): StructureND<T> = arg.exp()
|
override fun exp(arg: StructureND<T>): Tensor<T> = arg.exp()
|
||||||
|
|
||||||
override fun ln(arg: StructureND<T>): StructureND<T> = arg.ln()
|
override fun ln(arg: StructureND<T>): Tensor<T> = arg.ln()
|
||||||
|
|
||||||
override fun sinh(arg: StructureND<T>): StructureND<T> = arg.sinh()
|
override fun sinh(arg: StructureND<T>): Tensor<T> = arg.sinh()
|
||||||
|
|
||||||
override fun cosh(arg: StructureND<T>): StructureND<T> = arg.cosh()
|
override fun cosh(arg: StructureND<T>): Tensor<T> = arg.cosh()
|
||||||
|
|
||||||
override fun asinh(arg: StructureND<T>): StructureND<T> = arg.asinh()
|
override fun asinh(arg: StructureND<T>): Tensor<T> = arg.asinh()
|
||||||
|
|
||||||
override fun acosh(arg: StructureND<T>): StructureND<T> = arg.acosh()
|
override fun acosh(arg: StructureND<T>): Tensor<T> = arg.acosh()
|
||||||
|
|
||||||
override fun atanh(arg: StructureND<T>): StructureND<T> = arg.atanh()
|
override fun atanh(arg: StructureND<T>): Tensor<T> = arg.atanh()
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user