forked from kscience/kmath
Add TensorFlow prototype
This commit is contained in:
parent
41fc6b4dd9
commit
7bb66f6a00
@ -18,6 +18,7 @@
|
|||||||
- Integration between `MST` and Symja `IExpr`
|
- Integration between `MST` and Symja `IExpr`
|
||||||
- Complex power
|
- Complex power
|
||||||
- Separate methods for UInt, Int and Number powers. NaN safety.
|
- Separate methods for UInt, Int and Number powers. NaN safety.
|
||||||
|
- Tensorflow prototype
|
||||||
|
|
||||||
### Changed
|
### Changed
|
||||||
- Exponential operations merged with hyperbolic functions
|
- Exponential operations merged with hyperbolic functions
|
||||||
|
29
README.md
29
README.md
@ -50,35 +50,6 @@ module definitions below. The module stability could have the following levels:
|
|||||||
with [binary-compatibility-validator](https://github.com/Kotlin/binary-compatibility-validator) tool.
|
with [binary-compatibility-validator](https://github.com/Kotlin/binary-compatibility-validator) tool.
|
||||||
* **STABLE**. The API stabilized. Breaking changes are allowed only in major releases.
|
* **STABLE**. The API stabilized. Breaking changes are allowed only in major releases.
|
||||||
|
|
||||||
<!--Current feature list is [here](/docs/features.md)-->
|
|
||||||
|
|
||||||
|
|
||||||
<!--* **Array-like structures** Full support of many-dimensional array-like structures -->
|
|
||||||
<!--including mixed arithmetic operations and function operations over arrays and numbers (with the added benefit of static type checking).-->
|
|
||||||
|
|
||||||
<!--* **Histograms** Fast multi-dimensional histograms.-->
|
|
||||||
|
|
||||||
<!--* **Streaming** Streaming operations on mathematical objects and objects buffers.-->
|
|
||||||
|
|
||||||
<!--* **Type-safe dimensions** Type-safe dimensions for matrix operations.-->
|
|
||||||
|
|
||||||
<!--* **Commons-math wrapper** It is planned to gradually wrap most parts of -->
|
|
||||||
<!--[Apache commons-math](http://commons.apache.org/proper/commons-math/) library in Kotlin code and maybe rewrite some -->
|
|
||||||
<!--parts to better suit the Kotlin programming paradigm, however there is no established roadmap for that. Feel free to -->
|
|
||||||
<!--submit a feature request if you want something to be implemented first.-->
|
|
||||||
<!-- -->
|
|
||||||
<!--## Planned features-->
|
|
||||||
|
|
||||||
<!--* **Messaging** A mathematical notation to support multi-language and multi-node communication for mathematical tasks.-->
|
|
||||||
|
|
||||||
<!--* **Array statistics** -->
|
|
||||||
|
|
||||||
<!--* **Integration** Univariate and multivariate integration framework.-->
|
|
||||||
|
|
||||||
<!--* **Probability and distributions**-->
|
|
||||||
|
|
||||||
<!--* **Fitting** Non-linear curve fitting facilities-->
|
|
||||||
|
|
||||||
## Modules
|
## Modules
|
||||||
|
|
||||||
<hr/>
|
<hr/>
|
||||||
|
29
docs/templates/README-TEMPLATE.md
vendored
29
docs/templates/README-TEMPLATE.md
vendored
@ -50,35 +50,6 @@ module definitions below. The module stability could have the following levels:
|
|||||||
with [binary-compatibility-validator](https://github.com/Kotlin/binary-compatibility-validator) tool.
|
with [binary-compatibility-validator](https://github.com/Kotlin/binary-compatibility-validator) tool.
|
||||||
* **STABLE**. The API stabilized. Breaking changes are allowed only in major releases.
|
* **STABLE**. The API stabilized. Breaking changes are allowed only in major releases.
|
||||||
|
|
||||||
<!--Current feature list is [here](/docs/features.md)-->
|
|
||||||
|
|
||||||
|
|
||||||
<!--* **Array-like structures** Full support of many-dimensional array-like structures -->
|
|
||||||
<!--including mixed arithmetic operations and function operations over arrays and numbers (with the added benefit of static type checking).-->
|
|
||||||
|
|
||||||
<!--* **Histograms** Fast multi-dimensional histograms.-->
|
|
||||||
|
|
||||||
<!--* **Streaming** Streaming operations on mathematical objects and objects buffers.-->
|
|
||||||
|
|
||||||
<!--* **Type-safe dimensions** Type-safe dimensions for matrix operations.-->
|
|
||||||
|
|
||||||
<!--* **Commons-math wrapper** It is planned to gradually wrap most parts of -->
|
|
||||||
<!--[Apache commons-math](http://commons.apache.org/proper/commons-math/) library in Kotlin code and maybe rewrite some -->
|
|
||||||
<!--parts to better suit the Kotlin programming paradigm, however there is no established roadmap for that. Feel free to -->
|
|
||||||
<!--submit a feature request if you want something to be implemented first.-->
|
|
||||||
<!-- -->
|
|
||||||
<!--## Planned features-->
|
|
||||||
|
|
||||||
<!--* **Messaging** A mathematical notation to support multi-language and multi-node communication for mathematical tasks.-->
|
|
||||||
|
|
||||||
<!--* **Array statistics** -->
|
|
||||||
|
|
||||||
<!--* **Integration** Univariate and multivariate integration framework.-->
|
|
||||||
|
|
||||||
<!--* **Probability and distributions**-->
|
|
||||||
|
|
||||||
<!--* **Fitting** Non-linear curve fitting facilities-->
|
|
||||||
|
|
||||||
## Modules
|
## Modules
|
||||||
|
|
||||||
$modules
|
$modules
|
||||||
|
@ -12,4 +12,4 @@ org.gradle.configureondemand=true
|
|||||||
org.gradle.parallel=true
|
org.gradle.parallel=true
|
||||||
org.gradle.jvmargs=-XX:MaxMetaspaceSize=1G
|
org.gradle.jvmargs=-XX:MaxMetaspaceSize=1G
|
||||||
|
|
||||||
toolsVersion=0.10.9-kotlin-1.6.10
|
toolsVersion=0.11.1-kotlin-1.6.10
|
||||||
|
@ -6,8 +6,8 @@ description = "Google tensorflow connector"
|
|||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
api(project(":kmath-tensors"))
|
api(project(":kmath-tensors"))
|
||||||
api("org.tensorflow:tensorflow-core-api:0.3.3")
|
api("org.tensorflow:tensorflow-core-api:0.4.0")
|
||||||
testImplementation("org.tensorflow:tensorflow-core-platform:0.3.3")
|
testImplementation("org.tensorflow:tensorflow-core-platform:0.4.0")
|
||||||
}
|
}
|
||||||
|
|
||||||
readme {
|
readme {
|
||||||
|
@ -11,6 +11,7 @@ import space.kscience.kmath.nd.DefaultStrides
|
|||||||
import space.kscience.kmath.nd.Shape
|
import space.kscience.kmath.nd.Shape
|
||||||
import space.kscience.kmath.nd.StructureND
|
import space.kscience.kmath.nd.StructureND
|
||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.operations.DoubleField
|
||||||
|
import space.kscience.kmath.operations.PowerOperations
|
||||||
|
|
||||||
public class DoubleTensorFlowOutput(
|
public class DoubleTensorFlowOutput(
|
||||||
graph: Graph,
|
graph: Graph,
|
||||||
@ -23,7 +24,7 @@ public class DoubleTensorFlowOutput(
|
|||||||
|
|
||||||
public class DoubleTensorFlowAlgebra internal constructor(
|
public class DoubleTensorFlowAlgebra internal constructor(
|
||||||
graph: Graph,
|
graph: Graph,
|
||||||
) : TensorFlowAlgebra<Double, TFloat64, DoubleField>(graph) {
|
) : TensorFlowAlgebra<Double, TFloat64, DoubleField>(graph), PowerOperations<StructureND<Double>> {
|
||||||
|
|
||||||
override val elementAlgebra: DoubleField get() = DoubleField
|
override val elementAlgebra: DoubleField get() = DoubleField
|
||||||
|
|
||||||
@ -57,9 +58,22 @@ public class DoubleTensorFlowAlgebra internal constructor(
|
|||||||
|
|
||||||
override fun const(value: Double): Constant<TFloat64> = ops.constant(value)
|
override fun const(value: Double): Constant<TFloat64> = ops.constant(value)
|
||||||
|
|
||||||
|
override fun divide(
|
||||||
|
left: StructureND<Double>,
|
||||||
|
right: StructureND<Double>,
|
||||||
|
): TensorFlowOutput<Double, TFloat64> = left.operate(right) { l, r ->
|
||||||
|
ops.math.div(l, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override fun power(arg: StructureND<Double>, pow: Number): TensorFlowOutput<Double, TFloat64> =
|
||||||
|
arg.operate { ops.math.pow(it, const(pow.toDouble())) }
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Compute a tensor with TensorFlow in a single run.
|
||||||
|
*
|
||||||
|
* The resulting tensor is available outside of scope
|
||||||
|
*/
|
||||||
public fun DoubleField.produceWithTF(
|
public fun DoubleField.produceWithTF(
|
||||||
block: DoubleTensorFlowAlgebra.() -> StructureND<Double>,
|
block: DoubleTensorFlowAlgebra.() -> StructureND<Double>,
|
||||||
): StructureND<Double> = Graph().use { graph ->
|
): StructureND<Double> = Graph().use { graph ->
|
||||||
@ -67,6 +81,11 @@ public fun DoubleField.produceWithTF(
|
|||||||
scope.export(scope.block())
|
scope.export(scope.block())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Compute several outputs with TensorFlow in a single run.
|
||||||
|
*
|
||||||
|
* The resulting tensors are available outside of scope
|
||||||
|
*/
|
||||||
public fun DoubleField.produceMapWithTF(
|
public fun DoubleField.produceMapWithTF(
|
||||||
block: DoubleTensorFlowAlgebra.() -> Map<Symbol, StructureND<Double>>,
|
block: DoubleTensorFlowAlgebra.() -> Map<Symbol, StructureND<Double>>,
|
||||||
): Map<Symbol, StructureND<Double>> = Graph().use { graph ->
|
): Map<Symbol, StructureND<Double>> = Graph().use { graph ->
|
||||||
|
@ -12,6 +12,7 @@ import org.tensorflow.op.core.Max
|
|||||||
import org.tensorflow.op.core.Min
|
import org.tensorflow.op.core.Min
|
||||||
import org.tensorflow.op.core.Sum
|
import org.tensorflow.op.core.Sum
|
||||||
import org.tensorflow.types.TInt32
|
import org.tensorflow.types.TInt32
|
||||||
|
import org.tensorflow.types.family.TNumber
|
||||||
import org.tensorflow.types.family.TType
|
import org.tensorflow.types.family.TType
|
||||||
import space.kscience.kmath.misc.PerformancePitfall
|
import space.kscience.kmath.misc.PerformancePitfall
|
||||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
@ -29,6 +30,9 @@ internal val <T> NdArray<T>.scalar: T get() = getObject()
|
|||||||
|
|
||||||
public sealed interface TensorFlowTensor<T> : Tensor<T>
|
public sealed interface TensorFlowTensor<T> : Tensor<T>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Static (eager) in-memory TensorFlow tensor
|
||||||
|
*/
|
||||||
@JvmInline
|
@JvmInline
|
||||||
public value class TensorFlowArray<T>(public val tensor: NdArray<T>) : Tensor<T> {
|
public value class TensorFlowArray<T>(public val tensor: NdArray<T>) : Tensor<T> {
|
||||||
override val shape: Shape get() = tensor.shape().asArray().toIntArray()
|
override val shape: Shape get() = tensor.shape().asArray().toIntArray()
|
||||||
@ -42,6 +46,11 @@ public value class TensorFlowArray<T>(public val tensor: NdArray<T>) : Tensor<T>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Lazy graph-based TensorFlow tensor. The tensor is actualized on call.
|
||||||
|
*
|
||||||
|
* If the tensor is used for intermediate operations, actualizing it could impact performance.
|
||||||
|
*/
|
||||||
public abstract class TensorFlowOutput<T, TT : TType>(
|
public abstract class TensorFlowOutput<T, TT : TType>(
|
||||||
protected val graph: Graph,
|
protected val graph: Graph,
|
||||||
output: Output<TT>,
|
output: Output<TT>,
|
||||||
@ -72,11 +81,11 @@ public abstract class TensorFlowOutput<T, TT : TType>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public abstract class TensorFlowAlgebra<T, TT : TType, A : Ring<T>> internal constructor(
|
public abstract class TensorFlowAlgebra<T, TT : TNumber, A : Ring<T>> internal constructor(
|
||||||
protected val graph: Graph,
|
protected val graph: Graph,
|
||||||
) : TensorAlgebra<T, A> {
|
) : TensorAlgebra<T, A> {
|
||||||
|
|
||||||
protected val ops: Ops by lazy { Ops.create(graph) }
|
public val ops: Ops by lazy { Ops.create(graph) }
|
||||||
|
|
||||||
protected abstract fun StructureND<T>.asTensorFlow(): TensorFlowOutput<T, TT>
|
protected abstract fun StructureND<T>.asTensorFlow(): TensorFlowOutput<T, TT>
|
||||||
|
|
||||||
@ -87,7 +96,10 @@ public abstract class TensorFlowAlgebra<T, TT : TType, A : Ring<T>> internal con
|
|||||||
override fun StructureND<T>.valueOrNull(): T? = if (shape contentEquals intArrayOf(1))
|
override fun StructureND<T>.valueOrNull(): T? = if (shape contentEquals intArrayOf(1))
|
||||||
get(Shape(0)) else null
|
get(Shape(0)) else null
|
||||||
|
|
||||||
private inline fun StructureND<T>.biOp(
|
/**
|
||||||
|
* Perform binary lazy operation on tensor. Both arguments are implicitly converted
|
||||||
|
*/
|
||||||
|
public fun StructureND<T>.operate(
|
||||||
other: StructureND<T>,
|
other: StructureND<T>,
|
||||||
operation: (left: Operand<TT>, right: Operand<TT>) -> Operand<TT>,
|
operation: (left: Operand<TT>, right: Operand<TT>) -> Operand<TT>,
|
||||||
): TensorFlowOutput<T, TT> {
|
): TensorFlowOutput<T, TT> {
|
||||||
@ -96,7 +108,7 @@ public abstract class TensorFlowAlgebra<T, TT : TType, A : Ring<T>> internal con
|
|||||||
return operation(left, right).asOutput().wrap()
|
return operation(left, right).asOutput().wrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
private inline fun T.biOp(
|
public fun T.operate(
|
||||||
other: StructureND<T>,
|
other: StructureND<T>,
|
||||||
operation: (left: Operand<TT>, right: Operand<TT>) -> Operand<TT>,
|
operation: (left: Operand<TT>, right: Operand<TT>) -> Operand<TT>,
|
||||||
): TensorFlowOutput<T, TT> {
|
): TensorFlowOutput<T, TT> {
|
||||||
@ -105,7 +117,7 @@ public abstract class TensorFlowAlgebra<T, TT : TType, A : Ring<T>> internal con
|
|||||||
return operation(left, right).asOutput().wrap()
|
return operation(left, right).asOutput().wrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
private inline fun StructureND<T>.biOp(
|
public fun StructureND<T>.operate(
|
||||||
value: T,
|
value: T,
|
||||||
operation: (left: Operand<TT>, right: Operand<TT>) -> Operand<TT>,
|
operation: (left: Operand<TT>, right: Operand<TT>) -> Operand<TT>,
|
||||||
): TensorFlowOutput<T, TT> {
|
): TensorFlowOutput<T, TT> {
|
||||||
@ -114,7 +126,7 @@ public abstract class TensorFlowAlgebra<T, TT : TType, A : Ring<T>> internal con
|
|||||||
return operation(left, right).asOutput().wrap()
|
return operation(left, right).asOutput().wrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
private inline fun Tensor<T>.inPlaceOp(
|
public fun Tensor<T>.operateInPlace(
|
||||||
other: StructureND<T>,
|
other: StructureND<T>,
|
||||||
operation: (left: Operand<TT>, right: Operand<TT>) -> Operand<TT>,
|
operation: (left: Operand<TT>, right: Operand<TT>) -> Operand<TT>,
|
||||||
): Unit {
|
): Unit {
|
||||||
@ -124,7 +136,7 @@ public abstract class TensorFlowAlgebra<T, TT : TType, A : Ring<T>> internal con
|
|||||||
origin.output = operation(left, right).asOutput()
|
origin.output = operation(left, right).asOutput()
|
||||||
}
|
}
|
||||||
|
|
||||||
private inline fun Tensor<T>.inPlaceOp(
|
public fun Tensor<T>.operateInPlace(
|
||||||
value: T,
|
value: T,
|
||||||
operation: (left: Operand<TT>, right: Operand<TT>) -> Operand<TT>,
|
operation: (left: Operand<TT>, right: Operand<TT>) -> Operand<TT>,
|
||||||
): Unit {
|
): Unit {
|
||||||
@ -134,58 +146,58 @@ public abstract class TensorFlowAlgebra<T, TT : TType, A : Ring<T>> internal con
|
|||||||
origin.output = operation(left, right).asOutput()
|
origin.output = operation(left, right).asOutput()
|
||||||
}
|
}
|
||||||
|
|
||||||
private inline fun StructureND<T>.unOp(operation: (Operand<TT>) -> Operand<TT>): TensorFlowOutput<T, TT> =
|
public fun StructureND<T>.operate(operation: (Operand<TT>) -> Operand<TT>): TensorFlowOutput<T, TT> =
|
||||||
operation(asTensorFlow().output).asOutput().wrap()
|
operation(asTensorFlow().output).asOutput().wrap()
|
||||||
|
|
||||||
override fun T.plus(arg: StructureND<T>): TensorFlowOutput<T, TT> = biOp(arg, ops.math::add)
|
override fun T.plus(arg: StructureND<T>): TensorFlowOutput<T, TT> = operate(arg, ops.math::add)
|
||||||
|
|
||||||
override fun StructureND<T>.plus(arg: T): TensorFlowOutput<T, TT> = biOp(arg, ops.math::add)
|
override fun StructureND<T>.plus(arg: T): TensorFlowOutput<T, TT> = operate(arg, ops.math::add)
|
||||||
|
|
||||||
override fun StructureND<T>.plus(arg: StructureND<T>): TensorFlowOutput<T, TT> = biOp(arg, ops.math::add)
|
override fun StructureND<T>.plus(arg: StructureND<T>): TensorFlowOutput<T, TT> = operate(arg, ops.math::add)
|
||||||
|
|
||||||
override fun Tensor<T>.plusAssign(value: T): Unit = inPlaceOp(value, ops.math::add)
|
override fun Tensor<T>.plusAssign(value: T): Unit = operateInPlace(value, ops.math::add)
|
||||||
|
|
||||||
override fun Tensor<T>.plusAssign(arg: StructureND<T>): Unit = inPlaceOp(arg, ops.math::add)
|
override fun Tensor<T>.plusAssign(arg: StructureND<T>): Unit = operateInPlace(arg, ops.math::add)
|
||||||
|
|
||||||
override fun StructureND<T>.minus(arg: T): TensorFlowOutput<T, TT> = biOp(arg, ops.math::sub)
|
override fun StructureND<T>.minus(arg: T): TensorFlowOutput<T, TT> = operate(arg, ops.math::sub)
|
||||||
|
|
||||||
override fun StructureND<T>.minus(arg: StructureND<T>): TensorFlowOutput<T, TT> = biOp(arg, ops.math::sub)
|
override fun StructureND<T>.minus(arg: StructureND<T>): TensorFlowOutput<T, TT> = operate(arg, ops.math::sub)
|
||||||
|
|
||||||
override fun T.minus(arg: StructureND<T>): Tensor<T> = biOp(arg, ops.math::sub)
|
override fun T.minus(arg: StructureND<T>): Tensor<T> = operate(arg, ops.math::sub)
|
||||||
|
|
||||||
override fun Tensor<T>.minusAssign(value: T): Unit = inPlaceOp(value, ops.math::sub)
|
override fun Tensor<T>.minusAssign(value: T): Unit = operateInPlace(value, ops.math::sub)
|
||||||
|
|
||||||
override fun Tensor<T>.minusAssign(arg: StructureND<T>): Unit = inPlaceOp(arg, ops.math::sub)
|
override fun Tensor<T>.minusAssign(arg: StructureND<T>): Unit = operateInPlace(arg, ops.math::sub)
|
||||||
|
|
||||||
override fun T.times(arg: StructureND<T>): TensorFlowOutput<T, TT> = biOp(arg, ops.math::mul)
|
override fun T.times(arg: StructureND<T>): TensorFlowOutput<T, TT> = operate(arg, ops.math::mul)
|
||||||
|
|
||||||
override fun StructureND<T>.times(arg: T): TensorFlowOutput<T, TT> = biOp(arg, ops.math::mul)
|
override fun StructureND<T>.times(arg: T): TensorFlowOutput<T, TT> = operate(arg, ops.math::mul)
|
||||||
|
|
||||||
override fun StructureND<T>.times(arg: StructureND<T>): TensorFlowOutput<T, TT> = biOp(arg, ops.math::mul)
|
override fun StructureND<T>.times(arg: StructureND<T>): TensorFlowOutput<T, TT> = operate(arg, ops.math::mul)
|
||||||
|
|
||||||
override fun Tensor<T>.timesAssign(value: T): Unit = inPlaceOp(value, ops.math::mul)
|
override fun Tensor<T>.timesAssign(value: T): Unit = operateInPlace(value, ops.math::mul)
|
||||||
|
|
||||||
override fun Tensor<T>.timesAssign(arg: StructureND<T>): Unit = inPlaceOp(arg, ops.math::mul)
|
override fun Tensor<T>.timesAssign(arg: StructureND<T>): Unit = operateInPlace(arg, ops.math::mul)
|
||||||
|
|
||||||
override fun StructureND<T>.unaryMinus(): TensorFlowOutput<T, TT> = unOp(ops.math::neg)
|
override fun StructureND<T>.unaryMinus(): TensorFlowOutput<T, TT> = operate(ops.math::neg)
|
||||||
|
|
||||||
override fun Tensor<T>.get(i: Int): Tensor<T> = unOp {
|
override fun Tensor<T>.get(i: Int): Tensor<T> = operate {
|
||||||
TODO("Not yet implemented")
|
TODO("Not yet implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun Tensor<T>.transpose(i: Int, j: Int): Tensor<T> = unOp {
|
override fun Tensor<T>.transpose(i: Int, j: Int): Tensor<T> = operate {
|
||||||
ops.linalg.transpose(it, ops.constant(intArrayOf(i, j)))
|
ops.linalg.transpose(it, ops.constant(intArrayOf(i, j)))
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun Tensor<T>.view(shape: IntArray): Tensor<T> = unOp {
|
override fun Tensor<T>.view(shape: IntArray): Tensor<T> = operate {
|
||||||
ops.reshape(it, ops.constant(shape))
|
ops.reshape(it, ops.constant(shape))
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun Tensor<T>.viewAs(other: StructureND<T>): Tensor<T> = biOp(other) { l, r ->
|
override fun Tensor<T>.viewAs(other: StructureND<T>): Tensor<T> = operate(other) { l, r ->
|
||||||
ops.reshape(l, ops.shape(r))
|
ops.reshape(l, ops.shape(r))
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun StructureND<T>.dot(other: StructureND<T>): TensorFlowOutput<T, TT> = biOp(other) { l, r ->
|
override fun StructureND<T>.dot(other: StructureND<T>): TensorFlowOutput<T, TT> = operate(other) { l, r ->
|
||||||
ops.linalg.matMul(
|
ops.linalg.matMul(
|
||||||
if (l.asTensor().shape().numDimensions() == 1) ops.expandDims(l, ops.constant(0)) else l,
|
if (l.asTensor().shape().numDimensions() == 1) ops.expandDims(l, ops.constant(0)) else l,
|
||||||
if (r.asTensor().shape().numDimensions() == 1) ops.expandDims(r, ops.constant(-1)) else r)
|
if (r.asTensor().shape().numDimensions() == 1) ops.expandDims(r, ops.constant(-1)) else r)
|
||||||
@ -196,31 +208,31 @@ public abstract class TensorFlowAlgebra<T, TT : TType, A : Ring<T>> internal con
|
|||||||
offset: Int,
|
offset: Int,
|
||||||
dim1: Int,
|
dim1: Int,
|
||||||
dim2: Int,
|
dim2: Int,
|
||||||
): TensorFlowOutput<T, TT> = diagonalEntries.unOp {
|
): TensorFlowOutput<T, TT> = diagonalEntries.operate {
|
||||||
TODO()
|
TODO("Not yet implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun StructureND<T>.sum(): T = unOp {
|
override fun StructureND<T>.sum(): T = operate {
|
||||||
ops.sum(it, ops.constant(intArrayOf()))
|
ops.sum(it, ops.constant(intArrayOf()))
|
||||||
}.value()
|
}.value()
|
||||||
|
|
||||||
override fun StructureND<T>.sum(dim: Int, keepDim: Boolean): TensorFlowOutput<T, TT> = unOp {
|
override fun StructureND<T>.sum(dim: Int, keepDim: Boolean): TensorFlowOutput<T, TT> = operate {
|
||||||
ops.sum(it, ops.constant(dim), Sum.keepDims(keepDim))
|
ops.sum(it, ops.constant(dim), Sum.keepDims(keepDim))
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun StructureND<T>.min(): T = unOp {
|
override fun StructureND<T>.min(): T = operate {
|
||||||
ops.min(it, ops.constant(intArrayOf()))
|
ops.min(it, ops.constant(intArrayOf()))
|
||||||
}.value()
|
}.value()
|
||||||
|
|
||||||
override fun StructureND<T>.min(dim: Int, keepDim: Boolean): Tensor<T> = unOp {
|
override fun StructureND<T>.min(dim: Int, keepDim: Boolean): Tensor<T> = operate {
|
||||||
ops.min(it, ops.constant(dim), Min.keepDims(keepDim))
|
ops.min(it, ops.constant(dim), Min.keepDims(keepDim))
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun StructureND<T>.max(): T = unOp {
|
override fun StructureND<T>.max(): T = operate {
|
||||||
ops.max(it, ops.constant(intArrayOf()))
|
ops.max(it, ops.constant(intArrayOf()))
|
||||||
}.value()
|
}.value()
|
||||||
|
|
||||||
override fun StructureND<T>.max(dim: Int, keepDim: Boolean): Tensor<T> = unOp {
|
override fun StructureND<T>.max(dim: Int, keepDim: Boolean): Tensor<T> = operate {
|
||||||
ops.max(it, ops.constant(dim), Max.keepDims(keepDim))
|
ops.max(it, ops.constant(dim), Max.keepDims(keepDim))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -0,0 +1,23 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2021 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package space.kscience.kmath.tensorflow
|
||||||
|
|
||||||
|
import org.tensorflow.types.family.TNumber
|
||||||
|
import space.kscience.kmath.nd.StructureND
|
||||||
|
import space.kscience.kmath.operations.Ring
|
||||||
|
import space.kscience.kmath.operations.TrigonometricOperations
|
||||||
|
|
||||||
|
//
|
||||||
|
|
||||||
|
// TODO add other operations
|
||||||
|
|
||||||
|
public fun <T, TT : TNumber, A> TensorFlowAlgebra<T, TT, A>.sin(
|
||||||
|
arg: StructureND<T>,
|
||||||
|
): TensorFlowOutput<T, TT> where A : TrigonometricOperations<T>, A : Ring<T> = arg.operate { ops.math.sin(it) }
|
||||||
|
|
||||||
|
public fun <T, TT : TNumber, A> TensorFlowAlgebra<T, TT, A>.cos(
|
||||||
|
arg: StructureND<T>,
|
||||||
|
): TensorFlowOutput<T, TT> where A : TrigonometricOperations<T>, A : Ring<T> = arg.operate { ops.math.cos(it) }
|
@ -1,9 +1,10 @@
|
|||||||
package space.kscience.kmath.tensorflow
|
package space.kscience.kmath.tensorflow
|
||||||
|
|
||||||
import org.junit.jupiter.api.Test
|
import org.junit.jupiter.api.Test
|
||||||
import space.kscience.kmath.nd.StructureND
|
import space.kscience.kmath.nd.get
|
||||||
import space.kscience.kmath.nd.structureND
|
import space.kscience.kmath.nd.structureND
|
||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.operations.DoubleField
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
class DoubleTensorFlowOps {
|
class DoubleTensorFlowOps {
|
||||||
@Test
|
@Test
|
||||||
@ -13,7 +14,20 @@ class DoubleTensorFlowOps {
|
|||||||
|
|
||||||
initial + (initial * 2.0)
|
initial + (initial * 2.0)
|
||||||
}
|
}
|
||||||
println(StructureND.toString(res))
|
//println(StructureND.toString(res))
|
||||||
|
assertEquals(3.0, res[0, 0])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun extensionOps(){
|
||||||
|
val res = DoubleField.produceWithTF {
|
||||||
|
val i = structureND(2, 2) { 0.5 }
|
||||||
|
|
||||||
|
sin(i).pow(2) + cos(i).pow(2)
|
||||||
|
}
|
||||||
|
|
||||||
|
assertEquals(1.0, res[0,0],0.01)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user