-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathconvolutions.py
More file actions
42 lines (38 loc) · 1.1 KB
/
convolutions.py
File metadata and controls
42 lines (38 loc) · 1.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import ctypes
from arrayfire_wrapper.defines import AFArray, CShape
from arrayfire_wrapper.lib._constants import ConvGradient
from arrayfire_wrapper.lib._utility import call_from_clib
def convolve2_gradient_nn(
incoming_gradient: AFArray,
original_signal: AFArray,
original_filter: AFArray,
convolved_output: AFArray,
strides: tuple[int, int],
padding: tuple[int, int],
dilation: tuple[int, int],
grad_type: ConvGradient,
/,
) -> AFArray:
"""
source: https://arrayfire.org/docs/group__ml__convolution.htm#ga3dc8cbebcec76e5c1804ff377b4e1cfd
"""
out = AFArray.create_null_pointer()
c_strides = CShape(*strides)
c_padding = CShape(*padding)
c_dilation = CShape(*dilation)
call_from_clib(
convolve2_gradient_nn.__name__,
ctypes.pointer(out),
incoming_gradient,
original_signal,
original_filter,
convolved_output,
len(strides),
c_strides.c_array,
len(padding),
c_padding.c_array,
len(dilation),
c_dilation.c_array,
grad_type.value,
)
return out