Minor refactor

This commit is contained in:
Iaroslav 2020-06-29 22:30:08 +07:00
parent f54e5679cf
commit bf071bcdc1
No known key found for this signature in database
GPG Key ID: 46E15E4A31B3BCD7
2 changed files with 10 additions and 8 deletions

View File

@ -50,7 +50,7 @@ array[intArrayOf(0, 0)] = 24.0
println(array[0, 0]) // 24.0 println(array[0, 0]) // 24.0
``` ```
Fast element-wise arithmetics for INDArray: Fast element-wise and in-place arithmetics for INDArray:
```kotlin ```kotlin
import org.nd4j.linalg.factory.* import org.nd4j.linalg.factory.*

View File

@ -9,24 +9,22 @@ import scientifik.kmath.structures.NDRing
interface INDArrayRing<T, R, N> : interface INDArrayRing<T, R, N> :
NDRing<T, R, N> where R : Ring<T>, N : INDArrayStructure<T>, N : MutableNDStructure<T> { NDRing<T, R, N> where R : Ring<T>, N : INDArrayStructure<T>, N : MutableNDStructure<T> {
fun INDArray.wrap(): N
override val zero: N override val zero: N
get() = Nd4j.zeros(*shape).wrap() get() = Nd4j.zeros(*shape).wrap()
override val one: N override val one: N
get() = Nd4j.ones(*shape).wrap() get() = Nd4j.ones(*shape).wrap()
fun INDArray.wrap(): N
override fun produce(initializer: R.(IntArray) -> T): N { override fun produce(initializer: R.(IntArray) -> T): N {
val struct = Nd4j.create(*shape).wrap() val struct = Nd4j.create(*shape)!!.wrap()
struct.elements().map(Pair<IntArray, T>::first).forEach { struct[it] = elementContext.initializer(it) } struct.elements().map(Pair<IntArray, T>::first).forEach { struct[it] = elementContext.initializer(it) }
return struct return struct
} }
override fun map(arg: N, transform: R.(T) -> T): N { override fun map(arg: N, transform: R.(T) -> T): N {
val new = Nd4j.create(*shape) val newStruct = arg.ndArray.dup().wrap()
Nd4j.copy(arg.ndArray, new)
val newStruct = new.wrap()
newStruct.elements().forEach { (idx, value) -> newStruct[idx] = elementContext.transform(value) } newStruct.elements().forEach { (idx, value) -> newStruct[idx] = elementContext.transform(value) }
return newStruct return newStruct
} }
@ -52,6 +50,7 @@ interface INDArrayRing<T, R, N> :
override fun N.minus(b: Number): N = ndArray.subi(b).wrap() override fun N.minus(b: Number): N = ndArray.subi(b).wrap()
override fun N.plus(b: Number): N = ndArray.addi(b).wrap() override fun N.plus(b: Number): N = ndArray.addi(b).wrap()
override fun N.times(k: Number): N = ndArray.muli(k).wrap() override fun N.times(k: Number): N = ndArray.muli(k).wrap()
override fun Number.minus(b: N): N = b.ndArray.rsubi(this).wrap()
} }
interface INDArrayField<T, F, N> : NDField<T, F, N>, interface INDArrayField<T, F, N> : NDField<T, F, N>,
@ -61,13 +60,14 @@ interface INDArrayField<T, F, N> : NDField<T, F, N>,
} }
class RealINDArrayField(override val shape: IntArray, override val elementContext: Field<Double> = RealField) : class RealINDArrayField(override val shape: IntArray, override val elementContext: Field<Double> = RealField) :
INDArrayField<Double, Field<Double>, INDArrayRealStructure> { INDArrayField<Double, Field<Double>, INDArrayRealStructure> {
override fun INDArray.wrap(): INDArrayRealStructure = asRealStructure() override fun INDArray.wrap(): INDArrayRealStructure = asRealStructure()
override fun INDArrayRealStructure.div(arg: Double): INDArrayRealStructure = ndArray.divi(arg).wrap() override fun INDArrayRealStructure.div(arg: Double): INDArrayRealStructure = ndArray.divi(arg).wrap()
override fun INDArrayRealStructure.plus(arg: Double): INDArrayRealStructure = ndArray.addi(arg).wrap() override fun INDArrayRealStructure.plus(arg: Double): INDArrayRealStructure = ndArray.addi(arg).wrap()
override fun INDArrayRealStructure.minus(arg: Double): INDArrayRealStructure = ndArray.subi(arg).wrap() override fun INDArrayRealStructure.minus(arg: Double): INDArrayRealStructure = ndArray.subi(arg).wrap()
override fun INDArrayRealStructure.times(arg: Double): INDArrayRealStructure = ndArray.muli(arg).wrap() override fun INDArrayRealStructure.times(arg: Double): INDArrayRealStructure = ndArray.muli(arg).wrap()
override fun Double.div(arg: INDArrayRealStructure): INDArrayRealStructure = arg.ndArray.rdivi(this).wrap() override fun Double.div(arg: INDArrayRealStructure): INDArrayRealStructure = arg.ndArray.rdivi(this).wrap()
override fun Double.minus(arg: INDArrayRealStructure): INDArrayRealStructure = arg.ndArray.rsubi(this).wrap()
} }
class FloatINDArrayField(override val shape: IntArray, override val elementContext: Field<Float> = FloatField) : class FloatINDArrayField(override val shape: IntArray, override val elementContext: Field<Float> = FloatField) :
@ -78,6 +78,7 @@ class FloatINDArrayField(override val shape: IntArray, override val elementConte
override fun INDArrayFloatStructure.minus(arg: Float): INDArrayFloatStructure = ndArray.subi(arg).wrap() override fun INDArrayFloatStructure.minus(arg: Float): INDArrayFloatStructure = ndArray.subi(arg).wrap()
override fun INDArrayFloatStructure.times(arg: Float): INDArrayFloatStructure = ndArray.muli(arg).wrap() override fun INDArrayFloatStructure.times(arg: Float): INDArrayFloatStructure = ndArray.muli(arg).wrap()
override fun Float.div(arg: INDArrayFloatStructure): INDArrayFloatStructure = arg.ndArray.rdivi(this).wrap() override fun Float.div(arg: INDArrayFloatStructure): INDArrayFloatStructure = arg.ndArray.rdivi(this).wrap()
override fun Float.minus(arg: INDArrayFloatStructure): INDArrayFloatStructure = arg.ndArray.rsubi(this).wrap()
} }
class IntINDArrayRing(override val shape: IntArray, override val elementContext: Ring<Int> = IntRing) : class IntINDArrayRing(override val shape: IntArray, override val elementContext: Ring<Int> = IntRing) :
@ -86,4 +87,5 @@ class IntINDArrayRing(override val shape: IntArray, override val elementContext:
override fun INDArrayIntStructure.plus(arg: Int): INDArrayIntStructure = ndArray.addi(arg).wrap() override fun INDArrayIntStructure.plus(arg: Int): INDArrayIntStructure = ndArray.addi(arg).wrap()
override fun INDArrayIntStructure.minus(arg: Int): INDArrayIntStructure = ndArray.subi(arg).wrap() override fun INDArrayIntStructure.minus(arg: Int): INDArrayIntStructure = ndArray.subi(arg).wrap()
override fun INDArrayIntStructure.times(arg: Int): INDArrayIntStructure = ndArray.muli(arg).wrap() override fun INDArrayIntStructure.times(arg: Int): INDArrayIntStructure = ndArray.muli(arg).wrap()
override fun Int.minus(arg: INDArrayIntStructure): INDArrayIntStructure = arg.ndArray.rsubi(this).wrap()
} }