Make ND4J float algebra extended

This commit is contained in:
Alexander Nozik 2021-04-20 22:48:09 +03:00
parent 477f75270c
commit 9c353f4a0d
3 changed files with 57 additions and 4 deletions

View File

@ -10,6 +10,7 @@
- Blocking chains and Statistics
- Multiplatform integration
- Integration for any Field element
- Extendend operations for ND4J fields
### Changed
- Exponential operations merged with hyperbolic functions

View File

@ -6,7 +6,10 @@
package space.kscience.kmath.nd4j
import org.nd4j.linalg.api.ndarray.INDArray
import org.nd4j.linalg.api.ops.impl.scalar.Pow
import org.nd4j.linalg.api.ops.impl.transforms.strict.*
import org.nd4j.linalg.factory.Nd4j
import org.nd4j.linalg.ops.transforms.Transforms
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.nd.*
import space.kscience.kmath.operations.*
@ -207,7 +210,8 @@ public interface Nd4jArrayField<T, F : Field<T>> : FieldND<T, F>, Nd4jArrayRing<
/**
* Represents [FieldND] over [Nd4jArrayDoubleStructure].
*/
public class DoubleNd4jArrayField(public override val shape: IntArray) : Nd4jArrayField<Double, DoubleField> {
public class DoubleNd4jArrayField(public override val shape: IntArray) : Nd4jArrayField<Double, DoubleField>,
ExtendedField<StructureND<Double>> {
public override val elementContext: DoubleField get() = DoubleField
public override fun INDArray.wrap(): Nd4jArrayStructure<Double> = checkShape(this).asDoubleStructure()
@ -239,14 +243,31 @@ public class DoubleNd4jArrayField(public override val shape: IntArray) : Nd4jArr
public override operator fun Double.minus(arg: StructureND<Double>): Nd4jArrayStructure<Double> {
return arg.ndArray.rsub(this).wrap()
}
override fun sin(arg: StructureND<Double>): StructureND<Double> = Transforms.sin(arg.ndArray).wrap()
override fun cos(arg: StructureND<Double>): StructureND<Double> = Transforms.cos(arg.ndArray).wrap()
override fun asin(arg: StructureND<Double>): StructureND<Double> = Transforms.asin(arg.ndArray).wrap()
override fun acos(arg: StructureND<Double>): StructureND<Double> = Transforms.acos(arg.ndArray).wrap()
override fun atan(arg: StructureND<Double>): StructureND<Double> = Transforms.atan(arg.ndArray).wrap()
override fun power(arg: StructureND<Double>, pow: Number): StructureND<Double> =
Transforms.pow(arg.ndArray,pow).wrap()
override fun exp(arg: StructureND<Double>): StructureND<Double> = Transforms.exp(arg.ndArray).wrap()
override fun ln(arg: StructureND<Double>): StructureND<Double> = Transforms.log(arg.ndArray).wrap()
}
/**
* Represents [FieldND] over [Nd4jArrayStructure] of [Float].
*/
public class FloatNd4jArrayField(public override val shape: IntArray) : Nd4jArrayField<Float, FloatField> {
public override val elementContext: FloatField
get() = FloatField
public class FloatNd4jArrayField(public override val shape: IntArray) : Nd4jArrayField<Float, FloatField>,
ExtendedField<StructureND<Float>> {
public override val elementContext: FloatField get() = FloatField
public override fun INDArray.wrap(): Nd4jArrayStructure<Float> = checkShape(this).asFloatStructure()
@ -270,6 +291,23 @@ public class FloatNd4jArrayField(public override val shape: IntArray) : Nd4jArra
public override operator fun Float.minus(arg: StructureND<Float>): Nd4jArrayStructure<Float> =
arg.ndArray.rsub(this).wrap()
override fun sin(arg: StructureND<Float>): StructureND<Float> = Sin(arg.ndArray).z().wrap()
override fun cos(arg: StructureND<Float>): StructureND<Float> = Cos(arg.ndArray).z().wrap()
override fun asin(arg: StructureND<Float>): StructureND<Float> = ASin(arg.ndArray).z().wrap()
override fun acos(arg: StructureND<Float>): StructureND<Float> = ACos(arg.ndArray).z().wrap()
override fun atan(arg: StructureND<Float>): StructureND<Float> = ATan(arg.ndArray).z().wrap()
override fun power(arg: StructureND<Float>, pow: Number): StructureND<Float> =
Pow(arg.ndArray, pow.toDouble()).z().wrap()
override fun exp(arg: StructureND<Float>): StructureND<Float> = Exp(arg.ndArray).z().wrap()
override fun ln(arg: StructureND<Float>): StructureND<Float> = Log(arg.ndArray).z().wrap()
}
/**

View File

@ -6,8 +6,12 @@
package space.kscience.kmath.nd4j
import org.nd4j.linalg.factory.Nd4j
import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.operations.invoke
import kotlin.math.PI
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertTrue
import kotlin.test.fail
internal class Nd4jArrayAlgebraTest {
@ -43,4 +47,14 @@ internal class Nd4jArrayAlgebraTest {
expected[intArrayOf(1, 1)] = 26
assertEquals(expected, res)
}
@Test
fun testSin() = DoubleNd4jArrayField(intArrayOf(2, 2)).invoke {
val initial = produce { (i, j) -> if (i == j) PI/2 else 0.0 }
val transformed = sin(initial)
val expected = produce { (i, j) -> if (i == j) 1.0 else 0.0 }
println(transformed)
assertTrue { StructureND.contentEquals(transformed, expected) }
}
}