Fix build when CUDA not available
This commit is contained in:
parent
a229aaa6a4
commit
32e4b68061
@ -89,7 +89,7 @@ memScoped {
|
|||||||
scope = this,
|
scope = this,
|
||||||
array = floatArrayOf(7f,8.9f,2.6f,5.6f),
|
array = floatArrayOf(7f,8.9f,2.6f,5.6f),
|
||||||
shape = intArrayOf(4))
|
shape = intArrayOf(4))
|
||||||
println(intTensor)
|
println(floatTensor)
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -12,6 +12,10 @@ val home = System.getProperty("user.home")
|
|||||||
val thirdPartyDir = "$home/.konan/third-party/kmath-torch-${project.property("version")}"
|
val thirdPartyDir = "$home/.konan/third-party/kmath-torch-${project.property("version")}"
|
||||||
val cppBuildDir = "$thirdPartyDir/cpp-build"
|
val cppBuildDir = "$thirdPartyDir/cpp-build"
|
||||||
|
|
||||||
|
val cudaHome: String? = System.getenv("CUDA_HOME")
|
||||||
|
val cudaDefault = file("/usr/local/cuda").exists()
|
||||||
|
val cudaFound = cudaHome?.isNotEmpty() ?: false or cudaDefault
|
||||||
|
|
||||||
val cmakeArchive = "cmake-3.19.2-Linux-x86_64"
|
val cmakeArchive = "cmake-3.19.2-Linux-x86_64"
|
||||||
val torchArchive = "libtorch"
|
val torchArchive = "libtorch"
|
||||||
|
|
||||||
@ -32,8 +36,11 @@ val downloadNinja by tasks.registering(Download::class) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
val downloadTorch by tasks.registering(Download::class) {
|
val downloadTorch by tasks.registering(Download::class) {
|
||||||
val zipFile = "$torchArchive-cxx11-abi-shared-with-deps-1.7.1%2Bcu110.zip"
|
val abiMeta = "$torchArchive-cxx11-abi-shared-with-deps-1.7.1%2B"
|
||||||
src("https://download.pytorch.org/libtorch/cu110/$zipFile")
|
val cudaUrl = "https://download.pytorch.org/libtorch/cu110/${abiMeta}cu110.zip"
|
||||||
|
val cpuUrl = "https://download.pytorch.org/libtorch/cpu/${abiMeta}cpu.zip"
|
||||||
|
val url = if (cudaFound) cudaUrl else cpuUrl
|
||||||
|
src(url)
|
||||||
dest(File(thirdPartyDir, "$torchArchive.zip"))
|
dest(File(thirdPartyDir, "$torchArchive.zip"))
|
||||||
overwrite(false)
|
overwrite(false)
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
package kscience.kmath.torch
|
package kscience.kmath.torch
|
||||||
|
|
||||||
|
import kotlin.test.Ignore
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
import kotlin.test.assertTrue
|
import kotlin.test.assertTrue
|
||||||
@ -12,8 +13,4 @@ internal class TestUtils {
|
|||||||
setNumThreads(numThreads)
|
setNumThreads(numThreads)
|
||||||
assertEquals(numThreads, getNumThreads())
|
assertEquals(numThreads, getNumThreads())
|
||||||
}
|
}
|
||||||
@Test
|
|
||||||
fun cudaAvailability(){
|
|
||||||
assertTrue(cudaAvailable())
|
|
||||||
}
|
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user