Make ND4J float algebra extended

This commit is contained in:
Alexander Nozik 2021-04-20 22:48:09 +03:00
parent 477f75270c
commit 9c353f4a0d
3 changed files with 57 additions and 4 deletions

View File

@ -10,6 +10,7 @@
- Blocking chains and Statistics - Blocking chains and Statistics
- Multiplatform integration - Multiplatform integration
- Integration for any Field element - Integration for any Field element
- Extendend operations for ND4J fields
### Changed ### Changed
- Exponential operations merged with hyperbolic functions - Exponential operations merged with hyperbolic functions

View File

@ -6,7 +6,10 @@
package space.kscience.kmath.nd4j package space.kscience.kmath.nd4j
import org.nd4j.linalg.api.ndarray.INDArray import org.nd4j.linalg.api.ndarray.INDArray
import org.nd4j.linalg.api.ops.impl.scalar.Pow
import org.nd4j.linalg.api.ops.impl.transforms.strict.*
import org.nd4j.linalg.factory.Nd4j import org.nd4j.linalg.factory.Nd4j
import org.nd4j.linalg.ops.transforms.Transforms
import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.nd.* import space.kscience.kmath.nd.*
import space.kscience.kmath.operations.* import space.kscience.kmath.operations.*
@ -207,7 +210,8 @@ public interface Nd4jArrayField<T, F : Field<T>> : FieldND<T, F>, Nd4jArrayRing<
/** /**
* Represents [FieldND] over [Nd4jArrayDoubleStructure]. * Represents [FieldND] over [Nd4jArrayDoubleStructure].
*/ */
public class DoubleNd4jArrayField(public override val shape: IntArray) : Nd4jArrayField<Double, DoubleField> { public class DoubleNd4jArrayField(public override val shape: IntArray) : Nd4jArrayField<Double, DoubleField>,
ExtendedField<StructureND<Double>> {
public override val elementContext: DoubleField get() = DoubleField public override val elementContext: DoubleField get() = DoubleField
public override fun INDArray.wrap(): Nd4jArrayStructure<Double> = checkShape(this).asDoubleStructure() public override fun INDArray.wrap(): Nd4jArrayStructure<Double> = checkShape(this).asDoubleStructure()
@ -239,14 +243,31 @@ public class DoubleNd4jArrayField(public override val shape: IntArray) : Nd4jArr
public override operator fun Double.minus(arg: StructureND<Double>): Nd4jArrayStructure<Double> { public override operator fun Double.minus(arg: StructureND<Double>): Nd4jArrayStructure<Double> {
return arg.ndArray.rsub(this).wrap() return arg.ndArray.rsub(this).wrap()
} }
override fun sin(arg: StructureND<Double>): StructureND<Double> = Transforms.sin(arg.ndArray).wrap()
override fun cos(arg: StructureND<Double>): StructureND<Double> = Transforms.cos(arg.ndArray).wrap()
override fun asin(arg: StructureND<Double>): StructureND<Double> = Transforms.asin(arg.ndArray).wrap()
override fun acos(arg: StructureND<Double>): StructureND<Double> = Transforms.acos(arg.ndArray).wrap()
override fun atan(arg: StructureND<Double>): StructureND<Double> = Transforms.atan(arg.ndArray).wrap()
override fun power(arg: StructureND<Double>, pow: Number): StructureND<Double> =
Transforms.pow(arg.ndArray,pow).wrap()
override fun exp(arg: StructureND<Double>): StructureND<Double> = Transforms.exp(arg.ndArray).wrap()
override fun ln(arg: StructureND<Double>): StructureND<Double> = Transforms.log(arg.ndArray).wrap()
} }
/** /**
* Represents [FieldND] over [Nd4jArrayStructure] of [Float]. * Represents [FieldND] over [Nd4jArrayStructure] of [Float].
*/ */
public class FloatNd4jArrayField(public override val shape: IntArray) : Nd4jArrayField<Float, FloatField> { public class FloatNd4jArrayField(public override val shape: IntArray) : Nd4jArrayField<Float, FloatField>,
public override val elementContext: FloatField ExtendedField<StructureND<Float>> {
get() = FloatField public override val elementContext: FloatField get() = FloatField
public override fun INDArray.wrap(): Nd4jArrayStructure<Float> = checkShape(this).asFloatStructure() public override fun INDArray.wrap(): Nd4jArrayStructure<Float> = checkShape(this).asFloatStructure()
@ -270,6 +291,23 @@ public class FloatNd4jArrayField(public override val shape: IntArray) : Nd4jArra
public override operator fun Float.minus(arg: StructureND<Float>): Nd4jArrayStructure<Float> = public override operator fun Float.minus(arg: StructureND<Float>): Nd4jArrayStructure<Float> =
arg.ndArray.rsub(this).wrap() arg.ndArray.rsub(this).wrap()
override fun sin(arg: StructureND<Float>): StructureND<Float> = Sin(arg.ndArray).z().wrap()
override fun cos(arg: StructureND<Float>): StructureND<Float> = Cos(arg.ndArray).z().wrap()
override fun asin(arg: StructureND<Float>): StructureND<Float> = ASin(arg.ndArray).z().wrap()
override fun acos(arg: StructureND<Float>): StructureND<Float> = ACos(arg.ndArray).z().wrap()
override fun atan(arg: StructureND<Float>): StructureND<Float> = ATan(arg.ndArray).z().wrap()
override fun power(arg: StructureND<Float>, pow: Number): StructureND<Float> =
Pow(arg.ndArray, pow.toDouble()).z().wrap()
override fun exp(arg: StructureND<Float>): StructureND<Float> = Exp(arg.ndArray).z().wrap()
override fun ln(arg: StructureND<Float>): StructureND<Float> = Log(arg.ndArray).z().wrap()
} }
/** /**

View File

@ -6,8 +6,12 @@
package space.kscience.kmath.nd4j package space.kscience.kmath.nd4j
import org.nd4j.linalg.factory.Nd4j import org.nd4j.linalg.factory.Nd4j
import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.operations.invoke
import kotlin.math.PI
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
import kotlin.test.assertTrue
import kotlin.test.fail import kotlin.test.fail
internal class Nd4jArrayAlgebraTest { internal class Nd4jArrayAlgebraTest {
@ -43,4 +47,14 @@ internal class Nd4jArrayAlgebraTest {
expected[intArrayOf(1, 1)] = 26 expected[intArrayOf(1, 1)] = 26
assertEquals(expected, res) assertEquals(expected, res)
} }
@Test
fun testSin() = DoubleNd4jArrayField(intArrayOf(2, 2)).invoke {
val initial = produce { (i, j) -> if (i == j) PI/2 else 0.0 }
val transformed = sin(initial)
val expected = produce { (i, j) -> if (i == j) 1.0 else 0.0 }
println(transformed)
assertTrue { StructureND.contentEquals(transformed, expected) }
}
} }