Implement kmath-nd4j: module that implements NDStructure for INDArray of ND4J #116
@ -14,8 +14,9 @@ class BoxingNDField<T, F : Field<T>>(
|
||||
fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer<T> =
|
||||
bufferFactory(size, initializer)
|
||||
|
||||
override fun check(vararg elements: NDBuffer<T>) {
|
||||
override fun check(vararg elements: NDBuffer<T>): Array<out NDBuffer<T>> {
|
||||
if (!elements.all { it.strides == this.strides }) error("Element strides are not the same as context strides")
|
||||
return elements
|
||||
}
|
||||
|
||||
override val zero: BufferedNDFieldElement<T, F> by lazy { produce { zero } }
|
||||
|
@ -14,8 +14,9 @@ class BoxingNDRing<T, R : Ring<T>>(
|
||||
fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer<T> =
|
||||
bufferFactory(size, initializer)
|
||||
|
||||
override fun check(vararg elements: NDBuffer<T>) {
|
||||
override fun check(vararg elements: NDBuffer<T>): Array<out NDBuffer<T>> {
|
||||
if (!elements.all { it.strides == this.strides }) error("Element strides are not the same as context strides")
|
||||
return elements
|
||||
}
|
||||
|
||||
override val zero: BufferedNDRingElement<T, R> by lazy { produce { zero } }
|
||||
|
@ -5,8 +5,9 @@ import scientifik.kmath.operations.*
|
||||
interface BufferedNDAlgebra<T, C> : NDAlgebra<T, C, NDBuffer<T>> {
|
||||
val strides: Strides
|
||||
|
||||
override fun check(vararg elements: NDBuffer<T>) {
|
||||
override fun check(vararg elements: NDBuffer<T>): Array<out NDBuffer<T>> {
|
||||
if (!elements.all { it.strides == this.strides }) error("Strides mismatch")
|
||||
return elements
|
||||
}
|
||||
|
||||
/**
|
||||
|
Loading…
Reference in New Issue
Block a user