Andrew #253
@ -6,12 +6,13 @@ public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
|
|||||||
//https://pytorch.org/docs/stable/generated/torch.full.html
|
//https://pytorch.org/docs/stable/generated/torch.full.html
|
||||||
public fun full(value: T, shape: IntArray): TensorType
|
public fun full(value: T, shape: IntArray): TensorType
|
||||||
|
|
||||||
|
public fun ones(shape: IntArray): TensorType
|
||||||
|
public fun zeros(shape: IntArray): TensorType
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.full_like.html#torch.full_like
|
//https://pytorch.org/docs/stable/generated/torch.full_like.html#torch.full_like
|
||||||
public fun TensorType.fullLike(value: T): TensorType
|
public fun TensorType.fullLike(value: T): TensorType
|
||||||
|
|
||||||
public fun zeros(shape: IntArray): TensorType
|
public fun TensorType.zeroesLike(): TensorType
|
||||||
public fun TensorType.zeroesLike(): TensorType // mb it shouldn't be tensor but algebra method (like in numpy/torch) ?
|
|
||||||
public fun ones(shape: IntArray): TensorType
|
|
||||||
public fun TensorType.onesLike(): TensorType
|
public fun TensorType.onesLike(): TensorType
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.eye.html
|
//https://pytorch.org/docs/stable/generated/torch.eye.html
|
||||||
|
@ -33,28 +33,6 @@ public open class BufferedTensor<T>(
|
|||||||
|
|
||||||
override fun hashCode(): Int = 0
|
override fun hashCode(): Int = 0
|
||||||
|
|
||||||
// todo rename to vector
|
|
||||||
public inline fun forEachVector(vectorAction : (MutableStructure1D<T>) -> Unit): Unit {
|
|
||||||
check(shape.size >= 1) {"todo"}
|
|
||||||
val vectorOffset = strides.strides[0]
|
|
||||||
val vectorShape = intArrayOf(shape.last())
|
|
||||||
for (offset in 0 until numel step vectorOffset) {
|
|
||||||
val vector = BufferedTensor<T>(vectorShape, buffer, offset).as1D()
|
|
||||||
vectorAction(vector)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public inline fun forEachMatrix(matrixAction : (MutableStructure2D<T>) -> Unit): Unit {
|
|
||||||
check(shape.size >= 2) {"todo"}
|
|
||||||
val matrixOffset = strides.strides[1]
|
|
||||||
val matrixShape = intArrayOf(shape[shape.size - 2], shape.last()) //todo better way?
|
|
||||||
for (offset in 0 until numel step matrixOffset) {
|
|
||||||
val matrix = BufferedTensor<T>(matrixShape, buffer, offset).as2D()
|
|
||||||
matrixAction(matrix)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// todo remove code copy-pasting
|
|
||||||
|
|
||||||
public fun vectorSequence(): Sequence<MutableStructure1D<T>> = sequence {
|
public fun vectorSequence(): Sequence<MutableStructure1D<T>> = sequence {
|
||||||
check(shape.size >= 1) {"todo"}
|
check(shape.size >= 1) {"todo"}
|
||||||
val vectorOffset = strides.strides[0]
|
val vectorOffset = strides.strides[0]
|
||||||
@ -75,6 +53,18 @@ public open class BufferedTensor<T>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public inline fun forEachVector(vectorAction : (MutableStructure1D<T>) -> Unit): Unit {
|
||||||
|
for (vector in vectorSequence()){
|
||||||
|
vectorAction(vector)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public inline fun forEachMatrix(matrixAction : (MutableStructure2D<T>) -> Unit): Unit {
|
||||||
|
for (matrix in matrixSequence()){
|
||||||
|
matrixAction(matrix)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,24 +20,41 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
|||||||
return DoubleTensor(newShape, this.buffer.array(), newStart)
|
return DoubleTensor(newShape, this.buffer.array(), newStart)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
<<<<<<< HEAD:kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/DoubleTensorAlgebra.kt
|
||||||
|
override fun full(shape: IntArray, value: Double): DoubleTensor {
|
||||||
|
val buffer = DoubleArray(TensorStrides(shape).linearSize) { value }
|
||||||
|
return DoubleTensor(shape, buffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun zeros(shape: IntArray): DoubleTensor = full(shape, 0.0)
|
||||||
|
|
||||||
|
override fun ones(shape: IntArray): DoubleTensor = full(shape, 1.0)
|
||||||
|
|
||||||
|
=======
|
||||||
override fun full(value: Double, shape: IntArray): DoubleTensor {
|
override fun full(value: Double, shape: IntArray): DoubleTensor {
|
||||||
checkEmptyShape(shape)
|
checkEmptyShape(shape)
|
||||||
val buffer = DoubleArray(shape.reduce(Int::times)) { value }
|
val buffer = DoubleArray(shape.reduce(Int::times)) { value }
|
||||||
return DoubleTensor(shape, buffer)
|
return DoubleTensor(shape, buffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
>>>>>>> ups/feature/tensor-algebra:kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt
|
||||||
override fun DoubleTensor.fullLike(value: Double): DoubleTensor {
|
override fun DoubleTensor.fullLike(value: Double): DoubleTensor {
|
||||||
val shape = this.shape
|
val shape = this.shape
|
||||||
val buffer = DoubleArray(this.strides.linearSize) { value }
|
val buffer = DoubleArray(this.strides.linearSize) { value }
|
||||||
return DoubleTensor(shape, buffer)
|
return DoubleTensor(shape, buffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
<<<<<<< HEAD:kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/DoubleTensorAlgebra.kt
|
||||||
|
override fun DoubleTensor.zeroesLike(): DoubleTensor = this.fullLike(0.0)
|
||||||
|
|
||||||
|
=======
|
||||||
override fun zeros(shape: IntArray): DoubleTensor = full(0.0, shape)
|
override fun zeros(shape: IntArray): DoubleTensor = full(0.0, shape)
|
||||||
|
|
||||||
override fun DoubleTensor.zeroesLike(): DoubleTensor = this.fullLike(0.0)
|
override fun DoubleTensor.zeroesLike(): DoubleTensor = this.fullLike(0.0)
|
||||||
|
|
||||||
override fun ones(shape: IntArray): DoubleTensor = full(1.0, shape)
|
override fun ones(shape: IntArray): DoubleTensor = full(1.0, shape)
|
||||||
|
|
||||||
|
>>>>>>> ups/feature/tensor-algebra:kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt
|
||||||
override fun DoubleTensor.onesLike(): DoubleTensor = this.fullLike(1.0)
|
override fun DoubleTensor.onesLike(): DoubleTensor = this.fullLike(1.0)
|
||||||
|
|
||||||
override fun eye(n: Int): DoubleTensor {
|
override fun eye(n: Int): DoubleTensor {
|
||||||
|
Loading…
Reference in New Issue
Block a user