hotfix sequence
This commit is contained in:
parent
0365d41f31
commit
078686a046
@ -34,7 +34,8 @@ public open class BufferedTensor<T>(
|
|||||||
|
|
||||||
public fun vectorSequence(): Sequence<MutableStructure1D<T>> = sequence {
|
public fun vectorSequence(): Sequence<MutableStructure1D<T>> = sequence {
|
||||||
check(shape.size >= 1) {"todo"}
|
check(shape.size >= 1) {"todo"}
|
||||||
val vectorOffset = linearStructure.strides[0]
|
val n = shape.size
|
||||||
|
val vectorOffset = shape[n - 1]
|
||||||
val vectorShape = intArrayOf(shape.last())
|
val vectorShape = intArrayOf(shape.last())
|
||||||
for (offset in 0 until numel step vectorOffset) {
|
for (offset in 0 until numel step vectorOffset) {
|
||||||
val vector = BufferedTensor<T>(vectorShape, buffer, offset).as1D()
|
val vector = BufferedTensor<T>(vectorShape, buffer, offset).as1D()
|
||||||
@ -44,8 +45,9 @@ public open class BufferedTensor<T>(
|
|||||||
|
|
||||||
public fun matrixSequence(): Sequence<MutableStructure2D<T>> = sequence {
|
public fun matrixSequence(): Sequence<MutableStructure2D<T>> = sequence {
|
||||||
check(shape.size >= 2) {"todo"}
|
check(shape.size >= 2) {"todo"}
|
||||||
val matrixOffset = linearStructure.strides[1]
|
val n = shape.size
|
||||||
val matrixShape = intArrayOf(shape[shape.size - 2], shape.last()) //todo better way?
|
val matrixOffset = shape[n - 1] * shape[n - 2]
|
||||||
|
val matrixShape = intArrayOf(shape[n - 2], shape[n - 1]) //todo better way?
|
||||||
for (offset in 0 until numel step matrixOffset) {
|
for (offset in 0 until numel step matrixOffset) {
|
||||||
val matrix = BufferedTensor<T>(matrixShape, buffer, offset).as2D()
|
val matrix = BufferedTensor<T>(matrixShape, buffer, offset).as2D()
|
||||||
yield(matrix)
|
yield(matrix)
|
||||||
|
Loading…
Reference in New Issue
Block a user