ReduceOpsTensorAlgebra
This commit is contained in:
parent
f4454a6cf6
commit
0553a28ee8
@ -2,7 +2,17 @@ package space.kscience.kmath.tensors
|
|||||||
|
|
||||||
|
|
||||||
public interface AnalyticTensorAlgebra<T, TensorType : TensorStructure<T>> :
|
public interface AnalyticTensorAlgebra<T, TensorType : TensorStructure<T>> :
|
||||||
TensorPartialDivisionAlgebra<T, TensorType> {
|
TensorPartialDivisionAlgebra<T, TensorType>,
|
||||||
|
OrderedTensorAlgebra<T, TensorType>{
|
||||||
|
|
||||||
|
//https://pytorch.org/docs/stable/generated/torch.quantile.html#torch.quantile
|
||||||
|
public fun TensorType.quantile(q: T, dim: Int, keepDim: Boolean): TensorType
|
||||||
|
|
||||||
|
//https://pytorch.org/docs/stable/generated/torch.std.html#torch.std
|
||||||
|
public fun TensorType.std(dim: Int, unbiased: Boolean, keepDim: Boolean): TensorType
|
||||||
|
|
||||||
|
//https://pytorch.org/docs/stable/generated/torch.var.html#torch.var
|
||||||
|
public fun TensorType.variance(dim: Int, unbiased: Boolean, keepDim: Boolean): TensorType
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.exp.html
|
//https://pytorch.org/docs/stable/generated/torch.exp.html
|
||||||
public fun TensorType.exp(): TensorType
|
public fun TensorType.exp(): TensorType
|
||||||
@ -109,4 +119,7 @@ public interface AnalyticTensorAlgebra<T, TensorType : TensorStructure<T>> :
|
|||||||
//https://pytorch.org/docs/stable/generated/torch.trapz.html#torch.trapz
|
//https://pytorch.org/docs/stable/generated/torch.trapz.html#torch.trapz
|
||||||
public fun TensorType.trapz(xValues: TensorType, dim: Int): TensorType
|
public fun TensorType.trapz(xValues: TensorType, dim: Int): TensorType
|
||||||
|
|
||||||
|
//https://pytorch.org/docs/stable/generated/torch.histc.html#torch.histc
|
||||||
|
public fun TensorType.histc(bins: Int, min: T, max: T): TensorType
|
||||||
|
|
||||||
}
|
}
|
@ -2,7 +2,7 @@ package space.kscience.kmath.tensors
|
|||||||
|
|
||||||
public class DoubleAnalyticTensorAlgebra:
|
public class DoubleAnalyticTensorAlgebra:
|
||||||
AnalyticTensorAlgebra<Double, DoubleTensor>,
|
AnalyticTensorAlgebra<Double, DoubleTensor>,
|
||||||
DoubleTensorAlgebra()
|
DoubleOrderedTensorAlgebra()
|
||||||
{
|
{
|
||||||
override fun DoubleTensor.exp(): DoubleTensor {
|
override fun DoubleTensor.exp(): DoubleTensor {
|
||||||
TODO("Not yet implemented")
|
TODO("Not yet implemented")
|
||||||
@ -144,6 +144,22 @@ public class DoubleAnalyticTensorAlgebra:
|
|||||||
TODO("Not yet implemented")
|
TODO("Not yet implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override fun DoubleTensor.quantile(q: Double, dim: Int, keepDim: Boolean): DoubleTensor {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun DoubleTensor.std(dim: Int, unbiased: Boolean, keepDim: Boolean): DoubleTensor {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun DoubleTensor.variance(dim: Int, unbiased: Boolean, keepDim: Boolean): DoubleTensor {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun DoubleTensor.histc(bins: Int, min: Double, max: Double): DoubleTensor {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public inline fun <R> DoubleAnalyticTensorAlgebra(block: DoubleAnalyticTensorAlgebra.() -> R): R =
|
public inline fun <R> DoubleAnalyticTensorAlgebra(block: DoubleAnalyticTensorAlgebra.() -> R): R =
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
package space.kscience.kmath.tensors
|
package space.kscience.kmath.tensors
|
||||||
|
|
||||||
public class DoubleOrderedTensorAlgebra:
|
public open class DoubleOrderedTensorAlgebra:
|
||||||
OrderedTensorAlgebra<Double, DoubleTensor>,
|
OrderedTensorAlgebra<Double, DoubleTensor>,
|
||||||
DoubleTensorAlgebra()
|
DoubleTensorAlgebra()
|
||||||
{
|
{
|
||||||
|
@ -0,0 +1,16 @@
|
|||||||
|
package space.kscience.kmath.tensors
|
||||||
|
|
||||||
|
public class DoubleReduceOpsTensorAlgebra:
|
||||||
|
DoubleTensorAlgebra(),
|
||||||
|
ReduceOpsTensorAlgebra<Double, DoubleTensor> {
|
||||||
|
|
||||||
|
override fun DoubleTensor.value(): Double {
|
||||||
|
check(this.shape contentEquals intArrayOf(1)) {
|
||||||
|
"Inconsistent value for tensor of shape ${shape.toList()}"
|
||||||
|
}
|
||||||
|
return this.buffer.array()[this.bufferStart]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public inline fun <R> DoubleReduceOpsTensorAlgebra(block: DoubleReduceOpsTensorAlgebra.() -> R): R =
|
||||||
|
DoubleReduceOpsTensorAlgebra().block()
|
@ -3,13 +3,6 @@ package space.kscience.kmath.tensors
|
|||||||
|
|
||||||
public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, DoubleTensor> {
|
public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, DoubleTensor> {
|
||||||
|
|
||||||
override fun DoubleTensor.value(): Double {
|
|
||||||
check(this.shape contentEquals intArrayOf(1)) {
|
|
||||||
"Inconsistent value for tensor of shape ${shape.toList()}"
|
|
||||||
}
|
|
||||||
return this.buffer.array()[this.bufferStart]
|
|
||||||
}
|
|
||||||
|
|
||||||
override operator fun DoubleTensor.get(i: Int): DoubleTensor {
|
override operator fun DoubleTensor.get(i: Int): DoubleTensor {
|
||||||
val lastShape = this.shape.drop(1).toIntArray()
|
val lastShape = this.shape.drop(1).toIntArray()
|
||||||
val newShape = if (lastShape.isNotEmpty()) lastShape else intArrayOf(1)
|
val newShape = if (lastShape.isNotEmpty()) lastShape else intArrayOf(1)
|
||||||
@ -257,22 +250,6 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
|||||||
TODO("Not yet implemented")
|
TODO("Not yet implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.quantile(q: Double, dim: Int, keepDim: Boolean): DoubleTensor {
|
|
||||||
TODO("Not yet implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun DoubleTensor.std(dim: Int, unbiased: Boolean, keepDim: Boolean): DoubleTensor {
|
|
||||||
TODO("Not yet implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun DoubleTensor.variance(dim: Int, unbiased: Boolean, keepDim: Boolean): DoubleTensor {
|
|
||||||
TODO("Not yet implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun DoubleTensor.histc(bins: Int, min: Double, max: Double): DoubleTensor {
|
|
||||||
TODO("Not yet implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun DoubleTensor.det(): DoubleTensor {
|
override fun DoubleTensor.det(): DoubleTensor {
|
||||||
TODO("Not yet implemented")
|
TODO("Not yet implemented")
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,7 @@
|
|||||||
|
package space.kscience.kmath.tensors
|
||||||
|
|
||||||
|
public interface ReduceOpsTensorAlgebra<T, TensorType : TensorStructure<T>> :
|
||||||
|
TensorAlgebra<T, TensorType> {
|
||||||
|
public fun TensorType.value(): T
|
||||||
|
|
||||||
|
}
|
@ -3,7 +3,6 @@ package space.kscience.kmath.tensors
|
|||||||
// https://proofwiki.org/wiki/Definition:Algebra_over_Ring
|
// https://proofwiki.org/wiki/Definition:Algebra_over_Ring
|
||||||
public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
|
public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
|
||||||
|
|
||||||
public fun TensorType.value(): T
|
|
||||||
|
|
||||||
public fun zeros(shape: IntArray): TensorType
|
public fun zeros(shape: IntArray): TensorType
|
||||||
public fun TensorType.zeroesLike(): TensorType
|
public fun TensorType.zeroesLike(): TensorType
|
||||||
|
@ -3,7 +3,6 @@ package space.kscience.kmath.tensors
|
|||||||
// https://proofwiki.org/wiki/Definition:Division_Algebra
|
// https://proofwiki.org/wiki/Definition:Division_Algebra
|
||||||
public interface TensorPartialDivisionAlgebra<T, TensorType : TensorStructure<T>> :
|
public interface TensorPartialDivisionAlgebra<T, TensorType : TensorStructure<T>> :
|
||||||
TensorAlgebra<T, TensorType> {
|
TensorAlgebra<T, TensorType> {
|
||||||
|
|
||||||
public operator fun TensorType.div(value: T): TensorType
|
public operator fun TensorType.div(value: T): TensorType
|
||||||
public operator fun TensorType.div(other: TensorType): TensorType
|
public operator fun TensorType.div(other: TensorType): TensorType
|
||||||
public operator fun TensorType.divAssign(value: T)
|
public operator fun TensorType.divAssign(value: T)
|
||||||
@ -11,17 +10,4 @@ public interface TensorPartialDivisionAlgebra<T, TensorType : TensorStructure<T>
|
|||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.mean.html#torch.mean
|
//https://pytorch.org/docs/stable/generated/torch.mean.html#torch.mean
|
||||||
public fun TensorType.mean(dim: Int, keepDim: Boolean): TensorType
|
public fun TensorType.mean(dim: Int, keepDim: Boolean): TensorType
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.quantile.html#torch.quantile
|
|
||||||
public fun TensorType.quantile(q: T, dim: Int, keepDim: Boolean): TensorType
|
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.std.html#torch.std
|
|
||||||
public fun TensorType.std(dim: Int, unbiased: Boolean, keepDim: Boolean): TensorType
|
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.var.html#torch.var
|
|
||||||
public fun TensorType.variance(dim: Int, unbiased: Boolean, keepDim: Boolean): TensorType
|
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.histc.html#torch.histc
|
|
||||||
public fun TensorType.histc(bins: Int, min: T, max: T): TensorType
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package space.kscience.kmath.tensors
|
package space.kscience.kmath.tensors
|
||||||
|
|
||||||
|
|
||||||
|
import space.kscience.kmath.nd.as1D
|
||||||
import space.kscience.kmath.nd.as2D
|
import space.kscience.kmath.nd.as2D
|
||||||
import space.kscience.kmath.structures.toDoubleArray
|
import space.kscience.kmath.structures.toDoubleArray
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
@ -10,7 +11,7 @@ import kotlin.test.assertTrue
|
|||||||
class TestDoubleTensor {
|
class TestDoubleTensor {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun valueTest() = DoubleTensorAlgebra {
|
fun valueTest() = DoubleReduceOpsTensorAlgebra {
|
||||||
val value = 12.5
|
val value = 12.5
|
||||||
val tensor = DoubleTensor(intArrayOf(1), doubleArrayOf(value))
|
val tensor = DoubleTensor(intArrayOf(1), doubleArrayOf(value))
|
||||||
assertEquals(tensor.value(), value)
|
assertEquals(tensor.value(), value)
|
||||||
@ -28,5 +29,14 @@ class TestDoubleTensor {
|
|||||||
val tensor = DoubleTensor(intArrayOf(1,2,2), doubleArrayOf(3.5,5.8,58.4,2.4))
|
val tensor = DoubleTensor(intArrayOf(1,2,2), doubleArrayOf(3.5,5.8,58.4,2.4))
|
||||||
val matrix = tensor[0].as2D()
|
val matrix = tensor[0].as2D()
|
||||||
assertEquals(matrix[0,1], 5.8)
|
assertEquals(matrix[0,1], 5.8)
|
||||||
|
|
||||||
|
val vector = tensor[0][1].as1D()
|
||||||
|
assertEquals(vector[0], 58.4)
|
||||||
|
|
||||||
|
matrix[0,1] = 77.89
|
||||||
|
assertEquals(tensor[intArrayOf(0,0,1)], 77.89)
|
||||||
|
|
||||||
|
//vector[0] = 109.56
|
||||||
|
//println(tensor[intArrayOf(0,1,0)])
|
||||||
}
|
}
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user