Make ND4J float algebra extended
This commit is contained in:
parent
477f75270c
commit
9c353f4a0d
@ -10,6 +10,7 @@
|
||||
- Blocking chains and Statistics
|
||||
- Multiplatform integration
|
||||
- Integration for any Field element
|
||||
- Extendend operations for ND4J fields
|
||||
|
||||
### Changed
|
||||
- Exponential operations merged with hyperbolic functions
|
||||
|
@ -6,7 +6,10 @@
|
||||
package space.kscience.kmath.nd4j
|
||||
|
||||
import org.nd4j.linalg.api.ndarray.INDArray
|
||||
import org.nd4j.linalg.api.ops.impl.scalar.Pow
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.*
|
||||
import org.nd4j.linalg.factory.Nd4j
|
||||
import org.nd4j.linalg.ops.transforms.Transforms
|
||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||
import space.kscience.kmath.nd.*
|
||||
import space.kscience.kmath.operations.*
|
||||
@ -207,7 +210,8 @@ public interface Nd4jArrayField<T, F : Field<T>> : FieldND<T, F>, Nd4jArrayRing<
|
||||
/**
|
||||
* Represents [FieldND] over [Nd4jArrayDoubleStructure].
|
||||
*/
|
||||
public class DoubleNd4jArrayField(public override val shape: IntArray) : Nd4jArrayField<Double, DoubleField> {
|
||||
public class DoubleNd4jArrayField(public override val shape: IntArray) : Nd4jArrayField<Double, DoubleField>,
|
||||
ExtendedField<StructureND<Double>> {
|
||||
public override val elementContext: DoubleField get() = DoubleField
|
||||
|
||||
public override fun INDArray.wrap(): Nd4jArrayStructure<Double> = checkShape(this).asDoubleStructure()
|
||||
@ -239,14 +243,31 @@ public class DoubleNd4jArrayField(public override val shape: IntArray) : Nd4jArr
|
||||
public override operator fun Double.minus(arg: StructureND<Double>): Nd4jArrayStructure<Double> {
|
||||
return arg.ndArray.rsub(this).wrap()
|
||||
}
|
||||
|
||||
override fun sin(arg: StructureND<Double>): StructureND<Double> = Transforms.sin(arg.ndArray).wrap()
|
||||
|
||||
override fun cos(arg: StructureND<Double>): StructureND<Double> = Transforms.cos(arg.ndArray).wrap()
|
||||
|
||||
override fun asin(arg: StructureND<Double>): StructureND<Double> = Transforms.asin(arg.ndArray).wrap()
|
||||
|
||||
override fun acos(arg: StructureND<Double>): StructureND<Double> = Transforms.acos(arg.ndArray).wrap()
|
||||
|
||||
override fun atan(arg: StructureND<Double>): StructureND<Double> = Transforms.atan(arg.ndArray).wrap()
|
||||
|
||||
override fun power(arg: StructureND<Double>, pow: Number): StructureND<Double> =
|
||||
Transforms.pow(arg.ndArray,pow).wrap()
|
||||
|
||||
override fun exp(arg: StructureND<Double>): StructureND<Double> = Transforms.exp(arg.ndArray).wrap()
|
||||
|
||||
override fun ln(arg: StructureND<Double>): StructureND<Double> = Transforms.log(arg.ndArray).wrap()
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents [FieldND] over [Nd4jArrayStructure] of [Float].
|
||||
*/
|
||||
public class FloatNd4jArrayField(public override val shape: IntArray) : Nd4jArrayField<Float, FloatField> {
|
||||
public override val elementContext: FloatField
|
||||
get() = FloatField
|
||||
public class FloatNd4jArrayField(public override val shape: IntArray) : Nd4jArrayField<Float, FloatField>,
|
||||
ExtendedField<StructureND<Float>> {
|
||||
public override val elementContext: FloatField get() = FloatField
|
||||
|
||||
public override fun INDArray.wrap(): Nd4jArrayStructure<Float> = checkShape(this).asFloatStructure()
|
||||
|
||||
@ -270,6 +291,23 @@ public class FloatNd4jArrayField(public override val shape: IntArray) : Nd4jArra
|
||||
|
||||
public override operator fun Float.minus(arg: StructureND<Float>): Nd4jArrayStructure<Float> =
|
||||
arg.ndArray.rsub(this).wrap()
|
||||
|
||||
override fun sin(arg: StructureND<Float>): StructureND<Float> = Sin(arg.ndArray).z().wrap()
|
||||
|
||||
override fun cos(arg: StructureND<Float>): StructureND<Float> = Cos(arg.ndArray).z().wrap()
|
||||
|
||||
override fun asin(arg: StructureND<Float>): StructureND<Float> = ASin(arg.ndArray).z().wrap()
|
||||
|
||||
override fun acos(arg: StructureND<Float>): StructureND<Float> = ACos(arg.ndArray).z().wrap()
|
||||
|
||||
override fun atan(arg: StructureND<Float>): StructureND<Float> = ATan(arg.ndArray).z().wrap()
|
||||
|
||||
override fun power(arg: StructureND<Float>, pow: Number): StructureND<Float> =
|
||||
Pow(arg.ndArray, pow.toDouble()).z().wrap()
|
||||
|
||||
override fun exp(arg: StructureND<Float>): StructureND<Float> = Exp(arg.ndArray).z().wrap()
|
||||
|
||||
override fun ln(arg: StructureND<Float>): StructureND<Float> = Log(arg.ndArray).z().wrap()
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -6,8 +6,12 @@
|
||||
package space.kscience.kmath.nd4j
|
||||
|
||||
import org.nd4j.linalg.factory.Nd4j
|
||||
import space.kscience.kmath.nd.StructureND
|
||||
import space.kscience.kmath.operations.invoke
|
||||
import kotlin.math.PI
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertTrue
|
||||
import kotlin.test.fail
|
||||
|
||||
internal class Nd4jArrayAlgebraTest {
|
||||
@ -43,4 +47,14 @@ internal class Nd4jArrayAlgebraTest {
|
||||
expected[intArrayOf(1, 1)] = 26
|
||||
assertEquals(expected, res)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testSin() = DoubleNd4jArrayField(intArrayOf(2, 2)).invoke {
|
||||
val initial = produce { (i, j) -> if (i == j) PI/2 else 0.0 }
|
||||
val transformed = sin(initial)
|
||||
val expected = produce { (i, j) -> if (i == j) 1.0 else 0.0 }
|
||||
|
||||
println(transformed)
|
||||
assertTrue { StructureND.contentEquals(transformed, expected) }
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user