diff --git a/tutorial/cnn.sac b/tutorial/cnn.sac index c210549..e3243ae 100644 --- a/tutorial/cnn.sac +++ b/tutorial/cnn.sac @@ -155,32 +155,29 @@ float[*], float[*] BackConv2( float[*] d_output, float[*] weights, float[*] inpu #endif //------------------------------------------------------------------------------ + +inline float[*] >--conv-> (float[*] in, float[*] weights) { return Conv (in, weights); } +inline float[*] >--logistics-> (float[*] in, int _) { return Logistic (in); } +inline float[*] >--avgpool-> (float[*] in, int[.] filter) { return AveragePool (in, filter, []); } +inline float[*] >--reshape-> (float[*] in, int[.] s) { return reshape (s, in); } + int main() { in = genarray( [28,28], 0f); in[6,6] = 42f; - - print( in); - k1 = genarray( [5,5,6], 1f/25f); - c1 = Logistic( Conv( in, k1 )); - print( shape(c1)); - - s1 = AveragePool( c1, [2,2], []); - print( shape( s1)); - k2 = genarray( [5,5,6,12], 1f/150f); - c2 = Logistic( Conv( s1, k2)); - - print( shape( c2)); - - s2 = AveragePool( c2, [2,2], []); - print( shape( s2)); - fc = genarray( [4,4,1,12,10], 1f/192f); - out = reshape( [10], Conv( s2, fc)); - print( out); + _ = 0; + + out = in // Input Layer + >--conv-> k1 >--logistics-> _ // Convolution Layer C1 + >--avgpool-> [2,2] // Pooling Layer S1 + >--conv-> k2 >--logistics-> _ // Convolution Layer C2 + >--avgpool-> [2,2] // Pooling Layer S2 + >--conv-> fc >--reshape-> [10]; // Fully connection layer FC + print (out); return 0; }