forked from kscience/kmath
Make ND4J float algebra extended
This commit is contained in:
parent
477f75270c
commit
9c353f4a0d
@ -10,6 +10,7 @@
|
|||||||
- Blocking chains and Statistics
|
- Blocking chains and Statistics
|
||||||
- Multiplatform integration
|
- Multiplatform integration
|
||||||
- Integration for any Field element
|
- Integration for any Field element
|
||||||
|
- Extendend operations for ND4J fields
|
||||||
|
|
||||||
### Changed
|
### Changed
|
||||||
- Exponential operations merged with hyperbolic functions
|
- Exponential operations merged with hyperbolic functions
|
||||||
|
@ -6,7 +6,10 @@
|
|||||||
package space.kscience.kmath.nd4j
|
package space.kscience.kmath.nd4j
|
||||||
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray
|
import org.nd4j.linalg.api.ndarray.INDArray
|
||||||
|
import org.nd4j.linalg.api.ops.impl.scalar.Pow
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.strict.*
|
||||||
import org.nd4j.linalg.factory.Nd4j
|
import org.nd4j.linalg.factory.Nd4j
|
||||||
|
import org.nd4j.linalg.ops.transforms.Transforms
|
||||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
import space.kscience.kmath.nd.*
|
import space.kscience.kmath.nd.*
|
||||||
import space.kscience.kmath.operations.*
|
import space.kscience.kmath.operations.*
|
||||||
@ -207,7 +210,8 @@ public interface Nd4jArrayField<T, F : Field<T>> : FieldND<T, F>, Nd4jArrayRing<
|
|||||||
/**
|
/**
|
||||||
* Represents [FieldND] over [Nd4jArrayDoubleStructure].
|
* Represents [FieldND] over [Nd4jArrayDoubleStructure].
|
||||||
*/
|
*/
|
||||||
public class DoubleNd4jArrayField(public override val shape: IntArray) : Nd4jArrayField<Double, DoubleField> {
|
public class DoubleNd4jArrayField(public override val shape: IntArray) : Nd4jArrayField<Double, DoubleField>,
|
||||||
|
ExtendedField<StructureND<Double>> {
|
||||||
public override val elementContext: DoubleField get() = DoubleField
|
public override val elementContext: DoubleField get() = DoubleField
|
||||||
|
|
||||||
public override fun INDArray.wrap(): Nd4jArrayStructure<Double> = checkShape(this).asDoubleStructure()
|
public override fun INDArray.wrap(): Nd4jArrayStructure<Double> = checkShape(this).asDoubleStructure()
|
||||||
@ -239,14 +243,31 @@ public class DoubleNd4jArrayField(public override val shape: IntArray) : Nd4jArr
|
|||||||
public override operator fun Double.minus(arg: StructureND<Double>): Nd4jArrayStructure<Double> {
|
public override operator fun Double.minus(arg: StructureND<Double>): Nd4jArrayStructure<Double> {
|
||||||
return arg.ndArray.rsub(this).wrap()
|
return arg.ndArray.rsub(this).wrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override fun sin(arg: StructureND<Double>): StructureND<Double> = Transforms.sin(arg.ndArray).wrap()
|
||||||
|
|
||||||
|
override fun cos(arg: StructureND<Double>): StructureND<Double> = Transforms.cos(arg.ndArray).wrap()
|
||||||
|
|
||||||
|
override fun asin(arg: StructureND<Double>): StructureND<Double> = Transforms.asin(arg.ndArray).wrap()
|
||||||
|
|
||||||
|
override fun acos(arg: StructureND<Double>): StructureND<Double> = Transforms.acos(arg.ndArray).wrap()
|
||||||
|
|
||||||
|
override fun atan(arg: StructureND<Double>): StructureND<Double> = Transforms.atan(arg.ndArray).wrap()
|
||||||
|
|
||||||
|
override fun power(arg: StructureND<Double>, pow: Number): StructureND<Double> =
|
||||||
|
Transforms.pow(arg.ndArray,pow).wrap()
|
||||||
|
|
||||||
|
override fun exp(arg: StructureND<Double>): StructureND<Double> = Transforms.exp(arg.ndArray).wrap()
|
||||||
|
|
||||||
|
override fun ln(arg: StructureND<Double>): StructureND<Double> = Transforms.log(arg.ndArray).wrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Represents [FieldND] over [Nd4jArrayStructure] of [Float].
|
* Represents [FieldND] over [Nd4jArrayStructure] of [Float].
|
||||||
*/
|
*/
|
||||||
public class FloatNd4jArrayField(public override val shape: IntArray) : Nd4jArrayField<Float, FloatField> {
|
public class FloatNd4jArrayField(public override val shape: IntArray) : Nd4jArrayField<Float, FloatField>,
|
||||||
public override val elementContext: FloatField
|
ExtendedField<StructureND<Float>> {
|
||||||
get() = FloatField
|
public override val elementContext: FloatField get() = FloatField
|
||||||
|
|
||||||
public override fun INDArray.wrap(): Nd4jArrayStructure<Float> = checkShape(this).asFloatStructure()
|
public override fun INDArray.wrap(): Nd4jArrayStructure<Float> = checkShape(this).asFloatStructure()
|
||||||
|
|
||||||
@ -270,6 +291,23 @@ public class FloatNd4jArrayField(public override val shape: IntArray) : Nd4jArra
|
|||||||
|
|
||||||
public override operator fun Float.minus(arg: StructureND<Float>): Nd4jArrayStructure<Float> =
|
public override operator fun Float.minus(arg: StructureND<Float>): Nd4jArrayStructure<Float> =
|
||||||
arg.ndArray.rsub(this).wrap()
|
arg.ndArray.rsub(this).wrap()
|
||||||
|
|
||||||
|
override fun sin(arg: StructureND<Float>): StructureND<Float> = Sin(arg.ndArray).z().wrap()
|
||||||
|
|
||||||
|
override fun cos(arg: StructureND<Float>): StructureND<Float> = Cos(arg.ndArray).z().wrap()
|
||||||
|
|
||||||
|
override fun asin(arg: StructureND<Float>): StructureND<Float> = ASin(arg.ndArray).z().wrap()
|
||||||
|
|
||||||
|
override fun acos(arg: StructureND<Float>): StructureND<Float> = ACos(arg.ndArray).z().wrap()
|
||||||
|
|
||||||
|
override fun atan(arg: StructureND<Float>): StructureND<Float> = ATan(arg.ndArray).z().wrap()
|
||||||
|
|
||||||
|
override fun power(arg: StructureND<Float>, pow: Number): StructureND<Float> =
|
||||||
|
Pow(arg.ndArray, pow.toDouble()).z().wrap()
|
||||||
|
|
||||||
|
override fun exp(arg: StructureND<Float>): StructureND<Float> = Exp(arg.ndArray).z().wrap()
|
||||||
|
|
||||||
|
override fun ln(arg: StructureND<Float>): StructureND<Float> = Log(arg.ndArray).z().wrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -6,8 +6,12 @@
|
|||||||
package space.kscience.kmath.nd4j
|
package space.kscience.kmath.nd4j
|
||||||
|
|
||||||
import org.nd4j.linalg.factory.Nd4j
|
import org.nd4j.linalg.factory.Nd4j
|
||||||
|
import space.kscience.kmath.nd.StructureND
|
||||||
|
import space.kscience.kmath.operations.invoke
|
||||||
|
import kotlin.math.PI
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
import kotlin.test.assertTrue
|
||||||
import kotlin.test.fail
|
import kotlin.test.fail
|
||||||
|
|
||||||
internal class Nd4jArrayAlgebraTest {
|
internal class Nd4jArrayAlgebraTest {
|
||||||
@ -43,4 +47,14 @@ internal class Nd4jArrayAlgebraTest {
|
|||||||
expected[intArrayOf(1, 1)] = 26
|
expected[intArrayOf(1, 1)] = 26
|
||||||
assertEquals(expected, res)
|
assertEquals(expected, res)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testSin() = DoubleNd4jArrayField(intArrayOf(2, 2)).invoke {
|
||||||
|
val initial = produce { (i, j) -> if (i == j) PI/2 else 0.0 }
|
||||||
|
val transformed = sin(initial)
|
||||||
|
val expected = produce { (i, j) -> if (i == j) 1.0 else 0.0 }
|
||||||
|
|
||||||
|
println(transformed)
|
||||||
|
assertTrue { StructureND.contentEquals(transformed, expected) }
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user