Minor corrections
This commit is contained in:
parent
f0cdb9b657
commit
8a039326d4
@ -792,7 +792,6 @@ public final class space/kscience/kmath/nd/DefaultStrides : space/kscience/kmath
|
|||||||
public fun getStrides ()[I
|
public fun getStrides ()[I
|
||||||
public fun hashCode ()I
|
public fun hashCode ()I
|
||||||
public fun index (I)[I
|
public fun index (I)[I
|
||||||
public fun offset ([I)I
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public final class space/kscience/kmath/nd/DefaultStrides$Companion {
|
public final class space/kscience/kmath/nd/DefaultStrides$Companion {
|
||||||
@ -934,7 +933,7 @@ public abstract interface class space/kscience/kmath/nd/Strides {
|
|||||||
public abstract fun getStrides ()[I
|
public abstract fun getStrides ()[I
|
||||||
public abstract fun index (I)[I
|
public abstract fun index (I)[I
|
||||||
public fun indices ()Lkotlin/sequences/Sequence;
|
public fun indices ()Lkotlin/sequences/Sequence;
|
||||||
public abstract fun offset ([I)I
|
public fun offset ([I)I
|
||||||
}
|
}
|
||||||
|
|
||||||
public abstract interface class space/kscience/kmath/nd/Structure1D : space/kscience/kmath/nd/StructureND, space/kscience/kmath/structures/Buffer {
|
public abstract interface class space/kscience/kmath/nd/Structure1D : space/kscience/kmath/nd/StructureND, space/kscience/kmath/structures/Buffer {
|
||||||
|
@ -189,7 +189,10 @@ public interface Strides {
|
|||||||
/**
|
/**
|
||||||
* Get linear index from multidimensional index
|
* Get linear index from multidimensional index
|
||||||
*/
|
*/
|
||||||
public fun offset(index: IntArray): Int
|
public fun offset(index: IntArray): Int = index.mapIndexed { i, value ->
|
||||||
|
if (value < 0 || value >= shape[i]) throw IndexOutOfBoundsException("Index $value out of shape bounds: (0,${this.shape[i]})")
|
||||||
|
value * strides[i]
|
||||||
|
}.sum()
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get multidimensional from linear
|
* Get multidimensional from linear
|
||||||
@ -233,11 +236,6 @@ public class DefaultStrides private constructor(override val shape: IntArray) :
|
|||||||
}.toList().toIntArray()
|
}.toList().toIntArray()
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun offset(index: IntArray): Int = index.mapIndexed { i, value ->
|
|
||||||
if (value < 0 || value >= shape[i]) throw IndexOutOfBoundsException("Index $value out of shape bounds: (0,${this.shape[i]})")
|
|
||||||
value * strides[i]
|
|
||||||
}.sum()
|
|
||||||
|
|
||||||
override fun index(offset: Int): IntArray {
|
override fun index(offset: Int): IntArray {
|
||||||
val res = IntArray(shape.size)
|
val res = IntArray(shape.size)
|
||||||
var current = offset
|
var current = offset
|
||||||
|
@ -68,8 +68,6 @@ public class TensorLinearStructure(override val shape: IntArray) : Strides
|
|||||||
override val strides: IntArray
|
override val strides: IntArray
|
||||||
get() = stridesFromShape(shape)
|
get() = stridesFromShape(shape)
|
||||||
|
|
||||||
override fun offset(index: IntArray): Int = offsetFromIndex(index, shape, strides)
|
|
||||||
|
|
||||||
override fun index(offset: Int): IntArray =
|
override fun index(offset: Int): IntArray =
|
||||||
indexFromOffset(offset, strides, shape.size)
|
indexFromOffset(offset, strides, shape.size)
|
||||||
|
|
||||||
@ -82,7 +80,4 @@ public class TensorLinearStructure(override val shape: IntArray) : Strides
|
|||||||
public val dim: Int
|
public val dim: Int
|
||||||
get() = shape.size
|
get() = shape.size
|
||||||
|
|
||||||
override fun indices(): Sequence<IntArray> = (0 until linearSize).asSequence().map {
|
|
||||||
index(it)
|
|
||||||
}
|
|
||||||
}
|
}
|
@ -75,14 +75,13 @@ internal fun broadcastTensors(vararg tensors: DoubleTensor): List<DoubleTensor>
|
|||||||
val totalShape = broadcastShapes(*(tensors.map { it.shape }).toTypedArray())
|
val totalShape = broadcastShapes(*(tensors.map { it.shape }).toTypedArray())
|
||||||
val n = totalShape.reduce { acc, i -> acc * i }
|
val n = totalShape.reduce { acc, i -> acc * i }
|
||||||
|
|
||||||
val res = ArrayList<DoubleTensor>(0)
|
return buildList {
|
||||||
for (tensor in tensors) {
|
for (tensor in tensors) {
|
||||||
val resTensor = DoubleTensor(totalShape, DoubleArray(n))
|
val resTensor = DoubleTensor(totalShape, DoubleArray(n))
|
||||||
multiIndexBroadCasting(tensor, resTensor, n)
|
multiIndexBroadCasting(tensor, resTensor, n)
|
||||||
res.add(resTensor)
|
add(resTensor)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return res
|
|
||||||
}
|
}
|
||||||
|
|
||||||
internal fun broadcastOuterTensors(vararg tensors: DoubleTensor): List<DoubleTensor> {
|
internal fun broadcastOuterTensors(vararg tensors: DoubleTensor): List<DoubleTensor> {
|
||||||
|
@ -170,7 +170,6 @@ internal class TestDoubleTensorAlgebra {
|
|||||||
val tensor1 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
val tensor1 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||||
val tensor2 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
val tensor2 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||||
val tensor3 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 5.0))
|
val tensor3 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 5.0))
|
||||||
val tensor4 = fromArray(intArrayOf(6, 1), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
|
||||||
|
|
||||||
assertTrue(tensor1 eq tensor1)
|
assertTrue(tensor1 eq tensor1)
|
||||||
assertTrue(tensor1 eq tensor2)
|
assertTrue(tensor1 eq tensor2)
|
||||||
|
Loading…
Reference in New Issue
Block a user