RealNDField
This commit is contained in:
parent
20a7c6c4f1
commit
be18014d54
@ -0,0 +1,20 @@
|
||||
package scientifik.kmath.structures
|
||||
|
||||
//
|
||||
//class LazyStructureField<T: Any>(val field: Field<T>): Field<LazyStructure<T>>{
|
||||
//
|
||||
//}
|
||||
//
|
||||
//class LazyStructure<T : Any> : NDStructure<T> {
|
||||
//
|
||||
// override val shape: IntArray
|
||||
// get() = TODO("not implemented") //To change initializer of created properties use File | Settings | File Templates.
|
||||
//
|
||||
// override fun get(index: IntArray): T {
|
||||
// TODO("not implemented") //To change body of created functions use File | Settings | File Templates.
|
||||
// }
|
||||
//
|
||||
// override fun iterator(): Iterator<Pair<IntArray, T>> {
|
||||
// TODO("not implemented") //To change body of created functions use File | Settings | File Templates.
|
||||
// }
|
||||
//}
|
@ -1,6 +1,5 @@
|
||||
package scientifik.kmath.structures
|
||||
|
||||
import scientifik.kmath.operations.DoubleField
|
||||
import scientifik.kmath.operations.Field
|
||||
import scientifik.kmath.operations.FieldElement
|
||||
|
||||
@ -107,14 +106,15 @@ abstract class NDField<T>(val shape: IntArray, val field: Field<T>) : Field<NDAr
|
||||
/**
|
||||
* Immutable [NDStructure] coupled to the context. Emulates Python ndarray
|
||||
*/
|
||||
data class NDArray<T>(override val context: NDField<T>, private val structure: NDStructure<T>) : FieldElement<NDArray<T>, NDField<T>>, NDStructure<T> by structure {
|
||||
class NDArray<T>(override val context: NDField<T>, private val structure: NDStructure<T>) : FieldElement<NDArray<T>, NDField<T>>, NDStructure<T> by structure {
|
||||
|
||||
//TODO ensure structure is immutable
|
||||
|
||||
override val self: NDArray<T>
|
||||
get() = this
|
||||
|
||||
fun transform(action: (IntArray, T) -> T): NDArray<T> = context.produce { action(it, get(*it)) }
|
||||
inline fun transform(crossinline action: (IntArray, T) -> T): NDArray<T> = context.produce { action(it, get(*it)) }
|
||||
inline fun transform(crossinline action: (T) -> T): NDArray<T> = context.produce { action(get(*it)) }
|
||||
}
|
||||
|
||||
/**
|
||||
@ -173,7 +173,7 @@ object NDArrays {
|
||||
* Create a platform-optimized NDArray of doubles
|
||||
*/
|
||||
fun realNDArray(shape: IntArray, initializer: (IntArray) -> Double = { 0.0 }): NDArray<Double> {
|
||||
return GenericNDField(shape, DoubleField).produce(initializer)
|
||||
return RealNDField(shape).produce(initializer)
|
||||
}
|
||||
|
||||
fun real1DArray(dim: Int, initializer: (Int) -> Double = { _ -> 0.0 }): NDArray<Double> {
|
||||
@ -188,6 +188,8 @@ object NDArrays {
|
||||
return realNDArray(intArrayOf(dim1, dim2, dim3)) { initializer(it[0], it[1], it[2]) }
|
||||
}
|
||||
|
||||
inline fun produceReal(shape: IntArray, block: RealNDField.() -> RealNDArray) = RealNDField(shape).run(block)
|
||||
|
||||
// /**
|
||||
// * Simple boxing NDField
|
||||
// */
|
||||
|
@ -0,0 +1,46 @@
|
||||
package scientifik.kmath.structures
|
||||
|
||||
import scientifik.kmath.operations.DoubleField
|
||||
import scientifik.kmath.operations.ExponentialOperations
|
||||
import scientifik.kmath.operations.PowerOperations
|
||||
import scientifik.kmath.operations.TrigonometricOperations
|
||||
import kotlin.math.*
|
||||
|
||||
typealias RealNDArray = NDArray<Double>
|
||||
|
||||
|
||||
class RealNDField(shape: IntArray) : NDField<Double>(shape, DoubleField),
|
||||
TrigonometricOperations<RealNDArray>,
|
||||
PowerOperations<RealNDArray>,
|
||||
ExponentialOperations<RealNDArray> {
|
||||
|
||||
override fun produceStructure(initializer: (IntArray) -> Double): NDStructure<Double> {
|
||||
return genericNdStructure(shape, initializer)
|
||||
}
|
||||
|
||||
override fun power(arg: RealNDArray, pow: Double): RealNDArray {
|
||||
return arg.transform { d -> d.pow(pow) }
|
||||
}
|
||||
|
||||
override fun exp(arg: RealNDArray): RealNDArray {
|
||||
return arg.transform { d -> exp(d) }
|
||||
}
|
||||
|
||||
override fun ln(arg: RealNDArray): RealNDArray {
|
||||
return arg.transform { d -> ln(d) }
|
||||
}
|
||||
|
||||
override fun sin(arg: RealNDArray): RealNDArray {
|
||||
return arg.transform { d -> sin(d) }
|
||||
}
|
||||
|
||||
override fun cos(arg: RealNDArray): RealNDArray {
|
||||
return arg.transform { d -> cos(d) }
|
||||
}
|
||||
|
||||
fun abs(arg: RealNDArray): RealNDArray {
|
||||
return arg.transform { d -> abs(d) }
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -39,4 +39,15 @@ class FieldExpressionContextTest {
|
||||
val expression = FieldExpressionContext(DoubleField).expression()
|
||||
assertEquals(expression("x" to 1.0), 4.0)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun valueExpression() {
|
||||
val expressionBuilder: FieldExpressionContext<Double>.()->Expression<Double> = {
|
||||
val x = variable("x")
|
||||
x * x + 2 * x + 1.0
|
||||
}
|
||||
|
||||
val expression = FieldExpressionContext(DoubleField).expressionBuilder()
|
||||
assertEquals(expression("x" to 1.0), 4.0)
|
||||
}
|
||||
}
|
@ -1,6 +1,8 @@
|
||||
package scientifik.kmath.structures
|
||||
|
||||
import scientifik.kmath.structures.NDArrays.produceReal
|
||||
import scientifik.kmath.structures.NDArrays.real2DArray
|
||||
import kotlin.math.abs
|
||||
import kotlin.math.pow
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
@ -40,4 +42,18 @@ class RealNDFieldTest {
|
||||
val result = function(array1) + 1.0
|
||||
assertEquals(10.0, result[1,1])
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testLibraryFunction() {
|
||||
val abs: (Double) -> Double = ::abs
|
||||
val result = abs(array1)
|
||||
assertEquals(10.0, result[1,1])
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testAbs(){
|
||||
val res = produceReal(array1.shape){
|
||||
1 + abs(array1) + exp(array2)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user