altavir/diff #494
@ -10,7 +10,7 @@ allprojects {
|
|||||||
}
|
}
|
||||||
|
|
||||||
group = "space.kscience"
|
group = "space.kscience"
|
||||||
version = "0.3.0-dev-17"
|
version = "0.3.0-dev-18"
|
||||||
}
|
}
|
||||||
|
|
||||||
subprojects {
|
subprojects {
|
||||||
|
@ -164,8 +164,6 @@ public open class FunctionalExpressionExtendedField<T, out A : ExtendedField<T>>
|
|||||||
|
|
||||||
override fun binaryOperationFunction(operation: String): (left: Expression<T>, right: Expression<T>) -> Expression<T> =
|
override fun binaryOperationFunction(operation: String): (left: Expression<T>, right: Expression<T>) -> Expression<T> =
|
||||||
super<FunctionalExpressionField>.binaryOperationFunction(operation)
|
super<FunctionalExpressionField>.binaryOperationFunction(operation)
|
||||||
|
|
||||||
override fun bindSymbol(value: String): Expression<T> = super<FunctionalExpressionField>.bindSymbol(value)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public inline fun <T, A : Group<T>> A.expressionInGroup(
|
public inline fun <T, A : Group<T>> A.expressionInGroup(
|
||||||
|
@ -272,7 +272,7 @@ public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.sqrt(x: Aut
|
|||||||
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.pow(
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.pow(
|
||||||
x: AutoDiffValue<T>,
|
x: AutoDiffValue<T>,
|
||||||
y: Double,
|
y: Double,
|
||||||
): AutoDiffValue<T> = derive(const { x.value.pow(y)}) { z ->
|
): AutoDiffValue<T> = derive(const { x.value.pow(y) }) { z ->
|
||||||
x.d += z.d * y * x.value.pow(y - 1)
|
x.d += z.d * y * x.value.pow(y - 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -343,10 +343,7 @@ public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.atanh(x: Au
|
|||||||
public class SimpleAutoDiffExtendedField<T : Any, F : ExtendedField<T>>(
|
public class SimpleAutoDiffExtendedField<T : Any, F : ExtendedField<T>>(
|
||||||
context: F,
|
context: F,
|
||||||
bindings: Map<Symbol, T>,
|
bindings: Map<Symbol, T>,
|
||||||
) : ExtendedField<AutoDiffValue<T>>, ScaleOperations<AutoDiffValue<T>>,
|
) : ExtendedField<AutoDiffValue<T>>, ScaleOperations<AutoDiffValue<T>>, SimpleAutoDiffField<T, F>(context, bindings) {
|
||||||
SimpleAutoDiffField<T, F>(context, bindings) {
|
|
||||||
|
|
||||||
override fun bindSymbol(value: String): AutoDiffValue<T> = super<SimpleAutoDiffField>.bindSymbol(value)
|
|
||||||
|
|
||||||
override fun number(value: Number): AutoDiffValue<T> = const { number(value) }
|
override fun number(value: Number): AutoDiffValue<T> = const { number(value) }
|
||||||
|
|
||||||
|
@ -241,6 +241,16 @@ public abstract class TensorFlowAlgebra<T, TT : TNumber, A : Ring<T>> internal c
|
|||||||
ops.math.argMax(asTensorFlow().output, ops.constant(dim), TInt32::class.java).output()
|
ops.math.argMax(asTensorFlow().output, ops.constant(dim), TInt32::class.java).output()
|
||||||
).actualTensor
|
).actualTensor
|
||||||
|
|
||||||
|
// private val symbolCache = HashMap<String, TensorFlowOutput<T, TT>>()
|
||||||
|
//
|
||||||
|
// override fun bindSymbolOrNull(value: String): TensorFlowOutput<T, TT>? {
|
||||||
|
// return symbolCache.getOrPut(value){ops.var}
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// public fun StructureND<T>.grad(
|
||||||
|
//
|
||||||
|
// )= operate { ops.gradients() }
|
||||||
|
|
||||||
@OptIn(UnstableKMathAPI::class)
|
@OptIn(UnstableKMathAPI::class)
|
||||||
override fun export(arg: StructureND<T>): StructureND<T> =
|
override fun export(arg: StructureND<T>): StructureND<T> =
|
||||||
if (arg is TensorFlowOutput<T, *>) arg.actualTensor else arg
|
if (arg is TensorFlowOutput<T, *>) arg.actualTensor else arg
|
||||||
|
Loading…
Reference in New Issue
Block a user