altavir/diff #494

Merged
altavir merged 71 commits from altavir/diff into commandertvis/diff 2022-07-16 09:58:44 +03:00
5 changed files with 14 additions and 9 deletions
Showing only changes of commit 408443989c - Show all commits

View File

@ -10,7 +10,7 @@ allprojects {
} }
group = "space.kscience" group = "space.kscience"
version = "0.3.0-dev-17" version = "0.3.0-dev-18"
} }
subprojects { subprojects {

View File

@ -164,8 +164,6 @@ public open class FunctionalExpressionExtendedField<T, out A : ExtendedField<T>>
override fun binaryOperationFunction(operation: String): (left: Expression<T>, right: Expression<T>) -> Expression<T> = override fun binaryOperationFunction(operation: String): (left: Expression<T>, right: Expression<T>) -> Expression<T> =
super<FunctionalExpressionField>.binaryOperationFunction(operation) super<FunctionalExpressionField>.binaryOperationFunction(operation)
override fun bindSymbol(value: String): Expression<T> = super<FunctionalExpressionField>.bindSymbol(value)
} }
public inline fun <T, A : Group<T>> A.expressionInGroup( public inline fun <T, A : Group<T>> A.expressionInGroup(

View File

@ -272,7 +272,7 @@ public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.sqrt(x: Aut
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.pow( public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.pow(
x: AutoDiffValue<T>, x: AutoDiffValue<T>,
y: Double, y: Double,
): AutoDiffValue<T> = derive(const { x.value.pow(y)}) { z -> ): AutoDiffValue<T> = derive(const { x.value.pow(y) }) { z ->
x.d += z.d * y * x.value.pow(y - 1) x.d += z.d * y * x.value.pow(y - 1)
} }
@ -343,10 +343,7 @@ public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.atanh(x: Au
public class SimpleAutoDiffExtendedField<T : Any, F : ExtendedField<T>>( public class SimpleAutoDiffExtendedField<T : Any, F : ExtendedField<T>>(
context: F, context: F,
bindings: Map<Symbol, T>, bindings: Map<Symbol, T>,
) : ExtendedField<AutoDiffValue<T>>, ScaleOperations<AutoDiffValue<T>>, ) : ExtendedField<AutoDiffValue<T>>, ScaleOperations<AutoDiffValue<T>>, SimpleAutoDiffField<T, F>(context, bindings) {
SimpleAutoDiffField<T, F>(context, bindings) {
override fun bindSymbol(value: String): AutoDiffValue<T> = super<SimpleAutoDiffField>.bindSymbol(value)
override fun number(value: Number): AutoDiffValue<T> = const { number(value) } override fun number(value: Number): AutoDiffValue<T> = const { number(value) }

View File

@ -241,6 +241,16 @@ public abstract class TensorFlowAlgebra<T, TT : TNumber, A : Ring<T>> internal c
ops.math.argMax(asTensorFlow().output, ops.constant(dim), TInt32::class.java).output() ops.math.argMax(asTensorFlow().output, ops.constant(dim), TInt32::class.java).output()
).actualTensor ).actualTensor
// private val symbolCache = HashMap<String, TensorFlowOutput<T, TT>>()
//
// override fun bindSymbolOrNull(value: String): TensorFlowOutput<T, TT>? {
// return symbolCache.getOrPut(value){ops.var}
// }
//
// public fun StructureND<T>.grad(
//
// )= operate { ops.gradients() }
@OptIn(UnstableKMathAPI::class) @OptIn(UnstableKMathAPI::class)
override fun export(arg: StructureND<T>): StructureND<T> = override fun export(arg: StructureND<T>): StructureND<T> =
if (arg is TensorFlowOutput<T, *>) arg.actualTensor else arg if (arg is TensorFlowOutput<T, *>) arg.actualTensor else arg

View File

@ -20,4 +20,4 @@ public fun <T, TT : TNumber, A> TensorFlowAlgebra<T, TT, A>.sin(
public fun <T, TT : TNumber, A> TensorFlowAlgebra<T, TT, A>.cos( public fun <T, TT : TNumber, A> TensorFlowAlgebra<T, TT, A>.cos(
arg: StructureND<T>, arg: StructureND<T>,
): TensorFlowOutput<T, TT> where A : TrigonometricOperations<T>, A : Ring<T> = arg.operate { ops.math.cos(it) } ): TensorFlowOutput<T, TT> where A : TrigonometricOperations<T>, A : Ring<T> = arg.operate { ops.math.cos(it) }