ReduceOpsTensorAlgebra

This commit is contained in:
Roland Grinis 2021-03-16 07:47:02 +00:00
parent f4454a6cf6
commit 0553a28ee8
9 changed files with 66 additions and 42 deletions

View File

@ -2,7 +2,17 @@ package space.kscience.kmath.tensors
public interface AnalyticTensorAlgebra<T, TensorType : TensorStructure<T>> : public interface AnalyticTensorAlgebra<T, TensorType : TensorStructure<T>> :
TensorPartialDivisionAlgebra<T, TensorType> { TensorPartialDivisionAlgebra<T, TensorType>,
OrderedTensorAlgebra<T, TensorType>{
//https://pytorch.org/docs/stable/generated/torch.quantile.html#torch.quantile
public fun TensorType.quantile(q: T, dim: Int, keepDim: Boolean): TensorType
//https://pytorch.org/docs/stable/generated/torch.std.html#torch.std
public fun TensorType.std(dim: Int, unbiased: Boolean, keepDim: Boolean): TensorType
//https://pytorch.org/docs/stable/generated/torch.var.html#torch.var
public fun TensorType.variance(dim: Int, unbiased: Boolean, keepDim: Boolean): TensorType
//https://pytorch.org/docs/stable/generated/torch.exp.html //https://pytorch.org/docs/stable/generated/torch.exp.html
public fun TensorType.exp(): TensorType public fun TensorType.exp(): TensorType
@ -109,4 +119,7 @@ public interface AnalyticTensorAlgebra<T, TensorType : TensorStructure<T>> :
//https://pytorch.org/docs/stable/generated/torch.trapz.html#torch.trapz //https://pytorch.org/docs/stable/generated/torch.trapz.html#torch.trapz
public fun TensorType.trapz(xValues: TensorType, dim: Int): TensorType public fun TensorType.trapz(xValues: TensorType, dim: Int): TensorType
//https://pytorch.org/docs/stable/generated/torch.histc.html#torch.histc
public fun TensorType.histc(bins: Int, min: T, max: T): TensorType
} }

View File

@ -2,7 +2,7 @@ package space.kscience.kmath.tensors
public class DoubleAnalyticTensorAlgebra: public class DoubleAnalyticTensorAlgebra:
AnalyticTensorAlgebra<Double, DoubleTensor>, AnalyticTensorAlgebra<Double, DoubleTensor>,
DoubleTensorAlgebra() DoubleOrderedTensorAlgebra()
{ {
override fun DoubleTensor.exp(): DoubleTensor { override fun DoubleTensor.exp(): DoubleTensor {
TODO("Not yet implemented") TODO("Not yet implemented")
@ -144,6 +144,22 @@ public class DoubleAnalyticTensorAlgebra:
TODO("Not yet implemented") TODO("Not yet implemented")
} }
override fun DoubleTensor.quantile(q: Double, dim: Int, keepDim: Boolean): DoubleTensor {
TODO("Not yet implemented")
}
override fun DoubleTensor.std(dim: Int, unbiased: Boolean, keepDim: Boolean): DoubleTensor {
TODO("Not yet implemented")
}
override fun DoubleTensor.variance(dim: Int, unbiased: Boolean, keepDim: Boolean): DoubleTensor {
TODO("Not yet implemented")
}
override fun DoubleTensor.histc(bins: Int, min: Double, max: Double): DoubleTensor {
TODO("Not yet implemented")
}
} }
public inline fun <R> DoubleAnalyticTensorAlgebra(block: DoubleAnalyticTensorAlgebra.() -> R): R = public inline fun <R> DoubleAnalyticTensorAlgebra(block: DoubleAnalyticTensorAlgebra.() -> R): R =

View File

