imp full and remove copypaste in geners

This commit is contained in:
Andrei Kislitsyn 2021-03-19 17:51:30 +03:00
parent 1fa0da2810
commit 5e94610e28
3 changed files with 31 additions and 49 deletions

View File

@ -31,28 +31,6 @@ public open class BufferedTensor<T>(
override fun hashCode(): Int = 0
// todo rename to vector
public inline fun forEachVector(vectorAction : (MutableStructure1D<T>) -> Unit): Unit {
check(shape.size >= 1) {"todo"}
val vectorOffset = strides.strides[0]
val vectorShape = intArrayOf(shape.last())
for (offset in 0 until numel step vectorOffset) {
val vector = BufferedTensor<T>(vectorShape, buffer, offset).as1D()
vectorAction(vector)
}
}
public inline fun forEachMatrix(matrixAction : (MutableStructure2D<T>) -> Unit): Unit {
check(shape.size >= 2) {"todo"}
val matrixOffset = strides.strides[1]
val matrixShape = intArrayOf(shape[shape.size - 2], shape.last()) //todo better way?
for (offset in 0 until numel step matrixOffset) {
val matrix = BufferedTensor<T>(matrixShape, buffer, offset).as2D()
matrixAction(matrix)
}
}
// todo remove code copy-pasting
public fun vectorSequence(): Sequence<MutableStructure1D<T>> = sequence {
check(shape.size >= 1) {"todo"}
val vectorOffset = strides.strides[0]
@ -73,6 +51,18 @@ public open class BufferedTensor<T>(
}
}
public inline fun forEachVector(vectorAction : (MutableStructure1D<T>) -> Unit): Unit {
for (vector in vectorSequence()){
vectorAction(vector)
}
}
public inline fun forEachMatrix(matrixAction : (MutableStructure2D<T>) -> Unit): Unit {
for (matrix in matrixSequence()){
matrixAction(matrix)
}
}
}

View File

@ -10,23 +10,24 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
return DoubleTensor(newShape, this.buffer.array(), newStart)
}
override fun zeros(shape: IntArray): DoubleTensor {
TODO("Not yet implemented")
}
override fun DoubleTensor.zeroesLike(): DoubleTensor {
val shape = this.shape
val buffer = DoubleArray(this.strides.linearSize) { 0.0 }
override fun full(shape: IntArray, value: Double): DoubleTensor {
val buffer = DoubleArray(TensorStrides(shape).linearSize) { value }
return DoubleTensor(shape, buffer)
}
override fun ones(shape: IntArray): DoubleTensor {
TODO("Not yet implemented")
override fun zeros(shape: IntArray): DoubleTensor = full(shape, 0.0)
override fun ones(shape: IntArray): DoubleTensor = full(shape, 1.0)
override fun DoubleTensor.fullLike(value: Double): DoubleTensor {
val shape = this.shape
val buffer = DoubleArray(this.strides.linearSize) { value }
return DoubleTensor(shape, buffer)
}
override fun DoubleTensor.onesLike(): DoubleTensor {
TODO("Not yet implemented")
}
override fun DoubleTensor.zeroesLike(): DoubleTensor = this.fullLike(0.0)
override fun DoubleTensor.onesLike(): DoubleTensor = this.fullLike(1.0)
override fun eye(n: Int): DoubleTensor {
val shape = intArrayOf(n, n)
@ -192,15 +193,6 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
TODO("Not yet implemented")
}
override fun full(shape: IntArray, value: Double): DoubleTensor {
TODO("Not yet implemented")
}
override fun DoubleTensor.fullLike(value: Double): DoubleTensor {
TODO("Not yet implemented")
}
override fun DoubleTensor.sum(dim: Int, keepDim: Boolean): DoubleTensor {
TODO("Not yet implemented")
}

View File

@ -4,18 +4,18 @@ package space.kscience.kmath.tensors
public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
public fun zeros(shape: IntArray): TensorType
public fun TensorType.zeroesLike(): TensorType // mb it shouldn't be tensor but algebra method (like in numpy/torch) ?
public fun ones(shape: IntArray): TensorType
public fun TensorType.onesLike(): TensorType
//https://pytorch.org/docs/stable/generated/torch.full.html
public fun full(shape: IntArray, value: T): TensorType
public fun ones(shape: IntArray): TensorType
public fun zeros(shape: IntArray): TensorType
//https://pytorch.org/docs/stable/generated/torch.full_like.html#torch.full_like
public fun TensorType.fullLike(value: T): TensorType
public fun TensorType.zeroesLike(): TensorType
public fun TensorType.onesLike(): TensorType
//https://pytorch.org/docs/stable/generated/torch.eye.html
public fun eye(n: Int): TensorType