RealNDField
This commit is contained in:
parent
20a7c6c4f1
commit
be18014d54
@ -0,0 +1,20 @@
|
|||||||
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
|
//
|
||||||
|
//class LazyStructureField<T: Any>(val field: Field<T>): Field<LazyStructure<T>>{
|
||||||
|
//
|
||||||
|
//}
|
||||||
|
//
|
||||||
|
//class LazyStructure<T : Any> : NDStructure<T> {
|
||||||
|
//
|
||||||
|
// override val shape: IntArray
|
||||||
|
// get() = TODO("not implemented") //To change initializer of created properties use File | Settings | File Templates.
|
||||||
|
//
|
||||||
|
// override fun get(index: IntArray): T {
|
||||||
|
// TODO("not implemented") //To change body of created functions use File | Settings | File Templates.
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// override fun iterator(): Iterator<Pair<IntArray, T>> {
|
||||||
|
// TODO("not implemented") //To change body of created functions use File | Settings | File Templates.
|
||||||
|
// }
|
||||||
|
//}
|
@ -1,6 +1,5 @@
|
|||||||
package scientifik.kmath.structures
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
import scientifik.kmath.operations.DoubleField
|
|
||||||
import scientifik.kmath.operations.Field
|
import scientifik.kmath.operations.Field
|
||||||
import scientifik.kmath.operations.FieldElement
|
import scientifik.kmath.operations.FieldElement
|
||||||
|
|
||||||
@ -107,14 +106,15 @@ abstract class NDField<T>(val shape: IntArray, val field: Field<T>) : Field<NDAr
|
|||||||
/**
|
/**
|
||||||
* Immutable [NDStructure] coupled to the context. Emulates Python ndarray
|
* Immutable [NDStructure] coupled to the context. Emulates Python ndarray
|
||||||
*/
|
*/
|
||||||
data class NDArray<T>(override val context: NDField<T>, private val structure: NDStructure<T>) : FieldElement<NDArray<T>, NDField<T>>, NDStructure<T> by structure {
|
class NDArray<T>(override val context: NDField<T>, private val structure: NDStructure<T>) : FieldElement<NDArray<T>, NDField<T>>, NDStructure<T> by structure {
|
||||||
|
|
||||||
//TODO ensure structure is immutable
|
//TODO ensure structure is immutable
|
||||||
|
|
||||||
override val self: NDArray<T>
|
override val self: NDArray<T>
|
||||||
get() = this
|
get() = this
|
||||||
|
|
||||||
fun transform(action: (IntArray, T) -> T): NDArray<T> = context.produce { action(it, get(*it)) }
|
inline fun transform(crossinline action: (IntArray, T) -> T): NDArray<T> = context.produce { action(it, get(*it)) }
|
||||||
|
inline fun transform(crossinline action: (T) -> T): NDArray<T> = context.produce { action(get(*it)) }
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -173,7 +173,7 @@ object NDArrays {
|
|||||||
* Create a platform-optimized NDArray of doubles
|
* Create a platform-optimized NDArray of doubles
|
||||||
*/
|
*/
|
||||||
fun realNDArray(shape: IntArray, initializer: (IntArray) -> Double = { 0.0 }): NDArray<Double> {
|
fun realNDArray(shape: IntArray, initializer: (IntArray) -> Double = { 0.0 }): NDArray<Double> {
|
||||||
return GenericNDField(shape, DoubleField).produce(initializer)
|
return RealNDField(shape).produce(initializer)
|
||||||
}
|
}
|
||||||
|
|
||||||
fun real1DArray(dim: Int, initializer: (Int) -> Double = { _ -> 0.0 }): NDArray<Double> {
|
fun real1DArray(dim: Int, initializer: (Int) -> Double = { _ -> 0.0 }): NDArray<Double> {
|
||||||
@ -188,6 +188,8 @@ object NDArrays {
|
|||||||
return realNDArray(intArrayOf(dim1, dim2, dim3)) { initializer(it[0], it[1], it[2]) }
|
return realNDArray(intArrayOf(dim1, dim2, dim3)) { initializer(it[0], it[1], it[2]) }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline fun produceReal(shape: IntArray, block: RealNDField.() -> RealNDArray) = RealNDField(shape).run(block)
|
||||||
|
|
||||||
// /**
|
// /**
|
||||||
// * Simple boxing NDField
|
// * Simple boxing NDField
|
||||||
// */
|
// */
|
||||||
|
@ -0,0 +1,46 @@
|
|||||||
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
|
import scientifik.kmath.operations.DoubleField
|
||||||
|
import scientifik.kmath.operations.ExponentialOperations
|
||||||
|
import scientifik.kmath.operations.PowerOperations
|
||||||
|
import scientifik.kmath.operations.TrigonometricOperations
|
||||||
|
import kotlin.math.*
|
||||||
|
|
||||||
|
typealias RealNDArray = NDArray<Double>
|
||||||
|
|
||||||
|
|
||||||
|
class RealNDField(shape: IntArray) : NDField<Double>(shape, DoubleField),
|
||||||
|
TrigonometricOperations<RealNDArray>,
|
||||||
|
PowerOperations<RealNDArray>,
|
||||||
|
ExponentialOperations<RealNDArray> {
|
||||||
|
|
||||||
|
override fun produceStructure(initializer: (IntArray) -> Double): NDStructure<Double> {
|
||||||
|
return genericNdStructure(shape, initializer)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun power(arg: RealNDArray, pow: Double): RealNDArray {
|
||||||
|
return arg.transform { d -> d.pow(pow) }
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun exp(arg: RealNDArray): RealNDArray {
|
||||||
|
return arg.transform { d -> exp(d) }
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun ln(arg: RealNDArray): RealNDArray {
|
||||||
|
return arg.transform { d -> ln(d) }
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun sin(arg: RealNDArray): RealNDArray {
|
||||||
|
return arg.transform { d -> sin(d) }
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun cos(arg: RealNDArray): RealNDArray {
|
||||||
|
return arg.transform { d -> cos(d) }
|
||||||
|
}
|
||||||
|
|
||||||
|
fun abs(arg: RealNDArray): RealNDArray {
|
||||||
|
return arg.transform { d -> abs(d) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
@ -39,4 +39,15 @@ class FieldExpressionContextTest {
|
|||||||
val expression = FieldExpressionContext(DoubleField).expression()
|
val expression = FieldExpressionContext(DoubleField).expression()
|
||||||
assertEquals(expression("x" to 1.0), 4.0)
|
assertEquals(expression("x" to 1.0), 4.0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun valueExpression() {
|
||||||
|
val expressionBuilder: FieldExpressionContext<Double>.()->Expression<Double> = {
|
||||||
|
val x = variable("x")
|
||||||
|
x * x + 2 * x + 1.0
|
||||||
|
}
|
||||||
|
|
||||||
|
val expression = FieldExpressionContext(DoubleField).expressionBuilder()
|
||||||
|
assertEquals(expression("x" to 1.0), 4.0)
|
||||||
|
}
|
||||||
}
|
}
|
@ -1,6 +1,8 @@
|
|||||||
package scientifik.kmath.structures
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
|
import scientifik.kmath.structures.NDArrays.produceReal
|
||||||
import scientifik.kmath.structures.NDArrays.real2DArray
|
import scientifik.kmath.structures.NDArrays.real2DArray
|
||||||
|
import kotlin.math.abs
|
||||||
import kotlin.math.pow
|
import kotlin.math.pow
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
@ -40,4 +42,18 @@ class RealNDFieldTest {
|
|||||||
val result = function(array1) + 1.0
|
val result = function(array1) + 1.0
|
||||||
assertEquals(10.0, result[1,1])
|
assertEquals(10.0, result[1,1])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testLibraryFunction() {
|
||||||
|
val abs: (Double) -> Double = ::abs
|
||||||
|
val result = abs(array1)
|
||||||
|
assertEquals(10.0, result[1,1])
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testAbs(){
|
||||||
|
val res = produceReal(array1.shape){
|
||||||
|
1 + abs(array1) + exp(array2)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user