Implement kmath-nd4j: module that implements NDStructure for INDArray of ND4J #116

Merged
CommanderTvis merged 50 commits from nd4j into dev 2020-10-29 19:58:53 +03:00
3 changed files with 6 additions and 3 deletions
Showing only changes of commit 2bc62356d6 - Show all commits

View File

@ -14,8 +14,9 @@ class BoxingNDField<T, F : Field<T>>(
fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer<T> = fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer<T> =
bufferFactory(size, initializer) bufferFactory(size, initializer)
override fun check(vararg elements: NDBuffer<T>) { override fun check(vararg elements: NDBuffer<T>): Array<out NDBuffer<T>> {
if (!elements.all { it.strides == this.strides }) error("Element strides are not the same as context strides") if (!elements.all { it.strides == this.strides }) error("Element strides are not the same as context strides")
return elements
} }
override val zero: BufferedNDFieldElement<T, F> by lazy { produce { zero } } override val zero: BufferedNDFieldElement<T, F> by lazy { produce { zero } }

View File

@ -14,8 +14,9 @@ class BoxingNDRing<T, R : Ring<T>>(
fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer<T> = fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer<T> =
bufferFactory(size, initializer) bufferFactory(size, initializer)
override fun check(vararg elements: NDBuffer<T>) { override fun check(vararg elements: NDBuffer<T>): Array<out NDBuffer<T>> {
if (!elements.all { it.strides == this.strides }) error("Element strides are not the same as context strides") if (!elements.all { it.strides == this.strides }) error("Element strides are not the same as context strides")
return elements
} }
override val zero: BufferedNDRingElement<T, R> by lazy { produce { zero } } override val zero: BufferedNDRingElement<T, R> by lazy { produce { zero } }

View File

@ -5,8 +5,9 @@ import scientifik.kmath.operations.*
interface BufferedNDAlgebra<T, C> : NDAlgebra<T, C, NDBuffer<T>> { interface BufferedNDAlgebra<T, C> : NDAlgebra<T, C, NDBuffer<T>> {
val strides: Strides val strides: Strides
override fun check(vararg elements: NDBuffer<T>) { override fun check(vararg elements: NDBuffer<T>): Array<out NDBuffer<T>> {
if (!elements.all { it.strides == this.strides }) error("Strides mismatch") if (!elements.all { it.strides == this.strides }) error("Strides mismatch")
return elements
} }
/** /**