diff --git a/main.py b/main.py index 71baea2..981444a 100644 --- a/main.py +++ b/main.py @@ -49,23 +49,18 @@ def VGG16(include_top=True, weights='imagenet', '(pre-training on ImageNet).') # Determine proper input shape if K.image_dim_ordering() == 'th': - if include_top: - input_shape = (3, 224, 224) - else: - input_shape = (3, None, None) + input_shape = (3, 224, 224) if include_top else (3, None, None) else: - if include_top: - input_shape = (224, 224, 3) - else: - input_shape = (None, None, 3) - + input_shape = (224, 224, 3) if include_top else (None, None, 3) if input_tensor is None: img_input = Input(shape=input_shape) else: - if not K.is_keras_tensor(input_tensor): - img_input = Input(tensor=input_tensor) - else: - img_input = input_tensor + img_input = ( + input_tensor + if K.is_keras_tensor(input_tensor) + else Input(tensor=input_tensor) + ) + # Block 1 x = Convolution2D(64, 3, 3, activation='relu', border_mode='same', name='block1_conv1')(img_input) x = Convolution2D(64, 3, 3, activation='relu', border_mode='same', name='block1_conv2')(x)