[WIP] Refactor NDStructures

This commit is contained in:
Alexander Nozik 2021-01-23 21:50:48 +03:00
parent 332c04b573
commit 0baec14059
2 changed files with 3 additions and 5 deletions

View File

@ -50,7 +50,7 @@ public interface BufferNDAlgebra<T, C> : NDAlgebra<T, C> {
val aBuffer = a.ndBuffer val aBuffer = a.ndBuffer
val bBuffer = b.ndBuffer val bBuffer = b.ndBuffer
val buffer = bufferFactory(strides.linearSize) { offset -> val buffer = bufferFactory(strides.linearSize) { offset ->
elementContext.transform(aBuffer.buffer[offset], bBuffer[offset]) elementContext.transform(aBuffer.buffer[offset], bBuffer.buffer[offset])
} }
return NDBuffer(strides, buffer) return NDBuffer(strides, buffer)
} }

View File

@ -217,9 +217,7 @@ public class DefaultStrides private constructor(override val shape: IntArray) :
} }
override fun offset(index: IntArray): Int = index.mapIndexed { i, value -> override fun offset(index: IntArray): Int = index.mapIndexed { i, value ->
if (value < 0 || value >= shape[i]) if (value < 0 || value >= shape[i]) throw IndexOutOfBoundsException("Index $value out of shape bounds: (0,${this.shape[i]})")
throw IndexOutOfBoundsException("Index $value out of shape bounds: (0,${this.shape[i]})")
value * strides[i] value * strides[i]
}.sum() }.sum()
@ -332,7 +330,7 @@ public inline fun <T, reified R : Any> NDStructure<T>.mapToBuffer(
public class MutableNDBuffer<T>( public class MutableNDBuffer<T>(
strides: Strides, strides: Strides,
buffer: MutableBuffer<T>, buffer: MutableBuffer<T>,
) : NDBuffer<T>(strides,buffer), MutableNDStructure<T> { ) : NDBuffer<T>(strides, buffer), MutableNDStructure<T> {
init { init {
require(strides.linearSize == buffer.size) { require(strides.linearSize == buffer.size) {