forked from kscience/kmath
Optimize reverse division for FP INDArrayAlgebra
This commit is contained in:
parent
d7949fdb01
commit
8a8b314d0a
@ -7,8 +7,8 @@ import scientifik.kmath.structures.MutableNDStructure
|
|||||||
import scientifik.kmath.structures.NDField
|
import scientifik.kmath.structures.NDField
|
||||||
import scientifik.kmath.structures.NDRing
|
import scientifik.kmath.structures.NDRing
|
||||||
|
|
||||||
interface INDArrayRing<T, F, N> :
|
interface INDArrayRing<T, R, N> :
|
||||||
NDRing<T, F, N> where F : Ring<T>, N : INDArrayStructure<T>, N : MutableNDStructure<T> {
|
NDRing<T, R, N> where R : Ring<T>, N : INDArrayStructure<T>, N : MutableNDStructure<T> {
|
||||||
fun INDArray.wrap(): N
|
fun INDArray.wrap(): N
|
||||||
|
|
||||||
override val zero: N
|
override val zero: N
|
||||||
@ -17,13 +17,13 @@ interface INDArrayRing<T, F, N> :
|
|||||||
override val one: N
|
override val one: N
|
||||||
get() = Nd4j.ones(*shape).wrap()
|
get() = Nd4j.ones(*shape).wrap()
|
||||||
|
|
||||||
override fun produce(initializer: F.(IntArray) -> T): N {
|
override fun produce(initializer: R.(IntArray) -> T): N {
|
||||||
val struct = Nd4j.create(*shape).wrap()
|
val struct = Nd4j.create(*shape).wrap()
|
||||||
struct.elements().map(Pair<IntArray, T>::first).forEach { struct[it] = elementContext.initializer(it) }
|
struct.elements().map(Pair<IntArray, T>::first).forEach { struct[it] = elementContext.initializer(it) }
|
||||||
return struct
|
return struct
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun map(arg: N, transform: F.(T) -> T): N {
|
override fun map(arg: N, transform: R.(T) -> T): N {
|
||||||
val new = Nd4j.create(*shape)
|
val new = Nd4j.create(*shape)
|
||||||
Nd4j.copy(arg.ndArray, new)
|
Nd4j.copy(arg.ndArray, new)
|
||||||
val newStruct = new.wrap()
|
val newStruct = new.wrap()
|
||||||
@ -31,13 +31,13 @@ interface INDArrayRing<T, F, N> :
|
|||||||
return newStruct
|
return newStruct
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun mapIndexed(arg: N, transform: F.(index: IntArray, T) -> T): N {
|
override fun mapIndexed(arg: N, transform: R.(index: IntArray, T) -> T): N {
|
||||||
val new = Nd4j.create(*shape).wrap()
|
val new = Nd4j.create(*shape).wrap()
|
||||||
new.elements().forEach { (idx, _) -> new[idx] = elementContext.transform(idx, arg[idx]) }
|
new.elements().forEach { (idx, _) -> new[idx] = elementContext.transform(idx, arg[idx]) }
|
||||||
return new
|
return new
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun combine(a: N, b: N, transform: F.(T, T) -> T): N {
|
override fun combine(a: N, b: N, transform: R.(T, T) -> T): N {
|
||||||
val new = Nd4j.create(*shape).wrap()
|
val new = Nd4j.create(*shape).wrap()
|
||||||
new.elements().forEach { (idx, _) -> new[idx] = elementContext.transform(a[idx], b[idx]) }
|
new.elements().forEach { (idx, _) -> new[idx] = elementContext.transform(a[idx], b[idx]) }
|
||||||
return new
|
return new
|
||||||
@ -66,6 +66,7 @@ class RealINDArrayField(override val shape: IntArray, override val elementContex
|
|||||||
override fun INDArrayRealStructure.plus(arg: Double): INDArrayRealStructure = ndArray.addi(arg).wrap()
|
override fun INDArrayRealStructure.plus(arg: Double): INDArrayRealStructure = ndArray.addi(arg).wrap()
|
||||||
override fun INDArrayRealStructure.minus(arg: Double): INDArrayRealStructure = ndArray.subi(arg).wrap()
|
override fun INDArrayRealStructure.minus(arg: Double): INDArrayRealStructure = ndArray.subi(arg).wrap()
|
||||||
override fun INDArrayRealStructure.times(arg: Double): INDArrayRealStructure = ndArray.muli(arg).wrap()
|
override fun INDArrayRealStructure.times(arg: Double): INDArrayRealStructure = ndArray.muli(arg).wrap()
|
||||||
|
override fun Double.div(arg: INDArrayRealStructure): INDArrayRealStructure = arg.ndArray.rdivi(this).wrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
class FloatINDArrayField(override val shape: IntArray, override val elementContext: Field<Float> = FloatField) :
|
class FloatINDArrayField(override val shape: IntArray, override val elementContext: Field<Float> = FloatField) :
|
||||||
@ -75,6 +76,7 @@ class FloatINDArrayField(override val shape: IntArray, override val elementConte
|
|||||||
override fun INDArrayFloatStructure.plus(arg: Float): INDArrayFloatStructure = ndArray.addi(arg).wrap()
|
override fun INDArrayFloatStructure.plus(arg: Float): INDArrayFloatStructure = ndArray.addi(arg).wrap()
|
||||||
override fun INDArrayFloatStructure.minus(arg: Float): INDArrayFloatStructure = ndArray.subi(arg).wrap()
|
override fun INDArrayFloatStructure.minus(arg: Float): INDArrayFloatStructure = ndArray.subi(arg).wrap()
|
||||||
override fun INDArrayFloatStructure.times(arg: Float): INDArrayFloatStructure = ndArray.muli(arg).wrap()
|
override fun INDArrayFloatStructure.times(arg: Float): INDArrayFloatStructure = ndArray.muli(arg).wrap()
|
||||||
|
override fun Float.div(arg: INDArrayFloatStructure): INDArrayFloatStructure = arg.ndArray.rdivi(this).wrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
class IntINDArrayRing(override val shape: IntArray, override val elementContext: Ring<Int> = IntRing) :
|
class IntINDArrayRing(override val shape: IntArray, override val elementContext: Ring<Int> = IntRing) :
|
||||||
|
Loading…
Reference in New Issue
Block a user