hotfix sequence
This commit is contained in:
parent
0365d41f31
commit
078686a046
@ -34,7 +34,8 @@ public open class BufferedTensor<T>(
|
||||
|
||||
public fun vectorSequence(): Sequence<MutableStructure1D<T>> = sequence {
|
||||
check(shape.size >= 1) {"todo"}
|
||||
val vectorOffset = linearStructure.strides[0]
|
||||
val n = shape.size
|
||||
val vectorOffset = shape[n - 1]
|
||||
val vectorShape = intArrayOf(shape.last())
|
||||
for (offset in 0 until numel step vectorOffset) {
|
||||
val vector = BufferedTensor<T>(vectorShape, buffer, offset).as1D()
|
||||
@ -44,8 +45,9 @@ public open class BufferedTensor<T>(
|
||||
|
||||
public fun matrixSequence(): Sequence<MutableStructure2D<T>> = sequence {
|
||||
check(shape.size >= 2) {"todo"}
|
||||
val matrixOffset = linearStructure.strides[1]
|
||||
val matrixShape = intArrayOf(shape[shape.size - 2], shape.last()) //todo better way?
|
||||
val n = shape.size
|
||||
val matrixOffset = shape[n - 1] * shape[n - 2]
|
||||
val matrixShape = intArrayOf(shape[n - 2], shape[n - 1]) //todo better way?
|
||||
for (offset in 0 until numel step matrixOffset) {
|
||||
val matrix = BufferedTensor<T>(matrixShape, buffer, offset).as2D()
|
||||
yield(matrix)
|
||||
|
Loading…
Reference in New Issue
Block a user