Merge branch 'feature/tensorflow' into dev
This commit is contained in:
commit
41fc6b4dd9
@ -8,7 +8,7 @@ package space.kscience.kmath.histogram
|
|||||||
import kotlinx.atomicfu.atomic
|
import kotlinx.atomicfu.atomic
|
||||||
import kotlinx.atomicfu.getAndUpdate
|
import kotlinx.atomicfu.getAndUpdate
|
||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.operations.DoubleField
|
||||||
import space.kscience.kmath.operations.Ring
|
import space.kscience.kmath.operations.Group
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Common representation for atomic counters
|
* Common representation for atomic counters
|
||||||
@ -18,7 +18,7 @@ public interface Counter<T : Any> {
|
|||||||
public val value: T
|
public val value: T
|
||||||
|
|
||||||
public companion object {
|
public companion object {
|
||||||
public fun real(): ObjectCounter<Double> = ObjectCounter(DoubleField)
|
public fun double(): ObjectCounter<Double> = ObjectCounter(DoubleField)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -32,6 +32,16 @@ public class IntCounter : Counter<Int> {
|
|||||||
override val value: Int get() = innerValue.value
|
override val value: Int get() = innerValue.value
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public operator fun IntCounter.inc(): IntCounter {
|
||||||
|
add(1)
|
||||||
|
return this
|
||||||
|
}
|
||||||
|
|
||||||
|
public operator fun IntCounter.dec(): IntCounter {
|
||||||
|
add(-1)
|
||||||
|
return this
|
||||||
|
}
|
||||||
|
|
||||||
public class LongCounter : Counter<Long> {
|
public class LongCounter : Counter<Long> {
|
||||||
private val innerValue = atomic(0L)
|
private val innerValue = atomic(0L)
|
||||||
|
|
||||||
@ -42,7 +52,17 @@ public class LongCounter : Counter<Long> {
|
|||||||
override val value: Long get() = innerValue.value
|
override val value: Long get() = innerValue.value
|
||||||
}
|
}
|
||||||
|
|
||||||
public class ObjectCounter<T : Any>(public val group: Ring<T>) : Counter<T> {
|
public operator fun LongCounter.inc(): LongCounter {
|
||||||
|
add(1L)
|
||||||
|
return this
|
||||||
|
}
|
||||||
|
|
||||||
|
public operator fun LongCounter.dec(): LongCounter {
|
||||||
|
add(-1L)
|
||||||
|
return this
|
||||||
|
}
|
||||||
|
|
||||||
|
public class ObjectCounter<T : Any>(private val group: Group<T>) : Counter<T> {
|
||||||
private val innerValue = atomic(group.zero)
|
private val innerValue = atomic(group.zero)
|
||||||
|
|
||||||
override fun add(delta: T) {
|
override fun add(delta: T) {
|
||||||
|
@ -74,7 +74,7 @@ public class DoubleHistogramSpace(
|
|||||||
}
|
}
|
||||||
|
|
||||||
override fun produce(builder: HistogramBuilder<Double>.() -> Unit): IndexedHistogram<Double, Double> {
|
override fun produce(builder: HistogramBuilder<Double>.() -> Unit): IndexedHistogram<Double, Double> {
|
||||||
val ndCounter = StructureND.auto(shape) { Counter.real() }
|
val ndCounter = StructureND.auto(shape) { Counter.double() }
|
||||||
val hBuilder = HistogramBuilder<Double> { point, value ->
|
val hBuilder = HistogramBuilder<Double> { point, value ->
|
||||||
val index = getIndex(point)
|
val index = getIndex(point)
|
||||||
ndCounter[index].add(value.toDouble())
|
ndCounter[index].add(value.toDouble())
|
||||||
|
@ -41,7 +41,6 @@ public class IndexedHistogram<T : Comparable<T>, V : Any>(
|
|||||||
get() = DefaultStrides(context.shape).asSequence().map {
|
get() = DefaultStrides(context.shape).asSequence().map {
|
||||||
context.produceBin(it, values[it])
|
context.produceBin(it, values[it])
|
||||||
}.asIterable()
|
}.asIterable()
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -39,7 +39,7 @@ public class TreeHistogram(
|
|||||||
@PublishedApi
|
@PublishedApi
|
||||||
internal class TreeHistogramBuilder(val binFactory: (Double) -> UnivariateDomain) : UnivariateHistogramBuilder {
|
internal class TreeHistogramBuilder(val binFactory: (Double) -> UnivariateDomain) : UnivariateHistogramBuilder {
|
||||||
|
|
||||||
internal class BinCounter(val domain: UnivariateDomain, val counter: Counter<Double> = Counter.real()) :
|
internal class BinCounter(val domain: UnivariateDomain, val counter: Counter<Double> = Counter.double()) :
|
||||||
ClosedFloatingPointRange<Double> by domain.range
|
ClosedFloatingPointRange<Double> by domain.range
|
||||||
|
|
||||||
private val bins: TreeMap<Double, BinCounter> = TreeMap()
|
private val bins: TreeMap<Double, BinCounter> = TreeMap()
|
||||||
|
@ -42,7 +42,7 @@ public interface UnivariateHistogram : Histogram<Double, UnivariateBin> {
|
|||||||
/**
|
/**
|
||||||
* Build and fill a [UnivariateHistogram]. Returns a read-only histogram.
|
* Build and fill a [UnivariateHistogram]. Returns a read-only histogram.
|
||||||
*/
|
*/
|
||||||
public fun uniform(
|
public inline fun uniform(
|
||||||
binSize: Double,
|
binSize: Double,
|
||||||
start: Double = 0.0,
|
start: Double = 0.0,
|
||||||
builder: UnivariateHistogramBuilder.() -> Unit,
|
builder: UnivariateHistogramBuilder.() -> Unit,
|
||||||
@ -51,7 +51,7 @@ public interface UnivariateHistogram : Histogram<Double, UnivariateBin> {
|
|||||||
/**
|
/**
|
||||||
* Build and fill a histogram with custom borders. Returns a read-only histogram.
|
* Build and fill a histogram with custom borders. Returns a read-only histogram.
|
||||||
*/
|
*/
|
||||||
public fun custom(
|
public inline fun custom(
|
||||||
borders: DoubleArray,
|
borders: DoubleArray,
|
||||||
builder: UnivariateHistogramBuilder.() -> Unit,
|
builder: UnivariateHistogramBuilder.() -> Unit,
|
||||||
): UnivariateHistogram = TreeHistogramSpace.custom(borders).fill(builder)
|
): UnivariateHistogram = TreeHistogramSpace.custom(borders).fill(builder)
|
||||||
|
15
kmath-tensorflow/build.gradle.kts
Normal file
15
kmath-tensorflow/build.gradle.kts
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
plugins {
|
||||||
|
id("ru.mipt.npm.gradle.jvm")
|
||||||
|
}
|
||||||
|
|
||||||
|
description = "Google tensorflow connector"
|
||||||
|
|
||||||
|
dependencies {
|
||||||
|
api(project(":kmath-tensors"))
|
||||||
|
api("org.tensorflow:tensorflow-core-api:0.3.3")
|
||||||
|
testImplementation("org.tensorflow:tensorflow-core-platform:0.3.3")
|
||||||
|
}
|
||||||
|
|
||||||
|
readme {
|
||||||
|
maturity = ru.mipt.npm.gradle.Maturity.PROTOTYPE
|
||||||
|
}
|
@ -0,0 +1,75 @@
|
|||||||
|
package space.kscience.kmath.tensorflow
|
||||||
|
|
||||||
|
import org.tensorflow.Graph
|
||||||
|
import org.tensorflow.Output
|
||||||
|
import org.tensorflow.ndarray.NdArray
|
||||||
|
import org.tensorflow.op.core.Constant
|
||||||
|
import org.tensorflow.types.TFloat64
|
||||||
|
import space.kscience.kmath.expressions.Symbol
|
||||||
|
import space.kscience.kmath.misc.PerformancePitfall
|
||||||
|
import space.kscience.kmath.nd.DefaultStrides
|
||||||
|
import space.kscience.kmath.nd.Shape
|
||||||
|
import space.kscience.kmath.nd.StructureND
|
||||||
|
import space.kscience.kmath.operations.DoubleField
|
||||||
|
|
||||||
|
public class DoubleTensorFlowOutput(
|
||||||
|
graph: Graph,
|
||||||
|
output: Output<TFloat64>,
|
||||||
|
) : TensorFlowOutput<Double, TFloat64>(graph, output) {
|
||||||
|
|
||||||
|
override fun org.tensorflow.Tensor.actualizeTensor(): NdArray<Double> = this as TFloat64
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public class DoubleTensorFlowAlgebra internal constructor(
|
||||||
|
graph: Graph,
|
||||||
|
) : TensorFlowAlgebra<Double, TFloat64, DoubleField>(graph) {
|
||||||
|
|
||||||
|
override val elementAlgebra: DoubleField get() = DoubleField
|
||||||
|
|
||||||
|
override fun structureND(
|
||||||
|
shape: Shape,
|
||||||
|
initializer: DoubleField.(IntArray) -> Double,
|
||||||
|
): StructureND<Double> {
|
||||||
|
val res = TFloat64.tensorOf(org.tensorflow.ndarray.Shape.of(*shape.toLongArray())) { array ->
|
||||||
|
DefaultStrides(shape).forEach { index ->
|
||||||
|
array.setDouble(elementAlgebra.initializer(index), *index.toLongArray())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return DoubleTensorFlowOutput(graph, ops.constant(res).asOutput())
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.asTensorFlow(): TensorFlowOutput<Double, TFloat64> =
|
||||||
|
if (this is TensorFlowOutput<Double, *> && output.type() == TFloat64::class.java) {
|
||||||
|
@Suppress("UNCHECKED_CAST")
|
||||||
|
this as TensorFlowOutput<Double, TFloat64>
|
||||||
|
} else {
|
||||||
|
val res = TFloat64.tensorOf(org.tensorflow.ndarray.Shape.of(*shape.toLongArray())) { array ->
|
||||||
|
@OptIn(PerformancePitfall::class)
|
||||||
|
elements().forEach { (index, value) ->
|
||||||
|
array.setDouble(value, *index.toLongArray())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
DoubleTensorFlowOutput(graph, ops.constant(res).asOutput())
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun Output<TFloat64>.wrap(): TensorFlowOutput<Double, TFloat64> = DoubleTensorFlowOutput(graph, this)
|
||||||
|
|
||||||
|
override fun const(value: Double): Constant<TFloat64> = ops.constant(value)
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun DoubleField.produceWithTF(
|
||||||
|
block: DoubleTensorFlowAlgebra.() -> StructureND<Double>,
|
||||||
|
): StructureND<Double> = Graph().use { graph ->
|
||||||
|
val scope = DoubleTensorFlowAlgebra(graph)
|
||||||
|
scope.export(scope.block())
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun DoubleField.produceMapWithTF(
|
||||||
|
block: DoubleTensorFlowAlgebra.() -> Map<Symbol, StructureND<Double>>,
|
||||||
|
): Map<Symbol, StructureND<Double>> = Graph().use { graph ->
|
||||||
|
val scope = DoubleTensorFlowAlgebra(graph)
|
||||||
|
scope.block().mapValues { scope.export(it.value) }
|
||||||
|
}
|
@ -0,0 +1,21 @@
|
|||||||
|
package space.kscience.kmath.tensorflow
|
||||||
|
|
||||||
|
import org.tensorflow.Graph
|
||||||
|
import org.tensorflow.Output
|
||||||
|
import org.tensorflow.ndarray.NdArray
|
||||||
|
import org.tensorflow.types.TInt32
|
||||||
|
import org.tensorflow.types.TInt64
|
||||||
|
|
||||||
|
public class IntTensorFlowOutput(
|
||||||
|
graph: Graph,
|
||||||
|
output: Output<TInt32>,
|
||||||
|
) : TensorFlowOutput<Int, TInt32>(graph, output) {
|
||||||
|
override fun org.tensorflow.Tensor.actualizeTensor(): NdArray<Int> = this as TInt32
|
||||||
|
}
|
||||||
|
|
||||||
|
public class LongTensorFlowOutput(
|
||||||
|
graph: Graph,
|
||||||
|
output: Output<TInt64>,
|
||||||
|
) : TensorFlowOutput<Long, TInt64>(graph, output) {
|
||||||
|
override fun org.tensorflow.Tensor.actualizeTensor(): NdArray<Long> = this as TInt64
|
||||||
|
}
|
@ -0,0 +1,237 @@
|
|||||||
|
package space.kscience.kmath.tensorflow
|
||||||
|
|
||||||
|
|
||||||
|
import org.tensorflow.Graph
|
||||||
|
import org.tensorflow.Operand
|
||||||
|
import org.tensorflow.Output
|
||||||
|
import org.tensorflow.Session
|
||||||
|
import org.tensorflow.ndarray.NdArray
|
||||||
|
import org.tensorflow.op.Ops
|
||||||
|
import org.tensorflow.op.core.Constant
|
||||||
|
import org.tensorflow.op.core.Max
|
||||||
|
import org.tensorflow.op.core.Min
|
||||||
|
import org.tensorflow.op.core.Sum
|
||||||
|
import org.tensorflow.types.TInt32
|
||||||
|
import org.tensorflow.types.family.TType
|
||||||
|
import space.kscience.kmath.misc.PerformancePitfall
|
||||||
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
|
import space.kscience.kmath.nd.Shape
|
||||||
|
import space.kscience.kmath.nd.StructureND
|
||||||
|
import space.kscience.kmath.operations.Ring
|
||||||
|
import space.kscience.kmath.tensors.api.Tensor
|
||||||
|
import space.kscience.kmath.tensors.api.TensorAlgebra
|
||||||
|
|
||||||
|
internal fun IntArray.toLongArray() = LongArray(size) { get(it).toLong() }
|
||||||
|
internal fun LongArray.toIntArray() = IntArray(size) { get(it).toInt() }
|
||||||
|
|
||||||
|
internal val <T> NdArray<T>.scalar: T get() = getObject()
|
||||||
|
|
||||||
|
|
||||||
|
public sealed interface TensorFlowTensor<T> : Tensor<T>
|
||||||
|
|
||||||
|
@JvmInline
|
||||||
|
public value class TensorFlowArray<T>(public val tensor: NdArray<T>) : Tensor<T> {
|
||||||
|
override val shape: Shape get() = tensor.shape().asArray().toIntArray()
|
||||||
|
|
||||||
|
override fun get(index: IntArray): T = tensor.getObject(*index.toLongArray())
|
||||||
|
|
||||||
|
//TODO implement native element sequence
|
||||||
|
|
||||||
|
override fun set(index: IntArray, value: T) {
|
||||||
|
tensor.setObject(value, *index.toLongArray())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public abstract class TensorFlowOutput<T, TT : TType>(
|
||||||
|
protected val graph: Graph,
|
||||||
|
output: Output<TT>,
|
||||||
|
) : TensorFlowTensor<T> {
|
||||||
|
|
||||||
|
public var output: Output<TT> = output
|
||||||
|
internal set
|
||||||
|
|
||||||
|
override val shape: Shape get() = output.shape().asArray().toIntArray()
|
||||||
|
|
||||||
|
protected abstract fun org.tensorflow.Tensor.actualizeTensor(): NdArray<T>
|
||||||
|
|
||||||
|
internal val actualTensor by lazy {
|
||||||
|
Session(graph).use { session ->
|
||||||
|
TensorFlowArray(session.runner().fetch(output).run().first().actualizeTensor())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun get(index: IntArray): T = actualTensor[index]
|
||||||
|
|
||||||
|
@PerformancePitfall
|
||||||
|
override fun elements(): Sequence<Pair<IntArray, T>> = actualTensor.elements()
|
||||||
|
|
||||||
|
override fun set(index: IntArray, value: T) {
|
||||||
|
actualTensor[index] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public abstract class TensorFlowAlgebra<T, TT : TType, A : Ring<T>> internal constructor(
|
||||||
|
protected val graph: Graph,
|
||||||
|
) : TensorAlgebra<T, A> {
|
||||||
|
|
||||||
|
protected val ops: Ops by lazy { Ops.create(graph) }
|
||||||
|
|
||||||
|
protected abstract fun StructureND<T>.asTensorFlow(): TensorFlowOutput<T, TT>
|
||||||
|
|
||||||
|
protected abstract fun Output<TT>.wrap(): TensorFlowOutput<T, TT>
|
||||||
|
|
||||||
|
protected abstract fun const(value: T): Constant<TT>
|
||||||
|
|
||||||
|
override fun StructureND<T>.valueOrNull(): T? = if (shape contentEquals intArrayOf(1))
|
||||||
|
get(Shape(0)) else null
|
||||||
|
|
||||||
|
private inline fun StructureND<T>.biOp(
|
||||||
|
other: StructureND<T>,
|
||||||
|
operation: (left: Operand<TT>, right: Operand<TT>) -> Operand<TT>,
|
||||||
|
): TensorFlowOutput<T, TT> {
|
||||||
|
val left = asTensorFlow().output
|
||||||
|
val right = other.asTensorFlow().output
|
||||||
|
return operation(left, right).asOutput().wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
private inline fun T.biOp(
|
||||||
|
other: StructureND<T>,
|
||||||
|
operation: (left: Operand<TT>, right: Operand<TT>) -> Operand<TT>,
|
||||||
|
): TensorFlowOutput<T, TT> {
|
||||||
|
val left = const(this)
|
||||||
|
val right = other.asTensorFlow().output
|
||||||
|
return operation(left, right).asOutput().wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
private inline fun StructureND<T>.biOp(
|
||||||
|
value: T,
|
||||||
|
operation: (left: Operand<TT>, right: Operand<TT>) -> Operand<TT>,
|
||||||
|
): TensorFlowOutput<T, TT> {
|
||||||
|
val left = asTensorFlow().output
|
||||||
|
val right = const(value)
|
||||||
|
return operation(left, right).asOutput().wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
private inline fun Tensor<T>.inPlaceOp(
|
||||||
|
other: StructureND<T>,
|
||||||
|
operation: (left: Operand<TT>, right: Operand<TT>) -> Operand<TT>,
|
||||||
|
): Unit {
|
||||||
|
val origin = asTensorFlow()
|
||||||
|
val left = origin.output
|
||||||
|
val right = other.asTensorFlow().output
|
||||||
|
origin.output = operation(left, right).asOutput()
|
||||||
|
}
|
||||||
|
|
||||||
|
private inline fun Tensor<T>.inPlaceOp(
|
||||||
|
value: T,
|
||||||
|
operation: (left: Operand<TT>, right: Operand<TT>) -> Operand<TT>,
|
||||||
|
): Unit {
|
||||||
|
val origin = asTensorFlow()
|
||||||
|
val left = origin.output
|
||||||
|
val right = const(value)
|
||||||
|
origin.output = operation(left, right).asOutput()
|
||||||
|
}
|
||||||
|
|
||||||
|
private inline fun StructureND<T>.unOp(operation: (Operand<TT>) -> Operand<TT>): TensorFlowOutput<T, TT> =
|
||||||
|
operation(asTensorFlow().output).asOutput().wrap()
|
||||||
|
|
||||||
|
override fun T.plus(arg: StructureND<T>): TensorFlowOutput<T, TT> = biOp(arg, ops.math::add)
|
||||||
|
|
||||||
|
override fun StructureND<T>.plus(arg: T): TensorFlowOutput<T, TT> = biOp(arg, ops.math::add)
|
||||||
|
|
||||||
|
override fun StructureND<T>.plus(arg: StructureND<T>): TensorFlowOutput<T, TT> = biOp(arg, ops.math::add)
|
||||||
|
|
||||||
|
override fun Tensor<T>.plusAssign(value: T): Unit = inPlaceOp(value, ops.math::add)
|
||||||
|
|
||||||
|
override fun Tensor<T>.plusAssign(arg: StructureND<T>): Unit = inPlaceOp(arg, ops.math::add)
|
||||||
|
|
||||||
|
override fun StructureND<T>.minus(arg: T): TensorFlowOutput<T, TT> = biOp(arg, ops.math::sub)
|
||||||
|
|
||||||
|
override fun StructureND<T>.minus(arg: StructureND<T>): TensorFlowOutput<T, TT> = biOp(arg, ops.math::sub)
|
||||||
|
|
||||||
|
override fun T.minus(arg: StructureND<T>): Tensor<T> = biOp(arg, ops.math::sub)
|
||||||
|
|
||||||
|
override fun Tensor<T>.minusAssign(value: T): Unit = inPlaceOp(value, ops.math::sub)
|
||||||
|
|
||||||
|
override fun Tensor<T>.minusAssign(arg: StructureND<T>): Unit = inPlaceOp(arg, ops.math::sub)
|
||||||
|
|
||||||
|
override fun T.times(arg: StructureND<T>): TensorFlowOutput<T, TT> = biOp(arg, ops.math::mul)
|
||||||
|
|
||||||
|
override fun StructureND<T>.times(arg: T): TensorFlowOutput<T, TT> = biOp(arg, ops.math::mul)
|
||||||
|
|
||||||
|
override fun StructureND<T>.times(arg: StructureND<T>): TensorFlowOutput<T, TT> = biOp(arg, ops.math::mul)
|
||||||
|
|
||||||
|
override fun Tensor<T>.timesAssign(value: T): Unit = inPlaceOp(value, ops.math::mul)
|
||||||
|
|
||||||
|
override fun Tensor<T>.timesAssign(arg: StructureND<T>): Unit = inPlaceOp(arg, ops.math::mul)
|
||||||
|
|
||||||
|
override fun StructureND<T>.unaryMinus(): TensorFlowOutput<T, TT> = unOp(ops.math::neg)
|
||||||
|
|
||||||
|
override fun Tensor<T>.get(i: Int): Tensor<T> = unOp {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun Tensor<T>.transpose(i: Int, j: Int): Tensor<T> = unOp {
|
||||||
|
ops.linalg.transpose(it, ops.constant(intArrayOf(i, j)))
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun Tensor<T>.view(shape: IntArray): Tensor<T> = unOp {
|
||||||
|
ops.reshape(it, ops.constant(shape))
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun Tensor<T>.viewAs(other: StructureND<T>): Tensor<T> = biOp(other) { l, r ->
|
||||||
|
ops.reshape(l, ops.shape(r))
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<T>.dot(other: StructureND<T>): TensorFlowOutput<T, TT> = biOp(other) { l, r ->
|
||||||
|
ops.linalg.matMul(
|
||||||
|
if (l.asTensor().shape().numDimensions() == 1) ops.expandDims(l,ops.constant(0)) else l,
|
||||||
|
if (r.asTensor().shape().numDimensions() == 1) ops.expandDims(r,ops.constant(-1)) else r)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun diagonalEmbedding(
|
||||||
|
diagonalEntries: Tensor<T>,
|
||||||
|
offset: Int,
|
||||||
|
dim1: Int,
|
||||||
|
dim2: Int,
|
||||||
|
): TensorFlowOutput<T, TT> = diagonalEntries.unOp {
|
||||||
|
TODO()
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<T>.sum(): T = unOp {
|
||||||
|
ops.sum(it, ops.constant(intArrayOf()))
|
||||||
|
}.value()
|
||||||
|
|
||||||
|
override fun StructureND<T>.sum(dim: Int, keepDim: Boolean): TensorFlowOutput<T, TT> = unOp {
|
||||||
|
ops.sum(it, ops.constant(dim), Sum.keepDims(keepDim))
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<T>.min(): T = unOp {
|
||||||
|
ops.min(it, ops.constant(intArrayOf()))
|
||||||
|
}.value()
|
||||||
|
|
||||||
|
override fun StructureND<T>.min(dim: Int, keepDim: Boolean): Tensor<T> = unOp {
|
||||||
|
ops.min(it, ops.constant(dim), Min.keepDims(keepDim))
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<T>.max(): T = unOp {
|
||||||
|
ops.max(it, ops.constant(intArrayOf()))
|
||||||
|
}.value()
|
||||||
|
|
||||||
|
override fun StructureND<T>.max(dim: Int, keepDim: Boolean): Tensor<T> = unOp {
|
||||||
|
ops.max(it, ops.constant(dim), Max.keepDims(keepDim))
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<T>.argMax(dim: Int, keepDim: Boolean): Tensor<Int> = IntTensorFlowOutput(
|
||||||
|
graph,
|
||||||
|
ops.math.argMax(asTensorFlow().output, ops.constant(dim), TInt32::class.java).output()
|
||||||
|
).actualTensor
|
||||||
|
|
||||||
|
@OptIn(UnstableKMathAPI::class)
|
||||||
|
override fun export(arg: StructureND<T>): StructureND<T> =
|
||||||
|
if (arg is TensorFlowOutput<T, *>) arg.actualTensor else arg
|
||||||
|
}
|
||||||
|
|
||||||
|
//TODO add TensorFlow expressions
|
@ -0,0 +1,19 @@
|
|||||||
|
package space.kscience.kmath.tensorflow
|
||||||
|
|
||||||
|
import org.junit.jupiter.api.Test
|
||||||
|
import space.kscience.kmath.nd.StructureND
|
||||||
|
import space.kscience.kmath.nd.structureND
|
||||||
|
import space.kscience.kmath.operations.DoubleField
|
||||||
|
|
||||||
|
class DoubleTensorFlowOps {
|
||||||
|
@Test
|
||||||
|
fun basicOps() {
|
||||||
|
val res = DoubleField.produceWithTF {
|
||||||
|
val initial = structureND(2, 2) { 1.0 }
|
||||||
|
|
||||||
|
initial + (initial * 2.0)
|
||||||
|
}
|
||||||
|
println(StructureND.toString(res))
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -208,7 +208,7 @@ public interface TensorAlgebra<T, A : Ring<T>> : RingOpsND<T, A> {
|
|||||||
*
|
*
|
||||||
* 3. If the first argument is 1-dimensional and the second argument is 2-dimensional,
|
* 3. If the first argument is 1-dimensional and the second argument is 2-dimensional,
|
||||||
* a 1 is prepended to its dimension for the purpose of the matrix multiply.
|
* a 1 is prepended to its dimension for the purpose of the matrix multiply.
|
||||||
* After the matrix multiply, the prepended dimension is removed.
|
* After the matrix multiply, depending on the implementation the prepended dimension might be removed.
|
||||||
*
|
*
|
||||||
* 4. If the first argument is 2-dimensional and the second argument is 1-dimensional,
|
* 4. If the first argument is 2-dimensional and the second argument is 1-dimensional,
|
||||||
* the matrix-vector product is returned.
|
* the matrix-vector product is returned.
|
||||||
|
@ -424,13 +424,13 @@ public open class DoubleTensorAlgebra :
|
|||||||
dotTo(a.as2D(), b.as2D(), res.as2D(), l, m1, n)
|
dotTo(a.as2D(), b.as2D(), res.as2D(), l, m1, n)
|
||||||
}
|
}
|
||||||
|
|
||||||
if (penultimateDim) {
|
return if (penultimateDim) {
|
||||||
return resTensor.view(resTensor.shape.dropLast(2).toIntArray() + intArrayOf(resTensor.shape.last()))
|
resTensor.view(resTensor.shape.dropLast(2).toIntArray() + intArrayOf(resTensor.shape.last()))
|
||||||
|
} else if (lastDim) {
|
||||||
|
resTensor.view(resTensor.shape.dropLast(1).toIntArray())
|
||||||
|
} else {
|
||||||
|
resTensor
|
||||||
}
|
}
|
||||||
if (lastDim) {
|
|
||||||
return resTensor.view(resTensor.shape.dropLast(1).toIntArray())
|
|
||||||
}
|
|
||||||
return resTensor
|
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun diagonalEmbedding(
|
override fun diagonalEmbedding(
|
||||||
|
@ -3,6 +3,8 @@
|
|||||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
@file:OptIn(PerformancePitfall::class)
|
||||||
|
|
||||||
package space.kscience.kmath.viktor
|
package space.kscience.kmath.viktor
|
||||||
|
|
||||||
import org.jetbrains.bio.viktor.F64Array
|
import org.jetbrains.bio.viktor.F64Array
|
||||||
|
@ -29,6 +29,7 @@ include(
|
|||||||
":kmath-commons",
|
":kmath-commons",
|
||||||
":kmath-viktor",
|
":kmath-viktor",
|
||||||
":kmath-multik",
|
":kmath-multik",
|
||||||
|
":kmath-tensorflow",
|
||||||
":kmath-optimization",
|
":kmath-optimization",
|
||||||
":kmath-stat",
|
":kmath-stat",
|
||||||
":kmath-nd4j",
|
":kmath-nd4j",
|
||||||
|
Loading…
Reference in New Issue
Block a user