Merge remote-tracking branch 'ups/feature/tensor-algebra' into andrew
This commit is contained in:
commit
d6a1bee93b
@ -984,6 +984,8 @@ public final class space/kscience/kmath/nd/MutableStructure1D$DefaultImpls {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public abstract interface class space/kscience/kmath/nd/MutableStructure2D : space/kscience/kmath/nd/MutableStructureND, space/kscience/kmath/nd/Structure2D {
|
public abstract interface class space/kscience/kmath/nd/MutableStructure2D : space/kscience/kmath/nd/MutableStructureND, space/kscience/kmath/nd/Structure2D {
|
||||||
|
public abstract fun getColumns ()Ljava/util/List;
|
||||||
|
public abstract fun getRows ()Ljava/util/List;
|
||||||
public abstract fun set (IILjava/lang/Object;)V
|
public abstract fun set (IILjava/lang/Object;)V
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2608,6 +2610,15 @@ public final class space/kscience/kmath/structures/VirtualBuffer : space/kscienc
|
|||||||
public fun iterator ()Ljava/util/Iterator;
|
public fun iterator ()Ljava/util/Iterator;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public final class space/kscience/kmath/structures/VirtualMutableBuffer : space/kscience/kmath/structures/MutableBuffer {
|
||||||
|
public fun <init> (ILkotlin/jvm/functions/Function1;)V
|
||||||
|
public fun copy ()Lspace/kscience/kmath/structures/MutableBuffer;
|
||||||
|
public fun get (I)Ljava/lang/Object;
|
||||||
|
public fun getSize ()I
|
||||||
|
public fun iterator ()Ljava/util/Iterator;
|
||||||
|
public fun set (ILjava/lang/Object;)V
|
||||||
|
}
|
||||||
|
|
||||||
public abstract interface class space/kscience/kmath/tensors/AnalyticTensorAlgebra : space/kscience/kmath/tensors/OrderedTensorAlgebra, space/kscience/kmath/tensors/TensorPartialDivisionAlgebra {
|
public abstract interface class space/kscience/kmath/tensors/AnalyticTensorAlgebra : space/kscience/kmath/tensors/OrderedTensorAlgebra, space/kscience/kmath/tensors/TensorPartialDivisionAlgebra {
|
||||||
public abstract fun acos (Lspace/kscience/kmath/nd/MutableStructureND;)Lspace/kscience/kmath/nd/MutableStructureND;
|
public abstract fun acos (Lspace/kscience/kmath/nd/MutableStructureND;)Lspace/kscience/kmath/nd/MutableStructureND;
|
||||||
public abstract fun acosh (Lspace/kscience/kmath/nd/MutableStructureND;)Lspace/kscience/kmath/nd/MutableStructureND;
|
public abstract fun acosh (Lspace/kscience/kmath/nd/MutableStructureND;)Lspace/kscience/kmath/nd/MutableStructureND;
|
||||||
@ -2709,6 +2720,7 @@ public abstract interface class space/kscience/kmath/tensors/TensorAlgebra {
|
|||||||
public abstract fun cumsum (Lspace/kscience/kmath/nd/MutableStructureND;I)Lspace/kscience/kmath/nd/MutableStructureND;
|
public abstract fun cumsum (Lspace/kscience/kmath/nd/MutableStructureND;I)Lspace/kscience/kmath/nd/MutableStructureND;
|
||||||
public abstract fun diagonalEmbedding (Lspace/kscience/kmath/nd/MutableStructureND;III)Lspace/kscience/kmath/nd/MutableStructureND;
|
public abstract fun diagonalEmbedding (Lspace/kscience/kmath/nd/MutableStructureND;III)Lspace/kscience/kmath/nd/MutableStructureND;
|
||||||
public abstract fun dot (Lspace/kscience/kmath/nd/MutableStructureND;Lspace/kscience/kmath/nd/MutableStructureND;)Lspace/kscience/kmath/nd/MutableStructureND;
|
public abstract fun dot (Lspace/kscience/kmath/nd/MutableStructureND;Lspace/kscience/kmath/nd/MutableStructureND;)Lspace/kscience/kmath/nd/MutableStructureND;
|
||||||
|
public abstract fun eq (Lspace/kscience/kmath/nd/MutableStructureND;Lspace/kscience/kmath/nd/MutableStructureND;Ljava/lang/Object;)Z
|
||||||
public abstract fun eye (I)Lspace/kscience/kmath/nd/MutableStructureND;
|
public abstract fun eye (I)Lspace/kscience/kmath/nd/MutableStructureND;
|
||||||
public abstract fun flatten (Lspace/kscience/kmath/nd/MutableStructureND;II)Lspace/kscience/kmath/nd/MutableStructureND;
|
public abstract fun flatten (Lspace/kscience/kmath/nd/MutableStructureND;II)Lspace/kscience/kmath/nd/MutableStructureND;
|
||||||
public abstract fun full (Ljava/lang/Object;[I)Lspace/kscience/kmath/nd/MutableStructureND;
|
public abstract fun full (Ljava/lang/Object;[I)Lspace/kscience/kmath/nd/MutableStructureND;
|
||||||
@ -2970,9 +2982,10 @@ public class space/kscience/kmath/tensors/core/DoubleTensorAlgebra : space/kscie
|
|||||||
public fun divAssign (Lspace/kscience/kmath/tensors/core/DoubleTensor;Lspace/kscience/kmath/tensors/core/DoubleTensor;)V
|
public fun divAssign (Lspace/kscience/kmath/tensors/core/DoubleTensor;Lspace/kscience/kmath/tensors/core/DoubleTensor;)V
|
||||||
public synthetic fun dot (Lspace/kscience/kmath/nd/MutableStructureND;Lspace/kscience/kmath/nd/MutableStructureND;)Lspace/kscience/kmath/nd/MutableStructureND;
|
public synthetic fun dot (Lspace/kscience/kmath/nd/MutableStructureND;Lspace/kscience/kmath/nd/MutableStructureND;)Lspace/kscience/kmath/nd/MutableStructureND;
|
||||||
public fun dot (Lspace/kscience/kmath/tensors/core/DoubleTensor;Lspace/kscience/kmath/tensors/core/DoubleTensor;)Lspace/kscience/kmath/tensors/core/DoubleTensor;
|
public fun dot (Lspace/kscience/kmath/tensors/core/DoubleTensor;Lspace/kscience/kmath/tensors/core/DoubleTensor;)Lspace/kscience/kmath/tensors/core/DoubleTensor;
|
||||||
public final fun eq (Lspace/kscience/kmath/tensors/core/DoubleTensor;Lspace/kscience/kmath/tensors/core/DoubleTensor;D)Z
|
public synthetic fun eq (Lspace/kscience/kmath/nd/MutableStructureND;Lspace/kscience/kmath/nd/MutableStructureND;Ljava/lang/Object;)Z
|
||||||
|
public final fun eq (Lspace/kscience/kmath/tensors/core/DoubleTensor;Lspace/kscience/kmath/tensors/core/DoubleTensor;)Z
|
||||||
|
public fun eq (Lspace/kscience/kmath/tensors/core/DoubleTensor;Lspace/kscience/kmath/tensors/core/DoubleTensor;D)Z
|
||||||
public final fun eq (Lspace/kscience/kmath/tensors/core/DoubleTensor;Lspace/kscience/kmath/tensors/core/DoubleTensor;Lkotlin/jvm/functions/Function2;)Z
|
public final fun eq (Lspace/kscience/kmath/tensors/core/DoubleTensor;Lspace/kscience/kmath/tensors/core/DoubleTensor;Lkotlin/jvm/functions/Function2;)Z
|
||||||
public static synthetic fun eq$default (Lspace/kscience/kmath/tensors/core/DoubleTensorAlgebra;Lspace/kscience/kmath/tensors/core/DoubleTensor;Lspace/kscience/kmath/tensors/core/DoubleTensor;DILjava/lang/Object;)Z
|
|
||||||
public synthetic fun eye (I)Lspace/kscience/kmath/nd/MutableStructureND;
|
public synthetic fun eye (I)Lspace/kscience/kmath/nd/MutableStructureND;
|
||||||
public fun eye (I)Lspace/kscience/kmath/tensors/core/DoubleTensor;
|
public fun eye (I)Lspace/kscience/kmath/tensors/core/DoubleTensor;
|
||||||
public synthetic fun flatten (Lspace/kscience/kmath/nd/MutableStructureND;II)Lspace/kscience/kmath/nd/MutableStructureND;
|
public synthetic fun flatten (Lspace/kscience/kmath/nd/MutableStructureND;II)Lspace/kscience/kmath/nd/MutableStructureND;
|
||||||
|
@ -14,6 +14,7 @@ import kotlin.reflect.KClass
|
|||||||
* @param T the type of items.
|
* @param T the type of items.
|
||||||
*/
|
*/
|
||||||
public typealias Matrix<T> = Structure2D<T>
|
public typealias Matrix<T> = Structure2D<T>
|
||||||
|
public typealias MutableMatrix<T> = MutableStructure2D<T>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Alias or using [Buffer] as a point/vector in a many-dimensional space.
|
* Alias or using [Buffer] as a point/vector in a many-dimensional space.
|
||||||
|
@ -2,7 +2,9 @@ package space.kscience.kmath.nd
|
|||||||
|
|
||||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
import space.kscience.kmath.structures.Buffer
|
import space.kscience.kmath.structures.Buffer
|
||||||
|
import space.kscience.kmath.structures.MutableBuffer
|
||||||
import space.kscience.kmath.structures.VirtualBuffer
|
import space.kscience.kmath.structures.VirtualBuffer
|
||||||
|
import space.kscience.kmath.structures.VirtualMutableBuffer
|
||||||
import kotlin.reflect.KClass
|
import kotlin.reflect.KClass
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -69,6 +71,18 @@ public interface MutableStructure2D<T> : Structure2D<T>, MutableStructureND<T> {
|
|||||||
* @param value the value.
|
* @param value the value.
|
||||||
*/
|
*/
|
||||||
public operator fun set(i: Int, j: Int, value: T)
|
public operator fun set(i: Int, j: Int, value: T)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The buffer of rows of this structure. It gets elements from the structure dynamically.
|
||||||
|
*/
|
||||||
|
override val rows: List<MutableBuffer<T>>
|
||||||
|
get() = List(rowNum) { i -> VirtualMutableBuffer(colNum) { j -> get(i, j) } }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The buffer of columns of this structure. It gets elements from the structure dynamically.
|
||||||
|
*/
|
||||||
|
override val columns: List<MutableBuffer<T>>
|
||||||
|
get() = List(colNum) { j -> VirtualMutableBuffer(rowNum) { i -> get(i, j) } }
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -223,11 +223,7 @@ public inline class MutableListBuffer<T>(public val list: MutableList<T>) : Muta
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
<<<<<<< HEAD
|
|
||||||
* Returns an [MutableListBuffer] that wraps the original list.
|
* Returns an [MutableListBuffer] that wraps the original list.
|
||||||
=======
|
|
||||||
* Returns an [ListBuffer] that wraps the original list.
|
|
||||||
>>>>>>> dev
|
|
||||||
*/
|
*/
|
||||||
public fun <T> MutableList<T>.asMutableBuffer(): MutableListBuffer<T> = MutableListBuffer(this)
|
public fun <T> MutableList<T>.asMutableBuffer(): MutableListBuffer<T> = MutableListBuffer(this)
|
||||||
|
|
||||||
@ -286,6 +282,24 @@ public class VirtualBuffer<T>(override val size: Int, private val generator: (In
|
|||||||
override operator fun iterator(): Iterator<T> = (0 until size).asSequence().map(generator).iterator()
|
override operator fun iterator(): Iterator<T> = (0 until size).asSequence().map(generator).iterator()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public class VirtualMutableBuffer<T>(override val size: Int, private val generator: (Int) -> T) : MutableBuffer<T> {
|
||||||
|
|
||||||
|
private val bufferHolder: MutableListBuffer<T> = (0 until size).map(generator).toMutableList().asMutableBuffer()
|
||||||
|
|
||||||
|
override operator fun get(index: Int): T {
|
||||||
|
if (index < 0 || index >= size) throw IndexOutOfBoundsException("Expected index from 0 to ${size - 1}, but found $index")
|
||||||
|
return bufferHolder[index]
|
||||||
|
}
|
||||||
|
|
||||||
|
override operator fun iterator(): Iterator<T> = bufferHolder.iterator()
|
||||||
|
|
||||||
|
override fun set(index: Int, value: T) {
|
||||||
|
bufferHolder[index] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun copy(): MutableBuffer<T> = bufferHolder.copy()
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Convert this buffer to read-only buffer.
|
* Convert this buffer to read-only buffer.
|
||||||
*/
|
*/
|
||||||
|
@ -9,8 +9,8 @@ public interface TensorPartialDivisionAlgebra<T, TensorType : TensorStructure<T>
|
|||||||
public operator fun TensorType.divAssign(other: TensorType)
|
public operator fun TensorType.divAssign(other: TensorType)
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.mean.html#torch.mean
|
//https://pytorch.org/docs/stable/generated/torch.mean.html#torch.mean
|
||||||
public fun TensorType.mean(dim: Int, keepDim: Boolean): TensorType
|
public fun TensorType.mean(dim: Int = 0, keepDim: Boolean = false): TensorType
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.var.html#torch.var
|
//https://pytorch.org/docs/stable/generated/torch.var.html#torch.var
|
||||||
public fun TensorType.variance(dim: Int, unbiased: Boolean, keepDim: Boolean): TensorType
|
public fun TensorType.variance(dim: Int = 0, unbiased: Boolean = true, keepDim: Boolean = false): TensorType
|
||||||
}
|
}
|
||||||
|
@ -70,4 +70,14 @@ class TestDoubleLinearOpsTensorAlgebra {
|
|||||||
assertTrue { invTensor.shape contentEquals expectedShape }
|
assertTrue { invTensor.shape contentEquals expectedShape }
|
||||||
assertTrue { invTensor.buffer.array().epsEqual(expectedBuffer) }
|
assertTrue { invTensor.buffer.array().epsEqual(expectedBuffer) }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testScalarProduct() = DoubleLinearOpsTensorAlgebra {
|
||||||
|
val a = fromArray(intArrayOf(3), doubleArrayOf(1.8,2.5, 6.8))
|
||||||
|
val b = fromArray(intArrayOf(3), doubleArrayOf(5.5,2.6, 6.4))
|
||||||
|
DoubleReduceOpsTensorAlgebra {
|
||||||
|
assertEquals(a.dot(b).value(), 59.92)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user