Implement kmath-nd4j: module that implements NDStructure for INDArray of ND4J #116
@ -14,8 +14,9 @@ class BoxingNDField<T, F : Field<T>>(
|
|||||||
fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer<T> =
|
fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer<T> =
|
||||||
bufferFactory(size, initializer)
|
bufferFactory(size, initializer)
|
||||||
|
|
||||||
override fun check(vararg elements: NDBuffer<T>) {
|
override fun check(vararg elements: NDBuffer<T>): Array<out NDBuffer<T>> {
|
||||||
if (!elements.all { it.strides == this.strides }) error("Element strides are not the same as context strides")
|
if (!elements.all { it.strides == this.strides }) error("Element strides are not the same as context strides")
|
||||||
|
return elements
|
||||||
}
|
}
|
||||||
|
|
||||||
override val zero: BufferedNDFieldElement<T, F> by lazy { produce { zero } }
|
override val zero: BufferedNDFieldElement<T, F> by lazy { produce { zero } }
|
||||||
|
@ -14,8 +14,9 @@ class BoxingNDRing<T, R : Ring<T>>(
|
|||||||
fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer<T> =
|
fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer<T> =
|
||||||
bufferFactory(size, initializer)
|
bufferFactory(size, initializer)
|
||||||
|
|
||||||
override fun check(vararg elements: NDBuffer<T>) {
|
override fun check(vararg elements: NDBuffer<T>): Array<out NDBuffer<T>> {
|
||||||
if (!elements.all { it.strides == this.strides }) error("Element strides are not the same as context strides")
|
if (!elements.all { it.strides == this.strides }) error("Element strides are not the same as context strides")
|
||||||
|
return elements
|
||||||
}
|
}
|
||||||
|
|
||||||
override val zero: BufferedNDRingElement<T, R> by lazy { produce { zero } }
|
override val zero: BufferedNDRingElement<T, R> by lazy { produce { zero } }
|
||||||
|
@ -5,8 +5,9 @@ import scientifik.kmath.operations.*
|
|||||||
interface BufferedNDAlgebra<T, C> : NDAlgebra<T, C, NDBuffer<T>> {
|
interface BufferedNDAlgebra<T, C> : NDAlgebra<T, C, NDBuffer<T>> {
|
||||||
val strides: Strides
|
val strides: Strides
|
||||||
|
|
||||||
override fun check(vararg elements: NDBuffer<T>) {
|
override fun check(vararg elements: NDBuffer<T>): Array<out NDBuffer<T>> {
|
||||||
if (!elements.all { it.strides == this.strides }) error("Strides mismatch")
|
if (!elements.all { it.strides == this.strides }) error("Strides mismatch")
|
||||||
|
return elements
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
Loading…
Reference in New Issue
Block a user