@ -1,6 +1,6 @@
package space.kscience.kmath.tensors package space.kscience.kmath.tensors
public class DoubleOrderedTensorAlgebra: public open class DoubleOrderedTensorAlgebra:
OrderedTensorAlgebra<Double, DoubleTensor>, OrderedTensorAlgebra<Double, DoubleTensor>,
DoubleTensorAlgebra() DoubleTensorAlgebra()
{ {

View File

@ -0,0 +1,16 @@
package space.kscience.kmath.tensors
public class DoubleReduceOpsTensorAlgebra:
DoubleTensorAlgebra(),
ReduceOpsTensorAlgebra<Double, DoubleTensor> {
override fun DoubleTensor.value(): Double {
check(this.shape contentEquals intArrayOf(1)) {
"Inconsistent value for tensor of shape ${shape.toList()}"
}
return this.buffer.array()[this.bufferStart]
}
}
public inline fun <R> DoubleReduceOpsTensorAlgebra(block: DoubleReduceOpsTensorAlgebra.() -> R): R =
DoubleReduceOpsTensorAlgebra().block()

View File

@ -3,13 +3,6 @@ package space.kscience.kmath.tensors
public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, DoubleTensor> { public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, DoubleTensor> {
override fun DoubleTensor.value(): Double {
check(this.shape contentEquals intArrayOf(1)) {
"Inconsistent value for tensor of shape ${shape.toList()}"
}
return this.buffer.array()[this.bufferStart]
}
override operator fun DoubleTensor.get(i: Int): DoubleTensor { override operator fun DoubleTensor.get(i: Int): DoubleTensor {
val lastShape = this.shape.drop(1).toIntArray() val lastShape = this.shape.drop(1).toIntArray()
val newShape = if (lastShape.isNotEmpty()) lastShape else intArrayOf(1) val newShape = if (lastShape.isNotEmpty()) lastShape else intArrayOf(1)
@ -257,22 +250,6 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
TODO("Not yet implemented") TODO("Not yet implemented")
} }
override fun DoubleTensor.quantile(q: Double, dim: Int, keepDim: Boolean): DoubleTensor {
TODO("Not yet implemented")
}
override fun DoubleTensor.std(dim: Int, unbiased: Boolean, keepDim: Boolean): DoubleTensor {
TODO("Not yet implemented")
}
override fun DoubleTensor.variance(dim: Int, unbiased: Boolean, keepDim: Boolean): DoubleTensor {
TODO("Not yet implemented")
}
override fun DoubleTensor.histc(bins: Int, min: Double, max: Double): DoubleTensor {
TODO("Not yet implemented")
}
override fun DoubleTensor.det(): DoubleTensor { override fun DoubleTensor.det(): DoubleTensor {
TODO("Not yet implemented") TODO("Not yet implemented")
} }

View File

@ -0,0 +1,7 @@
package space.kscience.kmath.tensors
public interface ReduceOpsTensorAlgebra<T, TensorType : TensorStructure<T>> :
TensorAlgebra<T, TensorType> {
public fun TensorType.value(): T
}

View File

@ -3,7 +3,6 @@ package space.kscience.kmath.tensors
// https://proofwiki.org/wiki/Definition:Algebra_over_Ring // https://proofwiki.org/wiki/Definition:Algebra_over_Ring
public interface TensorAlgebra<T, TensorType : TensorStructure<T>> { public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
public fun TensorType.value(): T
public fun zeros(shape: IntArray): TensorType public fun zeros(shape: IntArray): TensorType
public fun TensorType.zeroesLike(): TensorType public fun TensorType.zeroesLike(): TensorType

View File

@ -3,7 +3,6 @@ package space.kscience.kmath.tensors
// https://proofwiki.org/wiki/Definition:Division_Algebra // https://proofwiki.org/wiki/Definition:Division_Algebra
public interface TensorPartialDivisionAlgebra<T, TensorType : TensorStructure<T>> : public interface TensorPartialDivisionAlgebra<T, TensorType : TensorStructure<T>> :
TensorAlgebra<T, TensorType> { TensorAlgebra<T, TensorType> {
public operator fun TensorType.div(value: T): TensorType public operator fun TensorType.div(value: T): TensorType
public operator fun TensorType.div(other: TensorType): TensorType public operator fun TensorType.div(other: TensorType): TensorType
public operator fun TensorType.divAssign(value: T) public operator fun TensorType.divAssign(value: T)
@ -11,17 +10,4 @@ public interface TensorPartialDivisionAlgebra<T, TensorType : TensorStructure<T>
//https://pytorch.org/docs/stable/generated/torch.mean.html#torch.mean //https://pytorch.org/docs/stable/generated/torch.mean.html#torch.mean
public fun TensorType.mean(dim: Int, keepDim: Boolean): TensorType public fun TensorType.mean(dim: Int, keepDim: Boolean): TensorType
//https://pytorch.org/docs/stable/generated/torch.quantile.html#torch.quantile
public fun TensorType.quantile(q: T, dim: Int, keepDim: Boolean): TensorType
//https://pytorch.org/docs/stable/generated/torch.std.html#torch.std
public fun TensorType.std(dim: Int, unbiased: Boolean, keepDim: Boolean): TensorType
//https://pytorch.org/docs/stable/generated/torch.var.html#torch.var
public fun TensorType.variance(dim: Int, unbiased: Boolean, keepDim: Boolean): TensorType
//https://pytorch.org/docs/stable/generated/torch.histc.html#torch.histc
public fun TensorType.histc(bins: Int, min: T, max: T): TensorType
} }

View File

@ -1,6 +1,7 @@
package space.kscience.kmath.tensors package space.kscience.kmath.tensors
import space.kscience.kmath.nd.as1D
import space.kscience.kmath.nd.as2D import space.kscience.kmath.nd.as2D
import space.kscience.kmath.structures.toDoubleArray import space.kscience.kmath.structures.toDoubleArray
import kotlin.test.Test import kotlin.test.Test
@ -10,7 +11,7 @@ import kotlin.test.assertTrue
class TestDoubleTensor { class TestDoubleTensor {
@Test @Test
fun valueTest() = DoubleTensorAlgebra { fun valueTest() = DoubleReduceOpsTensorAlgebra {
val value = 12.5 val value = 12.5
val tensor = DoubleTensor(intArrayOf(1), doubleArrayOf(value)) val tensor = DoubleTensor(intArrayOf(1), doubleArrayOf(value))
assertEquals(tensor.value(), value) assertEquals(tensor.value(), value)
@ -28,5 +29,14 @@ class TestDoubleTensor {
val tensor = DoubleTensor(intArrayOf(1,2,2), doubleArrayOf(3.5,5.8,58.4,2.4)) val tensor = DoubleTensor(intArrayOf(1,2,2), doubleArrayOf(3.5,5.8,58.4,2.4))
val matrix = tensor[0].as2D() val matrix = tensor[0].as2D()
assertEquals(matrix[0,1], 5.8) assertEquals(matrix[0,1], 5.8)
val vector = tensor[0][1].as1D()
assertEquals(vector[0], 58.4)
matrix[0,1] = 77.89
assertEquals(tensor[intArrayOf(0,0,1)], 77.89)
//vector[0] = 109.56
//println(tensor[intArrayOf(0,1,0)])
} }
} }