-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathcuda.py
More file actions
38 lines (29 loc) · 1.18 KB
/
cuda.py
File metadata and controls
38 lines (29 loc) · 1.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import ctypes
from arrayfire_wrapper.defines import AFArray
from arrayfire_wrapper.lib._utility import call_from_clib
def cublas_set_math_mode(mode: int) -> None:
"""
source: https://arrayfire.org/docs/group__cuda__mat.htm#gac23ea38f0bff77a0e12555f27f47aa4f
"""
call_from_clib("cublasSetMathMode", mode, clib_prefix="afcu")
return None
def get_native_id(index: int) -> int:
"""
source: https://arrayfire.org/docs/group__cuda__mat.htm#gaf38af1cbbf4be710cc8cbd95d20b24c4
"""
out = ctypes.c_int(0)
call_from_clib(get_native_id.__name__, ctypes.pointer(out), index, clib_prefix="afcu")
return out.value
def get_stream(index: int) -> int:
"""
source: https://arrayfire.org/docs/group__cuda__mat.htm#ga8323b850f80afe9878b099f647b0a7e5
"""
out = AFArray.create_null_pointer()
call_from_clib(get_stream.__name__, ctypes.pointer(out), index, clib_prefix="afcu")
return out.value # type: ignore[return-value]
def set_native_id(index: int) -> None:
"""
source: https://arrayfire.org/docs/group__cuda__mat.htm#ga966f4c6880e90ce91d9599c90c0db378
"""
call_from_clib(set_native_id.__name__, index, clib_prefix="afcu")
return None