Merge pull request 'Fix #532 by making ShapeND
a non-value class' (!522) from bug/defaultStridesCache into dev
Reviewed-on: #522 Reviewed-by: Alexander Nozik <altavir@gmail.com>
This commit is contained in:
commit
1881feb5e2
@ -35,7 +35,7 @@ class StreamDoubleFieldND(override val shape: ShapeND) : FieldND<Double, Float64
|
|||||||
@OptIn(PerformancePitfall::class)
|
@OptIn(PerformancePitfall::class)
|
||||||
private val StructureND<Double>.buffer: Float64Buffer
|
private val StructureND<Double>.buffer: Float64Buffer
|
||||||
get() = when {
|
get() = when {
|
||||||
!shape.contentEquals(this@StreamDoubleFieldND.shape) -> throw ShapeMismatchException(
|
shape != this@StreamDoubleFieldND.shape -> throw ShapeMismatchException(
|
||||||
this@StreamDoubleFieldND.shape,
|
this@StreamDoubleFieldND.shape,
|
||||||
shape
|
shape
|
||||||
)
|
)
|
||||||
|
@ -6,7 +6,6 @@
|
|||||||
package space.kscience.kmath.tensors
|
package space.kscience.kmath.tensors
|
||||||
|
|
||||||
import space.kscience.kmath.nd.ShapeND
|
import space.kscience.kmath.nd.ShapeND
|
||||||
import space.kscience.kmath.nd.contentEquals
|
|
||||||
import space.kscience.kmath.operations.invoke
|
import space.kscience.kmath.operations.invoke
|
||||||
import space.kscience.kmath.tensors.core.DoubleTensor
|
import space.kscience.kmath.tensors.core.DoubleTensor
|
||||||
import space.kscience.kmath.tensors.core.DoubleTensorAlgebra
|
import space.kscience.kmath.tensors.core.DoubleTensorAlgebra
|
||||||
@ -62,7 +61,7 @@ fun main() {
|
|||||||
// figure out MSE of approximation
|
// figure out MSE of approximation
|
||||||
fun mse(yTrue: DoubleTensor, yPred: DoubleTensor): Double {
|
fun mse(yTrue: DoubleTensor, yPred: DoubleTensor): Double {
|
||||||
require(yTrue.shape.size == 1)
|
require(yTrue.shape.size == 1)
|
||||||
require(yTrue.shape contentEquals yPred.shape)
|
require(yTrue.shape == yPred.shape)
|
||||||
|
|
||||||
val diff = yTrue - yPred
|
val diff = yTrue - yPred
|
||||||
return sqrt(diff.dot(diff)).value()
|
return sqrt(diff.dot(diff)).value()
|
||||||
|
@ -6,7 +6,6 @@
|
|||||||
package space.kscience.kmath.tensors
|
package space.kscience.kmath.tensors
|
||||||
|
|
||||||
import space.kscience.kmath.nd.ShapeND
|
import space.kscience.kmath.nd.ShapeND
|
||||||
import space.kscience.kmath.nd.contentEquals
|
|
||||||
import space.kscience.kmath.operations.asIterable
|
import space.kscience.kmath.operations.asIterable
|
||||||
import space.kscience.kmath.operations.invoke
|
import space.kscience.kmath.operations.invoke
|
||||||
import space.kscience.kmath.tensors.core.*
|
import space.kscience.kmath.tensors.core.*
|
||||||
@ -94,7 +93,7 @@ class Dense(
|
|||||||
|
|
||||||
// simple accuracy equal to the proportion of correct answers
|
// simple accuracy equal to the proportion of correct answers
|
||||||
fun accuracy(yPred: DoubleTensor, yTrue: DoubleTensor): Double {
|
fun accuracy(yPred: DoubleTensor, yTrue: DoubleTensor): Double {
|
||||||
check(yPred.shape contentEquals yTrue.shape)
|
check(yPred.shape == yTrue.shape)
|
||||||
val n = yPred.shape[0]
|
val n = yPred.shape[0]
|
||||||
var correctCnt = 0
|
var correctCnt = 0
|
||||||
for (i in 0 until n) {
|
for (i in 0 until n) {
|
||||||
|
@ -32,12 +32,12 @@ public class BufferedLinearSpace<T, out A : Ring<T>>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
override fun Matrix<T>.plus(other: Matrix<T>): Matrix<T> = ndAlgebra {
|
override fun Matrix<T>.plus(other: Matrix<T>): Matrix<T> = ndAlgebra {
|
||||||
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" }
|
require(shape == other.shape) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" }
|
||||||
asND().plus(other.asND()).as2D()
|
asND().plus(other.asND()).as2D()
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun Matrix<T>.minus(other: Matrix<T>): Matrix<T> = ndAlgebra {
|
override fun Matrix<T>.minus(other: Matrix<T>): Matrix<T> = ndAlgebra {
|
||||||
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" }
|
require(shape == other.shape) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" }
|
||||||
asND().minus(other.asND()).as2D()
|
asND().minus(other.asND()).as2D()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -33,12 +33,12 @@ public object Float64LinearSpace : LinearSpace<Double, Float64Field> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
override fun Matrix<Double>.plus(other: Matrix<Double>): Matrix<Double> = Floa64FieldOpsND {
|
override fun Matrix<Double>.plus(other: Matrix<Double>): Matrix<Double> = Floa64FieldOpsND {
|
||||||
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" }
|
require(shape == other.shape) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" }
|
||||||
asND().plus(other.asND()).as2D()
|
asND().plus(other.asND()).as2D()
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun Matrix<Double>.minus(other: Matrix<Double>): Matrix<Double> = Floa64FieldOpsND {
|
override fun Matrix<Double>.minus(other: Matrix<Double>): Matrix<Double> = Floa64FieldOpsND {
|
||||||
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" }
|
require(shape == other.shape) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" }
|
||||||
asND().minus(other.asND()).as2D()
|
asND().minus(other.asND()).as2D()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -55,7 +55,7 @@ public interface AlgebraND<T, out C : Algebra<T>> : Algebra<StructureND<T>> {
|
|||||||
*/
|
*/
|
||||||
@PerformancePitfall("Very slow on remote execution algebras")
|
@PerformancePitfall("Very slow on remote execution algebras")
|
||||||
public fun zip(left: StructureND<T>, right: StructureND<T>, transform: C.(T, T) -> T): StructureND<T> {
|
public fun zip(left: StructureND<T>, right: StructureND<T>, transform: C.(T, T) -> T): StructureND<T> {
|
||||||
require(left.shape.contentEquals(right.shape)) {
|
require(left.shape == right.shape) {
|
||||||
"Expected left and right of the same shape, but left - ${left.shape} and right - ${right.shape}"
|
"Expected left and right of the same shape, but left - ${left.shape} and right - ${right.shape}"
|
||||||
}
|
}
|
||||||
return structureND(left.shape) { index ->
|
return structureND(left.shape) { index ->
|
||||||
|
@ -106,10 +106,10 @@ public class ColumnStrides(override val shape: ShapeND) : Strides() {
|
|||||||
override fun equals(other: Any?): Boolean {
|
override fun equals(other: Any?): Boolean {
|
||||||
if (this === other) return true
|
if (this === other) return true
|
||||||
if (other !is ColumnStrides) return false
|
if (other !is ColumnStrides) return false
|
||||||
return shape.contentEquals(other.shape)
|
return shape == other.shape
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun hashCode(): Int = shape.contentHashCode()
|
override fun hashCode(): Int = shape.hashCode()
|
||||||
|
|
||||||
|
|
||||||
public companion object
|
public companion object
|
||||||
@ -156,10 +156,10 @@ public class RowStrides(override val shape: ShapeND) : Strides() {
|
|||||||
override fun equals(other: Any?): Boolean {
|
override fun equals(other: Any?): Boolean {
|
||||||
if (this === other) return true
|
if (this === other) return true
|
||||||
if (other !is RowStrides) return false
|
if (other !is RowStrides) return false
|
||||||
return shape.contentEquals(other.shape)
|
return shape == other.shape
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun hashCode(): Int = shape.contentHashCode()
|
override fun hashCode(): Int = shape.hashCode()
|
||||||
|
|
||||||
public companion object
|
public companion object
|
||||||
|
|
||||||
|
@ -11,19 +11,30 @@ import kotlin.jvm.JvmInline
|
|||||||
/**
|
/**
|
||||||
* A read-only ND shape
|
* A read-only ND shape
|
||||||
*/
|
*/
|
||||||
@JvmInline
|
public class ShapeND(@PublishedApi internal val array: IntArray) {
|
||||||
public value class ShapeND(@PublishedApi internal val array: IntArray) {
|
|
||||||
public val size: Int get() = array.size
|
public val size: Int get() = array.size
|
||||||
public operator fun get(index: Int): Int = array[index]
|
public operator fun get(index: Int): Int = array[index]
|
||||||
override fun toString(): String = array.contentToString()
|
override fun toString(): String = array.contentToString()
|
||||||
|
override fun hashCode(): Int = array.contentHashCode()
|
||||||
|
override fun equals(other: Any?): Boolean = other is ShapeND && array.contentEquals(other.array)
|
||||||
}
|
}
|
||||||
|
|
||||||
public inline fun ShapeND.forEach(block: (value: Int) -> Unit): Unit = array.forEach(block)
|
public inline fun ShapeND.forEach(block: (value: Int) -> Unit): Unit = array.forEach(block)
|
||||||
|
|
||||||
public inline fun ShapeND.forEachIndexed(block: (index: Int, value: Int) -> Unit): Unit = array.forEachIndexed(block)
|
public inline fun ShapeND.forEachIndexed(block: (index: Int, value: Int) -> Unit): Unit = array.forEachIndexed(block)
|
||||||
|
|
||||||
|
@Deprecated(
|
||||||
|
message = "ShapeND is made a usual class with correct `equals`. Use it instead of the `contentEquals`.",
|
||||||
|
replaceWith = ReplaceWith("this == other"),
|
||||||
|
level = DeprecationLevel.WARNING,
|
||||||
|
)
|
||||||
public infix fun ShapeND.contentEquals(other: ShapeND): Boolean = array.contentEquals(other.array)
|
public infix fun ShapeND.contentEquals(other: ShapeND): Boolean = array.contentEquals(other.array)
|
||||||
|
|
||||||
|
@Deprecated(
|
||||||
|
message = "ShapeND is made a usual class with correct `hashCode`. Use it instead of the `contentHashCode`.",
|
||||||
|
replaceWith = ReplaceWith("this.hashCode()"),
|
||||||
|
level = DeprecationLevel.WARNING,
|
||||||
|
)
|
||||||
public fun ShapeND.contentHashCode(): Int = array.contentHashCode()
|
public fun ShapeND.contentHashCode(): Int = array.contentHashCode()
|
||||||
|
|
||||||
public val ShapeND.indices: IntRange get() = array.indices
|
public val ShapeND.indices: IntRange get() = array.indices
|
||||||
|
@ -48,12 +48,12 @@ public object Float64ParallelLinearSpace : LinearSpace<Double, Float64Field> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
override fun Matrix<Double>.plus(other: Matrix<Double>): Matrix<Double> = Floa64FieldOpsND {
|
override fun Matrix<Double>.plus(other: Matrix<Double>): Matrix<Double> = Floa64FieldOpsND {
|
||||||
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" }
|
require(shape == other.shape) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" }
|
||||||
asND().plus(other.asND()).as2D()
|
asND().plus(other.asND()).as2D()
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun Matrix<Double>.minus(other: Matrix<Double>): Matrix<Double> = Floa64FieldOpsND {
|
override fun Matrix<Double>.minus(other: Matrix<Double>): Matrix<Double> = Floa64FieldOpsND {
|
||||||
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" }
|
require(shape == other.shape) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" }
|
||||||
asND().minus(other.asND()).as2D()
|
asND().minus(other.asND()).as2D()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -92,7 +92,7 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>>(
|
|||||||
|
|
||||||
@OptIn(PerformancePitfall::class)
|
@OptIn(PerformancePitfall::class)
|
||||||
override fun zip(left: StructureND<T>, right: StructureND<T>, transform: A.(T, T) -> T): MultikTensor<T> {
|
override fun zip(left: StructureND<T>, right: StructureND<T>, transform: A.(T, T) -> T): MultikTensor<T> {
|
||||||
require(left.shape.contentEquals(right.shape)) { "ND array shape mismatch" } //TODO replace by ShapeMismatchException
|
require(left.shape == right.shape) { "ND array shape mismatch" } //TODO replace by ShapeMismatchException
|
||||||
val leftArray = left.asMultik().array
|
val leftArray = left.asMultik().array
|
||||||
val rightArray = right.asMultik().array
|
val rightArray = right.asMultik().array
|
||||||
val data = initMemoryView<T>(leftArray.size, dataType)
|
val data = initMemoryView<T>(leftArray.size, dataType)
|
||||||
@ -124,7 +124,7 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>>(
|
|||||||
public fun MutableMultiArray<T, *>.wrap(): MultikTensor<T> = MultikTensor(this.asDNArray())
|
public fun MutableMultiArray<T, *>.wrap(): MultikTensor<T> = MultikTensor(this.asDNArray())
|
||||||
|
|
||||||
@OptIn(PerformancePitfall::class)
|
@OptIn(PerformancePitfall::class)
|
||||||
override fun StructureND<T>.valueOrNull(): T? = if (shape contentEquals ShapeND(1)) {
|
override fun StructureND<T>.valueOrNull(): T? = if (shape == ShapeND(1)) {
|
||||||
get(intArrayOf(0))
|
get(intArrayOf(0))
|
||||||
} else null
|
} else null
|
||||||
|
|
||||||
@ -224,7 +224,7 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
val mt = asMultik().array
|
val mt = asMultik().array
|
||||||
return if (ShapeND(mt.shape).contentEquals(shape)) {
|
return if (ShapeND(mt.shape) == shape) {
|
||||||
mt
|
mt
|
||||||
} else {
|
} else {
|
||||||
@OptIn(UnsafeKMathAPI::class)
|
@OptIn(UnsafeKMathAPI::class)
|
||||||
|
@ -63,7 +63,7 @@ public sealed interface Nd4jArrayAlgebra<T, out C : Algebra<T>> : AlgebraND<T, C
|
|||||||
right: StructureND<T>,
|
right: StructureND<T>,
|
||||||
transform: C.(T, T) -> T,
|
transform: C.(T, T) -> T,
|
||||||
): Nd4jArrayStructure<T> {
|
): Nd4jArrayStructure<T> {
|
||||||
require(left.shape.contentEquals(right.shape)) { "Can't zip tow structures of shape ${left.shape} and ${right.shape}" }
|
require(left.shape == right.shape) { "Can't zip tow structures of shape ${left.shape} and ${right.shape}" }
|
||||||
val new = Nd4j.create(*left.shape.asArray()).wrap()
|
val new = Nd4j.create(*left.shape.asArray()).wrap()
|
||||||
new.indicesIterator().forEach { idx -> new[idx] = elementAlgebra.transform(left[idx], right[idx]) }
|
new.indicesIterator().forEach { idx -> new[idx] = elementAlgebra.transform(left[idx], right[idx]) }
|
||||||
return new
|
return new
|
||||||
|
@ -49,7 +49,7 @@ public sealed interface Nd4jTensorAlgebra<T : Number, A : Field<T>> : AnalyticTe
|
|||||||
|
|
||||||
@OptIn(PerformancePitfall::class)
|
@OptIn(PerformancePitfall::class)
|
||||||
override fun zip(left: StructureND<T>, right: StructureND<T>, transform: A.(T, T) -> T): Nd4jArrayStructure<T> {
|
override fun zip(left: StructureND<T>, right: StructureND<T>, transform: A.(T, T) -> T): Nd4jArrayStructure<T> {
|
||||||
require(left.shape.contentEquals(right.shape))
|
require(left.shape == right.shape)
|
||||||
return mutableStructureND(left.shape) { index -> elementAlgebra.transform(left[index], right[index]) }
|
return mutableStructureND(left.shape) { index -> elementAlgebra.transform(left[index], right[index]) }
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -203,7 +203,7 @@ public object DoubleNd4jTensorAlgebra : Nd4jTensorAlgebra<Double, Float64Field>
|
|||||||
}
|
}
|
||||||
|
|
||||||
override fun StructureND<Double>.valueOrNull(): Double? =
|
override fun StructureND<Double>.valueOrNull(): Double? =
|
||||||
if (shape contentEquals ShapeND(1)) ndArray.getDouble(0) else null
|
if (shape == ShapeND(1)) ndArray.getDouble(0) else null
|
||||||
|
|
||||||
// TODO rewrite
|
// TODO rewrite
|
||||||
override fun diagonalEmbedding(
|
override fun diagonalEmbedding(
|
||||||
|
@ -23,7 +23,6 @@ import space.kscience.kmath.UnstableKMathAPI
|
|||||||
import space.kscience.kmath.nd.ShapeND
|
import space.kscience.kmath.nd.ShapeND
|
||||||
import space.kscience.kmath.nd.StructureND
|
import space.kscience.kmath.nd.StructureND
|
||||||
import space.kscience.kmath.nd.asArray
|
import space.kscience.kmath.nd.asArray
|
||||||
import space.kscience.kmath.nd.contentEquals
|
|
||||||
import space.kscience.kmath.operations.Ring
|
import space.kscience.kmath.operations.Ring
|
||||||
import space.kscience.kmath.tensors.api.Tensor
|
import space.kscience.kmath.tensors.api.Tensor
|
||||||
import space.kscience.kmath.tensors.api.TensorAlgebra
|
import space.kscience.kmath.tensors.api.TensorAlgebra
|
||||||
@ -105,7 +104,7 @@ public abstract class TensorFlowAlgebra<T, TT : TNumber, A : Ring<T>> internal c
|
|||||||
protected abstract fun const(value: T): Constant<TT>
|
protected abstract fun const(value: T): Constant<TT>
|
||||||
|
|
||||||
@OptIn(PerformancePitfall::class)
|
@OptIn(PerformancePitfall::class)
|
||||||
override fun StructureND<T>.valueOrNull(): T? = if (shape contentEquals ShapeND(1))
|
override fun StructureND<T>.valueOrNull(): T? = if (shape == ShapeND(1))
|
||||||
get(intArrayOf(0)) else null
|
get(intArrayOf(0)) else null
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -95,7 +95,7 @@ public open class DoubleTensorAlgebra :
|
|||||||
|
|
||||||
override fun StructureND<Double>.valueOrNull(): Double? {
|
override fun StructureND<Double>.valueOrNull(): Double? {
|
||||||
val dt = asDoubleTensor()
|
val dt = asDoubleTensor()
|
||||||
return if (dt.shape contentEquals ShapeND(1)) dt.source[0] else null
|
return if (dt.shape == ShapeND(1)) dt.source[0] else null
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun StructureND<Double>.value(): Double = valueOrNull()
|
override fun StructureND<Double>.value(): Double = valueOrNull()
|
||||||
|
@ -89,7 +89,7 @@ public open class IntTensorAlgebra : TensorAlgebra<Int, Int32Ring> {
|
|||||||
|
|
||||||
override fun StructureND<Int>.valueOrNull(): Int? {
|
override fun StructureND<Int>.valueOrNull(): Int? {
|
||||||
val dt = asIntTensor()
|
val dt = asIntTensor()
|
||||||
return if (dt.shape contentEquals ShapeND(1)) dt.source[0] else null
|
return if (dt.shape == ShapeND(1)) dt.source[0] else null
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun StructureND<Int>.value(): Int = valueOrNull()
|
override fun StructureND<Int>.value(): Int = valueOrNull()
|
||||||
@ -387,7 +387,7 @@ public open class IntTensorAlgebra : TensorAlgebra<Int, Int32Ring> {
|
|||||||
public fun stack(tensors: List<Tensor<Int>>): IntTensor {
|
public fun stack(tensors: List<Tensor<Int>>): IntTensor {
|
||||||
check(tensors.isNotEmpty()) { "List must have at least 1 element" }
|
check(tensors.isNotEmpty()) { "List must have at least 1 element" }
|
||||||
val shape = tensors[0].shape
|
val shape = tensors[0].shape
|
||||||
check(tensors.all { it.shape contentEquals shape }) { "Tensors must have same shapes" }
|
check(tensors.all { it.shape == shape }) { "Tensors must have same shapes" }
|
||||||
val resShape = ShapeND(tensors.size) + shape
|
val resShape = ShapeND(tensors.size) + shape
|
||||||
// val resBuffer: List<Int> = tensors.flatMap {
|
// val resBuffer: List<Int> = tensors.flatMap {
|
||||||
// it.asIntTensor().source.array.drop(it.asIntTensor().bufferStart)
|
// it.asIntTensor().source.array.drop(it.asIntTensor().bufferStart)
|
||||||
|
@ -7,7 +7,6 @@ package space.kscience.kmath.tensors.core.internal
|
|||||||
|
|
||||||
import space.kscience.kmath.nd.ShapeND
|
import space.kscience.kmath.nd.ShapeND
|
||||||
import space.kscience.kmath.nd.StructureND
|
import space.kscience.kmath.nd.StructureND
|
||||||
import space.kscience.kmath.nd.contentEquals
|
|
||||||
import space.kscience.kmath.nd.linearSize
|
import space.kscience.kmath.nd.linearSize
|
||||||
import space.kscience.kmath.tensors.api.Tensor
|
import space.kscience.kmath.tensors.api.Tensor
|
||||||
import space.kscience.kmath.tensors.core.DoubleTensor
|
import space.kscience.kmath.tensors.core.DoubleTensor
|
||||||
@ -32,7 +31,7 @@ internal fun checkBufferShapeConsistency(shape: ShapeND, buffer: DoubleArray) =
|
|||||||
|
|
||||||
@PublishedApi
|
@PublishedApi
|
||||||
internal fun <T> checkShapesCompatible(a: StructureND<T>, b: StructureND<T>): Unit =
|
internal fun <T> checkShapesCompatible(a: StructureND<T>, b: StructureND<T>): Unit =
|
||||||
check(a.shape contentEquals b.shape) {
|
check(a.shape == b.shape) {
|
||||||
"Incompatible shapes ${a.shape} and ${b.shape} "
|
"Incompatible shapes ${a.shape} and ${b.shape} "
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -47,7 +47,7 @@ public fun DoubleTensorAlgebra.randomNormalLike(structure: WithShape, seed: Long
|
|||||||
public fun stack(tensors: List<Tensor<Double>>): DoubleTensor {
|
public fun stack(tensors: List<Tensor<Double>>): DoubleTensor {
|
||||||
check(tensors.isNotEmpty()) { "List must have at least 1 element" }
|
check(tensors.isNotEmpty()) { "List must have at least 1 element" }
|
||||||
val shape = tensors[0].shape
|
val shape = tensors[0].shape
|
||||||
check(tensors.all { it.shape contentEquals shape }) { "Tensors must have same shapes" }
|
check(tensors.all { it.shape == shape }) { "Tensors must have same shapes" }
|
||||||
val resShape = ShapeND(tensors.size) + shape
|
val resShape = ShapeND(tensors.size) + shape
|
||||||
// val resBuffer: List<Double> = tensors.flatMap {
|
// val resBuffer: List<Double> = tensors.flatMap {
|
||||||
// it.asDoubleTensor().source.array.drop(it.asDoubleTensor().bufferStart)
|
// it.asDoubleTensor().source.array.drop(it.asDoubleTensor().bufferStart)
|
||||||
@ -91,7 +91,7 @@ public fun DoubleTensorAlgebra.luPivot(
|
|||||||
): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
|
): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
|
||||||
checkSquareMatrix(luTensor.shape)
|
checkSquareMatrix(luTensor.shape)
|
||||||
check(
|
check(
|
||||||
luTensor.shape.first(luTensor.shape.size - 2) contentEquals pivotsTensor.shape.first(pivotsTensor.shape.size - 1) ||
|
luTensor.shape.first(luTensor.shape.size - 2) == pivotsTensor.shape.first(pivotsTensor.shape.size - 1) ||
|
||||||
luTensor.shape.last() == pivotsTensor.shape.last() - 1
|
luTensor.shape.last() == pivotsTensor.shape.last() - 1
|
||||||
) { "Inappropriate shapes of input tensors" }
|
) { "Inappropriate shapes of input tensors" }
|
||||||
|
|
||||||
|
@ -6,29 +6,31 @@
|
|||||||
package space.kscience.kmath.tensors.core
|
package space.kscience.kmath.tensors.core
|
||||||
|
|
||||||
import space.kscience.kmath.nd.ShapeND
|
import space.kscience.kmath.nd.ShapeND
|
||||||
import space.kscience.kmath.nd.contentEquals
|
|
||||||
import space.kscience.kmath.operations.invoke
|
import space.kscience.kmath.operations.invoke
|
||||||
import space.kscience.kmath.tensors.core.internal.broadcastOuterTensors
|
import space.kscience.kmath.tensors.core.internal.broadcastOuterTensors
|
||||||
import space.kscience.kmath.tensors.core.internal.broadcastShapes
|
import space.kscience.kmath.tensors.core.internal.broadcastShapes
|
||||||
import space.kscience.kmath.tensors.core.internal.broadcastTensors
|
import space.kscience.kmath.tensors.core.internal.broadcastTensors
|
||||||
import space.kscience.kmath.tensors.core.internal.broadcastTo
|
import space.kscience.kmath.tensors.core.internal.broadcastTo
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
|
import kotlin.test.assertEquals
|
||||||
import kotlin.test.assertTrue
|
import kotlin.test.assertTrue
|
||||||
|
|
||||||
internal class TestBroadcasting {
|
internal class TestBroadcasting {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testBroadcastShapes() = DoubleTensorAlgebra {
|
fun testBroadcastShapes() = DoubleTensorAlgebra {
|
||||||
assertTrue(
|
assertEquals(
|
||||||
broadcastShapes(
|
broadcastShapes(
|
||||||
listOf(ShapeND(2, 3), ShapeND(1, 3), ShapeND(1, 1, 1))
|
listOf(ShapeND(2, 3), ShapeND(1, 3), ShapeND(1, 1, 1))
|
||||||
) contentEquals ShapeND(1, 2, 3)
|
),
|
||||||
|
ShapeND(1, 2, 3)
|
||||||
)
|
)
|
||||||
|
|
||||||
assertTrue(
|
assertEquals(
|
||||||
broadcastShapes(
|
broadcastShapes(
|
||||||
listOf(ShapeND(6, 7), ShapeND(5, 6, 1), ShapeND(7), ShapeND(5, 1, 7))
|
listOf(ShapeND(6, 7), ShapeND(5, 6, 1), ShapeND(7), ShapeND(5, 1, 7))
|
||||||
) contentEquals ShapeND(5, 6, 7)
|
),
|
||||||
|
ShapeND(5, 6, 7)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -38,7 +40,7 @@ internal class TestBroadcasting {
|
|||||||
val tensor2 = fromArray(ShapeND(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
|
val tensor2 = fromArray(ShapeND(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
|
||||||
|
|
||||||
val res = broadcastTo(tensor2, tensor1.shape)
|
val res = broadcastTo(tensor2, tensor1.shape)
|
||||||
assertTrue(res.shape contentEquals ShapeND(2, 3))
|
assertTrue(res.shape == ShapeND(2, 3))
|
||||||
assertTrue(res.source contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0))
|
assertTrue(res.source contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -50,9 +52,9 @@ internal class TestBroadcasting {
|
|||||||
|
|
||||||
val res = broadcastTensors(tensor1, tensor2, tensor3)
|
val res = broadcastTensors(tensor1, tensor2, tensor3)
|
||||||
|
|
||||||
assertTrue(res[0].shape contentEquals ShapeND(1, 2, 3))
|
assertEquals(res[0].shape, ShapeND(1, 2, 3))
|
||||||
assertTrue(res[1].shape contentEquals ShapeND(1, 2, 3))
|
assertEquals(res[1].shape, ShapeND(1, 2, 3))
|
||||||
assertTrue(res[2].shape contentEquals ShapeND(1, 2, 3))
|
assertEquals(res[2].shape, ShapeND(1, 2, 3))
|
||||||
|
|
||||||
assertTrue(res[0].source contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
assertTrue(res[0].source contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||||
assertTrue(res[1].source contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0))
|
assertTrue(res[1].source contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0))
|
||||||
@ -67,9 +69,9 @@ internal class TestBroadcasting {
|
|||||||
|
|
||||||
val res = broadcastOuterTensors(tensor1, tensor2, tensor3)
|
val res = broadcastOuterTensors(tensor1, tensor2, tensor3)
|
||||||
|
|
||||||
assertTrue(res[0].shape contentEquals ShapeND(1, 2, 3))
|
assertEquals(res[0].shape, ShapeND(1, 2, 3))
|
||||||
assertTrue(res[1].shape contentEquals ShapeND(1, 1, 3))
|
assertEquals(res[1].shape, ShapeND(1, 1, 3))
|
||||||
assertTrue(res[2].shape contentEquals ShapeND(1, 1, 1))
|
assertEquals(res[2].shape, ShapeND(1, 1, 1))
|
||||||
|
|
||||||
assertTrue(res[0].source contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
assertTrue(res[0].source contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||||
assertTrue(res[1].source contentEquals doubleArrayOf(10.0, 20.0, 30.0))
|
assertTrue(res[1].source contentEquals doubleArrayOf(10.0, 20.0, 30.0))
|
||||||
@ -84,9 +86,9 @@ internal class TestBroadcasting {
|
|||||||
|
|
||||||
val res = broadcastOuterTensors(tensor1, tensor2, tensor3)
|
val res = broadcastOuterTensors(tensor1, tensor2, tensor3)
|
||||||
|
|
||||||
assertTrue(res[0].shape contentEquals ShapeND(4, 2, 5, 3, 2, 3))
|
assertEquals(res[0].shape, ShapeND(4, 2, 5, 3, 2, 3))
|
||||||
assertTrue(res[1].shape contentEquals ShapeND(4, 2, 5, 3, 3, 3))
|
assertEquals(res[1].shape, ShapeND(4, 2, 5, 3, 3, 3))
|
||||||
assertTrue(res[2].shape contentEquals ShapeND(4, 2, 5, 3, 1, 1))
|
assertEquals(res[2].shape, ShapeND(4, 2, 5, 3, 1, 1))
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@ -99,16 +101,16 @@ internal class TestBroadcasting {
|
|||||||
val tensor31 = tensor3 - tensor1
|
val tensor31 = tensor3 - tensor1
|
||||||
val tensor32 = tensor3 - tensor2
|
val tensor32 = tensor3 - tensor2
|
||||||
|
|
||||||
assertTrue(tensor21.shape contentEquals ShapeND(2, 3))
|
assertEquals(tensor21.shape, ShapeND(2, 3))
|
||||||
assertTrue(tensor21.source contentEquals doubleArrayOf(9.0, 18.0, 27.0, 6.0, 15.0, 24.0))
|
assertTrue(tensor21.source contentEquals doubleArrayOf(9.0, 18.0, 27.0, 6.0, 15.0, 24.0))
|
||||||
|
|
||||||
assertTrue(tensor31.shape contentEquals ShapeND(1, 2, 3))
|
assertEquals(tensor31.shape, ShapeND(1, 2, 3))
|
||||||
assertTrue(
|
assertTrue(
|
||||||
tensor31.source
|
tensor31.source
|
||||||
contentEquals doubleArrayOf(499.0, 498.0, 497.0, 496.0, 495.0, 494.0)
|
contentEquals doubleArrayOf(499.0, 498.0, 497.0, 496.0, 495.0, 494.0)
|
||||||
)
|
)
|
||||||
|
|
||||||
assertTrue(tensor32.shape contentEquals ShapeND(1, 1, 3))
|
assertEquals(tensor32.shape, ShapeND(1, 1, 3))
|
||||||
assertTrue(tensor32.source contentEquals doubleArrayOf(490.0, 480.0, 470.0))
|
assertTrue(tensor32.source contentEquals doubleArrayOf(490.0, 480.0, 470.0))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6,7 +6,6 @@
|
|||||||
package space.kscience.kmath.tensors.core
|
package space.kscience.kmath.tensors.core
|
||||||
|
|
||||||
import space.kscience.kmath.nd.ShapeND
|
import space.kscience.kmath.nd.ShapeND
|
||||||
import space.kscience.kmath.nd.contentEquals
|
|
||||||
import space.kscience.kmath.operations.invoke
|
import space.kscience.kmath.operations.invoke
|
||||||
import space.kscience.kmath.tensors.core.internal.svd1d
|
import space.kscience.kmath.tensors.core.internal.svd1d
|
||||||
import kotlin.math.abs
|
import kotlin.math.abs
|
||||||
@ -113,8 +112,8 @@ internal class TestDoubleLinearOpsTensorAlgebra {
|
|||||||
|
|
||||||
val (q, r) = qr(tensor)
|
val (q, r) = qr(tensor)
|
||||||
|
|
||||||
assertTrue { q.shape contentEquals shape }
|
assertEquals(q.shape, shape)
|
||||||
assertTrue { r.shape contentEquals shape }
|
assertEquals(r.shape, shape)
|
||||||
|
|
||||||
assertTrue((q matmul r).eq(tensor))
|
assertTrue((q matmul r).eq(tensor))
|
||||||
|
|
||||||
@ -133,9 +132,9 @@ internal class TestDoubleLinearOpsTensorAlgebra {
|
|||||||
|
|
||||||
val (p, l, u) = lu(tensor)
|
val (p, l, u) = lu(tensor)
|
||||||
|
|
||||||
assertTrue { p.shape contentEquals shape }
|
assertEquals(p.shape, shape)
|
||||||
assertTrue { l.shape contentEquals shape }
|
assertEquals(l.shape, shape)
|
||||||
assertTrue { u.shape contentEquals shape }
|
assertEquals(u.shape, shape)
|
||||||
|
|
||||||
assertTrue((p matmul tensor).eq(l matmul u))
|
assertTrue((p matmul tensor).eq(l matmul u))
|
||||||
}
|
}
|
||||||
@ -157,7 +156,7 @@ internal class TestDoubleLinearOpsTensorAlgebra {
|
|||||||
|
|
||||||
val res = svd1d(tensor2)
|
val res = svd1d(tensor2)
|
||||||
|
|
||||||
assertTrue(res.shape contentEquals ShapeND(2))
|
assertEquals(res.shape, ShapeND(2))
|
||||||
assertTrue { abs(abs(res.source[0]) - 0.386) < 0.01 }
|
assertTrue { abs(abs(res.source[0]) - 0.386) < 0.01 }
|
||||||
assertTrue { abs(abs(res.source[1]) - 0.922) < 0.01 }
|
assertTrue { abs(abs(res.source[1]) - 0.922) < 0.01 }
|
||||||
}
|
}
|
||||||
|
@ -6,7 +6,8 @@
|
|||||||
package space.kscience.kmath.tensors.core
|
package space.kscience.kmath.tensors.core
|
||||||
|
|
||||||
|
|
||||||
import space.kscience.kmath.nd.*
|
import space.kscience.kmath.nd.ShapeND
|
||||||
|
import space.kscience.kmath.nd.get
|
||||||
import space.kscience.kmath.operations.invoke
|
import space.kscience.kmath.operations.invoke
|
||||||
import space.kscience.kmath.testutils.assertBufferEquals
|
import space.kscience.kmath.testutils.assertBufferEquals
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
@ -43,7 +44,7 @@ internal class TestDoubleTensorAlgebra {
|
|||||||
val res = tensor.transposed(0, 0)
|
val res = tensor.transposed(0, 0)
|
||||||
|
|
||||||
assertTrue(res.asDoubleTensor().source contentEquals doubleArrayOf(0.0))
|
assertTrue(res.asDoubleTensor().source contentEquals doubleArrayOf(0.0))
|
||||||
assertTrue(res.shape contentEquals ShapeND(1))
|
assertEquals(res.shape, ShapeND(1))
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@ -52,7 +53,7 @@ internal class TestDoubleTensorAlgebra {
|
|||||||
val res = tensor.transposed(1, 0)
|
val res = tensor.transposed(1, 0)
|
||||||
|
|
||||||
assertTrue(res.asDoubleTensor().source contentEquals doubleArrayOf(1.0, 3.0, 5.0, 2.0, 4.0, 6.0))
|
assertTrue(res.asDoubleTensor().source contentEquals doubleArrayOf(1.0, 3.0, 5.0, 2.0, 4.0, 6.0))
|
||||||
assertTrue(res.shape contentEquals ShapeND(2, 3))
|
assertEquals(res.shape, ShapeND(2, 3))
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@ -62,9 +63,9 @@ internal class TestDoubleTensorAlgebra {
|
|||||||
val res02 = tensor.transposed(-3, 2)
|
val res02 = tensor.transposed(-3, 2)
|
||||||
val res12 = tensor.transposed()
|
val res12 = tensor.transposed()
|
||||||
|
|
||||||
assertTrue(res01.shape contentEquals ShapeND(2, 1, 3))
|
assertEquals(res01.shape, ShapeND(2, 1, 3))
|
||||||
assertTrue(res02.shape contentEquals ShapeND(3, 2, 1))
|
assertEquals(res02.shape, ShapeND(3, 2, 1))
|
||||||
assertTrue(res12.shape contentEquals ShapeND(1, 3, 2))
|
assertEquals(res12.shape, ShapeND(1, 3, 2))
|
||||||
|
|
||||||
assertTrue(res01.asDoubleTensor().source contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
assertTrue(res01.asDoubleTensor().source contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||||
assertTrue(res02.asDoubleTensor().source contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
|
assertTrue(res02.asDoubleTensor().source contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
|
||||||
@ -114,19 +115,19 @@ internal class TestDoubleTensorAlgebra {
|
|||||||
|
|
||||||
val res12 = tensor1.dot(tensor2)
|
val res12 = tensor1.dot(tensor2)
|
||||||
assertTrue(res12.source contentEquals doubleArrayOf(140.0, 320.0))
|
assertTrue(res12.source contentEquals doubleArrayOf(140.0, 320.0))
|
||||||
assertTrue(res12.shape contentEquals ShapeND(2))
|
assertEquals(res12.shape, ShapeND(2))
|
||||||
|
|
||||||
val res32 = tensor3.matmul(tensor2)
|
val res32 = tensor3.matmul(tensor2)
|
||||||
assertTrue(res32.source contentEquals doubleArrayOf(-140.0))
|
assertTrue(res32.source contentEquals doubleArrayOf(-140.0))
|
||||||
assertTrue(res32.shape contentEquals ShapeND(1, 1))
|
assertEquals(res32.shape, ShapeND(1, 1))
|
||||||
|
|
||||||
val res22 = tensor2.dot(tensor2)
|
val res22 = tensor2.dot(tensor2)
|
||||||
assertTrue(res22.source contentEquals doubleArrayOf(1400.0))
|
assertTrue(res22.source contentEquals doubleArrayOf(1400.0))
|
||||||
assertTrue(res22.shape contentEquals ShapeND(1))
|
assertEquals(res22.shape, ShapeND(1))
|
||||||
|
|
||||||
val res11 = tensor1.dot(tensor11)
|
val res11 = tensor1.dot(tensor11)
|
||||||
assertTrue(res11.source contentEquals doubleArrayOf(22.0, 28.0, 49.0, 64.0))
|
assertTrue(res11.source contentEquals doubleArrayOf(22.0, 28.0, 49.0, 64.0))
|
||||||
assertTrue(res11.shape contentEquals ShapeND(2, 2))
|
assertEquals(res11.shape, ShapeND(2, 2))
|
||||||
|
|
||||||
val res45 = tensor4.matmul(tensor5)
|
val res45 = tensor4.matmul(tensor5)
|
||||||
assertTrue(
|
assertTrue(
|
||||||
@ -135,7 +136,7 @@ internal class TestDoubleTensorAlgebra {
|
|||||||
468.0, 501.0, 534.0, 594.0, 636.0, 678.0, 720.0, 771.0, 822.0
|
468.0, 501.0, 534.0, 594.0, 636.0, 678.0, 720.0, 771.0, 822.0
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
assertTrue(res45.shape contentEquals ShapeND(2, 3, 3))
|
assertEquals(res45.shape, ShapeND(2, 3, 3))
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@ -144,35 +145,35 @@ internal class TestDoubleTensorAlgebra {
|
|||||||
val tensor2 = fromArray(ShapeND(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
val tensor2 = fromArray(ShapeND(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||||
val tensor3 = zeros(ShapeND(2, 3, 4, 5))
|
val tensor3 = zeros(ShapeND(2, 3, 4, 5))
|
||||||
|
|
||||||
assertTrue(
|
assertEquals(
|
||||||
diagonalEmbedding(tensor3, 0, 3, 4).shape contentEquals
|
diagonalEmbedding(tensor3, 0, 3, 4).shape,
|
||||||
ShapeND(2, 3, 4, 5, 5)
|
ShapeND(2, 3, 4, 5, 5)
|
||||||
)
|
)
|
||||||
assertTrue(
|
assertEquals(
|
||||||
diagonalEmbedding(tensor3, 1, 3, 4).shape contentEquals
|
diagonalEmbedding(tensor3, 1, 3, 4).shape,
|
||||||
ShapeND(2, 3, 4, 6, 6)
|
ShapeND(2, 3, 4, 6, 6)
|
||||||
)
|
)
|
||||||
assertTrue(
|
assertEquals(
|
||||||
diagonalEmbedding(tensor3, 2, 0, 3).shape contentEquals
|
diagonalEmbedding(tensor3, 2, 0, 3).shape,
|
||||||
ShapeND(7, 2, 3, 7, 4)
|
ShapeND(7, 2, 3, 7, 4)
|
||||||
)
|
)
|
||||||
|
|
||||||
val diagonal1 = diagonalEmbedding(tensor1, 0, 1, 0)
|
val diagonal1 = diagonalEmbedding(tensor1, 0, 1, 0)
|
||||||
assertTrue(diagonal1.shape contentEquals ShapeND(3, 3))
|
assertEquals(diagonal1.shape, ShapeND(3, 3))
|
||||||
assertTrue(
|
assertTrue(
|
||||||
diagonal1.source contentEquals
|
diagonal1.source contentEquals
|
||||||
doubleArrayOf(10.0, 0.0, 0.0, 0.0, 20.0, 0.0, 0.0, 0.0, 30.0)
|
doubleArrayOf(10.0, 0.0, 0.0, 0.0, 20.0, 0.0, 0.0, 0.0, 30.0)
|
||||||
)
|
)
|
||||||
|
|
||||||
val diagonal1Offset = diagonalEmbedding(tensor1, 1, 1, 0)
|
val diagonal1Offset = diagonalEmbedding(tensor1, 1, 1, 0)
|
||||||
assertTrue(diagonal1Offset.shape contentEquals ShapeND(4, 4))
|
assertEquals(diagonal1Offset.shape, ShapeND(4, 4))
|
||||||
assertTrue(
|
assertTrue(
|
||||||
diagonal1Offset.source contentEquals
|
diagonal1Offset.source contentEquals
|
||||||
doubleArrayOf(0.0, 0.0, 0.0, 0.0, 10.0, 0.0, 0.0, 0.0, 0.0, 20.0, 0.0, 0.0, 0.0, 0.0, 30.0, 0.0)
|
doubleArrayOf(0.0, 0.0, 0.0, 0.0, 10.0, 0.0, 0.0, 0.0, 0.0, 20.0, 0.0, 0.0, 0.0, 0.0, 30.0, 0.0)
|
||||||
)
|
)
|
||||||
|
|
||||||
val diagonal2 = diagonalEmbedding(tensor2, 1, 0, 2)
|
val diagonal2 = diagonalEmbedding(tensor2, 1, 0, 2)
|
||||||
assertTrue(diagonal2.shape contentEquals ShapeND(4, 2, 4))
|
assertEquals(diagonal2.shape, ShapeND(4, 2, 4))
|
||||||
assertTrue(
|
assertTrue(
|
||||||
diagonal2.source contentEquals
|
diagonal2.source contentEquals
|
||||||
doubleArrayOf(
|
doubleArrayOf(
|
||||||
@ -202,7 +203,7 @@ internal class TestDoubleTensorAlgebra {
|
|||||||
val l = tensor.getTensor(0).map { it + 1.0 }
|
val l = tensor.getTensor(0).map { it + 1.0 }
|
||||||
val r = tensor.getTensor(1).map { it - 1.0 }
|
val r = tensor.getTensor(1).map { it - 1.0 }
|
||||||
val res = l + r
|
val res = l + r
|
||||||
assertTrue { ShapeND(5, 5) contentEquals res.shape }
|
assertEquals(ShapeND(5, 5), res.shape)
|
||||||
assertEquals(2.0, res[4, 4])
|
assertEquals(2.0, res[4, 4])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -67,7 +67,7 @@ public open class ViktorFieldOpsND :
|
|||||||
right: StructureND<Double>,
|
right: StructureND<Double>,
|
||||||
transform: Float64Field.(Double, Double) -> Double,
|
transform: Float64Field.(Double, Double) -> Double,
|
||||||
): ViktorStructureND {
|
): ViktorStructureND {
|
||||||
require(left.shape.contentEquals(right.shape))
|
require(left.shape == right.shape)
|
||||||
return F64Array(*left.shape.asArray()).apply {
|
return F64Array(*left.shape.asArray()).apply {
|
||||||
ColumnStrides(left.shape).asSequence().forEach { index ->
|
ColumnStrides(left.shape).asSequence().forEach { index ->
|
||||||
set(value = Float64Field.transform(left[index], right[index]), indices = index)
|
set(value = Float64Field.transform(left[index], right[index]), indices = index)
|
||||||
|
Loading…
Reference in New Issue
Block a user